sathishphdai's picture
Upload folder using huggingface_hub
7d0c7c5 verified
#!/usr/bin/env python3
"""
model.py — Role SLM Transformer (~1B params) with RoPE + Gradient Checkpointing
================================================================================
Supports context lengths up to 1M tokens via:
* RoPE (no fixed position embedding table)
* RMSNorm (more efficient than LayerNorm)
* SwiGLU activation (better training dynamics)
* Flash Attention via PyTorch scaled_dot_product_attention
* Gradient checkpointing for memory-efficient training on 24GB
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as grad_checkpoint
from typing import Optional, Tuple
from config import cfg
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
return (x.float() * norm).type_as(x) * self.weight
def precompute_rope_freqs(dim, max_seq_len, theta=10000.0, device=None):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim))
t = torch.arange(max_seq_len, device=device).float()
freqs = torch.outer(t, freqs)
return freqs.cos(), freqs.sin()
def apply_rope(x, cos, sin):
seq_len = x.shape[2]
head_dim = x.shape[3]
cos = cos[:seq_len].unsqueeze(0).unsqueeze(0)
sin = sin[:seq_len].unsqueeze(0).unsqueeze(0)
x1 = x[..., :head_dim // 2]
x2 = x[..., head_dim // 2:]
return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
class CausalSelfAttention(nn.Module):
def __init__(self):
super().__init__()
assert cfg.n_embd % cfg.n_head == 0
self.n_head = cfg.n_head
self.head_dim = cfg.n_embd // cfg.n_head
self.q_proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False)
self.k_proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False)
self.v_proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False)
self.out_proj = nn.Linear(cfg.n_embd, cfg.n_embd, bias=False)
self.resid_drop = nn.Dropout(cfg.dropout)
def forward(self, x, rope_cos, rope_sin):
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_head, self.head_dim).transpose(1, 2)
q = apply_rope(q, rope_cos, rope_sin)
k = apply_rope(k, rope_cos, rope_sin)
if hasattr(F, 'scaled_dot_product_attention'):
y = F.scaled_dot_product_attention(q, k, v,
dropout_p=cfg.dropout if self.training else 0.0, is_causal=True)
else:
scale = 1.0 / math.sqrt(self.head_dim)
att = (q @ k.transpose(-2, -1)) * scale
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
att = att.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
att = F.softmax(att, dim=-1)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.resid_drop(self.out_proj(y))
class SwiGLUFFN(nn.Module):
def __init__(self):
super().__init__()
hidden_dim = int(cfg.n_embd * getattr(cfg, 'ffn_multiplier', 2.667))
hidden_dim = ((hidden_dim + 63) // 64) * 64
self.gate_proj = nn.Linear(cfg.n_embd, hidden_dim, bias=False)
self.up_proj = nn.Linear(cfg.n_embd, hidden_dim, bias=False)
self.down_proj = nn.Linear(hidden_dim, cfg.n_embd, bias=False)
self.dropout = nn.Dropout(cfg.dropout)
def forward(self, x):
return self.dropout(self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)))
class TransformerBlock(nn.Module):
def __init__(self):
super().__init__()
self.attn_norm = RMSNorm(cfg.n_embd)
self.attn = CausalSelfAttention()
self.ffn_norm = RMSNorm(cfg.n_embd)
self.ffn = SwiGLUFFN()
def forward(self, x, rope_cos, rope_sin):
x = x + self.attn(self.attn_norm(x), rope_cos, rope_sin)
x = x + self.ffn(self.ffn_norm(x))
return x
class RoleSLM(nn.Module):
"""Role-Based Small Language Model — ~1B params, LLaMA-style with gradient checkpointing."""
def __init__(self):
super().__init__()
self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.n_embd)
self.drop = nn.Dropout(cfg.dropout)
self.blocks = nn.ModuleList([TransformerBlock() for _ in range(cfg.n_layer)])
self.norm = RMSNorm(cfg.n_embd)
self.lm_head = nn.Linear(cfg.n_embd, cfg.vocab_size, bias=False)
self.tok_emb.weight = self.lm_head.weight # Weight tying
self.use_checkpointing = getattr(cfg, 'gradient_checkpointing', True)
head_dim = cfg.n_embd // cfg.n_head
max_pos = getattr(cfg, 'max_position_embeddings', 1_000_000)
rope_theta = getattr(cfg, 'rope_theta', 10000.0)
precompute_len = min(max_pos, cfg.block_size * 2)
cos, sin = precompute_rope_freqs(head_dim, precompute_len, theta=rope_theta)
self.register_buffer("rope_cos", cos, persistent=False)
self.register_buffer("rope_sin", sin, persistent=False)
self._rope_max_len = precompute_len
self._rope_theta = rope_theta
self._head_dim = head_dim
self.apply(self._init_weights)
n_params = sum(p.numel() for p in self.parameters())
print(f"{cfg.domain_name}-SLM initialized: {n_params/1e6:.2f}M parameters ({n_params/1e9:.3f}B)")
print(f" Architecture: {cfg.n_layer}L / {cfg.n_head}H / {cfg.n_embd}D")
print(f" Gradient checkpointing: {self.use_checkpointing}")
print(f" Max context: {max_pos:,} tokens (via RoPE)")
print(f" Estimated model size: {n_params * 4 / 1e9:.2f} GB (fp32)")
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def _extend_rope(self, seq_len, device):
if seq_len > self._rope_max_len:
new_len = max(seq_len, self._rope_max_len * 2)
cos, sin = precompute_rope_freqs(self._head_dim, new_len,
theta=self._rope_theta, device=device)
self.rope_cos = cos
self.rope_sin = sin
self._rope_max_len = new_len
def _block_forward(self, block, x, rope_cos, rope_sin):
"""Wrapper for gradient checkpointing."""
return block(x, rope_cos, rope_sin)
def forward(self, idx, targets=None):
B, T = idx.shape
device = idx.device
self._extend_rope(T, device)
x = self.drop(self.tok_emb(idx))
rope_cos = self.rope_cos[:T].to(device)
rope_sin = self.rope_sin[:T].to(device)
for block in self.blocks:
if self.use_checkpointing and self.training:
x = grad_checkpoint(self._block_forward, block, x, rope_cos, rope_sin,
use_reentrant=False)
else:
x = block(x, rope_cos, rope_sin)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=0.8, top_k=50, top_p=0.9):
self.use_checkpointing = False # No checkpointing during generation
for _ in range(max_new_tokens):
idx_cond = idx if idx.size(1) <= cfg.block_size else idx[:, -cfg.block_size:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :]
if temperature == 0:
idx_next = logits.argmax(dim=-1, keepdim=True)
else:
logits = logits / temperature
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, idx_next], dim=1)
if idx_next.item() == 3: # <eos>
break
self.use_checkpointing = getattr(cfg, 'gradient_checkpointing', True)
return idx
def count_parameters(self):
return sum(p.numel() for p in self.parameters())
if __name__ == "__main__":
model = RoleSLM()
x = torch.randint(0, cfg.vocab_size, (1, 32))
logits, loss = model(x, x)
print(f"Test forward: logits={logits.shape}, loss={loss.item():.4f}")