|
|
from dataclasses import dataclass |
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch import Tensor |
|
|
|
|
|
try: |
|
|
from flash_attn_interface import flash_attn_func, flash_attn_varlen_func |
|
|
|
|
|
FLASH_ATTN_AVAILABLE = True |
|
|
except ImportError: |
|
|
FLASH_ATTN_AVAILABLE = False |
|
|
flash_attn_func = None |
|
|
flash_attn_varlen_func = None |
|
|
|
|
|
|
|
|
class Rotary(nn.Module): |
|
|
cos_cached: Tensor |
|
|
sin_cached: Tensor |
|
|
|
|
|
def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000): |
|
|
super().__init__() |
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
|
|
t = torch.arange(max_seq_len) |
|
|
freqs = torch.outer(t, inv_freq) |
|
|
self.register_buffer("cos_cached", freqs.cos().bfloat16(), persistent=False) |
|
|
self.register_buffer("sin_cached", freqs.sin().bfloat16(), persistent=False) |
|
|
|
|
|
def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: |
|
|
seq_len = x.shape[1] |
|
|
return self.cos_cached[None, :seq_len, None, :], self.sin_cached[None, :seq_len, None, :] |
|
|
|
|
|
|
|
|
def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor: |
|
|
assert x.ndim == 4 |
|
|
d: int = x.shape[3] // 2 |
|
|
x1 = x[..., :d] |
|
|
x2 = x[..., d:] |
|
|
y1 = x1 * cos + x2 * sin |
|
|
y2 = x1 * (-sin) + x2 * cos |
|
|
return torch.cat([y1, y2], 3).type_as(x) |
|
|
|
|
|
|
|
|
class CausalSelfAttention(nn.Module): |
|
|
def __init__(self, config: "GPTConfig") -> None: |
|
|
super().__init__() |
|
|
self.n_head: int = config.n_head |
|
|
self.n_embd: int = config.n_embd |
|
|
self.head_dim: int = self.n_embd // self.n_head |
|
|
assert self.n_embd % self.n_head == 0 |
|
|
|
|
|
self.c_q = nn.Linear(self.n_embd, self.n_embd, bias=False) |
|
|
self.c_k = nn.Linear(self.n_embd, self.n_embd, bias=False) |
|
|
self.c_v = nn.Linear(self.n_embd, self.n_embd, bias=False) |
|
|
|
|
|
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False) |
|
|
self.c_proj.weight.data.zero_() |
|
|
|
|
|
self.rotary = Rotary(self.head_dim, max_seq_len=config.sequence_length) |
|
|
|
|
|
def forward(self, x: Tensor, cu_seqlens: Optional[Tensor] = None, max_seqlen: Optional[int] = None) -> Tensor: |
|
|
assert x.ndim == 3, f"x must be 3D, got shape {x.shape}" |
|
|
B, T, C = x.size() |
|
|
assert C == self.n_embd, f"hidden dim mismatch: {C} != {self.n_embd}" |
|
|
assert B > 0 and T > 0, f"batch and seq length must be > 0: B={B}, T={T}" |
|
|
|
|
|
q = self.c_q(x).view(B, T, self.n_head, self.head_dim) |
|
|
k = self.c_k(x).view(B, T, self.n_head, self.head_dim) |
|
|
v = self.c_v(x).view(B, T, self.n_head, self.head_dim) |
|
|
assert q.shape == (B, T, self.n_head, self.head_dim), f"q shape mismatch: {q.shape}" |
|
|
|
|
|
cos, sin = self.rotary(q) |
|
|
q = F.rms_norm(q, (q.size(-1),)) |
|
|
k = F.rms_norm(k, (k.size(-1),)) |
|
|
q = apply_rotary_emb(q, cos, sin) |
|
|
k = apply_rotary_emb(k, cos, sin) |
|
|
|
|
|
use_flash = FLASH_ATTN_AVAILABLE and x.is_cuda |
|
|
if use_flash and flash_attn_varlen_func is not None and cu_seqlens is not None: |
|
|
q_flat = q.reshape(-1, self.n_head, self.head_dim) |
|
|
k_flat = k.reshape(-1, self.n_head, self.head_dim) |
|
|
v_flat = v.reshape(-1, self.n_head, self.head_dim) |
|
|
|
|
|
|
|
|
seqlen: int = max_seqlen if max_seqlen is not None else int((cu_seqlens[1:] - cu_seqlens[:-1]).max().item()) |
|
|
|
|
|
y_flat = flash_attn_varlen_func( |
|
|
q_flat, |
|
|
k_flat, |
|
|
v_flat, |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=seqlen, |
|
|
max_seqlen_k=seqlen, |
|
|
causal=True, |
|
|
) |
|
|
y = y_flat.reshape(B, T, C) |
|
|
elif use_flash and flash_attn_func is not None: |
|
|
y = flash_attn_func(q, k, v, causal=True) |
|
|
y = y.contiguous().view_as(x) |
|
|
else: |
|
|
y = F.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), is_causal=True) |
|
|
y = y.transpose(1, 2).contiguous().view_as(x) |
|
|
|
|
|
y = self.c_proj(y) |
|
|
return y |
|
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
|
def __init__(self, config: "GPTConfig") -> None: |
|
|
super().__init__() |
|
|
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False) |
|
|
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False) |
|
|
self.c_proj.weight.data.zero_() |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
x = self.c_fc(x) |
|
|
x = F.relu(x).square() |
|
|
x = self.c_proj(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, config: "GPTConfig") -> None: |
|
|
super().__init__() |
|
|
self.attn = CausalSelfAttention(config) |
|
|
self.mlp = MLP(config) |
|
|
|
|
|
def forward(self, x: Tensor, cu_seqlens: Optional[Tensor] = None, max_seqlen: Optional[int] = None) -> Tensor: |
|
|
x = x + self.attn(F.rms_norm(x, (x.size(-1),)), cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) |
|
|
x = x + self.mlp(F.rms_norm(x, (x.size(-1),))) |
|
|
return x |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GPTConfig: |
|
|
vocab_size: int = 32256 |
|
|
n_layer: int = 12 |
|
|
n_head: int = 12 |
|
|
n_embd: int = 768 |
|
|
sequence_length: int = 1024 |
|
|
|
|
|
|
|
|
class Transformer(nn.Module): |
|
|
wte: nn.Embedding |
|
|
h: nn.ModuleList |
|
|
|
|
|
def __init__(self, config: GPTConfig): |
|
|
super().__init__() |
|
|
self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
|
|
self.h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]) |
|
|
|
|
|
|
|
|
class GPT(nn.Module): |
|
|
config: GPTConfig |
|
|
transformer: Transformer |
|
|
lm_head: nn.Linear |
|
|
|
|
|
def __init__(self, config: GPTConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.transformer = Transformer(config) |
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
self.transformer.wte.weight = self.lm_head.weight |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
idx: Tensor, |
|
|
targets: Optional[Tensor] = None, |
|
|
return_logits: bool = True, |
|
|
return_hidden: bool = False, |
|
|
cu_seqlens: Optional[Tensor] = None, |
|
|
max_seqlen: Optional[int] = None, |
|
|
) -> tuple[Optional[Tensor], Optional[Tensor]] | tuple[Optional[Tensor], Optional[Tensor], Tensor]: |
|
|
assert idx.ndim == 2, f"idx must be 2D, got shape {idx.shape}" |
|
|
B, T = idx.shape |
|
|
assert B > 0 and T > 0, f"batch and seq length must be > 0: B={B}, T={T}" |
|
|
if targets is not None: |
|
|
assert targets.shape == idx.shape, f"targets shape {targets.shape} != idx shape {idx.shape}" |
|
|
|
|
|
x = self.transformer.wte(idx) |
|
|
assert x.shape == (B, T, self.config.n_embd), f"embedding output shape mismatch: {x.shape}" |
|
|
|
|
|
for block in self.transformer.h: |
|
|
x = block(x, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) |
|
|
assert x.shape == (B, T, self.config.n_embd), f"block output shape mismatch: {x.shape}" |
|
|
|
|
|
hidden = F.rms_norm(x, (x.size(-1),)) |
|
|
assert hidden.shape == x.shape, f"rms_norm shape mismatch: {hidden.shape}" |
|
|
|
|
|
if targets is not None: |
|
|
logits = self.lm_head(hidden) |
|
|
assert logits.shape == (B, T, self.config.vocab_size), f"logits shape mismatch: {logits.shape}" |
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) |
|
|
assert loss.ndim == 0, f"loss must be scalar, got shape {loss.shape}" |
|
|
else: |
|
|
if return_logits: |
|
|
logits = self.lm_head(hidden) |
|
|
else: |
|
|
logits = self.lm_head(hidden[:, [-1], :]) |
|
|
loss = None |
|
|
|
|
|
if not return_logits: |
|
|
logits = None |
|
|
|
|
|
if return_hidden: |
|
|
return logits, loss, hidden |
|
|
return logits, loss |
|
|
|
|
|
def get_num_params(self) -> int: |
|
|
return sum(p.numel() for p in self.parameters()) |
|
|
|
|
|
|
|
|
class StackedGPT(nn.Module): |
|
|
def __init__(self, model1: GPT, model2: GPT) -> None: |
|
|
super().__init__() |
|
|
self.model1 = model1 |
|
|
self.model2 = model2 |
|
|
|
|
|
assert model1.config.vocab_size == model2.config.vocab_size |
|
|
assert model1.config.n_embd == model2.config.n_embd |
|
|
|
|
|
self.model2.transformer.wte = self.model1.transformer.wte |
|
|
self.model2.lm_head = self.model1.lm_head |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
idx: Tensor, |
|
|
targets: Optional[Tensor] = None, |
|
|
return_logits: bool = True, |
|
|
cu_seqlens: Optional[Tensor] = None, |
|
|
) -> tuple[Optional[Tensor], Optional[Tensor]]: |
|
|
logits1, _ = self.model1(idx, targets=targets, return_logits=True, cu_seqlens=cu_seqlens) |
|
|
logits2, _ = self.model2(idx, targets=targets, return_logits=True, cu_seqlens=cu_seqlens) |
|
|
|
|
|
logits = (logits1 + logits2) / 2.0 |
|
|
|
|
|
if targets is not None: |
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) |
|
|
else: |
|
|
loss = None |
|
|
|
|
|
if not return_logits: |
|
|
logits = None |
|
|
|
|
|
return logits, loss |
|
|
|
|
|
|
|
|
def generate_expert_vectors(n_experts: int, embed_dim: int, seed: int = 42) -> torch.Tensor: |
|
|
torch.manual_seed(seed) |
|
|
vectors = torch.randn(n_experts, embed_dim) |
|
|
vectors = F.normalize(vectors, p=2, dim=1) |
|
|
return vectors |
|
|
|
|
|
|
|
|
class MoEGPT(nn.Module): |
|
|
expert_vectors: nn.Parameter |
|
|
expert_models: list[GPT] |
|
|
top_k: int |
|
|
temperature: float |
|
|
config: GPTConfig |
|
|
|
|
|
def __init__(self, expert_vectors: Tensor, *models: GPT, top_k: int = 2, temperature: float = 20.0): |
|
|
super().__init__() |
|
|
self.expert_vectors = nn.Parameter(expert_vectors, requires_grad=False) |
|
|
self.expert_models = list(models) |
|
|
self.models = nn.ModuleList(models) |
|
|
self.top_k = top_k |
|
|
self.temperature = temperature |
|
|
self.config = models[0].config |
|
|
|
|
|
for model in models[1:]: |
|
|
assert model.config.vocab_size == self.config.vocab_size |
|
|
assert model.config.n_embd == self.config.n_embd |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
idx: Tensor, |
|
|
targets: Optional[Tensor] = None, |
|
|
return_logits: bool = True, |
|
|
cu_seqlens: Optional[Tensor] = None, |
|
|
max_seqlen: Optional[int] = None, |
|
|
) -> tuple[Optional[Tensor], Optional[Tensor]]: |
|
|
B, T = idx.size() |
|
|
vocab_size = self.config.vocab_size |
|
|
|
|
|
token_embeds = self.expert_models[0].transformer.wte(idx) |
|
|
token_embeds_flat = token_embeds.reshape(-1, self.config.n_embd) |
|
|
|
|
|
token_embeds_norm = F.normalize(token_embeds_flat, p=2, dim=1) |
|
|
cosine_similarities = torch.matmul(token_embeds_norm, self.expert_vectors.T) |
|
|
|
|
|
|
|
|
scaled_similarities = cosine_similarities * self.temperature |
|
|
top_k_similarities, top_k_indices = torch.topk(scaled_similarities, self.top_k, dim=-1) |
|
|
top_k_weights = F.softmax(top_k_similarities, dim=-1) |
|
|
|
|
|
|
|
|
output: Optional[Tensor] = None |
|
|
for expert_id, expert_model in enumerate(self.expert_models): |
|
|
logits, _ = expert_model( |
|
|
idx, targets=None, return_logits=True, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen |
|
|
) |
|
|
|
|
|
if output is None: |
|
|
output = torch.zeros_like(logits) |
|
|
|
|
|
for k in range(self.top_k): |
|
|
routing_mask = (top_k_indices[:, k] == expert_id).float().view(B, T, 1) |
|
|
expert_weight = top_k_weights[:, k].view(B, T, 1) |
|
|
output = output + logits * routing_mask * expert_weight |
|
|
|
|
|
del logits |
|
|
|
|
|
assert output is not None, "No experts to process" |
|
|
|
|
|
if targets is not None: |
|
|
loss = F.cross_entropy(output.view(-1, vocab_size), targets.view(-1), ignore_index=-1) |
|
|
else: |
|
|
loss = None |
|
|
|
|
|
if not return_logits: |
|
|
return None, loss |
|
|
|
|
|
return output, loss |
|
|
|