synth-gpt-110m / model.py
ethanthoma's picture
Upload model.py with huggingface_hub
2b5f127 verified
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
try:
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func # pyright: ignore[reportMissingImports]
FLASH_ATTN_AVAILABLE = True
except ImportError:
FLASH_ATTN_AVAILABLE = False
flash_attn_func = None
flash_attn_varlen_func = None
class Rotary(nn.Module):
cos_cached: Tensor
sin_cached: Tensor
def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len)
freqs = torch.outer(t, inv_freq)
self.register_buffer("cos_cached", freqs.cos().bfloat16(), persistent=False)
self.register_buffer("sin_cached", freqs.sin().bfloat16(), persistent=False)
def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
seq_len = x.shape[1]
return self.cos_cached[None, :seq_len, None, :], self.sin_cached[None, :seq_len, None, :]
def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor:
assert x.ndim == 4 # [batch, seq_len, n_heads, head_dim]
d: int = x.shape[3] // 2
x1 = x[..., :d]
x2 = x[..., d:]
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat([y1, y2], 3).type_as(x)
class CausalSelfAttention(nn.Module):
def __init__(self, config: "GPTConfig") -> None:
super().__init__()
self.n_head: int = config.n_head
self.n_embd: int = config.n_embd
self.head_dim: int = self.n_embd // self.n_head
assert self.n_embd % self.n_head == 0
self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
self.c_proj.weight.data.zero_()
self.rotary = Rotary(self.head_dim, max_seq_len=config.sequence_length)
def forward(self, x: Tensor, cu_seqlens: Optional[Tensor] = None, max_seqlen: Optional[int] = None) -> Tensor:
assert x.ndim == 3, f"x must be 3D, got shape {x.shape}"
B, T, C = x.size()
assert C == self.n_embd, f"hidden dim mismatch: {C} != {self.n_embd}"
assert B > 0 and T > 0, f"batch and seq length must be > 0: B={B}, T={T}"
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
k = self.c_k(x).view(B, T, self.n_head, self.head_dim)
v = self.c_v(x).view(B, T, self.n_head, self.head_dim)
assert q.shape == (B, T, self.n_head, self.head_dim), f"q shape mismatch: {q.shape}"
cos, sin = self.rotary(q)
q = F.rms_norm(q, (q.size(-1),))
k = F.rms_norm(k, (k.size(-1),))
q = apply_rotary_emb(q, cos, sin)
k = apply_rotary_emb(k, cos, sin)
use_flash = FLASH_ATTN_AVAILABLE and x.is_cuda
if use_flash and flash_attn_varlen_func is not None and cu_seqlens is not None:
q_flat = q.reshape(-1, self.n_head, self.head_dim)
k_flat = k.reshape(-1, self.n_head, self.head_dim)
v_flat = v.reshape(-1, self.n_head, self.head_dim)
# Use pre-computed max_seqlen from dataloader (avoids .item() graph break)
seqlen: int = max_seqlen if max_seqlen is not None else int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item())
y_flat = flash_attn_varlen_func(
q_flat,
k_flat,
v_flat,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens,
max_seqlen_q=seqlen,
max_seqlen_k=seqlen,
causal=True,
)
y = y_flat.reshape(B, T, C)
elif use_flash and flash_attn_func is not None:
y = flash_attn_func(q, k, v, causal=True)
y = y.contiguous().view_as(x)
else:
y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True)
y = y.transpose(1, 2).contiguous().view_as(x)
y = self.c_proj(y)
return y
class MLP(nn.Module):
def __init__(self, config: "GPTConfig") -> None:
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
self.c_proj.weight.data.zero_()
def forward(self, x: Tensor) -> Tensor:
x = self.c_fc(x)
x = F.relu(x).square()
x = self.c_proj(x)
return x
class Block(nn.Module):
def __init__(self, config: "GPTConfig") -> None:
super().__init__()
self.attn = CausalSelfAttention(config)
self.mlp = MLP(config)
def forward(self, x: Tensor, cu_seqlens: Optional[Tensor] = None, max_seqlen: Optional[int] = None) -> Tensor:
x = x + self.attn(F.rms_norm(x, (x.size(-1),)), cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
x = x + self.mlp(F.rms_norm(x, (x.size(-1),)))
return x
@dataclass
class GPTConfig:
vocab_size: int = 32256
n_layer: int = 12
n_head: int = 12
n_embd: int = 768
sequence_length: int = 1024
class Transformer(nn.Module):
wte: nn.Embedding
h: nn.ModuleList
def __init__(self, config: GPTConfig):
super().__init__()
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
class GPT(nn.Module):
config: GPTConfig
transformer: Transformer
lm_head: nn.Linear
def __init__(self, config: GPTConfig):
super().__init__()
self.config = config
self.transformer = Transformer(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.transformer.wte.weight = self.lm_head.weight
def forward(
self,
idx: Tensor,
targets: Optional[Tensor] = None,
return_logits: bool = True,
return_hidden: bool = False,
cu_seqlens: Optional[Tensor] = None,
max_seqlen: Optional[int] = None,
) -> tuple[Optional[Tensor], Optional[Tensor]] | tuple[Optional[Tensor], Optional[Tensor], Tensor]:
assert idx.ndim == 2, f"idx must be 2D, got shape {idx.shape}"
B, T = idx.shape
assert B > 0 and T > 0, f"batch and seq length must be > 0: B={B}, T={T}"
if targets is not None:
assert targets.shape == idx.shape, f"targets shape {targets.shape} != idx shape {idx.shape}"
x = self.transformer.wte(idx)
assert x.shape == (B, T, self.config.n_embd), f"embedding output shape mismatch: {x.shape}"
for block in self.transformer.h:
x = block(x, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
assert x.shape == (B, T, self.config.n_embd), f"block output shape mismatch: {x.shape}"
hidden = F.rms_norm(x, (x.size(-1),))
assert hidden.shape == x.shape, f"rms_norm shape mismatch: {hidden.shape}"
if targets is not None:
logits = self.lm_head(hidden)
assert logits.shape == (B, T, self.config.vocab_size), f"logits shape mismatch: {logits.shape}"
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
assert loss.ndim == 0, f"loss must be scalar, got shape {loss.shape}"
else:
if return_logits:
logits = self.lm_head(hidden)
else:
logits = self.lm_head(hidden[:, [-1], :])
loss = None
if not return_logits:
logits = None
if return_hidden:
return logits, loss, hidden
return logits, loss
def get_num_params(self) -> int:
return sum(p.numel() for p in self.parameters())
class StackedGPT(nn.Module):
def __init__(self, model1: GPT, model2: GPT) -> None:
super().__init__()
self.model1 = model1
self.model2 = model2
assert model1.config.vocab_size == model2.config.vocab_size
assert model1.config.n_embd == model2.config.n_embd
self.model2.transformer.wte = self.model1.transformer.wte
self.model2.lm_head = self.model1.lm_head
def forward(
self,
idx: Tensor,
targets: Optional[Tensor] = None,
return_logits: bool = True,
cu_seqlens: Optional[Tensor] = None,
) -> tuple[Optional[Tensor], Optional[Tensor]]:
logits1, _ = self.model1(idx, targets=targets, return_logits=True, cu_seqlens=cu_seqlens)
logits2, _ = self.model2(idx, targets=targets, return_logits=True, cu_seqlens=cu_seqlens)
logits = (logits1 + logits2) / 2.0
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
loss = None
if not return_logits:
logits = None
return logits, loss
def generate_expert_vectors(n_experts: int, embed_dim: int, seed: int = 42) -> torch.Tensor:
torch.manual_seed(seed)
vectors = torch.randn(n_experts, embed_dim)
vectors = F.normalize(vectors, p=2, dim=1)
return vectors
class MoEGPT(nn.Module):
expert_vectors: nn.Parameter
expert_models: list[GPT]
top_k: int
temperature: float
config: GPTConfig
def __init__(self, expert_vectors: Tensor, *models: GPT, top_k: int = 2, temperature: float = 20.0):
super().__init__()
self.expert_vectors = nn.Parameter(expert_vectors, requires_grad=False)
self.expert_models = list(models)
self.models = nn.ModuleList(models) # For state_dict compatibility
self.top_k = top_k
self.temperature = temperature
self.config = models[0].config
for model in models[1:]:
assert model.config.vocab_size == self.config.vocab_size
assert model.config.n_embd == self.config.n_embd
def forward(
self,
idx: Tensor,
targets: Optional[Tensor] = None,
return_logits: bool = True,
cu_seqlens: Optional[Tensor] = None,
max_seqlen: Optional[int] = None,
) -> tuple[Optional[Tensor], Optional[Tensor]]:
B, T = idx.size()
vocab_size = self.config.vocab_size
token_embeds = self.expert_models[0].transformer.wte(idx)
token_embeds_flat = token_embeds.reshape(-1, self.config.n_embd)
token_embeds_norm = F.normalize(token_embeds_flat, p=2, dim=1)
cosine_similarities = torch.matmul(token_embeds_norm, self.expert_vectors.T)
# Apply temperature scaling before topk/softmax
scaled_similarities = cosine_similarities * self.temperature
top_k_similarities, top_k_indices = torch.topk(scaled_similarities, self.top_k, dim=-1)
top_k_weights = F.softmax(top_k_similarities, dim=-1)
# Process experts sequentially to avoid OOM
output: Optional[Tensor] = None
for expert_id, expert_model in enumerate(self.expert_models):
logits, _ = expert_model(
idx, targets=None, return_logits=True, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen
)
if output is None:
output = torch.zeros_like(logits)
for k in range(self.top_k):
routing_mask = (top_k_indices[:, k] == expert_id).float().view(B, T, 1)
expert_weight = top_k_weights[:, k].view(B, T, 1)
output = output + logits * routing_mask * expert_weight
del logits
assert output is not None, "No experts to process"
if targets is not None:
loss = F.cross_entropy(output.view(-1, vocab_size), targets.view(-1), ignore_index=-1)
else:
loss = None
if not return_logits:
return None, loss
return output, loss