"""TinyBuddy 100K — 84K parameter Llama-style model for Transformers.""" import math from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from transformers.modeling_utils import PreTrainedModel from transformers.modeling_outputs import CausalLMOutputWithPast from .configuration_tinybuddy import TinyBuddyConfig class RMSNorm(nn.Module): def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight def precompute_rope_cos_sin(head_dim, max_seq_len, theta=10000.0): freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim)) t = torch.arange(max_seq_len, dtype=torch.float32) freqs = torch.outer(t, freqs) return freqs.cos(), freqs.sin() def apply_rotary_emb(xq, xk, cos, sin): *_, seq_len, head_dim = xq.shape cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) cos = cos.repeat_interleave(2, dim=-1) sin = sin.repeat_interleave(2, dim=-1) def rotate(x): x1, x2 = x[..., ::2], x[..., 1::2] return x * cos + torch.cat([-x2, x1], dim=-1) * sin return rotate(xq), rotate(xk) class CausalSelfAttention(nn.Module): def __init__(self, config): super().__init__() self.n_heads = config.num_attention_heads self.n_kv_heads = config.num_key_value_heads self.head_dim = config.hidden_size // self.n_heads self.n_rep = self.n_heads // self.n_kv_heads self.q_proj = nn.Linear(config.hidden_size, self.n_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(config.hidden_size, self.n_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, self.n_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.n_heads * self.head_dim, config.hidden_size, bias=False) mask = torch.triu(torch.ones(config.block_size, config.block_size), diagonal=1).bool() self.register_buffer("causal_mask", mask) cos, sin = precompute_rope_cos_sin(self.head_dim, config.block_size, config.rope_theta) self.register_buffer("rope_cos", cos) self.register_buffer("rope_sin", sin) def forward(self, x): B, T, C = x.shape q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) q, k = apply_rotary_emb(q, k, self.rope_cos, self.rope_sin) if self.n_rep > 1: k = k.repeat_interleave(self.n_rep, dim=1) v = v.repeat_interleave(self.n_rep, dim=1) att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) att = att.masked_fill(self.causal_mask[:T, :T], float("-inf")) att = F.softmax(att, dim=-1) y = (att @ v).transpose(1, 2).contiguous().view(B, T, C) return self.o_proj(y) class FeedForward(nn.Module): def __init__(self, config): super().__init__() self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) def forward(self, x): return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class TransformerBlock(nn.Module): def __init__(self, config): super().__init__() self.attn_norm = RMSNorm(config.hidden_size, config.rms_norm_eps) self.attn = CausalSelfAttention(config) self.ffn_norm = RMSNorm(config.hidden_size, config.rms_norm_eps) self.ffn = FeedForward(config) def forward(self, x): x = x + self.attn(self.attn_norm(x)) x = x + self.ffn(self.ffn_norm(x)) return x class TinyBuddyForCausalLM(PreTrainedModel): config_class = TinyBuddyConfig base_model_prefix = "model" supports_gradient_checkpointing = False _no_split_modules = ["TransformerBlock"] def __init__(self, config): super().__init__(config) self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)]) self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head.weight = self.token_embedding.weight self.post_init() def _tie_weights(self): if self.config.tie_word_embeddings: self.lm_head.weight = self.token_embedding.weight def _init_weights(self, module): std = 0.02 if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) def get_input_embeddings(self): return self.token_embedding def set_input_embeddings(self, value): self.token_embedding = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): if past_key_values is not None: input_ids = input_ids[:, -1:] return {"input_ids": input_ids, "attention_mask": attention_mask} def _reorder_cache(self, past_key_values, beam_idx): return past_key_values @property def num_parameters(self): return sum(p.numel() for p in self.parameters()) def forward(self, input_ids, attention_mask=None, labels=None, **kwargs): x = self.token_embedding(input_ids) for layer in self.layers: x = layer(x) x = self.norm(x) logits = self.lm_head(x) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) return CausalLMOutputWithPast(loss=loss, logits=logits)