|
|
import math |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional, Tuple, List |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from .config import ModelConfig |
|
|
|
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
|
"""Rotary Position Embedding (RoPE) - used in LLaMA, GPT-NeoX""" |
|
|
def __init__(self, dim: int, max_seq_len: int = 8192, base: float = 10000.0): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.max_seq_len = max_seq_len |
|
|
self.base = base |
|
|
|
|
|
|
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
|
|
|
self._build_cache(max_seq_len) |
|
|
|
|
|
def _build_cache(self, seq_len: int): |
|
|
"""Precompute cos/sin for given sequence length""" |
|
|
t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq) |
|
|
freqs = torch.outer(t, self.inv_freq) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
self.register_buffer("cos_cached", emb.cos(), persistent=False) |
|
|
self.register_buffer("sin_cached", emb.sin(), persistent=False) |
|
|
self.cached_seq_len = seq_len |
|
|
|
|
|
def forward(self, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Return cos and sin for position embeddings""" |
|
|
if seq_len > self.cached_seq_len: |
|
|
self._build_cache(seq_len) |
|
|
return self.cos_cached[:seq_len], self.sin_cached[:seq_len] |
|
|
|
|
|
|
|
|
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Apply rotary position embedding to queries and keys. |
|
|
|
|
|
Args: |
|
|
q: (B, n_heads, T, d_head) |
|
|
k: (B, n_heads, T, d_head) |
|
|
cos: (T, d_head) |
|
|
sin: (T, d_head) |
|
|
""" |
|
|
|
|
|
cos = cos.unsqueeze(0).unsqueeze(0) |
|
|
sin = sin.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
|
|
|
q_half1, q_half2 = q.chunk(2, dim=-1) |
|
|
k_half1, k_half2 = k.chunk(2, dim=-1) |
|
|
|
|
|
|
|
|
q_rot = torch.cat([ |
|
|
q_half1 * cos - q_half2 * sin, |
|
|
q_half2 * cos + q_half1 * sin |
|
|
], dim=-1) |
|
|
|
|
|
k_rot = torch.cat([ |
|
|
k_half1 * cos - k_half2 * sin, |
|
|
k_half2 * cos + k_half1 * sin |
|
|
], dim=-1) |
|
|
|
|
|
return q_rot, k_rot |
|
|
|
|
|
|
|
|
class MultiHeadSelfAttention(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
n_heads: int, |
|
|
dropout: float, |
|
|
max_seq_len: int = 8192, |
|
|
use_rope: bool = True, |
|
|
use_flash: bool = True |
|
|
): |
|
|
super().__init__() |
|
|
assert d_model % n_heads == 0, "d_model must be divisible by n_heads" |
|
|
|
|
|
self.d_model = d_model |
|
|
self.n_heads = n_heads |
|
|
self.d_head = d_model // n_heads |
|
|
self.use_rope = use_rope |
|
|
self.use_flash = use_flash and hasattr(F, 'scaled_dot_product_attention') |
|
|
|
|
|
|
|
|
self.qkv = nn.Linear(d_model, 3 * d_model, bias=True) |
|
|
self.out_proj = nn.Linear(d_model, d_model, bias=True) |
|
|
|
|
|
|
|
|
self.attn_dropout = nn.Dropout(dropout) |
|
|
self.resid_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
if use_rope: |
|
|
self.rotary_emb = RotaryEmbedding(self.d_head, max_seq_len) |
|
|
|
|
|
|
|
|
if not self.use_flash: |
|
|
self.register_buffer( |
|
|
"causal_mask", |
|
|
torch.tril(torch.ones(max_seq_len, max_seq_len, dtype=torch.bool)), |
|
|
persistent=False |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
use_cache: bool = False |
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
|
|
B, T, C = x.size() |
|
|
|
|
|
|
|
|
qkv = self.qkv(x) |
|
|
q, k, v = qkv.split(self.d_model, dim=-1) |
|
|
|
|
|
|
|
|
q = q.view(B, T, self.n_heads, self.d_head).transpose(1, 2) |
|
|
k = k.view(B, T, self.n_heads, self.d_head).transpose(1, 2) |
|
|
v = v.view(B, T, self.n_heads, self.d_head).transpose(1, 2) |
|
|
|
|
|
|
|
|
if self.use_rope: |
|
|
cos, sin = self.rotary_emb(T) |
|
|
q, k = apply_rotary_pos_emb(q, k, cos, sin) |
|
|
|
|
|
|
|
|
if past_kv is not None: |
|
|
past_k, past_v = past_kv |
|
|
k = torch.cat([past_k, k], dim=2) |
|
|
v = torch.cat([past_v, v], dim=2) |
|
|
|
|
|
present_kv = (k, v) if use_cache else None |
|
|
|
|
|
|
|
|
if self.use_flash: |
|
|
|
|
|
y = F.scaled_dot_product_attention( |
|
|
q, k, v, |
|
|
attn_mask=None, |
|
|
dropout_p=self.attn_dropout.p if self.training else 0.0, |
|
|
is_causal=True |
|
|
) |
|
|
else: |
|
|
|
|
|
att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head) |
|
|
|
|
|
|
|
|
T_q, T_k = q.size(2), k.size(2) |
|
|
causal = self.causal_mask[:T_q, :T_k] |
|
|
att = att.masked_fill(~causal, float("-inf")) |
|
|
|
|
|
|
|
|
if attn_mask is not None: |
|
|
att = att + attn_mask |
|
|
|
|
|
att = F.softmax(att, dim=-1) |
|
|
att = self.attn_dropout(att) |
|
|
y = att @ v |
|
|
|
|
|
|
|
|
y = y.transpose(1, 2).contiguous().view(B, T, C) |
|
|
y = self.out_proj(y) |
|
|
y = self.resid_dropout(y) |
|
|
|
|
|
return y, present_kv |
|
|
|
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
n_heads: int, |
|
|
mlp_ratio: int, |
|
|
dropout: float, |
|
|
max_seq_len: int = 8192, |
|
|
use_rope: bool = True, |
|
|
use_flash: bool = True |
|
|
): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(d_model) |
|
|
self.attn = MultiHeadSelfAttention( |
|
|
d_model, n_heads, dropout, max_seq_len, use_rope, use_flash |
|
|
) |
|
|
self.ln2 = nn.LayerNorm(d_model) |
|
|
|
|
|
|
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(d_model, mlp_ratio * d_model, bias=True), |
|
|
nn.GELU(), |
|
|
nn.Linear(mlp_ratio * d_model, d_model, bias=True), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
attn_mask: Optional[torch.Tensor] = None, |
|
|
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
use_cache: bool = False |
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
|
|
|
|
|
attn_out, present_kv = self.attn(self.ln1(x), attn_mask, past_kv, use_cache) |
|
|
x = x + attn_out |
|
|
x = x + self.mlp(self.ln2(x)) |
|
|
return x, present_kv |
|
|
|
|
|
|
|
|
class SupernovaModel(nn.Module): |
|
|
""" |
|
|
Optimized Transformer Language Model with: |
|
|
- Flash Attention support |
|
|
- Rotary Position Embeddings (RoPE) |
|
|
- KV caching for efficient generation |
|
|
- Gradient checkpointing support |
|
|
- Mixed precision training compatibility |
|
|
""" |
|
|
|
|
|
def __init__(self, cfg: ModelConfig): |
|
|
super().__init__() |
|
|
self.cfg = cfg |
|
|
d = cfg.d_model |
|
|
V = cfg.vocab_size |
|
|
|
|
|
|
|
|
self.tok_emb = nn.Embedding(V, d) |
|
|
|
|
|
|
|
|
use_rope = getattr(cfg, 'use_rope', True) |
|
|
if not use_rope and cfg.use_positional_embedding: |
|
|
self.pos_emb = nn.Embedding(cfg.n_positions, d) |
|
|
else: |
|
|
self.pos_emb = None |
|
|
|
|
|
|
|
|
self.drop = nn.Dropout(cfg.dropout) |
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList([ |
|
|
TransformerBlock( |
|
|
d, |
|
|
cfg.n_heads, |
|
|
cfg.mlp_ratio, |
|
|
cfg.dropout, |
|
|
max_seq_len=getattr(cfg, 'n_positions', 8192), |
|
|
use_rope=use_rope, |
|
|
use_flash=getattr(cfg, 'use_flash', True) |
|
|
) |
|
|
for _ in range(cfg.n_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.ln_f = nn.LayerNorm(d) if cfg.final_layer_norm else nn.Identity() |
|
|
|
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize weights following GPT-2/3 initialization scheme""" |
|
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
nn.init.normal_(module.weight, mean=0.0, std=0.02) |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
nn.init.ones_(module.weight) |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
targets: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None, |
|
|
use_cache: bool = False |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]: |
|
|
""" |
|
|
Forward pass with optional KV caching for efficient generation. |
|
|
|
|
|
Args: |
|
|
input_ids: (B, T) input token indices |
|
|
targets: (B, T) target token indices for loss computation |
|
|
past_key_values: List of (k, v) tuples for each layer (for caching) |
|
|
use_cache: Whether to return present key values |
|
|
|
|
|
Returns: |
|
|
logits: (B, T, V) output logits |
|
|
loss: Optional loss value |
|
|
present_key_values: Optional list of present (k, v) for caching |
|
|
""" |
|
|
B, T = input_ids.shape |
|
|
device = input_ids.device |
|
|
|
|
|
|
|
|
tok = self.tok_emb(input_ids) |
|
|
|
|
|
|
|
|
if self.pos_emb is not None: |
|
|
if past_key_values is not None: |
|
|
|
|
|
pos_offset = past_key_values[0][0].size(2) |
|
|
pos = torch.arange(pos_offset, pos_offset + T, device=device) |
|
|
else: |
|
|
pos = torch.arange(0, T, device=device) |
|
|
|
|
|
assert pos.max() < self.cfg.n_positions, f"Position {pos.max()} exceeds n_positions {self.cfg.n_positions}" |
|
|
pos_emb = self.pos_emb(pos)[None, :, :] |
|
|
x = tok + pos_emb |
|
|
else: |
|
|
x = tok |
|
|
|
|
|
x = self.drop(x) |
|
|
|
|
|
|
|
|
present_key_values = [] if use_cache else None |
|
|
for i, block in enumerate(self.blocks): |
|
|
past_kv = past_key_values[i] if past_key_values is not None else None |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
|
|
|
def create_custom_forward(module): |
|
|
def custom_forward(*inputs): |
|
|
return module(*inputs, use_cache=False) |
|
|
return custom_forward |
|
|
|
|
|
x, _ = torch.utils.checkpoint.checkpoint( |
|
|
create_custom_forward(block), |
|
|
x, |
|
|
None, |
|
|
past_kv, |
|
|
use_reentrant=False |
|
|
) |
|
|
if use_cache: |
|
|
present_key_values.append(None) |
|
|
else: |
|
|
x, present_kv = block(x, attn_mask=None, past_kv=past_kv, use_cache=use_cache) |
|
|
if use_cache: |
|
|
present_key_values.append(present_kv) |
|
|
|
|
|
x = self.ln_f(x) |
|
|
|
|
|
|
|
|
logits = x @ self.tok_emb.weight.T |
|
|
|
|
|
|
|
|
loss = None |
|
|
if targets is not None: |
|
|
|
|
|
logits_ = logits[:, :-1, :].contiguous() |
|
|
targets_ = targets[:, 1:].contiguous() |
|
|
loss = F.cross_entropy( |
|
|
logits_.view(-1, logits_.size(-1)), |
|
|
targets_.view(-1), |
|
|
ignore_index=-100, |
|
|
) |
|
|
|
|
|
return logits, loss, present_key_values |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate( |
|
|
self, |
|
|
idx: torch.Tensor, |
|
|
max_new_tokens: int, |
|
|
temperature: float = 1.0, |
|
|
top_k: Optional[int] = None, |
|
|
top_p: Optional[float] = None, |
|
|
repetition_penalty: float = 1.0, |
|
|
use_cache: bool = True |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Generate text autoregressively with various sampling strategies. |
|
|
|
|
|
Args: |
|
|
idx: (B, T) input token indices |
|
|
max_new_tokens: Number of tokens to generate |
|
|
temperature: Sampling temperature (higher = more random) |
|
|
top_k: Keep only top k logits (None = disabled) |
|
|
top_p: Nucleus sampling threshold (None = disabled) |
|
|
repetition_penalty: Penalty for repeated tokens (1.0 = no penalty) |
|
|
use_cache: Use KV caching for faster generation |
|
|
|
|
|
Returns: |
|
|
(B, T + max_new_tokens) generated token indices |
|
|
""" |
|
|
past_key_values = None |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
|
|
|
if not use_cache or past_key_values is None: |
|
|
max_len = getattr(self.cfg, 'n_positions', 8192) |
|
|
idx_cond = idx if idx.size(1) <= max_len else idx[:, -max_len:] |
|
|
else: |
|
|
|
|
|
idx_cond = idx[:, -1:] |
|
|
|
|
|
|
|
|
logits, _, past_key_values = self( |
|
|
idx_cond, |
|
|
use_cache=use_cache |
|
|
) |
|
|
logits = logits[:, -1, :] |
|
|
|
|
|
|
|
|
if repetition_penalty != 1.0: |
|
|
for i in range(idx.size(0)): |
|
|
for token_id in set(idx[i].tolist()): |
|
|
logits[i, token_id] /= repetition_penalty |
|
|
|
|
|
|
|
|
logits = logits / temperature |
|
|
|
|
|
|
|
|
if top_k is not None: |
|
|
v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
|
|
logits[logits < v[:, [-1]]] = float('-inf') |
|
|
|
|
|
|
|
|
if top_p is not None: |
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True) |
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() |
|
|
sorted_indices_to_remove[:, 0] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) |
|
|
logits[indices_to_remove] = float('-inf') |
|
|
|
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
|
|
|
|
|
|
|
idx = torch.cat([idx, idx_next], dim=1) |
|
|
|
|
|
return idx |
|
|
|
|
|
def num_parameters(self, only_trainable: bool = True) -> int: |
|
|
""" |
|
|
Count model parameters. |
|
|
|
|
|
Args: |
|
|
only_trainable: If True, count only trainable parameters |
|
|
|
|
|
Returns: |
|
|
Total number of parameters |
|
|
""" |
|
|
if only_trainable: |
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
return sum(p.numel() for p in self.parameters()) |
|
|
|
|
|
def parameter_breakdown(self) -> dict: |
|
|
""" |
|
|
Get detailed parameter count by component. |
|
|
|
|
|
Returns: |
|
|
Dictionary with parameter counts for each component |
|
|
""" |
|
|
breakdown = { |
|
|
"token_embeddings": sum(p.numel() for p in self.tok_emb.parameters()), |
|
|
"positional_embeddings": sum(p.numel() for p in self.pos_emb.parameters()) if self.pos_emb else 0, |
|
|
"attention": sum( |
|
|
p.numel() |
|
|
for block in self.blocks |
|
|
for p in block.attn.parameters() |
|
|
), |
|
|
"mlp": sum( |
|
|
p.numel() |
|
|
for block in self.blocks |
|
|
for p in block.mlp.parameters() |
|
|
), |
|
|
"layer_norm": sum( |
|
|
p.numel() |
|
|
for block in self.blocks |
|
|
for p in [block.ln1, block.ln2] |
|
|
) + (sum(p.numel() for p in self.ln_f.parameters()) if self.cfg.final_layer_norm else 0), |
|
|
} |
|
|
breakdown["total"] = sum(breakdown.values()) |
|
|
breakdown["total_trainable"] = self.num_parameters(only_trainable=True) |
|
|
|
|
|
return breakdown |
|
|
|
|
|
def estimate_mfu(self, fwdbwd_per_iter: int, dt: float) -> float: |
|
|
""" |
|
|
Estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS. |
|
|
|
|
|
Args: |
|
|
fwdbwd_per_iter: Number of forward-backward passes per iteration |
|
|
dt: Time taken for iteration (seconds) |
|
|
|
|
|
Returns: |
|
|
MFU as a percentage (0-100) |
|
|
""" |
|
|
N = self.num_parameters() |
|
|
cfg = self.cfg |
|
|
L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.d_model // cfg.n_heads, cfg.n_positions |
|
|
|
|
|
|
|
|
|
|
|
flops_per_token = 6 * N + 12 * L * H * Q * T |
|
|
flops_per_fwdbwd = flops_per_token * T * fwdbwd_per_iter * 3 |
|
|
flops_per_iter = flops_per_fwdbwd |
|
|
|
|
|
|
|
|
flops_achieved = flops_per_iter / dt |
|
|
flops_promised = 312e12 |
|
|
|
|
|
mfu = flops_achieved / flops_promised * 100 |
|
|
return mfu |
|
|
|
|
|
def configure_optimizers( |
|
|
self, |
|
|
weight_decay: float, |
|
|
learning_rate: float, |
|
|
betas: Tuple[float, float], |
|
|
device_type: str |
|
|
): |
|
|
""" |
|
|
Configure optimizer with weight decay only on specific parameters. |
|
|
|
|
|
Args: |
|
|
weight_decay: L2 regularization coefficient |
|
|
learning_rate: Learning rate |
|
|
betas: Adam beta parameters |
|
|
device_type: 'cuda' or 'cpu' |
|
|
|
|
|
Returns: |
|
|
Configured AdamW optimizer |
|
|
""" |
|
|
|
|
|
decay = set() |
|
|
no_decay = set() |
|
|
|
|
|
whitelist_weight_modules = (nn.Linear,) |
|
|
blacklist_weight_modules = (nn.LayerNorm, nn.Embedding) |
|
|
|
|
|
for mn, m in self.named_modules(): |
|
|
for pn, p in m.named_parameters(): |
|
|
fpn = f'{mn}.{pn}' if mn else pn |
|
|
|
|
|
if pn.endswith('bias'): |
|
|
no_decay.add(fpn) |
|
|
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): |
|
|
decay.add(fpn) |
|
|
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): |
|
|
no_decay.add(fpn) |
|
|
|
|
|
|
|
|
param_dict = {pn: p for pn, p in self.named_parameters()} |
|
|
inter_params = decay & no_decay |
|
|
union_params = decay | no_decay |
|
|
assert len(inter_params) == 0, f"Parameters in both decay/no_decay: {inter_params}" |
|
|
assert len(param_dict.keys() - union_params) == 0, f"Missing parameters: {param_dict.keys() - union_params}" |
|
|
|
|
|
|
|
|
optim_groups = [ |
|
|
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, |
|
|
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, |
|
|
] |
|
|
|
|
|
|
|
|
use_fused = device_type == 'cuda' |
|
|
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused) |
|
|
|
|
|
return optimizer |