algorythmtechnologies's picture
Update supernova/model.py
f89b6c1 verified
import math
from dataclasses import dataclass
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from .config import ModelConfig
class RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE) - used in LLaMA, GPT-NeoX"""
def __init__(self, dim: int, max_seq_len: int = 8192, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
# Precompute frequencies
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build cache for efficiency
self._build_cache(max_seq_len)
def _build_cache(self, seq_len: int):
"""Precompute cos/sin for given sequence length"""
t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
self.cached_seq_len = seq_len
def forward(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return cos and sin for position embeddings"""
if seq_len > self.cached_seq_len:
self._build_cache(seq_len)
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary position embedding to queries and keys.
Args:
q: (B, n_heads, T, d_head)
k: (B, n_heads, T, d_head)
cos: (T, d_head)
sin: (T, d_head)
"""
# Reshape for broadcasting
cos = cos.unsqueeze(0).unsqueeze(0) # (1, 1, T, d_head)
sin = sin.unsqueeze(0).unsqueeze(0)
# Split into first and second half
q_half1, q_half2 = q.chunk(2, dim=-1)
k_half1, k_half2 = k.chunk(2, dim=-1)
# Apply rotation
q_rot = torch.cat([
q_half1 * cos - q_half2 * sin,
q_half2 * cos + q_half1 * sin
], dim=-1)
k_rot = torch.cat([
k_half1 * cos - k_half2 * sin,
k_half2 * cos + k_half1 * sin
], dim=-1)
return q_rot, k_rot
class MultiHeadSelfAttention(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
dropout: float,
max_seq_len: int = 8192,
use_rope: bool = True,
use_flash: bool = True
):
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
self.use_rope = use_rope
self.use_flash = use_flash and hasattr(F, 'scaled_dot_product_attention')
# QKV projection
self.qkv = nn.Linear(d_model, 3 * d_model, bias=True)
self.out_proj = nn.Linear(d_model, d_model, bias=True)
# Dropout
self.attn_dropout = nn.Dropout(dropout)
self.resid_dropout = nn.Dropout(dropout)
# Rotary embeddings
if use_rope:
self.rotary_emb = RotaryEmbedding(self.d_head, max_seq_len)
# Causal mask (fallback for non-flash attention)
if not self.use_flash:
self.register_buffer(
"causal_mask",
torch.tril(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)),
persistent=False
)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
B, T, C = x.size()
# Compute QKV
qkv = self.qkv(x) # (B, T, 3*C)
q, k, v = qkv.split(self.d_model, dim=-1)
# Reshape to (B, n_heads, T, d_head)
q = q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
k = k.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
v = v.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
# Apply rotary embeddings
if self.use_rope:
cos, sin = self.rotary_emb(T)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
# KV cache for inference
if past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
present_kv = (k, v) if use_cache else None
# Compute attention
if self.use_flash:
# Use PyTorch's optimized Flash Attention
y = F.scaled_dot_product_attention(
q, k, v,
attn_mask=None,
dropout_p=self.attn_dropout.p if self.training else 0.0,
is_causal=True
)
else:
# Fallback: manual attention computation
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)
# Apply causal mask
T_q, T_k = q.size(2), k.size(2)
causal = self.causal_mask[:T_q, :T_k]
att = att.masked_fill(~causal, float("-inf"))
# Apply additional mask if provided
if attn_mask is not None:
att = att + attn_mask
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, n_heads, T, d_head)
# Reshape and project output
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.out_proj(y)
y = self.resid_dropout(y)
return y, present_kv
class TransformerBlock(nn.Module):
def __init__(
self,
d_model: int,
n_heads: int,
mlp_ratio: int,
dropout: float,
max_seq_len: int = 8192,
use_rope: bool = True,
use_flash: bool = True
):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.attn = MultiHeadSelfAttention(
d_model, n_heads, dropout, max_seq_len, use_rope, use_flash
)
self.ln2 = nn.LayerNorm(d_model)
# MLP with GELU activation (SwiGLU would be even better)
self.mlp = nn.Sequential(
nn.Linear(d_model, mlp_ratio * d_model, bias=True),
nn.GELU(),
nn.Linear(mlp_ratio * d_model, d_model, bias=True),
nn.Dropout(dropout),
)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Pre-LayerNorm architecture
attn_out, present_kv = self.attn(self.ln1(x), attn_mask, past_kv, use_cache)
x = x + attn_out
x = x + self.mlp(self.ln2(x))
return x, present_kv
class SupernovaModel(nn.Module):
"""
Optimized Transformer Language Model with:
- Flash Attention support
- Rotary Position Embeddings (RoPE)
- KV caching for efficient generation
- Gradient checkpointing support
- Mixed precision training compatibility
"""
def __init__(self, cfg: ModelConfig):
super().__init__()
self.cfg = cfg
d = cfg.d_model
V = cfg.vocab_size
# Token embeddings
self.tok_emb = nn.Embedding(V, d)
# Optional learned positional embeddings (if not using RoPE)
use_rope = getattr(cfg, 'use_rope', True)
if not use_rope and cfg.use_positional_embedding:
self.pos_emb = nn.Embedding(cfg.n_positions, d)
else:
self.pos_emb = None
# Dropout
self.drop = nn.Dropout(cfg.dropout)
# Transformer blocks
self.blocks = nn.ModuleList([
TransformerBlock(
d,
cfg.n_heads,
cfg.mlp_ratio,
cfg.dropout,
max_seq_len=getattr(cfg, 'n_positions', 8192),
use_rope=use_rope,
use_flash=getattr(cfg, 'use_flash', True)
)
for _ in range(cfg.n_layers)
])
# Final layer norm
self.ln_f = nn.LayerNorm(d) if cfg.final_layer_norm else nn.Identity()
# Gradient checkpointing flag (set during training)
self.gradient_checkpointing = False
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize weights following GPT-2/3 initialization scheme"""
if isinstance(module, nn.Linear):
# Use normal distribution with std=0.02
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
def forward(
self,
input_ids: torch.Tensor,
targets: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
"""
Forward pass with optional KV caching for efficient generation.
Args:
input_ids: (B, T) input token indices
targets: (B, T) target token indices for loss computation
past_key_values: List of (k, v) tuples for each layer (for caching)
use_cache: Whether to return present key values
Returns:
logits: (B, T, V) output logits
loss: Optional loss value
present_key_values: Optional list of present (k, v) for caching
"""
B, T = input_ids.shape
device = input_ids.device
# Compute embeddings
tok = self.tok_emb(input_ids) # (B, T, d)
# Add positional embeddings if using learned positions (not RoPE)
if self.pos_emb is not None:
if past_key_values is not None:
# During generation with cache, only process new position
pos_offset = past_key_values[0][0].size(2)
pos = torch.arange(pos_offset, pos_offset + T, device=device)
else:
pos = torch.arange(0, T, device=device)
assert pos.max() < self.cfg.n_positions, f"Position {pos.max()} exceeds n_positions {self.cfg.n_positions}"
pos_emb = self.pos_emb(pos)[None, :, :] # (1, T, d)
x = tok + pos_emb
else:
x = tok
x = self.drop(x)
# Pass through transformer blocks
present_key_values = [] if use_cache else None
for i, block in enumerate(self.blocks):
past_kv = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
# Use gradient checkpointing to save memory
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, use_cache=False)
return custom_forward
x, _ = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
x,
None, # attn_mask
past_kv,
use_reentrant=False
)
if use_cache:
present_key_values.append(None) # Placeholder
else:
x, present_kv = block(x, attn_mask=None, past_kv=past_kv, use_cache=use_cache)
if use_cache:
present_key_values.append(present_kv)
x = self.ln_f(x)
# Compute logits via tied embeddings
logits = x @ self.tok_emb.weight.T # (B, T, V)
# Compute loss if targets provided
loss = None
if targets is not None:
# Shift for next-token prediction
logits_ = logits[:, :-1, :].contiguous()
targets_ = targets[:, 1:].contiguous()
loss = F.cross_entropy(
logits_.view(-1, logits_.size(-1)),
targets_.view(-1),
ignore_index=-100,
)
return logits, loss, present_key_values
@torch.no_grad()
def generate(
self,
idx: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
repetition_penalty: float = 1.0,
use_cache: bool = True
) -> torch.Tensor:
"""
Generate text autoregressively with various sampling strategies.
Args:
idx: (B, T) input token indices
max_new_tokens: Number of tokens to generate
temperature: Sampling temperature (higher = more random)
top_k: Keep only top k logits (None = disabled)
top_p: Nucleus sampling threshold (None = disabled)
repetition_penalty: Penalty for repeated tokens (1.0 = no penalty)
use_cache: Use KV caching for faster generation
Returns:
(B, T + max_new_tokens) generated token indices
"""
past_key_values = None
for _ in range(max_new_tokens):
# Crop context if needed (only when not using cache)
if not use_cache or past_key_values is None:
max_len = getattr(self.cfg, 'n_positions', 8192)
idx_cond = idx if idx.size(1) <= max_len else idx[:, -max_len:]
else:
# With cache, only process the last token
idx_cond = idx[:, -1:]
# Forward pass
logits, _, past_key_values = self(
idx_cond,
use_cache=use_cache
)
logits = logits[:, -1, :] # (B, V)
# Apply repetition penalty
if repetition_penalty != 1.0:
for i in range(idx.size(0)):
for token_id in set(idx[i].tolist()):
logits[i, token_id] /= repetition_penalty
# Apply temperature
logits = logits / temperature
# Top-k filtering
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float('-inf')
# Nucleus (top-p) sampling
if top_p is not None:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above threshold
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
sorted_indices_to_remove[:, 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
# Sample next token
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
# Append to sequence
idx = torch.cat([idx, idx_next], dim=1)
return idx
def num_parameters(self, only_trainable: bool = True) -> int:
"""
Count model parameters.
Args:
only_trainable: If True, count only trainable parameters
Returns:
Total number of parameters
"""
if only_trainable:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
return sum(p.numel() for p in self.parameters())
def parameter_breakdown(self) -> dict:
"""
Get detailed parameter count by component.
Returns:
Dictionary with parameter counts for each component
"""
breakdown = {
"token_embeddings": sum(p.numel() for p in self.tok_emb.parameters()),
"positional_embeddings": sum(p.numel() for p in self.pos_emb.parameters()) if self.pos_emb else 0,
"attention": sum(
p.numel()
for block in self.blocks
for p in block.attn.parameters()
),
"mlp": sum(
p.numel()
for block in self.blocks
for p in block.mlp.parameters()
),
"layer_norm": sum(
p.numel()
for block in self.blocks
for p in [block.ln1, block.ln2]
) + (sum(p.numel() for p in self.ln_f.parameters()) if self.cfg.final_layer_norm else 0),
}
breakdown["total"] = sum(breakdown.values())
breakdown["total_trainable"] = self.num_parameters(only_trainable=True)
return breakdown
def estimate_mfu(self, fwdbwd_per_iter: int, dt: float) -> float:
"""
Estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS.
Args:
fwdbwd_per_iter: Number of forward-backward passes per iteration
dt: Time taken for iteration (seconds)
Returns:
MFU as a percentage (0-100)
"""
N = self.num_parameters()
cfg = self.cfg
L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.d_model // cfg.n_heads, cfg.n_positions
# Estimate FLOPs per token (forward pass only)
# Approximation: 6N + 12LHQ*T (attention dominates)
flops_per_token = 6 * N + 12 * L * H * Q * T
flops_per_fwdbwd = flops_per_token * T * fwdbwd_per_iter * 3 # 3x for backward pass
flops_per_iter = flops_per_fwdbwd
# A100 bfloat16 peak FLOPS
flops_achieved = flops_per_iter / dt
flops_promised = 312e12 # A100 GPU bfloat16 peak
mfu = flops_achieved / flops_promised * 100
return mfu
def configure_optimizers(
self,
weight_decay: float,
learning_rate: float,
betas: Tuple[float, float],
device_type: str
):
"""
Configure optimizer with weight decay only on specific parameters.
Args:
weight_decay: L2 regularization coefficient
learning_rate: Learning rate
betas: Adam beta parameters
device_type: 'cuda' or 'cpu'
Returns:
Configured AdamW optimizer
"""
# Separate parameters that should and shouldn't have weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (nn.Linear,)
blacklist_weight_modules = (nn.LayerNorm, nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = f'{mn}.{pn}' if mn else pn
if pn.endswith('bias'):
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
no_decay.add(fpn)
# Validate that we've covered all parameters
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, f"Parameters in both decay/no_decay: {inter_params}"
assert len(param_dict.keys() - union_params) == 0, f"Missing parameters: {param_dict.keys() - union_params}"
# Create optimizer groups
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
# Use fused AdamW if on CUDA for better performance
use_fused = device_type == 'cuda'
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused)
return optimizer