| | import math
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| |
|
| |
|
| | class RMSNorm(nn.Module):
|
| | """
|
| | Root Mean Square Layer Normalization (RMSNorm).
|
| |
|
| | Normalizes the input across the last dimension using RMS normalization,
|
| | which scales the input without subtracting the mean. Commonly used as a
|
| | lighter alternative to LayerNorm in transformer models.
|
| |
|
| | Args:
|
| | cfg: A configuration object containing:
|
| | - lm_hidden_dim (int): The dimensionality of the model hidden states.
|
| | - lm_rms_eps (float): A small constant to avoid division by zero.
|
| | """
|
| | def __init__(self, cfg):
|
| | super().__init__()
|
| | self.weight = nn.Parameter(torch.ones(cfg.lm_hidden_dim))
|
| | self.eps = cfg.lm_rms_eps
|
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| | """
|
| | Forward pass for RMSNorm.
|
| |
|
| | Args:
|
| | x (torch.Tensor): Input tensor of shape (batch_size, sequence_length, lm_hidden_dim).
|
| |
|
| | Returns:
|
| | torch.Tensor: Normalized tensor of the same shape as input.
|
| | """
|
| |
|
| | irms = torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
|
| | x = x * irms * self.weight
|
| |
|
| | return x
|
| |
|
| |
|
| |
|
| | class RotaryEmbedding(nn.Module):
|
| | """
|
| | Compute Rotary Embedding to introduce positional dependency to input sequence without additional training parameters and
|
| | relative distance of token position ids through angle rotation.
|
| |
|
| | Args:
|
| | cfg: Configuration object containing:
|
| | - lm_hidden_dim (int): Hidden dimension size.
|
| | - lm_n_heads (int): Number of attention heads.
|
| | - lm_re_base (float): Base for rotary embedding frequencies.
|
| | - lm_max_position_embeddings (int): Max sequence length supported for rotary embedding.
|
| | - lm_attn_scaling (float): Attention scaling factor.
|
| | """
|
| |
|
| | def __init__(self, cfg):
|
| | super().__init__()
|
| | assert cfg.lm_hidden_dim % cfg.lm_n_heads == 0, "Hidden dimension must be divisible by number of heads"
|
| |
|
| | self.dim = cfg.lm_hidden_dim // cfg.lm_n_heads
|
| | self.base = cfg.lm_re_base
|
| | self.max_seq_len = cfg.lm_max_position_embeddings
|
| |
|
| |
|
| | inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
| | self.register_buffer("inv_freq", inv_freq)
|
| | self.original_max_seq_len = cfg.lm_max_position_embeddings
|
| | self.attention_scaling = cfg.lm_attn_scaling
|
| |
|
| | @torch.no_grad()
|
| | def forward(self, position_ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| | """
|
| | Compute rotary positional embeddings (cosine and sine components).
|
| |
|
| | Args:
|
| | position_ids (torch.Tensor): Tensor of shape (batch_size, seq_len) containing position indices.
|
| |
|
| | Returns:
|
| | Tuple[torch.Tensor, torch.Tensor]: Tuple of two tensors (cos, sin), each of shape
|
| | (batch_size, seq_len, dim), representing rotary embeddings.
|
| | """
|
| |
|
| | batch_size, seq_len = position_ids.shape
|
| |
|
| |
|
| | max_seq = position_ids.max() + 1
|
| | if max_seq > self.original_max_seq_len:
|
| | scale = max_seq / self.original_max_seq_len
|
| | inv_freq = self.inv_freq / scale
|
| | else:
|
| | inv_freq = self.inv_freq
|
| |
|
| |
|
| |
|
| | flat_position_ids = position_ids.reshape(-1).float()
|
| |
|
| |
|
| | freqs = flat_position_ids.unsqueeze(-1) * inv_freq.unsqueeze(0)
|
| |
|
| |
|
| | freqs = freqs.reshape(batch_size, seq_len, -1)
|
| |
|
| |
|
| | emb = torch.cat([freqs, freqs], dim=-1)
|
| |
|
| |
|
| | cos = torch.cos(emb) * self.attention_scaling
|
| | sin = torch.sin(emb) * self.attention_scaling
|
| |
|
| | return cos, sin
|
| |
|
| | def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
| | """
|
| | Rotates the input by dividing the hidden dimension to two, then swapping and negating dimensions.
|
| | """
|
| | x1, x2 = x.chunk(2, dim=-1)
|
| | return torch.cat((-x2, x1), dim=-1)
|
| |
|
| |
|
| | def apply_rotary_pos_embd(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, unsqueeze_dim:int=1)-> tuple[torch.Tensor, torch.Tensor]:
|
| | """
|
| | Applies rotary positional embeddings to query and key tensors in attention mechanisms.
|
| |
|
| | Rotary positional embeddings inject position-dependent rotations into query and key vectors,
|
| | enabling transformers to encode positional information effectively without explicit positional encoding.
|
| |
|
| | Args:
|
| | q (torch.Tensor): Query tensor with shape [batch_size, num_heads, seq_len, head_dim].
|
| | k (torch.Tensor): Key tensor with shape [batch_size, num_heads, seq_len, head_dim].
|
| | cos (torch.Tensor): Precomputed cosine positional embeddings with shape [batch_size, seq_len, head_dim].
|
| | sin (torch.Tensor): Precomputed sine positional embeddings with shape [batch_size, seq_len, head_dim].
|
| | unsqueeze_dim (int, optional): Dimension index to unsqueeze `cos` and `sin` to enable broadcasting.
|
| | Defaults to 1 (typically the heads dimension).
|
| |
|
| | Returns:
|
| | tuple[torch.Tensor, torch.Tensor]: The rotated query and key tensors (`q_embed`, `k_embed`),
|
| | each with the same shape as the input tensors.
|
| |
|
| | How it works:
|
| | - `cos` and `sin` tensors are unsqueezed at `unsqueeze_dim` to broadcast across attention heads.
|
| | - Rotary embeddings apply a complex number rotation in the embedding space using:
|
| | rotated = (original * cos) + (rotate_half(original) * sin)
|
| | - `rotate_half` performs a specific half-dimension rotation on the input tensor.
|
| | - This operation encodes relative position information in q and k without adding explicit positional vectors.
|
| |
|
| | Example:
|
| | q_embed, k_embed = apply_rotary_pos_embd(q, k, cos, sin)
|
| |
|
| | """
|
| |
|
| |
|
| |
|
| | cos = cos.unsqueeze(unsqueeze_dim)
|
| | sin = sin.unsqueeze(unsqueeze_dim)
|
| |
|
| |
|
| |
|
| | q_embed = (q * cos) + (rotate_half(q) * sin)
|
| | k_embed = (k * cos) + (rotate_half(k) * sin)
|
| |
|
| | return q_embed, k_embed
|
| |
|
| |
|
| |
|
| | class LanguageModelGroupedQueryAttention(nn.Module):
|
| | """
|
| | Implements Grouped Query Attention (GQA) as used in some transformer-based language models.
|
| |
|
| | GQA reduces computation by using fewer key-value heads than query heads,
|
| | grouping multiple query heads to share the same key-value heads.
|
| |
|
| | Args:
|
| | cfg: Configuration object containing:
|
| | - lm_n_heads (int): Number of query heads.
|
| | - lm_n_kv_heads (int): Number of key-value heads.
|
| | - lm_hidden_dim (int): Hidden embedding dimension.
|
| | - lm_dropout (float): Dropout rate.
|
| | """
|
| | def __init__(self, cfg):
|
| | super().__init__()
|
| |
|
| | self.n_heads = cfg.lm_n_heads
|
| | self.n_kv_heads = cfg.lm_n_kv_heads
|
| | self.embd_dim = cfg.lm_hidden_dim
|
| | self.dropout = cfg.lm_dropout
|
| |
|
| | assert self.n_heads % self.n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"
|
| | assert self.embd_dim % self.n_heads == 0, "embd_dim must be divisible by num_heads"
|
| |
|
| | self.n_kv_groups = self.n_heads // self.n_kv_heads
|
| | self.head_dim = self.embd_dim // self.n_heads
|
| |
|
| | self.q_proj = nn.Linear(self.embd_dim, self.embd_dim, bias=False)
|
| | self.k_proj = nn.Linear(self.embd_dim, self.head_dim * self.n_kv_heads, bias=False)
|
| | self.v_proj = nn.Linear(self.embd_dim, self.head_dim * self.n_kv_heads, bias=False)
|
| | self.out_proj = nn.Linear(self.embd_dim, self.embd_dim, bias=False)
|
| |
|
| | self.attn_dropout = nn.Dropout(self.dropout)
|
| | self.resid_dropout = nn.Dropout(self.dropout)
|
| |
|
| |
|
| | self.sdpa = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
| | if not self.sdpa:
|
| | print("Warning: scaled dot product attention not available, using standard attention in LM.")
|
| |
|
| | def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, attention_mask=None, block_kv_cache=None) -> tuple[torch.Tensor, dict]:
|
| | """
|
| | Forward pass for grouped query attention.
|
| |
|
| | Args:
|
| | x (Tensor): Input tensor of shape (B, T_curr, C), where
|
| | B = batch size,
|
| | T_curr = current sequence length,
|
| | C = embedding dimension.
|
| | cos (Tensor): Rotary embedding cosines, shape compatible with q and k.
|
| | sin (Tensor): Rotary embedding sines, shape compatible with q and k.
|
| | attention_mask (Tensor, optional): Attention mask tensor of shape (B, total_kv_length),
|
| | with 1 for tokens to attend to and 0 for padding.
|
| | block_kv_cache (dict, optional): Cache dict with 'key' and 'value' tensors for autoregressive decoding.
|
| |
|
| | Returns:
|
| | tuple[Tensor, dict]:
|
| | - Output tensor after attention and projection, shape (B, T_curr, C).
|
| | - Updated block_kv_cache dict for caching key-value states.
|
| | """
|
| | is_prefill = block_kv_cache is None
|
| |
|
| | B, T_curr, C = x.size()
|
| |
|
| | q_curr = self.q_proj(x).view(B, T_curr, self.n_heads, self.head_dim).transpose(1, 2)
|
| | k_curr = self.k_proj(x).view(B, T_curr, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
| | v_curr = self.v_proj(x).view(B, T_curr, self.n_kv_heads, self.head_dim).transpose(1, 2)
|
| |
|
| |
|
| | q, k_rotated = apply_rotary_pos_embd(q_curr, k_curr, cos, sin)
|
| |
|
| |
|
| | if not is_prefill and block_kv_cache['key'] is not None:
|
| |
|
| |
|
| | k = block_kv_cache['key']
|
| | v = block_kv_cache['value']
|
| | k = torch.cat([k, k_rotated], dim=2)
|
| | v = torch.cat([v, v_curr], dim=2)
|
| | block_kv_cache['key'] = k
|
| | block_kv_cache['value'] = v
|
| | else:
|
| |
|
| | k = k_rotated
|
| | v = v_curr
|
| | block_kv_cache = {'key': k, 'value': v}
|
| |
|
| |
|
| | k_exp = k.repeat_interleave(self.n_kv_groups, dim=1)
|
| | v_exp = v.repeat_interleave(self.n_kv_groups, dim=1)
|
| |
|
| | T_kv = k_exp.size(2)
|
| |
|
| |
|
| |
|
| | additive_attn_mask = None
|
| | if attention_mask is not None:
|
| |
|
| |
|
| | mask_for_keys = attention_mask[:, :T_kv]
|
| | additive_attn_mask = (1.0 - mask_for_keys.unsqueeze(1).unsqueeze(2).float()) * torch.finfo(q.dtype).min
|
| |
|
| |
|
| | if self.sdpa and x.device.type != 'mps':
|
| |
|
| | is_causal = (T_curr == T_kv and T_curr > 1)
|
| | y = torch.nn.functional.scaled_dot_product_attention(
|
| | q, k_exp, v_exp,
|
| | attn_mask=additive_attn_mask,
|
| | dropout_p=self.dropout if self.training else 0.0,
|
| | is_causal=is_causal
|
| | )
|
| | else:
|
| |
|
| | attn = torch.matmul(q, k_exp.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| |
|
| | if T_curr == T_kv and T_curr > 1:
|
| | causal_mask_val = torch.tril(torch.ones(T_curr, T_curr, device=x.device, dtype=torch.bool)).view(1, 1, T_curr, T_curr)
|
| | attn = attn.masked_fill(~causal_mask_val, float('-inf'))
|
| |
|
| | if additive_attn_mask is not None:
|
| |
|
| | attn = attn + additive_attn_mask
|
| |
|
| | attn = F.softmax(attn, dim=-1)
|
| | attn = self.attn_dropout(attn)
|
| | y = attn @ v_exp
|
| |
|
| | y = y.transpose(1, 2).contiguous().view(B, T_curr, C)
|
| | y = self.out_proj(y)
|
| | y = self.resid_dropout(y)
|
| |
|
| | return y, block_kv_cache
|
| |
|
| |
|
| | class LanguageModelMLP(nn.Module):
|
| | """
|
| | Implements the feed-forward network (MLP) block used in transformer-based language models.
|
| |
|
| | This MLP uses a gated activation mechanism where two separate linear projections
|
| | are applied to the input: one passed through an activation function (gate_proj),
|
| | and the other as is (up_proj). Their element-wise product is then projected back
|
| | to the embedding dimension (down_proj).
|
| |
|
| | Args:
|
| | cfg: Configuration object containing:
|
| | - lm_hidden_dim (int): The embedding dimension size.
|
| | - lm_inter_dim (int): The intermediate dimension size for the MLP.
|
| |
|
| | Attributes:
|
| | activation_fn (Callable): The activation function used (SiLU).
|
| | gate_proj (nn.Linear): Linear projection for gating pathway.
|
| | up_proj (nn.Linear): Linear projection for upscaling pathway.
|
| | down_proj (nn.Linear): Linear projection for downscaling back to embedding dim.
|
| | """
|
| |
|
| | def __init__(self, cfg):
|
| | super().__init__()
|
| | self.embd_dim = cfg.lm_hidden_dim
|
| | self.inter_dim = cfg.lm_inter_dim
|
| |
|
| | self.activation_fn = F.silu
|
| | self.gate_proj = nn.Linear(self.embd_dim, self.inter_dim, bias=False)
|
| | self.up_proj = nn.Linear(self.embd_dim, self.inter_dim, bias=False)
|
| | self.down_proj = nn.Linear(self.inter_dim, self.embd_dim, bias=False)
|
| |
|
| | def forward(self, x):
|
| | """
|
| | Forward pass through the gated MLP block.
|
| |
|
| | Args:
|
| | x (Tensor): Input tensor of shape (batch_size, seq_length, embd_dim).
|
| |
|
| | Returns:
|
| | Tensor: Output tensor of shape (batch_size, seq_length, embd_dim),
|
| | after gated MLP transformation.
|
| | """
|
| | gate = self.activation_fn(self.gate_proj(x))
|
| | x = self.up_proj(x)
|
| | x = self.down_proj(gate * x)
|
| |
|
| | return x
|
| |
|
| |
|
| | class LanguageModelBlock(nn.Module):
|
| | def __init__(self, cfg):
|
| | super().__init__()
|
| | self.mlp = LanguageModelMLP(cfg)
|
| | self.attn = LanguageModelGroupedQueryAttention(cfg)
|
| | self.norm1 = RMSNorm(cfg)
|
| | self.norm2 = RMSNorm(cfg)
|
| |
|
| | def forward(self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, attention_mask: torch.Tensor=None, block_kv_cache: dict=None):
|
| | """
|
| | Forward pass of the Transformer block.
|
| |
|
| | Args:
|
| | x (Tensor): Input tensor of shape (batch_size, seq_len, hidden_dim).
|
| | cos (Tensor): Cosine positional embeddings for rotary embedding, shape
|
| | matching sequence length and head dimension.
|
| | sin (Tensor): Sine positional embeddings for rotary embedding, same shape as cos.
|
| | attention_mask (Tensor, optional): Attention mask of shape (batch_size, total_kv_length),
|
| | with 1 indicating tokens to attend to and 0 for padding tokens.
|
| | block_kv_cache (dict, optional): Key-value cache dict for cached keys and values
|
| | during decoding. If None, no cache is used.
|
| |
|
| | Returns:
|
| | Tuple[Tensor, dict]: Output tensor after the block (same shape as input),
|
| | and the updated key-value cache dictionary.
|
| | """
|
| | res = x
|
| | x = self.norm1(x)
|
| | x, block_kv_cache = self.attn(x, cos, sin, attention_mask, block_kv_cache)
|
| | x = res + x
|
| |
|
| | res = x
|
| | x = self.norm2(x)
|
| | x = self.mlp(x)
|
| | x = res + x
|
| |
|
| | return x, block_kv_cache
|
| |
|
| |
|
| | class LanguageModel(nn.Module):
|
| | def __init__(self, cfg):
|
| | super().__init__()
|
| | self.cfg = cfg
|
| | self.lm_use_tokens = cfg.lm_use_tokens
|
| | self.lm_tie_weights = cfg.lm_tie_weights
|
| |
|
| | self.token_embedding = nn.Embedding(cfg.lm_vocab_size, cfg.lm_hidden_dim)
|
| | self.rotary_embd = RotaryEmbedding(cfg)
|
| | self.blocks = nn.ModuleList([
|
| | LanguageModelBlock(cfg) for _ in range(cfg.lm_n_blocks)
|
| | ])
|
| | self.norm = RMSNorm(cfg)
|
| | self.head = nn.Linear(cfg.lm_hidden_dim, cfg.lm_vocab_size, bias=False)
|
| | if self.lm_tie_weights:
|
| | self.head.weight = self.token_embedding.weight
|
| |
|
| | self.apply(self._init_weights)
|
| |
|
| | def _init_weights(self, module):
|
| | if isinstance(module, nn.Linear):
|
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| | if module.bias is not None:
|
| | torch.nn.init.zeros_(module.bias)
|
| | elif isinstance(module, nn.Embedding):
|
| | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| | elif isinstance(module, RMSNorm):
|
| | module.weight.data.fill_(1.0)
|
| |
|
| | def forward(self, x: torch.Tensor, attention_mask: torch.Tensor=None, kv_cache: list[dict]=None, start_pos: int=0):
|
| | """
|
| | Performs a forward pass through the language model.
|
| |
|
| | Args:
|
| | x (Tensor): Input tensor. If `lm_use_tokens` is True, this should be
|
| | token indices with shape (batch_size, sequence_length).
|
| | If False, it should be embeddings of shape (batch_size, sequence_length, hidden_dim).
|
| | attention_mask (Tensor, optional): Mask tensor for attention to
|
| | specify which tokens to attend to, typically of shape
|
| | (batch_size, sequence_length). Default is None.
|
| | kv_cache (list[dict], optional): List of key-value caches for each transformer
|
| | block to enable efficient autoregressive decoding.
|
| | If None, no cache is used and new ones are created. Default is None.
|
| | start_pos (int, optional): The starting position index for the current input
|
| | sequence. Used to compute rotary positional embeddings correctly,
|
| | especially for cached sequences during generation. Default is 0.
|
| |
|
| | Returns:
|
| | Tuple:
|
| | - Tensor: Output logits with shape (batch_size, sequence_length, vocab_size)
|
| | if `lm_use_tokens` is True, otherwise the hidden state embeddings
|
| | (batch_size, sequence_length, hidden_dim).
|
| | - list: Updated list of key-value caches, one for each transformer block,
|
| | useful for autoregressive decoding and incremental generation.
|
| |
|
| | Behavior:
|
| | - If `lm_use_tokens` is True, the input token indices are first embedded.
|
| | - Rotary positional embeddings are generated for the current input positions,
|
| | which are passed along to each transformer block.
|
| | - For each transformer block, the input is processed along with
|
| | rotary embeddings, attention mask, and optional cached key-values.
|
| | - After processing all blocks, a final RMS normalization is applied.
|
| | - If tokens are used, the normalized hidden states are projected to logits
|
| | over the vocabulary.
|
| | - The method returns the logits or embeddings along with the updated
|
| | cache for efficient decoding.
|
| | """
|
| | if self.lm_use_tokens:
|
| | x = self.token_embedding(x)
|
| |
|
| |
|
| | B, T_curr, _ = x.size()
|
| |
|
| |
|
| | current_position_ids = torch.arange(start_pos, start_pos + T_curr, device=x.device).unsqueeze(0).expand(B, -1)
|
| | cos, sin = self.rotary_embd(current_position_ids)
|
| |
|
| |
|
| | if kv_cache is None:
|
| | kv_cache = [None] * len(self.blocks)
|
| |
|
| | for i, block in enumerate(self.blocks):
|
| | x, kv_cache[i] = block(x, cos, sin, attention_mask, kv_cache[i])
|
| |
|
| | x = self.norm(x)
|
| |
|
| |
|
| | if self.lm_use_tokens:
|
| | x = self.head(x)
|
| |
|
| | return x, kv_cache
|
| |
|
| |
|
| | @torch.inference_mode()
|
| | def generate(self, inputs: torch.Tensor, max_new_tokens: int=20):
|
| | """
|
| | Generate tokens autoregressively from a given input sequence.
|
| |
|
| | Args:
|
| | inputs (torch.Tensor): Input tensor containing token indices or embeddings.
|
| | Shape: (batch_size, sequence_length) or (sequence_length,) for a single sequence.
|
| | max_new_tokens (int): Number of new tokens to generate after the input sequence.
|
| |
|
| | Returns:
|
| | torch.Tensor: The generated sequence, including the original inputs and newly generated tokens.
|
| | Shape: (batch_size, sequence_length + max_new_tokens)
|
| | """
|
| |
|
| | if inputs.dim() == 1:
|
| | inputs = inputs.unsqueeze(0)
|
| | generated_outputs = inputs.clone()
|
| |
|
| | prompt_output, kv_cache_list = self.forward(
|
| | generated_outputs,
|
| | attention_mask=None,
|
| | kv_cache=None,
|
| | start_pos=0
|
| | )
|
| | last_output = prompt_output[:, -1, :]
|
| |
|
| |
|
| | for i in range(max_new_tokens):
|
| | if self.lm_use_tokens:
|
| |
|
| | next_output = torch.argmax(last_output, dim=-1, keepdim=True)
|
| | else:
|
| |
|
| | next_output = last_output.unsqueeze(1)
|
| |
|
| | generated_outputs = torch.cat((generated_outputs, next_output), dim=1)
|
| |
|
| |
|
| | current_token_start_pos = generated_outputs.size(1) - 1
|
| |
|
| | if i == max_new_tokens - 1:
|
| | break
|
| |
|
| | decode_step_output, kv_cache_list = self.forward(
|
| | next_output,
|
| | attention_mask=None,
|
| | kv_cache=kv_cache_list,
|
| | start_pos=current_token_start_pos
|
| | )
|
| | last_output = decode_step_output[:, -1, :]
|
| |
|
| | return generated_outputs
|
| |
|
| |
|
| | @classmethod
|
| | def from_pretrained(cls, cfg):
|
| | from transformers import AutoConfig
|
| | from huggingface_hub import hf_hub_download
|
| | import safetensors
|
| | import torch.nn.init as init
|
| | import json
|
| | from huggingface_hub.utils import EntryNotFoundError
|
| |
|
| |
|
| | hf_config = AutoConfig.from_pretrained(cfg.lm_model_type)
|
| |
|
| |
|
| | original_vocab_size = hf_config.vocab_size
|
| |
|
| |
|
| |
|
| | cfg.lm_hidden_dim = hf_config.hidden_size
|
| | cfg.lm_inter_dim = hf_config.intermediate_size
|
| | cfg.lm_rms_eps = hf_config.rms_norm_eps
|
| | cfg.lm_re_base = hf_config.rope_theta
|
| | cfg.lm_max_position_embeddings = hf_config.max_position_embeddings
|
| |
|
| | if hasattr(cfg, 'lm_vocab_size'):
|
| | if cfg.lm_vocab_size < original_vocab_size:
|
| | raise ValueError(f"Config vocab size ({cfg.lm_vocab_size}) is smaller than pretrained model vocab size ({original_vocab_size})")
|
| |
|
| | else:
|
| |
|
| | cfg.lm_vocab_size = original_vocab_size
|
| |
|
| |
|
| | cfg.lm_n_heads = hf_config.num_attention_heads
|
| | cfg.lm_n_kv_heads = hf_config.num_key_value_heads
|
| | cfg.lm_dropout = hf_config.attention_dropout
|
| | cfg.lm_n_blocks = hf_config.num_hidden_layers
|
| |
|
| |
|
| | model = cls(cfg)
|
| |
|
| | try:
|
| | index_path = hf_hub_download(repo_id=cfg.lm_model_type, filename="model.safetensors.index.json")
|
| | with open(index_path, 'r') as f:
|
| | index = json.load(f)
|
| |
|
| | safetensors_filenames = sorted(list(set(index['weight_map'].values())))
|
| |
|
| | safetensors_files = [hf_hub_download(repo_id=cfg.lm_model_type, filename=fn) for fn in safetensors_filenames]
|
| | except EntryNotFoundError:
|
| | safetensors_files = [hf_hub_download(repo_id=cfg.lm_model_type, filename="model.safetensors")]
|
| |
|
| | sd = model.state_dict()
|
| |
|
| | mapping = {
|
| | 'model.embed_tokens.weight': 'token_embedding.weight',
|
| | 'model.norm.weight': 'norm.weight'
|
| | }
|
| |
|
| | for i in range(cfg.lm_n_blocks):
|
| | layer_prefix = f'model.layers.{i}.'
|
| | block_prefix = f'blocks.{i}.'
|
| |
|
| | mapping.update({
|
| | f"{layer_prefix}self_attn.q_proj.weight": f"{block_prefix}attn.q_proj.weight",
|
| | f"{layer_prefix}self_attn.k_proj.weight": f"{block_prefix}attn.k_proj.weight",
|
| | f"{layer_prefix}self_attn.v_proj.weight": f"{block_prefix}attn.v_proj.weight",
|
| | f"{layer_prefix}self_attn.o_proj.weight": f"{block_prefix}attn.out_proj.weight",
|
| | f"{layer_prefix}mlp.gate_proj.weight": f"{block_prefix}mlp.gate_proj.weight",
|
| | f"{layer_prefix}mlp.up_proj.weight": f"{block_prefix}mlp.up_proj.weight",
|
| | f"{layer_prefix}mlp.down_proj.weight": f"{block_prefix}mlp.down_proj.weight",
|
| | f"{layer_prefix}input_layernorm.weight": f"{block_prefix}norm1.weight",
|
| | f"{layer_prefix}post_attention_layernorm.weight": f"{block_prefix}norm2.weight"
|
| | })
|
| |
|
| |
|
| | has_extended_embeddings = False
|
| | loaded_keys = set()
|
| |
|
| | for safetensors_file in safetensors_files:
|
| | with safetensors.safe_open(filename=safetensors_file, framework="pt", device="cpu") as f:
|
| | for hf_key, our_key in mapping.items():
|
| | if our_key in loaded_keys:
|
| | continue
|
| |
|
| | if hf_key in f.keys() and our_key in sd:
|
| | tensor = f.get_tensor(hf_key)
|
| |
|
| |
|
| | if hf_key == 'model.embed_tokens.weight' and tensor.shape[0] != sd[our_key].shape[0]:
|
| | has_extended_embeddings = True
|
| | print(f"Extending token embeddings from {tensor.shape} to {sd[our_key].shape}")
|
| |
|
| |
|
| | sd[our_key][:tensor.shape[0]].copy_(tensor)
|
| |
|
| |
|
| | std = 0.02
|
| | init.normal_(sd[our_key][tensor.shape[0]:], mean=0.0, std=std)
|
| |
|
| | print(f"Initialized {sd[our_key].shape[0] - tensor.shape[0]} new token embeddings")
|
| | sd['head.weight'].copy_(sd[our_key])
|
| | elif tensor.shape == sd[our_key].shape:
|
| | sd[our_key].copy_(tensor)
|
| | else:
|
| | print(f"Shape mismatch for {hf_key} -> {our_key}: {tensor.shape} vs {sd[our_key].shape}")
|
| |
|
| | loaded_keys.add(our_key)
|
| |
|
| | for hf_key, our_key in mapping.items():
|
| | if our_key not in loaded_keys:
|
| | if our_key in sd:
|
| | print(f"Warning: Key {our_key} not found in any safetensors file (HF key: {hf_key})")
|
| |
|
| |
|
| | model.load_state_dict(sd)
|
| |
|
| |
|
| | if has_extended_embeddings and hasattr(model, 'head') and 'head.weight' in sd:
|
| |
|
| |
|
| | lm_head_loaded = False
|
| | for safetensors_file in safetensors_files:
|
| | with safetensors.safe_open(filename=safetensors_file, framework="pt", device="cpu") as f:
|
| | if 'lm_head.weight' in f.keys():
|
| | lm_head = f.get_tensor('lm_head.weight')
|
| | if lm_head.shape[0] != sd['head.weight'].shape[0]:
|
| | print(f"Extending LM head from {lm_head.shape} to {sd['head.weight'].shape}")
|
| |
|
| | sd['head.weight'][:lm_head.shape[0]].copy_(lm_head)
|
| |
|
| | std = 0.02
|
| | init.normal_(sd['head.weight'][lm_head.shape[0]:], mean=0.0, std=std)
|
| |
|
| | model.load_state_dict(sd)
|
| | lm_head_loaded = True
|
| | break
|
| |
|
| |
|
| | if cfg.lm_tie_weights and hasattr(model, 'head') and hasattr(model, 'token_embedding'):
|
| | model.head.weight = model.token_embedding.weight
|
| |
|
| |
|
| | print(f"Successfully loaded {cfg.lm_model_type} weights from safetensors. Model has {sum(p.numel() for p in model.parameters()):,} parameters.")
|
| | return model
|
| |
|