TinyBuddy-500K / modeling_tinybuddy.py
Eeppa's picture
Upload 12 files
de58358 verified
"""
TinyBuddy-500K: Educational ~500K parameter Llama-style model
MIT License
"""
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
@dataclass
class TinyBuddyConfig(PretrainedConfig):
model_type = "tinybuddy"
vocab_size: int = 2048
hidden_size: int = 96
num_hidden_layers: int = 2
num_attention_heads: int = 4
num_key_value_heads: int = 2
intermediate_size: int = 384
max_position_embeddings: int = 512
rms_norm_eps: float = 1e-6
tie_word_embeddings: bool = True
bos_token_id: int = 2
eos_token_id: int = 2
def __init__(self, **kwargs):
super().__init__(**kwargs)
for k, v in kwargs.items():
setattr(self, k, v)
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.weight * x
class GroupedQueryAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // self.num_heads
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
def forward(self, x):
B, T, _ = x.shape
q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2)
k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
out = out.transpose(1, 2).contiguous().view(B, T, self.num_heads * self.head_dim)
return self.o_proj(out)
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class DecoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.self_attn = GroupedQueryAttention(config)
self.mlp = MLP(config)
self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
def forward(self, x):
residual = x
x = self.input_layernorm(x)
x = self.self_attn(x)
x = residual + x
residual = x
x = self.post_attention_layernorm(x)
x = self.mlp(x)
x = residual + x
return x
class TinyBuddyForCausalLM(PreTrainedModel):
config_class = TinyBuddyConfig
base_model_prefix = "tinybuddy"
def __init__(self, config):
super().__init__(config)
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.lm_head.weight = self.embed_tokens.weight
self.post_init()
def forward(self, input_ids, labels=None, **kwargs):
x = self.embed_tokens(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1))
return CausalLMOutputWithPast(loss=loss, logits=logits)
@torch.no_grad()
def generate(self, input_ids, max_new_tokens=50, temperature=0.8, top_k=50, **kwargs):
for _ in range(max_new_tokens):
logits = self(input_ids).logits[:, -1, :] / temperature
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float("Inf")
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=1)
return input_ids
TinyBuddyForCausalLM.register_for_auto_class("AutoModelForCausalLM")