TinyBuddy-80K / modeling_tinybuddy.py
Eeppa's picture
Upload 9 files
702689e verified
"""TinyBuddy 100K — 84K parameter Llama-style model for Transformers."""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_tinybuddy import TinyBuddyConfig
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
def precompute_rope_cos_sin(head_dim, max_seq_len, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, head_dim, 2).float() / head_dim))
t = torch.arange(max_seq_len, dtype=torch.float32)
freqs = torch.outer(t, freqs)
return freqs.cos(), freqs.sin()
def apply_rotary_emb(xq, xk, cos, sin):
*_, seq_len, head_dim = xq.shape
cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)
cos = cos.repeat_interleave(2, dim=-1)
sin = sin.repeat_interleave(2, dim=-1)
def rotate(x):
x1, x2 = x[..., ::2], x[..., 1::2]
return x * cos + torch.cat([-x2, x1], dim=-1) * sin
return rotate(xq), rotate(xk)
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_heads = config.num_attention_heads
self.n_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // self.n_heads
self.n_rep = self.n_heads // self.n_kv_heads
self.q_proj = nn.Linear(config.hidden_size, self.n_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.n_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.n_heads * self.head_dim, config.hidden_size, bias=False)
mask = torch.triu(torch.ones(config.block_size, config.block_size), diagonal=1).bool()
self.register_buffer("causal_mask", mask)
cos, sin = precompute_rope_cos_sin(self.head_dim, config.block_size, config.rope_theta)
self.register_buffer("rope_cos", cos)
self.register_buffer("rope_sin", sin)
def forward(self, x):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
q, k = apply_rotary_emb(q, k, self.rope_cos, self.rope_sin)
if self.n_rep > 1:
k = k.repeat_interleave(self.n_rep, dim=1)
v = v.repeat_interleave(self.n_rep, dim=1)
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
att = att.masked_fill(self.causal_mask[:T, :T], float("-inf"))
att = F.softmax(att, dim=-1)
y = (att @ v).transpose(1, 2).contiguous().view(B, T, C)
return self.o_proj(y)
class FeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.attn_norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.attn = CausalSelfAttention(config)
self.ffn_norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.ffn = FeedForward(config)
def forward(self, x):
x = x + self.attn(self.attn_norm(x))
x = x + self.ffn(self.ffn_norm(x))
return x
class TinyBuddyForCausalLM(PreTrainedModel):
config_class = TinyBuddyConfig
base_model_prefix = "model"
supports_gradient_checkpointing = False
_no_split_modules = ["TransformerBlock"]
def __init__(self, config):
super().__init__(config)
self.token_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.lm_head.weight = self.token_embedding.weight
self.post_init()
def _tie_weights(self):
if self.config.tie_word_embeddings:
self.lm_head.weight = self.token_embedding.weight
def _init_weights(self, module):
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
def get_input_embeddings(self):
return self.token_embedding
def set_input_embeddings(self, value):
self.token_embedding = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "attention_mask": attention_mask}
def _reorder_cache(self, past_key_values, beam_idx):
return past_key_values
@property
def num_parameters(self):
return sum(p.numel() for p in self.parameters())
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
x = self.token_embedding(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
return CausalLMOutputWithPast(loss=loss, logits=logits)