sky-2002's picture
Upload deepseek_tinystories/modeling_deepseek.py
fc96ac7 verified
from dataclasses import dataclass
from torch import nn
import torch
from typing import Optional
import torch.nn.functional as F
@dataclass
class DeepSeekModelConfig:
num_attention_heads: int = 8
input_dim: int = 1024
embed_dim: int = 1024
bias: bool = False
dropout: float = 0.1
kv_heads: int = 4 # number of key-value heads for grouped query attention
# configs needed for MLA
mla_kv_heads: int = (
4 # number of groups of attention heads that share the same K and V matrices
)
use_mla: bool = False
num_gpus: int = 1 # number of gpus
# n_local_heads
# this is maybe for cases where computation is distributed across gpus, will have to read more
q_latent_dim: int = 4 # dimension of latent used to build queries
kv_latent_dim: int = 4 # dimension of latent used to build keys and values
# in official implementation, there are configs for
# rope and no-rope attention head dimensions, I am keeping it same as head dim
# since we concatenate the no-rope and rope queries and keys, they add these dimnensions
# to be later used to scaling attention scores
max_batch_size: int = 8
max_token_len: int = 1024
num_shared_experts: int = 8
num_routed_experts: int = 16
moe_top_k: int = 2
expert_intermediate_dim: int = 8192
eta: float = 0.05
num_dense_ffn: int = 2
num_moe_ffn: int = 4
mtp_depth: int = 3
vocab_size: int = 50257
class Expert(nn.Module):
def __init__(self, input_dim: int, intermediate_dim: int, dropout: float):
super().__init__()
self.w1 = nn.Linear(input_dim, intermediate_dim)
self.w11 = nn.Linear(input_dim, intermediate_dim)
self.w2 = nn.Linear(intermediate_dim, input_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w11(x)))
class MoE(nn.Module):
def __init__(self, config: DeepSeekModelConfig):
super().__init__()
self.num_shared_experts = config.num_shared_experts
self.num_routed_experts = config.num_routed_experts
self.num_local_experts = config.num_routed_experts // config.num_gpus
self.top_k = config.moe_top_k
self.expert_selector = nn.Linear(
config.input_dim, self.num_routed_experts, bias=False
)
self.routed_experts = nn.ModuleList(
[
Expert(config.input_dim, config.expert_intermediate_dim, config.dropout)
for _ in range(self.num_routed_experts)
]
)
self.shared_experts = Expert(
config.input_dim,
config.expert_intermediate_dim * self.num_shared_experts,
config.dropout,
)
self.eta = config.eta
self.register_buffer("expert_bias", torch.zeros(self.num_routed_experts))
def forward(self, x):
batch_size, num_tokens, input_dim = x.shape
gate_output, topk_indices = self.topk_routing(x, self.expert_bias)
x = x.view(
batch_size * num_tokens, input_dim
) # so now it is like a list of tokens
gate_output = gate_output.view(batch_size * num_tokens, -1)
topk_indices = topk_indices.view(batch_size * num_tokens, -1)
# --- cache routing info for interpretability ---
self.last_topk_indices = (
topk_indices.view(batch_size, num_tokens, -1).detach().cpu()
)
self.last_gate_output = (
gate_output.view(batch_size, num_tokens, -1).detach().cpu()
)
expert_counts = torch.bincount(
topk_indices.flatten(), minlength=self.num_routed_experts
)
with torch.no_grad():
avg = expert_counts.float().mean()
err = expert_counts.float() - avg
self.expert_bias += -self.eta * err.sign()
# Save for logging
if hasattr(self, "expert_usage"):
self.expert_usage.append(expert_counts.detach().cpu())
else:
self.expert_usage = [expert_counts.detach().cpu()]
y = torch.zeros_like(x)
# counts = torch.bincount(
# topk_indices.flatten(), minlength=self.num_routed_experts
# ).tolist()
counts = expert_counts.tolist()
for i in range(self.num_routed_experts):
if counts[i] == 0:
continue
expert = self.routed_experts[i]
idx, expert_rank = torch.where(topk_indices == i)
y[idx] += expert(x[idx]) * gate_output[idx, expert_rank, None]
z = self.shared_experts(x)
return (y + z).view(batch_size, num_tokens, input_dim)
def topk_routing(self, x, bias=None):
batch_size, num_tokens, input_dim = x.shape
expert_logits = self.expert_selector(x) # B, T, num_experts
if bias is not None:
expert_logits = expert_logits + bias
topk_logits, topk_indices = torch.topk(expert_logits, k=self.top_k, dim=-1)
zeros = torch.full_like(expert_logits, float("-inf"))
sparse_logits = zeros.scatter(dim=-1, index=topk_indices, src=topk_logits)
gate_output = sparse_logits.softmax(dim=-1)
return gate_output, topk_indices
class RoPE(nn.Module):
def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_seq_len = max_seq_len
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._cached_cos = None
self._cached_sin = None
self._cached_seq_len = 0
def _compute_cos_sin(self, seq_len: int, device: torch.device):
if seq_len > self._cached_seq_len or self._cached_cos is None:
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
cos_vals = torch.cos(freqs)
sin_vals = torch.sin(freqs)
self._cached_cos = cos_vals
self._cached_sin = sin_vals
self._cached_seq_len = seq_len
return self._cached_cos[:seq_len], self._cached_sin[:seq_len]
def apply_rope(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None):
"""Apply RoPE to input tensor"""
batch_size, num_tokens, n_heads, head_dim = x.shape
cos, sin = self._compute_cos_sin(num_tokens, x.device)
if position_ids is not None:
cos = cos[position_ids]
sin = sin[position_ids]
cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq_len, 1, head_dim//2]
sin = sin.unsqueeze(0).unsqueeze(2)
x1 = x[..., ::2] # Even indices
x2 = x[..., 1::2] # Odd indices
rotated_x1 = x1 * cos - x2 * sin
rotated_x2 = x1 * sin + x2 * cos
rotated_x = torch.stack([rotated_x1, rotated_x2], dim=-1).flatten(-2)
return rotated_x
class MultiHeadAttention(nn.Module):
def __init__(self, config: DeepSeekModelConfig):
super().__init__()
self.num_heads = config.num_attention_heads
self.input_dim = config.input_dim
self.embed_dim = config.embed_dim
self.head_dim = self.embed_dim // self.num_heads
self.Wq = nn.Linear(self.input_dim, self.embed_dim, bias=False)
self.Wk = nn.Linear(self.input_dim, self.embed_dim, bias=False)
self.Wv = nn.Linear(self.input_dim, self.embed_dim, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.input_dim, bias=config.bias)
def forward(self, x):
# x is B, T, input_dim
batch_size, num_tokens, input_dim = x.shape
Q = (
self.Wq(x)
.view(batch_size, num_tokens, self.num_heads, self.head_dim)
.transpose(1, 2)
) # becomes B, num_heads, T, head_dim
K = (
self.Wk(x)
.view(batch_size, num_tokens, self.num_heads, self.head_dim)
.transpose(1, 2)
) # becomes B, num_heads, T, head_dim
V = (
self.Wv(x)
.view(batch_size, num_tokens, self.num_heads, self.head_dim)
.transpose(1, 2)
) # becomes B, num_heads, T, head_dim
attention_scores = Q @ K.transpose(2, 3)
attention_scores = attention_scores / (self.head_dim**0.5)
causal_mask = torch.triu(torch.ones(num_tokens, num_tokens), diagonal=1)
attention_scores = attention_scores.masked_fill(
causal_mask.bool(), float("-inf")
)
attention_weights = torch.softmax(
attention_scores, dim=-1
) # B, num_heads, T, T
context = attention_weights @ V # B, num_heads, T, head_dim
context = attention_weights.transpose(1, 2) # B, T, num_heads, head_dim
context = attention_weights.view(batch_size, num_tokens, self.embed_dim)
out = self.out_proj(context) # B, T, input_dim
return out
class MultiQueryAttention(nn.Module):
def __init__(self, config: DeepSeekModelConfig):
super().__init__()
self.num_heads = config.num_attention_heads
self.input_dim = config.input_dim
self.embed_dim = config.embed_dim
self.head_dim = self.embed_dim // self.num_heads
self.Wq = nn.Linear(self.input_dim, self.embed_dim, bias=False)
self.Wk = nn.Linear(self.input_dim, self.head_dim, bias=False)
self.Wv = nn.Linear(self.input_dim, self.head_dim, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.input_dim, bias=config.bias)
def forward(self, x):
# x is B, T, input_dim
batch_size, num_tokens, input_dim = x.shape
Q = (
self.Wq(x)
.view(batch_size, num_tokens, self.num_heads, self.head_dim)
.transpose(1, 2)
) # becomes B, num_heads, T, head_dim
K = self.Wk(x) # B, T, head_dim
V = self.Wv(x) # B, T, head_dim
# create copies for all heads
K = K.expand(-1, self.num_heads, -1, -1)
V = V.expand(-1, self.num_heads, -1, -1)
attention_scores = Q @ K.transpose(2, 3)
attention_scores = attention_scores / (self.head_dim**0.5)
causal_mask = torch.triu(torch.ones(num_tokens, num_tokens), diagonal=1)
attention_scores = attention_scores.masked_fill(
causal_mask.bool(), float("-inf")
)
attention_weights = torch.softmax(
attention_scores, dim=-1
) # B, num_heads, T, T
context = attention_weights @ V # B, num_heads, T, head_dim
context = attention_weights.transpose(1, 2) # B, T, num_heads, head_dim
context = attention_weights.view(batch_size, num_tokens, self.embed_dim)
out = self.out_proj(context) # B, T, input_dim
return out
class GroupedQueryAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.num_heads = config.num_attention_heads
self.input_dim = config.input_dim
self.embed_dim = config.embed_dim
self.head_dim = self.embed_dim // self.num_heads
self.kv_heads = config.kv_heads
self.Wq = nn.Linear(self.input_dim, self.embed_dim, bias=False)
self.Wk = nn.Linear(self.input_dim, self.head_dim * config.kv_heads, bias=False)
self.Wv = nn.Linear(self.input_dim, self.head_dim * config.kv_heads, bias=False)
self.out_proj = nn.Linear(self.embed_dim, self.input_dim, bias=config.bias)
def forward(self, x):
batch_size, num_tokens, input_dim = x.shape
Q = (
self.Wq(x)
.view(batch_size, num_tokens, self.num_heads, self.head_dim)
.transpose(1, 2)
) # becomes B, num_heads, T, head_dim
K = self.Wk(x) # B, T, head_dim*kv_heads
V = self.Wv(x) # B, T, head_dim*kv_heads
K = K.view(batch_size, num_tokens, self.kv_heads, self.head_dim)
V = V.view(batch_size, num_tokens, self.kv_heads, self.head_dim)
# now i need this
# if kv_heads is 3 and num_heads is 6
# I want k = [k1, k1, k2, k2, k3, k3] and same for v
K = K.repeat_interleave(
self.num_heads // self.kv_heads, dim=2
) # B, T, num_heads, head_dim
V = V.repeat_interleave(
self.num_heads // self.kv_heads, dim=2
) # B, T, num_heads, head_dim
attention_scores = Q @ K.transpose(2, 3)
attention_scores = attention_scores / (self.head_dim**0.5)
causal_mask = torch.triu(torch.ones(num_tokens, num_tokens), diagonal=1)
attention_scores = attention_scores.masked_fill(
causal_mask.bool(), float("-inf")
)
attention_weights = torch.softmax(
attention_scores, dim=-1
) # B, num_heads, T, T
context = attention_weights @ V # B, num_heads, T, head_dim
context = attention_weights.transpose(1, 2) # B, T, num_heads, head_dim
context = attention_weights.view(batch_size, num_tokens, self.embed_dim)
out = self.out_proj(context) # B, T, input_dim
return out
# I have copied RMSNorm directly from Deepseek-V3 repo
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor):
return F.rms_norm(x, (self.dim,), self.weight, self.eps)
# TODO:
# 1. Try out grouped query attention styled MLA, where each kv head has its own latent cache
# 2.Try out sliding window attention, I read about this in gemma paper
class MultiHeadLatentAttention(nn.Module):
def __init__(self, config: DeepSeekModelConfig):
super().__init__()
self.num_heads = config.num_attention_heads
self.input_dim = config.input_dim
self.embed_dim = config.embed_dim
self.n_local_heads = config.num_attention_heads // config.num_gpus
self.head_dim = self.embed_dim // self.num_heads
self.mla_kv_heads = config.mla_kv_heads
self.kv_latent_dim = config.kv_latent_dim
self.q_latent_dim = config.q_latent_dim
self.dropout = nn.Dropout(config.dropout)
self.rope = RoPE(dim=self.head_dim)
self.out_proj = nn.Linear(
self.num_heads * self.head_dim, self.input_dim, bias=False
)
if self.q_latent_dim == 0:
self.Wq = nn.Linear(
self.input_dim, self.num_heads * self.head_dim, bias=False
)
else:
# -------------------(decoupled from RoPE)-----------------------------
# Query path - This feels to me like LoRa on Q
# because instead of Wq (input_dim, input_dim) we now have
# Wdq(input_dim, q_latent_dim) and Wuq(q_latent_dim, input_dim)
self.Wdq = nn.Linear(self.input_dim, self.q_latent_dim, bias=False)
self.q_norm = RMSNorm(self.q_latent_dim)
self.Wuq = nn.Linear(
self.q_latent_dim, self.num_heads * self.head_dim, bias=False
)
# this will build KV latent and also construct K and V from it
self.Wdkv = nn.Linear(self.input_dim, self.kv_latent_dim, bias=False)
self.kv_norm = RMSNorm(self.kv_latent_dim)
self.Wuk = nn.Linear(
self.kv_latent_dim, self.head_dim, bias=False
) # here I am not using num_heads because we will use kv heads (grouped query attention)
self.Wuv = nn.Linear(
self.kv_latent_dim, self.mla_kv_heads * self.head_dim, bias=False
)
# cache the kv latent and the roped keys
self.register_buffer(
"kv_latent_cache",
torch.zeros(
config.max_batch_size, config.max_token_len, self.kv_latent_dim
),
persistent=False, # I won't store on disk
)
self.register_buffer(
"keys_roped",
torch.zeros(
config.max_batch_size,
config.max_token_len,
self.mla_kv_heads,
# I could have not used these heads, then we have same keys for each head,4
# here it is same for a group of attention heads which come under one kv head
self.head_dim,
),
persistent=False,
)
# --------------------------------------------------------------------
# -------------RoPE path----------------------------------------------
self.Wkr = nn.Linear(
self.input_dim, self.mla_kv_heads * self.head_dim, bias=False
)
self.Wqr = nn.Linear(self.q_latent_dim, self.embed_dim, bias=False)
def forward(self, x, start_pos=0):
batch_size, num_tokens, input_dim = x.shape
end_pos = start_pos + num_tokens
S = end_pos # total cached sequence length
# ----- Queries -----
if self.q_latent_dim == 0:
Q = (
self.Wq(x)
.view(batch_size, num_tokens, self.num_heads, self.head_dim)
.transpose(1, 2)
) # [B, num_heads, T, head_dim]
else:
query_latent = self.Wdq(x)
query_latent = self.q_norm(query_latent)
Q = (
self.Wuq(query_latent)
.view(batch_size, num_tokens, self.num_heads, self.head_dim)
.transpose(1, 2) # [B, num_heads, T, head_dim]
)
# ----- RoPE path -----
if self.q_latent_dim == 0:
Qr = self.rope.apply_rope(
Q.view(batch_size, num_tokens, self.num_heads, self.head_dim)
).transpose(1, 2)
else:
Qr = self.rope.apply_rope(
self.Wqr(query_latent).view(
batch_size, num_tokens, self.num_heads, self.head_dim
)
).transpose(1, 2)
# ---------------------
# ----- KV latent -----
kv_latent = self.Wdkv(x) # [B, T, kv_latent_dim]
# update cache
self.kv_latent_cache[:batch_size, start_pos:end_pos] = self.kv_norm(
kv_latent
).detach()
kv_latent_all = self.kv_latent_cache[
:batch_size, :end_pos
] # [B, T, kv_latent_dim]
# [B, num_heads, T, head_dim] x [head_dim, kv_latent_dim]
Q_absorbed = Q @ self.Wuk.weight # B, num_heads, T, kv_latent_dim
V = self.Wuv(kv_latent_all).view(
batch_size, S, self.mla_kv_heads, self.head_dim
) # [B, S, mla_kv_heads, head_dim]
# expand V to match n_heads
V = V.repeat_interleave(
self.num_heads // self.mla_kv_heads, dim=2
) # [B, T, num_heads, head_dim]
V = V.transpose(1, 2) # [B, H, S, D]
# ----- RoPE path -----
K_pos_encoding = self.rope.apply_rope(
self.Wkr(x)
.view(batch_size, num_tokens, self.mla_kv_heads, self.head_dim)
.transpose(1, 2)
).transpose(
1, 2
) # B, T, mla_kv_heads head_dim
self.keys_roped[:batch_size, start_pos:end_pos] = K_pos_encoding.detach()
keys_roped_all = self.keys_roped[:batch_size, :end_pos]
Kr = (
keys_roped_all.repeat_interleave(self.num_heads // self.mla_kv_heads, dim=2)
.view(batch_size, S, self.num_heads, self.head_dim)
.transpose(1, 2) # [B, S, T, head_dim]
)
# ----- Attention scores -----
# doing unsqueeze to account for heads, since kv cache is only one, not per head
attention_scores_1 = Q_absorbed @ kv_latent_all.unsqueeze(1).transpose(2, 3)
attention_scores_2 = Qr @ Kr.transpose(-2, -1) # [B, num_heads, T, T]
attention_scores = (attention_scores_1 + attention_scores_2) / (
2 * self.head_dim
) ** 0.5
# causal mask
causal_mask = torch.triu(
torch.ones(end_pos, end_pos, device=x.device), diagonal=1
)
attention_scores = attention_scores.masked_fill(
causal_mask.bool()[:, -num_tokens:], float("-inf")
)
attention_weights = torch.softmax(attention_scores, dim=-1)
self.last_attention = attention_weights.detach()
attention_weights = self.dropout(attention_weights)
# ----- Context -----
context = attention_weights @ V # [B, H, T, D]
context = (
context.transpose(1, 2)
.contiguous()
.view(batch_size, num_tokens, self.embed_dim)
)
out = self.out_proj(context)
return out
# Note: I might not use this in training, will do normal single token prediction only
class BasicMultiTokenPrediction(nn.Module):
def __init__(self, config: DeepSeekModelConfig):
super().__init__()
# If k is mtp_depth, and current token position is i
# this module predicts next k tokens, so from
# (i+1) to (i+k)
self.k = config.mtp_depth
self.vocab_size = config.vocab_size
self.rms_norm = RMSNorm(config.input_dim)
self.embed = nn.Embedding(self.vocab_size, config.input_dim)
self.unembed = nn.Linear(config.input_dim, self.vocab_size, bias=False)
self.unembed.weight = self.embed.weight
self.projections = nn.ModuleList(
[nn.Linear(2 * config.input_dim, config.input_dim) for _ in range(self.k)]
)
self.transformers = nn.ModuleList(
[
nn.TransformerEncoderLayer(config.input_dim, config.num_attention_heads)
for _ in range(self.k)
]
)
def forward(self, x):
# x is the final hidden states for all tokens that we get after all transformer blocks,
# so it is just before the final un-ebedding layer
batch_size, num_tokens, input_size = x.shape
# if num_tokens is 6
# i = 0, 1, 2, 3, 4, 5
# k=3
# i can predict till 2+3 = 5
# so i have to iterate i from 0 to 2 only
# 2 = 6(num_tokens)-3(k)-1
# so I have to go till x[:,num_tokens-k, :]
logits = []
for ith_token_pos in range(0, num_tokens - self.k):
hidden_state_ith_token = x[:, ith_token_pos, :]
logits_k = []
for k in range(self.k):
future_position = ith_token_pos + k + 1
token_embedding = x[
:, future_position, :
] # considering x as the final hidden state after all blocks
_h = self.rms_norm(hidden_state_ith_token)
_e = self.rms_norm(token_embedding)
merged = torch.cat([_h, _e], dim=1)
proj = self.projections[k](merged).unsqueeze(0)
out = self.transformers[k](proj)
hidden_state_current = out.squeeze(0)
_logits = self.unembed(hidden_state_current)
logits_k.append(_logits)
hidden_state_ith_token = hidden_state_current
logits_k = torch.stack(logits_k, dim=1)
logits.append(logits_k)
logits = torch.stack(logits, dim=0)
logits = logits.permute(1, 0, 2, 3).contiguous()
return logits
class TransformerBlock(nn.Module):
def __init__(self, config: DeepSeekModelConfig, moe: bool = True):
super().__init__()
self.rms_norm_1 = RMSNorm(config.input_dim)
self.mhla = MultiHeadLatentAttention(config)
self.rms_norm_2 = RMSNorm(config.input_dim)
if moe:
self.ffn = MoE(config)
else:
self.ffn = Expert(
config.input_dim, config.expert_intermediate_dim, config.dropout
)
def forward(self, x):
x = x + self.mhla(self.rms_norm_1(x))
x = x + self.ffn(self.rms_norm_2(x))
return x
class DeepseekInspiredModel(nn.Module):
def __init__(self, config: DeepSeekModelConfig):
super().__init__()
self.config = config
self.token_embedding = nn.Embedding(config.vocab_size, config.input_dim)
self.position_embedding = nn.Embedding(config.max_token_len, config.input_dim)
_blocks = [
TransformerBlock(config, moe=False) for _ in range(config.num_dense_ffn)
]
_blocks.extend(
[TransformerBlock(config, moe=True) for _ in range(config.num_moe_ffn)]
)
self.transformer_blocks = nn.ModuleList(_blocks)
self.ln_f = RMSNorm(config.input_dim)
self.head = nn.Linear(config.input_dim, config.vocab_size, bias=False)
self.head.weight = self.token_embedding.weight
def forward(self, x):
batch_size, num_tokens = x.shape
token_embeddings = self.token_embedding(x)
position_ids = torch.arange(0, num_tokens, device=x.device).unsqueeze(0)
position_embeddings = self.position_embedding(position_ids)
h = token_embeddings + position_embeddings
for block in self.transformer_blocks:
h = block(h)
h = self.ln_f(h)
logits = self.head(h)
return logits
if __name__ == "__main__":
config = DeepSeekModelConfig()
x = torch.rand(1, 10)
dim = DeepseekInspiredModel(config)
print(
f"Number of parameters (in millions): {sum(p.numel() for p in dim.parameters()) / 1_000_000}"
)
print(
f"Number of parameters (in GB): {sum(p.numel() for p in dim.parameters())*4/1024**3:.2f} GB"
)