import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, GenerationMixin from .configuration_dwarf import DwarfConfig class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps self.scale = nn.Parameter(torch.ones(dim)) def forward(self, x): rms = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt() return (x.float() * rms).to(x.dtype) * self.scale class RotaryEmbedding(nn.Module): def __init__(self, head_dim, max_seq_len, theta=10000.0): super().__init__() assert head_dim % 2 == 0 self.head_dim = head_dim self.max_seq_len = max_seq_len self.theta = theta self.cos_cache = None self.sin_cache = None self._max = 0 def _build_cache(self, seq_len, device): inv_freq = 1.0 / (self.theta ** (torch.arange(0, self.head_dim, 2, device=device).float() / self.head_dim)) t = torch.arange(seq_len, device=device).float() freqs = torch.outer(t, inv_freq) emb = torch.cat([freqs, freqs], dim=-1) self.cos_cache = emb.cos()[None, None] self.sin_cache = emb.sin()[None, None] self._max = seq_len @staticmethod def _rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat([-x2, x1], dim=-1) def forward(self, q, k): T = q.size(2) if self.cos_cache is None or T > self._max or self.cos_cache.device != q.device: self._build_cache(max(T, self.max_seq_len), q.device) cos = self.cos_cache[:, :, :T, :] sin = self.sin_cache[:, :, :T, :] q = q * cos + self._rotate_half(q) * sin k = k * cos + self._rotate_half(k) * sin return q, k class GQAAttention(nn.Module): def __init__(self, config): super().__init__() self.n_heads = config.n_heads self.n_kv_heads = config.n_kv_heads self.n_groups = config.n_heads // config.n_kv_heads self.head_dim = config.head_dim self.q_proj = nn.Linear(config.d_model, config.n_heads * config.head_dim, bias=True) self.k_proj = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=True) self.v_proj = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=True) self.o_proj = nn.Linear(config.n_heads * config.head_dim, config.d_model, bias=False) self.rope = RotaryEmbedding(config.head_dim, config.max_seq_len, config.rope_theta) def forward(self, x): B, T, _ = x.shape q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) q, k = self.rope(q, k) if self.n_groups > 1: k = k.repeat_interleave(self.n_groups, dim=1) v = v.repeat_interleave(self.n_groups, dim=1) out = F.scaled_dot_product_attention(q, k, v, is_causal=True) out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim) return self.o_proj(out) class SwiGLUFFN(nn.Module): def __init__(self, config): super().__init__() self.gate_proj = nn.Linear(config.d_model, config.d_ff, bias=False) self.up_proj = nn.Linear(config.d_model, config.d_ff, bias=False) self.down_proj = nn.Linear(config.d_ff, config.d_model, bias=False) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class DwarfBlock(nn.Module): def __init__(self, config): super().__init__() self.norm_attn = RMSNorm(config.d_model, config.norm_eps) self.attn = GQAAttention(config) self.norm_ffn = RMSNorm(config.d_model, config.norm_eps) self.ffn = SwiGLUFFN(config) def forward(self, x): x = x + self.attn(self.norm_attn(x)) x = x + self.ffn(self.norm_ffn(x)) return x class DwarfForCausalLM(PreTrainedModel, GenerationMixin): config_class = DwarfConfig _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) self.layers = nn.ModuleList([DwarfBlock(config) for _ in range(config.n_layers)]) self.norm = RMSNorm(config.d_model, config.norm_eps) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.post_init() def tie_weights(self, **kwargs): self.lm_head.weight = self.embed_tokens.weight def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): x = self.embed_tokens(input_ids) for layer in self.layers: x = layer(x) logits = self.lm_head(self.norm(x)) loss = None if labels is not None: loss = F.cross_entropy( logits[:, :-1].contiguous().view(-1, logits.size(-1)), labels[:, 1:].contiguous().view(-1), ignore_index=-100) from transformers.modeling_outputs import CausalLMOutput return CausalLMOutput(loss=loss, logits=logits) def prepare_inputs_for_generation(self, input_ids, **kwargs): return {"input_ids": input_ids}