| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import PreTrainedModel |
| from transformers.cache_utils import DynamicCache |
| from transformers.generation import GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from .configuration_seed import SeedConfig |
|
|
|
|
| class RMSNorm(nn.Module): |
|
|
| def __init__(self, dim, eps=1e-6): |
| super().__init__() |
| self.epsilon = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
| |
| def forward(self, x): |
| x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.epsilon) * self.weight |
| return x |
|
|
|
|
| class RoPEEmbedding(nn.Module): |
|
|
| def __init__(self, config, device=None): |
| super().__init__() |
| self.config = config |
| assert config.n_embd % config.n_head == 0 |
| self.head_dim = config.head_dim |
| self.rope_scaling_type = str(getattr(config, "rope_scaling_type", "none")) |
| self.rope_scaling_factor = float(getattr(config, "rope_scaling_factor", 1.0)) |
|
|
| base = float(config.rope_theta) |
| self.position_scale = 1.0 |
| self.attention_scaling = 1.0 |
|
|
| if self.rope_scaling_type == "none" or self.rope_scaling_factor == 1.0: |
| pass |
| elif self.rope_scaling_type == "yarn": |
| base = base * (self.rope_scaling_factor ** (self.head_dim / (self.head_dim - 2.0))) |
| self.attention_scaling = 0.1 * math.log(self.rope_scaling_factor) + 1.0 |
| elif self.rope_scaling_type == "ntk": |
| base = base * (self.rope_scaling_factor ** (self.head_dim / (self.head_dim - 2.0))) |
| else: |
| raise ValueError(f"Unknown rope_scaling_type={self.rope_scaling_type!r}") |
|
|
| self.base = base |
|
|
| inv_freq = 1.0 / ( |
| self.base |
| ** (torch.arange(0, self.head_dim, 2, dtype=torch.float32, device=device) / float(self.head_dim)) |
| ) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
| def forward(self, x, position_ids): |
| dtype = x.dtype |
| |
| pos = position_ids.float().unsqueeze(-1) * self.position_scale |
| inv_freq = self.inv_freq.unsqueeze(0).unsqueeze(0) |
| freqs = pos * inv_freq |
| emb = torch.cat([freqs, freqs], dim=-1) |
| |
| cos = (emb.cos() * self.attention_scaling).to(dtype) |
| sin = (emb.sin() * self.attention_scaling).to(dtype) |
| return cos, sin |
|
|
|
|
| def rotate_half(x): |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat([-x2, x1], dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q = (q * cos) + (rotate_half(q) * sin) |
| k = (k * cos) + (rotate_half(k) * sin) |
| return q, k |
|
|
|
|
| class GQA(nn.Module): |
|
|
| def __init__(self, config, layer_idx): |
| super().__init__() |
| self.layer_idx = int(layer_idx) |
| self.n_head = config.n_head |
| self.n_kv_head = int(getattr(config, "n_kv_head", config.n_head)) |
| self.n_embd = config.n_embd |
| self.block_size = int(config.block_size) |
| assert 1 <= self.n_kv_head <= self.n_head |
| assert self.n_head % self.n_kv_head == 0 |
|
|
| self.head_dim = config.head_dim |
| q_dim = self.n_head * self.head_dim |
| kv_dim = self.n_kv_head * self.head_dim |
|
|
| self.q_proj = nn.Linear(self.n_embd, q_dim, bias=False) |
| self.k_proj = nn.Linear(self.n_embd, kv_dim, bias=False) |
| self.v_proj = nn.Linear(self.n_embd, kv_dim, bias=False) |
| self.o_proj = nn.Linear(q_dim, self.n_embd, bias=False) |
|
|
| self.q_norm = RMSNorm(self.head_dim) |
| self.k_norm = RMSNorm(self.head_dim) |
|
|
| def forward(self, x, cos, sin, past_key_values=None): |
| B, T, C = x.shape |
| q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2) |
| k = self.k_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) |
| v = self.v_proj(x).view(B, T, self.n_kv_head, self.head_dim).transpose(1, 2) |
| q = self.q_norm(q) |
| k = self.k_norm(k) |
| q, k = apply_rotary_pos_emb(q, k, cos, sin) |
|
|
| past_len = 0 |
| if past_key_values is not None: |
| past_len = past_key_values.get_seq_length(self.layer_idx) |
| k, v = past_key_values.update(k, v, self.layer_idx) |
|
|
| if self.n_kv_head != self.n_head: |
| repeat_factor = self.n_head // self.n_kv_head |
| k = k.repeat_interleave(repeat_factor, dim=1) |
| v = v.repeat_interleave(repeat_factor, dim=1) |
|
|
| if past_len == 0: |
| y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=True) |
| else: |
| Tk = int(k.size(2)) |
| query_pos = past_len + torch.arange(T, device=x.device) |
| key_pos = torch.arange(Tk, device=x.device) |
| causal_mask = key_pos.unsqueeze(0) <= query_pos.unsqueeze(1) |
| attn_mask = torch.zeros((1, 1, T, Tk), device=x.device, dtype=q.dtype) |
| attn_mask = attn_mask.masked_fill(~causal_mask.view(1, 1, T, Tk), torch.finfo(q.dtype).min) |
| y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=False) |
|
|
| y = y.transpose(1, 2).contiguous().view(B, T, -1) |
| y = self.o_proj(y) |
| return y |
|
|
|
|
| class SwiGLU(nn.Module): |
|
|
| def __init__(self, config): |
| super().__init__() |
| self.n_embd = config.n_embd |
| hidden_dim = getattr(config, "mlp_hidden_dim", None) |
| if hidden_dim is None: |
| hidden_dim = int(4 * self.n_embd * 2 / 3) |
| hidden_dim = (hidden_dim + 255) // 256 * 256 |
|
|
| self.gate_proj = nn.Linear(self.n_embd, hidden_dim, bias=config.bias) |
| self.up_proj = nn.Linear(self.n_embd, hidden_dim, bias=config.bias) |
| self.down_proj = nn.Linear(hidden_dim, self.n_embd, bias=config.bias) |
|
|
| def forward(self, x): |
| gate = self.gate_proj(x) |
| up = self.up_proj(x) |
| x = self.down_proj(F.silu(gate) * up) |
| return x |
|
|
|
|
| class DecoderLayer(nn.Module): |
| |
| def __init__(self, config, layer_idx): |
| super().__init__() |
| self.input_norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) |
| self.post_attn_norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) |
| self.attn = GQA(config, layer_idx=layer_idx) |
| self.mlp = SwiGLU(config) |
|
|
| def forward(self, x, cos, sin, past_key_values=None): |
| residual = x |
| x = self.input_norm(x) |
| x = self.attn(x, cos, sin, past_key_values=past_key_values) |
| x = residual + x |
|
|
| residual = x |
| x = self.post_attn_norm(x) |
| x = self.mlp(x) |
| x = residual + x |
| return x |
|
|
|
|
| class SeedPreTrainedModel(PreTrainedModel): |
| config_class = SeedConfig |
| base_model_prefix = "model" |
| _no_split_modules = ["DecoderLayer"] |
| _skip_keys_device_placement = ["past_key_values"] |
| _supports_sdpa = True |
|
|
|
|
| class SeedForCausalLM(SeedPreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
| self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
| self.layers = nn.ModuleList([DecoderLayer(config, layer_idx=i) for i in range(config.n_layer)]) |
| self.norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| self.rope = RoPEEmbedding(config) |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.wte |
|
|
| def set_input_embeddings(self, value): |
| self.wte = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def forward( |
| self, |
| input_ids=None, |
| attention_mask=None, |
| position_ids=None, |
| past_key_values=None, |
| inputs_embeds=None, |
| labels=None, |
| use_cache=None, |
| token_type_ids=None, |
| **kwargs |
| ): |
| if inputs_embeds is None: |
| inputs_embeds = self.wte(input_ids) |
|
|
| B, T = inputs_embeds.shape[:2] |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = DynamicCache() |
|
|
| if position_ids is None: |
| past_seen = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| position_ids = torch.arange(past_seen, past_seen + T, device=inputs_embeds.device).unsqueeze(0).expand(B, T) |
|
|
| cos, sin = self.rope(inputs_embeds, position_ids) |
|
|
| x = inputs_embeds |
| for layer in self.layers: |
| x = layer(x, cos, sin, past_key_values=past_key_values) |
|
|
| x = self.norm(x) |
| logits = self.lm_head(x) |
|
|
| loss = None |
| if labels is not None: |
| loss = F.cross_entropy( |
| logits[:, :-1].contiguous().view(-1, logits.size(-1)), |
| labels[:, 1:].contiguous().view(-1) |
| ) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=past_key_values if use_cache else None |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
| ): |
| past_length = 0 |
| if past_key_values is not None: |
| past_length = past_key_values.get_seq_length() |
| if past_length > 0: |
| input_ids = input_ids[:, -1:] |
|
|
| if inputs_embeds is not None and past_length == 0: |
| model_inputs = {"inputs_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| model_inputs.update({ |
| "past_key_values": past_key_values, |
| "use_cache": True, |
| "attention_mask": attention_mask, |
| }) |
| return model_inputs |
|
|