deepseek-mlha-moe / model.py
malarsaravanan's picture
Upload 5 files
bb63e66 verified
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Tuple, List
@dataclass
class DeepSeekConfig:
"""
Configuration for DeepSeek model (Scaled down to ~135M params).
"""
block_size: int = 2048
vocab_size: int = 49152
n_layer: int = 12 # 12 layers (Pruned from 30)
n_head: int = 9 # Trained with 9
n_embd: int = 576 # Trained with 576
n_kv_head: int = 3 # Trained with 3
intermediate_size: int = 1536 # Trained with 1536
rms_norm_eps: float = 1e-5
rope_theta: float = 100000.0
# MLHA params
q_lora_rank: int = 192 # Trained with 192
kv_lora_rank: int = 128 # Trained with 128
# MoE params
n_routed_experts: int = 8 # Trained with 8
n_shared_experts: int = 2 # Trained with 2
n_activated_experts: int = 2
moe_intermediate_size: int = 1536 # Trained with 1536 (Fixed mismatch)
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
"""Precompute complex exponentials for Rotary Positional Embeddings (RoPE)."""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
"""Reshape frequency tensor for broadcasting with input tensor."""
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply Rotary Positional Embeddings to query and key tensors."""
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class MultiHeadLatentAttention(nn.Module):
"""
Multi-Head Latent Attention (MLHA) from DeepSeek.
Key innovation: Low-rank compression of KV cache to reduce memory.
- Queries: Compressed via low-rank projection (q_lora_rank)
- Keys/Values: Compressed via low-rank projection (kv_lora_rank)
- Significantly reduces KV cache size during inference
"""
def __init__(self, config: DeepSeekConfig):
super().__init__()
self.n_head = config.n_head
self.n_kv_head = config.n_kv_head
self.n_embd = config.n_embd
self.head_dim = config.n_embd // config.n_head
self.n_rep = self.n_head // self.n_kv_head
# MLHA: Low-rank compression
self.kv_lora_rank = config.kv_lora_rank
self.q_lora_rank = config.q_lora_rank
# Query projection with low-rank compression
# q = W_q_down @ W_q_up @ x
self.q_down_proj = nn.Linear(config.n_embd, config.q_lora_rank, bias=False)
self.q_up_proj = nn.Linear(config.q_lora_rank, config.n_head * self.head_dim, bias=False)
# KV projection with low-rank compression
self.kv_down_proj = nn.Linear(config.n_embd, config.kv_lora_rank, bias=False)
self.kv_up_proj = nn.Linear(config.kv_lora_rank, 2 * config.n_kv_head * self.head_dim, bias=False)
# Output projection
self.o_proj = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=False)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
B, T, C = x.size()
# MLHA: Low-rank query projection
q_compressed = self.q_down_proj(x) # (B, T, q_lora_rank)
xq = self.q_up_proj(q_compressed) # (B, T, n_head * head_dim)
# MLHA: Low-rank KV projection
kv_compressed = self.kv_down_proj(x) # (B, T, kv_lora_rank)
kv = self.kv_up_proj(kv_compressed) # (B, T, 2 * n_kv_head * head_dim)
# Split KV
xk, xv = kv.chunk(2, dim=-1)
# Reshape for multi-head attention
xq = xq.view(B, T, self.n_head, self.head_dim)
xk = xk.view(B, T, self.n_kv_head, self.head_dim)
xv = xv.view(B, T, self.n_kv_head, self.head_dim)
# Apply RoPE
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
# GQA: Repeat KV heads to match query heads
if self.n_rep > 1:
xk = torch.repeat_interleave(xk, self.n_rep, dim=2)
xv = torch.repeat_interleave(xv, self.n_rep, dim=2)
# Transpose for attention: (B, n_head, T, head_dim)
xq = xq.transpose(1, 2)
xk = xk.transpose(1, 2)
xv = xv.transpose(1, 2)
# Flash Attention
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
# Reshape and project
output = output.transpose(1, 2).contiguous().view(B, T, C)
return self.o_proj(output)
class Expert(nn.Module):
"""Single expert in the MoE layer."""
def __init__(self, config: DeepSeekConfig):
super().__init__()
self.gate_proj = nn.Linear(config.n_embd, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.n_embd, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.n_embd, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU activation
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class MoELayer(nn.Module):
"""
Mixture of Experts (MoE) with lossless load balancing.
Key features:
- Shared experts: Always activated for all tokens
- Routed experts: Top-k selection per token
- Lossless load balancing: No auxiliary loss, uses expert capacity
"""
def __init__(self, config: DeepSeekConfig):
super().__init__()
self.n_routed_experts = config.n_routed_experts
self.n_shared_experts = config.n_shared_experts
self.n_activated_experts = config.n_activated_experts
self.n_embd = config.n_embd
# Router: Maps input to expert scores
self.router = nn.Linear(config.n_embd, config.n_routed_experts, bias=False)
# Routed experts
self.routed_experts = nn.ModuleList([
Expert(config) for _ in range(config.n_routed_experts)
])
# Shared experts (always active)
self.shared_experts = nn.ModuleList([
Expert(config) for _ in range(config.n_shared_experts)
])
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Memory-optimized forward pass with lossless load balancing.
Returns:
output: MoE output
router_logits: For monitoring load balance (not used in loss)
"""
B, T, C = x.size()
x_flat = x.view(-1, C) # (B*T, C)
# 1. Shared experts (always active) - memory efficient
if self.n_shared_experts > 0:
shared_output = torch.zeros_like(x_flat)
for expert in self.shared_experts:
shared_output.add_(expert(x_flat))
shared_output.div_(self.n_shared_experts)
else:
shared_output = torch.zeros_like(x_flat)
# 2. Routed experts (top-k selection) - optimized routing
router_logits = self.router(x_flat) # (B*T, n_routed_experts)
routing_weights = F.softmax(router_logits, dim=-1)
# Select top-k experts
top_k_weights, top_k_indices = torch.topk(
routing_weights,
k=self.n_activated_experts,
dim=-1
) # (B*T, k)
# Normalize top-k weights
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-8)
# Memory-efficient expert routing
routed_output = torch.zeros_like(x_flat)
# Process each expert efficiently
for expert_id in range(self.n_routed_experts):
# Find tokens routed to this expert
expert_mask = (top_k_indices == expert_id).any(dim=1)
if expert_mask.any():
# Get indices and weights for this expert
token_indices = expert_mask.nonzero(as_tuple=True)[0]
expert_input = x_flat[token_indices]
# Compute expert output
expert_out = self.routed_experts[expert_id](expert_input)
# Get weights for these tokens
weights = torch.zeros(token_indices.size(0), 1, device=x_flat.device)
for k in range(self.n_activated_experts):
mask = (top_k_indices[token_indices, k] == expert_id)
weights[mask] = top_k_weights[token_indices[mask], k:k+1]
# Add weighted output
routed_output[token_indices] += weights * expert_out
# Combine shared and routed outputs
output = shared_output + routed_output
output = output.view(B, T, C)
return output, router_logits
class DeepSeekBlock(nn.Module):
"""
DeepSeek Transformer Block with MLHA and MoE.
"""
def __init__(self, config: DeepSeekConfig):
super().__init__()
self.attention = MultiHeadLatentAttention(config)
self.moe = MoELayer(config)
self.input_layernorm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# Self-attention with residual
h = x + self.attention(self.input_layernorm(x), freqs_cis)
# MoE with residual
moe_output, router_logits = self.moe(self.post_attention_layernorm(h))
output = h + moe_output
return output, router_logits
class DeepSeek(nn.Module):
"""
DeepSeek Model with MLHA and MoE for Causal Language Modeling.
"""
def __init__(self, config: DeepSeekConfig):
super().__init__()
self.config = config
# Token embeddings
self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd)
# Transformer blocks
self.layers = nn.ModuleList([
DeepSeekBlock(config) for _ in range(config.n_layer)
])
# Final layer norm
self.norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
# Language modeling head
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Weight tying
self.embed_tokens.weight = self.lm_head.weight
# Precompute RoPE frequencies
self.freqs_cis = precompute_freqs_cis(
config.n_embd // config.n_head,
config.block_size * 2,
config.rope_theta
)
print(f"DeepSeek Model initialized with {self.count_parameters():,} parameters")
def count_parameters(self) -> int:
"""Count total trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
B, T = idx.size()
assert T <= self.config.block_size, f"Sequence length {T} exceeds block size {self.config.block_size}"
# Move freqs_cis to device if needed
if self.freqs_cis.device != idx.device:
self.freqs_cis = self.freqs_cis.to(idx.device)
freqs_cis = self.freqs_cis[:T]
# Embeddings
x = self.embed_tokens(idx)
# Transformer blocks
all_router_logits = []
for layer in self.layers:
x, router_logits = layer(x, freqs_cis)
all_router_logits.append(router_logits)
# Final norm
x = self.norm(x)
# Language modeling head
if targets is not None:
logits = self.lm_head(x)
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
return logits, loss, all_router_logits
else:
# Generation mode: only compute last token
logits = self.lm_head(x[:, [-1], :])
return logits
@classmethod
def from_pretrained(cls, model_name: str = "HuggingFaceTB/SmolLM2-135M"):
"""
Initialize from SmolLM2 pretrained weights (where possible).
Note: MLHA and MoE layers will be randomly initialized.
"""
from transformers import AutoModelForCausalLM
print(f"Loading base weights from {model_name}")
hf_model = AutoModelForCausalLM.from_pretrained(model_name)
hf_sd = hf_model.state_dict()
config = DeepSeekConfig()
model = cls(config)
sd = model.state_dict()
# Only load embeddings and LM head (architecture is different)
keys_to_load = ["embed_tokens.weight", "lm_head.weight"]
for k in keys_to_load:
hf_key = f"model.{k}" if "embed" in k else k
if hf_key in hf_sd:
with torch.no_grad():
sd[k].copy_(hf_sd[hf_key])
print(f"Loaded: {k}")
print("Note: MLHA and MoE layers initialized randomly (architecture mismatch)")
return model