| | |
| |
|
| | """Full definition of a decoder-only transformer-based language model, all of it in this single file. |
| | |
| | Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT and |
| | https://github.com/EleutherAI/gpt-neox/tree/main/megatron/model. |
| | """ |
| |
|
| | import math |
| | from typing import Any, Optional, Tuple |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from typing_extensions import Self |
| | from litgpt.config import Config |
| |
|
| |
|
| | class GPT(nn.Module): |
| | def __init__(self, config: Config) -> None: |
| | super().__init__() |
| | assert config.padded_vocab_size is not None |
| | self.config = config |
| | if self.config.asr_adapter == "mlp": |
| | print("Using MLP adapter for ASR feature") |
| | self.whisper_adapter = nn.Linear(config.whisper_adapter_dim, config.n_embd) |
| | elif self.config.asr_adapter == "llamamlp": |
| | print("using LLAMA MLP adapter for ASR feature") |
| | self.whisper_adapter = whisperMLP(config=config) |
| | else: |
| | raise ValueError("asr_adapter should be mlp or llamamlp") |
| | self.lm_head = nn.Linear( |
| | config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias |
| | ) |
| |
|
| | self.vision_adapter = visionMLP(config = config) |
| | if config.post_adapter: |
| | self.transformer = nn.ModuleDict( |
| | dict( |
| | wte=nn.Embedding(config.padded_vocab_size, config.n_embd), |
| | h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), |
| | post_adapter=nn.ModuleList( |
| | Block(config) for _ in range(config.post_adapter_layers) |
| | ), |
| | ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), |
| | post_adapter_audio_ln=config.norm_class( |
| | config.n_embd, eps=config.norm_eps |
| | ), |
| | post_adapter_audio_lm_head=nn.Linear( |
| | config.n_embd, config.cat_audio_vocab_size, bias=config.lm_head_bias |
| | ), |
| | ) |
| | ) |
| | else: |
| | self.transformer = nn.ModuleDict( |
| | dict( |
| | wte=nn.Embedding(config.padded_vocab_size, config.n_embd), |
| | h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), |
| | ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), |
| | ) |
| | ) |
| | self.max_seq_length = self.config.block_size |
| | self.mask_cache: Optional[torch.Tensor] = None |
| | if config.tie_word_embeddings: |
| | self.lm_head.weight = self.transformer.wte.weight |
| |
|
| | @property |
| | def max_seq_length(self) -> int: |
| | return self._max_seq_length |
| |
|
| | @max_seq_length.setter |
| | def max_seq_length(self, value: int) -> None: |
| | """ |
| | When doing inference, the sequences used might be shorter than the model's context length. |
| | This allows setting a smaller number to avoid allocating unused memory |
| | """ |
| | if value > self.config.block_size: |
| | raise ValueError( |
| | f"Cannot attend to {value}, block size is only {self.config.block_size}" |
| | ) |
| | self._max_seq_length = value |
| | if not hasattr(self, "cos"): |
| | |
| | cos, sin = self.rope_cache() |
| | self.register_buffer("cos", cos, persistent=False) |
| | self.register_buffer("sin", sin, persistent=False) |
| | |
| | elif value != self.cos.size(0): |
| | self.cos, self.sin = self.rope_cache(device=self.cos.device) |
| | |
| | |
| |
|
| | def reset_parameters(self) -> None: |
| | |
| | self.cos, self.sin = self.rope_cache(device=self.cos.device) |
| |
|
| | def _init_weights(self, module: nn.Module) -> None: |
| | """Meant to be used with `gpt.apply(gpt._init_weights)`.""" |
| | if isinstance(module, nn.Linear): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | if module.bias is not None: |
| | torch.nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| |
|
| | def concat_feat(self, audio_feature, clip_feature, input_ids, T, task): |
| |
|
| | for j in range(len(T)): |
| | if task[j] != 'T1T2' and task[j] != 'T1A2' and task[j]!='ImageQA_T' and not task[j] == 'ImageCAP' and not task[j] == 'ImageQA_A' and not task[j] == 'ImageQA_AT': |
| | for i in range(7): |
| | input_ids[i][j,1:T[j]+1,:] = audio_feature[j][:T[j]].clone() |
| | assert task[j] != 'ImageQ', "ImageQ should be concat with audio feature" |
| |
|
| | elif task[j] == 'ImageQA_A' or task[j] == 'ImageQA_AT': |
| | print("concat ImageQA_A feature") |
| | for i in range(7): |
| | input_ids[i][j,1:51,:] = clip_feature[j].clone() |
| | |
| | input_ids[i][j,52 : 52 + T[j],:] = audio_feature[j][:T[j]].clone() |
| |
|
| | elif task[j] == 'ImageQA_T' or task[j] =='ImageCAP': |
| | for i in range(7): |
| | input_ids[i][j,1:51,:] = clip_feature[j].clone() |
| |
|
| | return input_ids |
| |
|
| | def forward( |
| | self, |
| | audio_features: torch.Tensor, |
| | input_ids: torch.Tensor, |
| | clip_features: torch.Tensor, |
| | input_pos: Optional[torch.Tensor] = None, |
| | whisper_lens: Optional[list] = None, |
| | task: Optional[str] = None, |
| | ) -> torch.Tensor: |
| |
|
| | show = False |
| | T = input_ids[0].size(1) |
| | if self.max_seq_length < T: |
| | raise ValueError( |
| | f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}." |
| | ) |
| |
|
| | if input_pos is not None: |
| | cos = self.cos.index_select(0, input_pos) |
| | sin = self.sin.index_select(0, input_pos) |
| | if self.mask_cache is None: |
| | raise TypeError("You need to call `gpt.set_kv_cache()`") |
| | mask = self.mask_cache.index_select(2, input_pos) |
| | else: |
| | cos = self.cos[:T] |
| | sin = self.sin[:T] |
| | mask = None |
| |
|
| | if audio_features is not None: |
| | |
| | x_a = self.whisper_adapter(audio_features) |
| | if clip_features is not None: |
| | x_v = self.vision_adapter(clip_features) |
| | else: |
| | x_v = None |
| | |
| | x0, x1, x2, x3, x4, x5, x6, x7 = input_ids |
| |
|
| | x0 = self.transformer.wte(x0) |
| | x1 = self.transformer.wte(x1) |
| | x2 = self.transformer.wte(x2) |
| | x3 = self.transformer.wte(x3) |
| | x4 = self.transformer.wte(x4) |
| | x5 = self.transformer.wte(x5) |
| | x6 = self.transformer.wte(x6) |
| | x7 = self.transformer.wte(x7) |
| |
|
| | |
| | input_emb = self.concat_feat( |
| | x_a, x_v, [x0, x1, x2, x3, x4, x5, x6, x7], whisper_lens, task |
| | ) |
| | x0, x1, x2, x3, x4, x5, x6, x7 = input_emb |
| |
|
| | else: |
| | x0, x1, x2, x3, x4, x5, x6, x7 = input_ids |
| |
|
| | x0 = self.transformer.wte(x0) |
| | x1 = self.transformer.wte(x1) |
| | x2 = self.transformer.wte(x2) |
| | x3 = self.transformer.wte(x3) |
| | x4 = self.transformer.wte(x4) |
| | x5 = self.transformer.wte(x5) |
| | x6 = self.transformer.wte(x6) |
| | x7 = self.transformer.wte(x7) |
| |
|
| | x = (x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7) / 8 |
| |
|
| | if self.config.scale_embeddings: |
| | x = x * (self.config.n_embd**0.5) |
| |
|
| | for block in self.transformer.h: |
| | x = block(x, cos, sin, mask, input_pos) |
| |
|
| |
|
| | text_vocab_size = self.config.text_vocab_size |
| | audio_vocab_size = self.config.audio_vocab_size |
| |
|
| | x_ori = x |
| | x_ori = self.transformer.ln_f(x_ori) |
| | x_ori = self.lm_head(x_ori) |
| | xt = x_ori[..., :text_vocab_size] |
| |
|
| | if self.config.post_adapter: |
| | for block in self.transformer.post_adapter: |
| | x = block(x, cos, sin, mask, input_pos) |
| | x = self.transformer.post_adapter_audio_ln(x) |
| | x = self.transformer.post_adapter_audio_lm_head(x) |
| | xa = [] |
| | for i in range(7): |
| | xa.append(x[..., audio_vocab_size * i : audio_vocab_size * (i + 1)]) |
| | else: |
| | xa = [] |
| | for i in range(7): |
| | xa.append(x_ori[..., text_vocab_size + audio_vocab_size * i : text_vocab_size + audio_vocab_size * (i + 1)]) |
| |
|
| | return xa, xt |
| |
|
| | @classmethod |
| | def from_name(cls, name: str, **kwargs: Any) -> Self: |
| | return cls(Config.from_name(name, **kwargs)) |
| |
|
| | def rope_cache( |
| | self, device: Optional[torch.device] = None |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | return build_rope_cache( |
| | seq_len=self.max_seq_length, |
| | n_elem=self.config.rope_n_elem, |
| | device=device, |
| | condense_ratio=self.config.rope_condense_ratio, |
| | base=self.config.rope_base, |
| | ) |
| |
|
| | def set_kv_cache( |
| | self, |
| | batch_size: int, |
| | rope_cache_length: Optional[int] = None, |
| | device: Optional[torch.device] = None, |
| | dtype: Optional[torch.dtype] = None, |
| | ) -> None: |
| | if rope_cache_length is None: |
| | rope_cache_length = self.cos.size(-1) |
| | max_seq_length = self.max_seq_length |
| |
|
| | |
| | for block in self.transformer.h: |
| | block.attn.kv_cache = block.attn.build_kv_cache( |
| | batch_size, max_seq_length, rope_cache_length, device, dtype |
| | ) |
| | if self.config.post_adapter: |
| | for block in self.transformer.post_adapter: |
| | block.attn.kv_cache = block.attn.build_kv_cache( |
| | batch_size, max_seq_length, rope_cache_length, device, dtype |
| | ) |
| |
|
| | if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length: |
| | |
| | |
| | self.mask_cache = build_mask_cache(max_seq_length, device) |
| |
|
| | def clear_kv_cache(self) -> None: |
| | self.mask_cache = None |
| | for block in self.transformer.h: |
| | block.attn.kv_cache = None |
| |
|
| |
|
| | class visionMLP(nn.Module): |
| | def __init__(self, config: Config) -> None: |
| | super().__init__() |
| | vision_adapter_dim = config.vision_adapter_dim |
| | self.fc_1 = nn.Linear(vision_adapter_dim, config.intermediate_size, bias=config.bias) |
| | self.fc_2 = nn.Linear(vision_adapter_dim, config.intermediate_size, bias=config.bias) |
| | self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) |
| |
|
| | self.config = config |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x_fc_1 = self.fc_1(x) |
| | x_fc_2 = self.fc_2(x) |
| | x = torch.nn.functional.silu(x_fc_1) * x_fc_2 |
| | return self.proj(x) |
| |
|
| |
|
| | class Block(nn.Module): |
| |
|
| | def __init__(self, config: Config) -> None: |
| | super().__init__() |
| | if not config.parallel_residual and config.shared_attention_norm: |
| | raise NotImplementedError( |
| | "No checkpoint amongst the ones we support uses this configuration" |
| | " (non-parallel residual and shared attention norm)." |
| | ) |
| |
|
| | self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) |
| | self.attn = CausalSelfAttention(config) |
| | self.norm_2 = ( |
| | None |
| | if config.shared_attention_norm |
| | else config.norm_class(config.n_embd, eps=config.norm_eps) |
| | ) |
| | self.mlp = config.mlp_class(config) |
| |
|
| | self.config = config |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | cos: torch.Tensor, |
| | sin: torch.Tensor, |
| | mask: Optional[torch.Tensor] = None, |
| | input_pos: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | """ |
| | Non-parallel residual Parallel residual |
| | ┌─ x ┌─ x ────────────┐ Note: if `shared_attention_norm` is True, |
| | │ ↓ │ ↓ ↓ the output from `norm_1` is reused |
| | │ norm_1 │ norm_1 ───► norm_2 |
| | │ ↓ │ ↓ ↓ |
| | │ attn │ attn mlp |
| | │ ↓ │ ↓ │ |
| | ┌─ └► + └► + ◄───────────┘ |
| | │ norm_2 |
| | │ ↓ |
| | │ mlp |
| | │ ↓ |
| | └───► + |
| | """ |
| |
|
| | x_normed = self.norm_1(x) |
| | attention_output = self.attn(x_normed, cos, sin, mask, input_pos) |
| |
|
| | if self.config.parallel_residual: |
| | x_normed = x_normed if self.config.shared_attention_norm else self.norm_2(x) |
| | x = self.mlp(x_normed) + attention_output + x |
| | else: |
| | x = attention_output + x |
| | x = self.mlp(self.norm_2(x)) + x |
| | return x |
| |
|
| |
|
| | class CausalSelfAttention(nn.Module): |
| | def __init__(self, config: Config) -> None: |
| | super().__init__() |
| | shape = (config.n_head + 2 * config.n_query_groups) * config.head_size |
| | |
| | self.attn = nn.Linear(config.n_embd, shape, bias=config.add_qkv_bias) |
| | |
| | |
| | self.proj = nn.Linear( |
| | config.head_size * config.n_head, config.n_embd, bias=config.bias |
| | ) |
| | |
| | self.kv_cache: Optional[KVCache] = None |
| |
|
| | self.config = config |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | cos: torch.Tensor, |
| | sin: torch.Tensor, |
| | mask: Optional[torch.Tensor] = None, |
| | input_pos: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | B, T, C = ( |
| | x.size() |
| | ) |
| |
|
| | qkv = self.attn(x) |
| |
|
| | |
| | q_per_kv = self.config.n_head // self.config.n_query_groups |
| | total_qkv = q_per_kv + 2 |
| | qkv = qkv.view( |
| | B, T, self.config.n_query_groups, total_qkv, self.config.head_size |
| | ) |
| | qkv = qkv.permute(0, 2, 3, 1, 4) |
| |
|
| | |
| | q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) |
| |
|
| | |
| | |
| | |
| | if self.config.n_query_groups != self.config.n_head and ( |
| | input_pos is None or self.config.n_query_groups != 1 |
| | ): |
| | k = k.expand( |
| | B, self.config.n_query_groups, q_per_kv, T, self.config.head_size |
| | ) |
| | v = v.expand( |
| | B, self.config.n_query_groups, q_per_kv, T, self.config.head_size |
| | ) |
| |
|
| | q = q.reshape(B, -1, T, self.config.head_size) |
| | k = k.reshape(B, -1, T, self.config.head_size) |
| | v = v.reshape(B, -1, T, self.config.head_size) |
| |
|
| | q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) |
| | k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) |
| | q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) |
| | k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) |
| |
|
| | if input_pos is not None: |
| | if not isinstance(self.kv_cache, KVCache): |
| | raise TypeError("You need to call `gpt.set_kv_cache()`") |
| | k, v = self.kv_cache(input_pos, k, v) |
| |
|
| | y = self.scaled_dot_product_attention(q, k, v, mask) |
| |
|
| | y = y.reshape( |
| | B, T, self.config.head_size * self.config.n_head |
| | ) |
| |
|
| | |
| | return self.proj(y) |
| |
|
| | def scaled_dot_product_attention( |
| | self, |
| | q: torch.Tensor, |
| | k: torch.Tensor, |
| | v: torch.Tensor, |
| | mask: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | scale = 1.0 / math.sqrt(self.config.head_size) |
| | y = torch.nn.functional.scaled_dot_product_attention( |
| | q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None |
| | ) |
| | return y.transpose(1, 2) |
| |
|
| | def build_kv_cache( |
| | self, |
| | batch_size: int, |
| | max_seq_length: int, |
| | rope_cache_length: Optional[int] = None, |
| | device: Optional[torch.device] = None, |
| | dtype: Optional[torch.dtype] = None, |
| | ) -> "KVCache": |
| | heads = 1 if self.config.n_query_groups == 1 else self.config.n_head |
| | v_shape = (batch_size, heads, max_seq_length, self.config.head_size) |
| | if rope_cache_length is None: |
| | if self.config.rotary_percentage != 1.0: |
| | raise TypeError( |
| | "Please pass the `rope_cache_length=gpt.cos.size(-1)` value" |
| | ) |
| | k_shape = v_shape |
| | else: |
| | k_shape = ( |
| | batch_size, |
| | heads, |
| | max_seq_length, |
| | rope_cache_length + self.config.head_size - self.config.rope_n_elem, |
| | ) |
| | return KVCache(k_shape, v_shape, device=device, dtype=dtype) |
| |
|
| |
|
| | class GptNeoxMLP(nn.Module): |
| | def __init__(self, config: Config) -> None: |
| | super().__init__() |
| | self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) |
| | self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) |
| |
|
| | self.config = config |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x = self.fc(x) |
| | x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate) |
| | return self.proj(x) |
| |
|
| |
|
| | class LLaMAMLP(nn.Module): |
| | def __init__(self, config: Config) -> None: |
| | super().__init__() |
| | self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) |
| | self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) |
| | self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) |
| |
|
| | self.config = config |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x_fc_1 = self.fc_1(x) |
| | x_fc_2 = self.fc_2(x) |
| | x = torch.nn.functional.silu(x_fc_1) * x_fc_2 |
| | return self.proj(x) |
| |
|
| |
|
| | class whisperMLP(nn.Module): |
| | def __init__(self, config: Config) -> None: |
| | super().__init__() |
| | self.fc_1 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias) |
| | self.fc_2 = nn.Linear(config.whisper_adapter_dim, config.intermediate_size, bias=config.bias) |
| | self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) |
| |
|
| | self.config = config |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x_fc_1 = self.fc_1(x) |
| | x_fc_2 = self.fc_2(x) |
| | x = torch.nn.functional.silu(x_fc_1) * x_fc_2 |
| | return self.proj(x) |
| |
|
| |
|
| | class GemmaMLP(LLaMAMLP): |
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | x_fc_1 = self.fc_1(x) |
| | x_fc_2 = self.fc_2(x) |
| | x = ( |
| | torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate) |
| | * x_fc_2 |
| | ) |
| | return self.proj(x) |
| |
|
| |
|
| | class LLaMAMoE(nn.Module): |
| | def __init__(self, config: Config) -> None: |
| | super().__init__() |
| | self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False) |
| | self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert)) |
| |
|
| | self.config = config |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Derived from: https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219 |
| | See also figure 1 in https://arxiv.org/abs/2211.15841 |
| | """ |
| | B, T, C = ( |
| | x.size() |
| | ) |
| | x = x.view(-1, C) |
| | router = self.gate(x) |
| | probs, indices = torch.topk( |
| | router, self.config.n_expert_per_token |
| | ) |
| | probs = probs.softmax(dim=1, dtype=torch.float).to(dtype=x.dtype) |
| | masks = indices.unsqueeze(-1) == torch.arange( |
| | self.config.n_expert, device=x.device |
| | ) |
| | masks = masks.permute(2, 0, 1) |
| | y = torch.zeros_like(x) |
| | for mask, expert in zip(masks, self.experts): |
| | token_idx, expert_idx = torch.where(mask) |
| | y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx]) |
| | return y.view(B, T, C) |
| |
|
| |
|
| | def build_rope_cache( |
| | seq_len: int, |
| | n_elem: int, |
| | device: Optional[torch.device] = None, |
| | base: int = 10000, |
| | condense_ratio: int = 1, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Enhanced Transformer with Rotary Position Embedding. |
| | |
| | Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ |
| | transformers/rope/__init__.py. MIT License: |
| | https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. |
| | """ |
| | |
| | theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) |
| |
|
| | |
| | seq_idx = torch.arange(seq_len, device=device) / condense_ratio |
| |
|
| | |
| | idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) |
| |
|
| | return torch.cos(idx_theta), torch.sin(idx_theta) |
| |
|
| |
|
| | def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
| | head_size = x.size(-1) |
| | x1 = x[..., : head_size // 2] |
| | x2 = x[..., head_size // 2 :] |
| | rotated = torch.cat((-x2, x1), dim=-1) |
| | roped = (x * cos) + (rotated * sin) |
| | return roped.to(dtype=x.dtype) |
| |
|
| |
|
| | class KVCache(nn.Module): |
| | def __init__( |
| | self, |
| | k_shape: Tuple[int, int, int, int], |
| | v_shape: Tuple[int, int, int, int], |
| | device: Optional[torch.device] = None, |
| | dtype: Optional[torch.dtype] = None, |
| | ) -> None: |
| | super().__init__() |
| | self.register_buffer( |
| | "k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False |
| | ) |
| | self.register_buffer( |
| | "v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False |
| | ) |
| |
|
| | def forward( |
| | self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | |
| | self.k = self.k.to(k.dtype) |
| | self.v = self.v.to(v.dtype) |
| | |
| | k = self.k.index_copy_(2, input_pos, k) |
| | v = self.v.index_copy_(2, input_pos, v) |
| | return k, v |
| |
|
| | def reset_parameters(self) -> None: |
| | torch.nn.init.zeros_(self.k) |
| | torch.nn.init.zeros_(self.v) |
| |
|
| |
|
| | def build_mask_cache( |
| | max_seq_length: int, device: Optional[torch.device] = None |
| | ) -> torch.Tensor: |
| | ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool) |
| | return torch.tril(ones).unsqueeze(0).unsqueeze(0) |
| |
|
| |
|
| | class RMSNorm(torch.nn.Module): |
| | """Root Mean Square Layer Normalization. |
| | |
| | Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: |
| | https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. |
| | """ |
| |
|
| | def __init__( |
| | self, size: int, dim: int = -1, eps: float = 1e-6, add_unit_offset: bool = False |
| | ) -> None: |
| | super().__init__() |
| | self.weight = torch.nn.Parameter(torch.ones(size)) |
| | self.eps = eps |
| | self.dim = dim |
| | self.add_unit_offset = add_unit_offset |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | dtype = x.dtype |
| | x = x.float() |
| | |
| | norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) |
| | x_normed = x * torch.rsqrt(norm_x + self.eps) |
| | x_normed = x_normed.to(dtype=dtype) |
| | if self.add_unit_offset: |
| | |
| | |
| | return x_normed * (1 + self.weight) |
| | return x_normed * self.weight |
| |
|
| | def reset_parameters(self) -> None: |
| | torch.nn.init.ones_(self.weight) |
| |
|