ChatSPE / model.py
saifkhanengr's picture
Updated
7544a45
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import tiktoken
__all__ = [
'Rope',
'DeepSeek_MLA',
'DeepSeek_MoE',
'DeepSeek_MTP',
'DeepSeek_V3_Block',
'DeepSeek_V3_Encoder',
'DeepSeek_V3_Model',
'generate_text',
'clean_response'
]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define model configurations
class Config:
hidden_size = 128 # Embedding dimension (D)
latent_dim = hidden_size // 2 # Latent dimension, half of D (a random choice)
num_heads = 16 # Number of attention heads (should divide hidden_size)
pos_dim = 24 # Positional encoding dimension
pad_token_id = 50256 # Padding token ID (matches <|endoftext|> in GPT-2 vocab)
num_shared_experts = 4
num_routed_experts = 8
top_k = 8 # Kr, number of experts selected per token
bias_update_speed = 0.01
balance_alpha = 0.01
lambda_mtp = 0.5 # λ, weighting
num_depths = 3 # D, number of prediction depths
vocab_size= tiktoken.get_encoding("gpt2").n_vocab # Vocab size of tiktoken’s GPT-2 vocab (50257)
layer_norm_eps = 1e-5 # Small epsilon value for numerical stability in layer normalization
num_blocks = 12 # Number of transformer blocks to stack in the model
batch_size = 64 # Number of sequences per batch
context_length = 60 # Number of tokens per sequence
class Rope(nn.Module):
"""
Rotary Position Embedding (RoPE) module.
Applies rotary position encoding to an input tensor of shape (B, H, S, D),
"""
def __init__(self, dim, max_seq_len = 4096):
super().__init__()
# Safety check: RoPE requires even dimensionality (for splitting into pairs)
assert dim % 2 == 0, f"RoPE dim must be even, got {dim}"
self.dim = dim
self.max_seq_len = max_seq_len
# Step 2: Compute rotation frequencies for sinusoidal positions
# inv_freq[i] = 1 / (10000^(2i/dim)), where i = 0, 1, ..., dim/2 - 1
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
# Store as non-trainable buffer
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Precompute and cache cos/sin values up to max_seq_len
self._build_cache(max_seq_len)
def _build_cache(self, seq_len):
"""
Precompute cosine and sine embeddings for all positions up to seq_len.
This avoids recomputing trig functions during every forward pass.
"""
# Positions: [0, 1, 2, ..., seq_len-1]
t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device)
# Step 3: Compute rotation angles (per position and dimension pair)
# Each row is t * inv_freq[i], giving angular frequency per dimension
freqs = torch.outer(t, self.inv_freq)
# Duplicate for concatenation of sin and cos values, shape: (seq_len, dim)
emb = torch.cat((freqs, freqs), dim=-1)
# Step 4: Construct rotation matrix elements (cos and sin)
# Register as buffers with shape (1, 1, seq_len, dim)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
# Track how many positions we have cached
self.max_seq_len = seq_len
def forward(self, x, seq_len, position_offset = 0):
"""
Apply RoPE to input tensor.
Args:
x: Input tensor of shape (B, H, S, D)
seq_len: Actual sequence length to encode
position_offset: Offset for decoding continuation (default = 0)
Returns:
Tensor with RoPE applied, same shape as x.
"""
device = x.device
# Ensure input matches expected dimensionality
assert x.shape[-1] == self.dim, (
f"RoPE input dim mismatch: expected {self.dim}, got {x.shape[-1]}"
)
seq_len_x = x.size(-2) # sequence length from input tensor
if (position_offset + seq_len) > self.max_seq_len:
# Rebuild cache with doubled size for efficiency
self._build_cache(max(position_offset + seq_len, self.max_seq_len * 2))
# Select only the needed positions
cos = self.cos_cached[:, :, position_offset:position_offset + seq_len, :].to(device)
sin = self.sin_cached[:, :, position_offset:position_offset + seq_len, :].to(device)
# Ensure cache slice matches actual input sequence length
assert cos.shape[2] == seq_len_x, (
f"RoPE seq_len mismatch: expected {seq_len_x}, got {cos.shape[2]}"
)
# Step 1: Split Q/K into 2D subspaces (pairs of dimensions)
# Split last dimension into pairs: (x1, x2)
x1, x2 = x.chunk(2, dim=-1)
# Step 5 and Step 6: Apply rotation to each 2D subspace
# Rotate pairs: (x1, x2) → (-x2, x1)
rotated = torch.cat((-x2, x1), dim=-1)
# Apply rotary transformation: elementwise (x*cos + rotated*sin)
result = x*cos + rotated*sin
return result
class DeepSeek_MLA(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size # Embedding dimension
self.num_heads = config.num_heads
self.head_dim = self.hidden_size // self.num_heads
self.latent_dim = config.latent_dim
self.pos_dim = config.pos_dim
self.max_seq_len = getattr(config, 'max_seq_len', 512) # Add max sequence length
self.pad_token_id = getattr(config, 'pad_token_id', 50256) # Default pad token ID to 50256 of tiktoken’s GPT-2 vocab, same as <|endoftext|> in the tiktoken’s GPT-2 vocab
assert self.hidden_size % self.num_heads == 0, f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})"
# Ensure pos_dim is even for RoPE
assert self.pos_dim % 2 == 0, f"pos_dim ({self.pos_dim}) must be even for RoPE"
# Latent compression projections
self.W_DKV = nn.Linear(self.hidden_size, self.latent_dim, bias=False) # KV compression
self.W_DQ = nn.Linear(self.hidden_size, self.latent_dim, bias=False) # Q compression
# Content projection from latent to multi-head space
self.W_UK = nn.Linear(self.latent_dim, self.hidden_size, bias=False) # K content
self.W_UV = nn.Linear(self.latent_dim, self.hidden_size, bias=False) # V content
self.W_UQ = nn.Linear(self.latent_dim, self.hidden_size, bias=False) # Q content
# Positional projections (RoPE pathway)
self.W_KR = nn.Linear(self.hidden_size, self.pos_dim, bias=False) # K positional
self.W_QR = nn.Linear(self.latent_dim, self.num_heads * self.pos_dim, bias=False) # Q positional
# Output projection
self.W_O = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
# RoPE initialization
self.rope_k = Rope(self.pos_dim)
self.rope_q = Rope(self.pos_dim)
# ---- Precomputed causal mask ----
# Create upper triangular mask with ones above diagonal and convert to boolean
self.register_buffer("causal_mask", torch.triu(torch.ones(self.max_seq_len, self.max_seq_len), diagonal=1).bool())
def forward(self, hidden_states, input_tokens=None, mode="train", use_cache=False, past_key_values=None, attention_mask=False):
batch_size, seq_len, hidden_size = hidden_states.shape
assert hidden_size == self.hidden_size, f"hidden_size mismatch: got {hidden_size}, expected {self.hidden_size}"
# ---- Latent compressions ----
c_KV = self.W_DKV(hidden_states) # (batch_size, seq_len, latent_dim)
c_Q = self.W_DQ(hidden_states) # (batch_size, seq_len, latent_dim)
# ---- Content projections (per-head) ----
k_C = self.W_UK(c_KV).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, H, seq_len, head_dim)
v_C = self.W_UV(c_KV).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, H, seq_len, head_dim)
q_C = self.W_UQ(c_Q).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, H, seq_len, head_dim)
# ---- Positional projections ----
k_R = self.W_KR(hidden_states) # (batch_size, seq_len, pos_dim)
q_R = self.W_QR(c_Q).view(batch_size, seq_len, self.num_heads, self.pos_dim).transpose(1, 2) # (batch_size, H, seq_len, pos_dim)
# ---- Determine past length for RoPE position_offset ----
past_len = 0 if past_key_values is None else past_key_values[0].size(2)
# ---- Apply RoPE (position offset = past_len) ----
k_R = self.rope_k(k_R.unsqueeze(1).expand(-1, self.num_heads, -1, -1), seq_len=seq_len, position_offset=past_len) # (batch_size, H, seq_len, pos_dim)
q_R = self.rope_q(q_R, seq_len=seq_len, position_offset=past_len) # (batch_size, H, seq_len, pos_dim)
######### TRAINING MODE #########
if mode == "train":
k = torch.cat([k_C, k_R], dim=-1) # (batch_size, H, seq_len, head_dim + pos_dim)
q = torch.cat([q_C, q_R], dim=-1) # (batch_size, H, seq_len, head_dim + pos_dim)
scale = 1.0 / math.sqrt(q.shape[-1]) # same as scale = 1.0 / math.sqrt(head_dim + pos_dim)
attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale # (batch_size, H, seq_len, seq_len)
# ---- Apply mask (causal + padding) ----
if attention_mask:
# Mask truncated to the number of tokens and converted to boolean
mask_bool = self.causal_mask[:seq_len, :seq_len]
# Convert boolean mask to -inf format for attention
causal_mask = mask_bool.float().masked_fill(mask_bool, float('-inf'))
# Create padding mask from hidden states
padding_mask = (input_tokens == self.pad_token_id) #.all(dim=-1) # (B, S) - True where all features are 50256
# Expand padding mask to match attention scores shape
padding_mask = padding_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, S)
padding_mask = padding_mask.expand(-1, self.num_heads, seq_len, -1) # (B, H, S, S)
padding_mask = padding_mask.float().masked_fill(padding_mask, float('-inf'))
# Combine causal and padding masks
full_mask = causal_mask.unsqueeze(0).unsqueeze(0) + padding_mask
attn_scores = attn_scores + full_mask
attn_probs = F.softmax(attn_scores, dim=-1)
o_heads = torch.matmul(attn_probs, v_C) # (batch_size, H, seq_len, head_dim)
kv_cache = None # training returns no cache
######### INFERENCE MODE #########
elif mode == "inference":
# Concatenate past and current per-head keys/values/pos if provided
if past_key_values is None:
k_C_total = k_C # (batch_size, H, seq_len, head_dim)
v_C_total = v_C
k_R_total = k_R
q_R_total = q_R
c_KV_total = c_KV.unsqueeze(1).expand(-1, self.num_heads, -1, -1) # (batch_size, H, seq_len, latent_dim)
total_len = seq_len
else:
# past_key_values: (past_k_cache, past_v_cache, past_kR_cache, past_qR_cache)
past_k_cache, past_v_cache, past_k_R_cache, past_q_R_cache, past_c_KV_total = past_key_values
# Append along sequence dim (dim=2 for per-head)
k_C_total = torch.cat([past_k_cache, k_C], dim=2) # (batch_size, H, past_len+seq_len, head_dim)
v_C_total = torch.cat([past_v_cache, v_C], dim=2)
k_R_total = torch.cat([past_k_R_cache, k_R], dim=2) # (batch_size, H, past_len+seq_len, pos_dim)
q_R_total = torch.cat([past_q_R_cache, q_R], dim=2)
c_KV_total = torch.cat([past_c_KV_total, c_KV.unsqueeze(1).expand(-1, self.num_heads, -1, -1)], dim=2) # (batch_size, H, total_len, latent_dim)
total_len = k_C_total.size(2)
# q_latent computation
W_UK_heads = self.W_UK.weight.view(self.num_heads, self.head_dim, self.latent_dim)
q_latent = torch.matmul(q_C, W_UK_heads) # (batch, heads, seq_len, latent_dim)
k_hat = torch.cat([c_KV_total, k_R_total], dim=-1)
q_hat = torch.cat([q_latent, q_R], dim=-1) # (batch_size, H, seq_len, head_dim+pos_dim)
# Attention
scale = 1.0 / math.sqrt(k_hat.shape[-1])
attn_scores = torch.matmul(q_hat, k_hat.transpose(-2, -1)) * scale # (batch_size, H, seq_len, total_len)
# ---- Apply mask (causal + padding, cache-aware) ----
if attention_mask:
mask_bool = self.causal_mask[:total_len, :total_len]
causal_mask_base = mask_bool.float().masked_fill(mask_bool, float('-inf'))
offset = total_len - seq_len
causal_mask = causal_mask_base[offset:offset+seq_len, :total_len].unsqueeze(0).unsqueeze(0)
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, total_len)
# Create padding mask from hidden states
padding_mask = (input_tokens == self.pad_token_id) #.all(dim=-1) # (B, S) - True where all features are 50256
# For inference with cache, we need to handle the full sequence length
padding_mask_full = torch.zeros(batch_size, total_len, device=hidden_states.device, dtype=torch.bool)
padding_mask_full[:, -seq_len:] = padding_mask # Only the current tokens have padding
padding_mask_expanded = padding_mask_full.unsqueeze(1).unsqueeze(2).expand(-1, self.num_heads, seq_len, -1)
padding_mask_expanded = padding_mask_expanded.float().masked_fill(padding_mask_expanded, float('-inf'))
full_mask = causal_mask + padding_mask_expanded
attn_scores = attn_scores + full_mask
attn_probs = F.softmax(attn_scores, dim=-1)
o_hat = torch.matmul(attn_probs, c_KV_total) # (batch_size, H, seq_len, latent_dim)
# 2. Apply per-head W_UV projection (Absorb step)
W_UV_heads = self.W_UV.weight.view(self.num_heads, self.head_dim, self.latent_dim) # [H, head_dim, latent_dim]
o_heads = torch.matmul(o_hat, W_UV_heads.transpose(1, 2)) # [batch_size, H, seq_len, head]
# Prepare kv_cache tuple to return (present caches covering full sequence)
if use_cache:
kv_cache = (
k_C_total.detach(),
v_C_total.detach(),
k_R_total.detach(),
q_R_total.detach(),
c_KV_total.detach()
)
else:
kv_cache = None
else:
raise ValueError("mode must be 'train' or 'inference'")
# ---- Final projection ----
o = o_heads.transpose(1, 2).reshape(batch_size, seq_len, self.num_heads * self.head_dim) # (batch_size, seq_len, hidden_size)
attn_output = self.W_O(o) # (batch_size, seq_len, hidden_size)
return attn_output, kv_cache
class DeepSeek_MoE(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size # Embedding dimension
self.latent_dim = config.latent_dim
self.num_shared_experts = config.num_shared_experts
self.num_routed_experts = config.num_routed_experts
self.top_k = config.top_k # Kr
self.bias_update_speed = config.bias_update_speed
self.balance_alpha = config.balance_alpha
assert self.top_k <= self.num_routed_experts, f"top_k: ({self.top_k}) exceeds available experts: ({self.num_routed_experts})"
# Expert centroids for affinity scores
self.expert_centroids = nn.Parameter(
torch.empty(self.num_routed_experts, self.hidden_size)
)
# Bias terms for load balancing
self.register_buffer("expert_biases", torch.zeros(self.num_routed_experts))
# Shared experts
self.shared_experts = nn.ModuleList([
nn.Sequential(
nn.Linear(self.hidden_size, self.latent_dim),
nn.SiLU(),
nn.Linear(self.latent_dim, self.hidden_size)
) for _ in range(self.num_shared_experts)
])
# Routed experts
self.routed_experts = nn.ModuleList([
nn.Sequential(
nn.Linear(self.hidden_size, self.latent_dim),
nn.SiLU(),
nn.Linear(self.latent_dim, self.hidden_size)
) for _ in range(self.num_routed_experts)
])
# Initialize centroids
nn.init.xavier_uniform_(self.expert_centroids)
def forward(self, hidden_states, training=True):
batch_size, seq_len, hidden_dim = hidden_states.shape
assert hidden_dim == self.hidden_size, f"Input hidden size mismatch: got {hidden_dim}, expected {self.hidden_size}."
total_tokens = batch_size * seq_len
# ========== Compute affinity scores ==========
# Equation: s_i,t = Sigmoid(u_t^T e_i)
flat_input = hidden_states.view(-1, hidden_dim)
affinity_scores = torch.sigmoid(
F.linear(flat_input, self.expert_centroids) # u_t^T e_i
).view(batch_size, seq_len, self.num_routed_experts)
# ========== Top-K routing with bias ==========
# Equation: Use biased scores s_i,t + b_i for routing selection
biased_scores = affinity_scores + self.expert_biases
# Get top-K experts using biased scores
topk_values, topk_indices = torch.topk(biased_scores, self.top_k, dim=-1)
# Create mask for selected experts
expert_mask = torch.zeros_like(affinity_scores)
expert_mask.scatter_(-1, topk_indices, 1.0)
# ========== Compute gating values ==========
# Equation: g'_i,t = s_i,t if selected, 0 otherwise
selected_scores = affinity_scores * expert_mask
# Equation: g_i,t = g'_i,t / sum_j(g'_j,t) - normalization
gating_values = selected_scores / (selected_scores.sum(dim=-1, keepdim=True) + 1e-8)
# ========== Shared experts computation ==========
# Equation: ∑_{i=1}^{N_s} FFN_i^{(s)}(u_t)
shared_output = sum(expert(hidden_states) for expert in self.shared_experts)
# ========== Routed experts computation ==========
# Equation: ∑_{i=1}^{N_r} g_i,t FFN_i^{(r)}(u_t)
flat_gating = gating_values.view(-1, self.num_routed_experts)
flat_indices = topk_indices.view(-1, self.top_k)
# Precompute all expert outputs: FFN_i^{(r)}(u_t) for all experts
all_expert_outputs = torch.stack([
expert(flat_input) for expert in self.routed_experts
], dim=1) # [total_tokens, num_routed_experts, hidden_size]
# Gather outputs for selected experts and apply gating
expanded_indices = flat_indices.unsqueeze(-1).expand(-1, -1, hidden_dim)
selected_outputs = all_expert_outputs.gather(1, expanded_indices) # Get FFN outputs for top-k experts
gating_weights = flat_gating.gather(1, flat_indices).unsqueeze(-1) # Get g_i,t for selected experts
routed_output_flat = (selected_outputs * gating_weights).sum(dim=1) # ∑ g_i,t * FFN_i^{(r)}(u_t)
routed_output = routed_output_flat.view(batch_size, seq_len, hidden_dim)
# ========== Load balancing updates ==========
aux_loss = torch.tensor(0.0, device=hidden_states.device)
if training:
# ========== Bias Update ==========
# Count how many times each expert is selected (or the number of tokens routed to that expert)
expert_counts = torch.bincount(
topk_indices.view(-1),
minlength=self.num_routed_experts
).float()
expert_loads = expert_counts / total_tokens # Load proportion for each expert
target_load = torch.ones_like(expert_loads) / self.num_routed_experts # Ideal balanced load
load_diff = expert_loads - target_load # Positive = overloaded, Negative = underloaded
# Update: decrease bias for overloaded experts, increase for underloaded
self.expert_biases -= self.bias_update_speed * load_diff
# ========== Sequence-wise Auxiliary Loss ==========
# Equation: f_i = (N_r / (K_r * T)) * ∑_t 𝟙(s_i,t ∈ TopK)
f_i = expert_mask.view(-1, self.num_routed_experts).sum(dim=0) # Count selections per expert
f_i = f_i * (self.num_routed_experts / (self.top_k * seq_len)) # Normalize by sequence length
f_i = f_i / batch_size # Average over batch
# Equation: P_i = (1/T) ∑_t s'_i,t where s'_i,t = s_i,t / ∑_j s_j,t
s_prime = affinity_scores / (affinity_scores.sum(dim=-1, keepdim=True) + 1e-8) # Normalized affinities
P_i = s_prime.view(-1, self.num_routed_experts).mean(dim=0) # Average over all tokens
# Equation: ℒ_Bal = α * ∑_{i=1}^{N_r} f_i * P_i
aux_loss = self.balance_alpha * (f_i * P_i).sum()
# ========== Final output ==========
# Equation: O_t = X_t + shared_experts + routed_experts
output = hidden_states + shared_output + routed_output
return output, aux_loss
class DeepSeek_MTP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size # Embedding dimension
self.vocab_size = config.vocab_size
self.num_depths = config.num_depths # D (Please note that this D is different from Embedding dimension D; feel free to replace it with another notation)
self.lambda_mtp = config.lambda_mtp # λ
self.max_seq_len = getattr(config, 'max_seq_len', 512) # Add max sequence length
self.pad_token_id = getattr(config, 'pad_token_id', 50256) # Default pad token ID to 50256 of tiktoken’s GPT-2 vocab, same as <|endoftext|> in the tiktoken’s GPT-2 vocab
assert self.hidden_size % config.num_heads == 0,f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({config.num_heads})"
# ===== Shared layers =====
self.embedding = nn.Embedding(self.vocab_size, self.hidden_size) # shared Emb(·)
self.output_head = nn.Linear(self.hidden_size, self.vocab_size) # shared OutHead(·)
# ---- Create D Transformer blocks TRM_k ----
self.trm_blocks = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=self.hidden_size,
nhead=config.num_heads,
dim_feedforward=config.latent_dim,
activation="gelu",
batch_first=True,
)
for _ in range(self.num_depths)
])
# ---- Projection matrices M_k ∈ ℝ^{d×2d} ----
self.proj_matrices = nn.ParameterList([
nn.Parameter(torch.randn(self.hidden_size, 2 * self.hidden_size))
for _ in range(self.num_depths)
])
# ---- RMSNorm layers ----
self.rmsnorm_h = nn.RMSNorm(self.hidden_size)
self.rmsnorm_e = nn.RMSNorm(self.hidden_size)
# ---- Precomputed causal mask ----
# Create upper triangular mask with ones above diagonal and convert to boolean
#self.register_buffer("causal_mask", torch.triu(torch.ones(self.max_seq_len, self.max_seq_len), diagonal=1).bool())
self.register_buffer("causal_mask", torch.triu(torch.ones(self.max_seq_len, self.max_seq_len, device=device)).bool())
def forward(self, hidden_states, input_tokens=None, mode="train", attention_mask=True):
batch_size, seq_len, hidden_size = hidden_states.shape
assert hidden_size == self.hidden_size, f"hidden_states last dim {hidden_size} != expected hidden_size {self.hidden_size}"
if mode == "train":
assert input_tokens is not None, "input_tokens required in training mode"
assert input_tokens.shape == (batch_size, seq_len), f"input_tokens {(input_tokens.shape)} must match batch & seq length of hidden_states= {[batch_size, seq_len]}"
mtp_losses = []
# Use separate variable to prevent in-place overwriting
h_current = hidden_states
# ===== MTP depths loop =====
for k in range(1, self.num_depths + 1):
current_seq_len = h_current.shape[1] # Use current sequence length
if current_seq_len - k <= 0:
break # nothing left to predict
# ---- h'_i^k = M_k [RMSNorm(h_i^{k−1}); RMSNorm(Emb(t_{i+k}))] ----
h_prev = h_current[:, :current_seq_len - k, :] # h_i^{k−1}
emb_shifted = self.embedding(input_tokens[:, k:]) # Emb(t_{i+k})
h_prev_norm = self.rmsnorm_h(h_prev)
emb_norm = self.rmsnorm_e(emb_shifted)
concat = torch.cat([h_prev_norm, emb_norm], dim=-1) # concat [h; e]
h_prime_k = torch.matmul(concat, self.proj_matrices[k - 1].T)
# ---- causal + padding attention mask ----
causal_mask = None
padding_mask = None
if attention_mask:
# Get the actual sequence length for this depth
L = current_seq_len - k
# Original mask truncated to the number of tokens and converted to boolean
causal_mask = self.causal_mask[:L, :L]
# Create padding mask from input tokens (also boolean)
padding_mask = (input_tokens[:, k:current_seq_len] == self.pad_token_id) # (B, L)
# ---- Transformer block TRM_k(h'_i^k) ----
h_k = self.trm_blocks[k - 1](h_prime_k, src_mask=causal_mask, src_key_padding_mask=padding_mask)
# ---- logits = OutHead(h_i^k) ----
mtp_logits = self.output_head(h_k)
# ---- Cross-entropy loss ----
target_k = input_tokens[:, k:current_seq_len] # shift targets by +k, match current length
loss_k = F.cross_entropy(
mtp_logits.reshape(-1, self.vocab_size),
target_k.reshape(-1),
reduction="mean",
ignore_index=self.pad_token_id
)
mtp_losses.append(loss_k)
# Update h_current for next depth (maintain causal chain)
h_current = torch.cat([h_k, h_current[:, current_seq_len - k:, :]], dim=1)
assert mtp_losses, "No valid MTP losses computed"
mtp_loss = self.lambda_mtp * torch.stack(mtp_losses).mean()
return mtp_loss, mtp_logits
elif mode == "inference":
# completely skip MTP path — just run the shared output head
logits = self.output_head(hidden_states) # [B, S, V]
predicted_ids = torch.argmax(logits, dim=-1) # [B, S]
return predicted_ids, logits
else:
raise ValueError(f"Invalid mode '{mode}', must be 'train' or 'inference'")
class DeepSeek_V3_Block(nn.Module):
"""
Single-Block Transformer.
"""
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.layer_norm_eps = config.layer_norm_eps # Small epsilon value for numerical stability in layer normalization
# --- Layers ---
# Input normalization
self.rms_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
# DeepSeek_MLA
self.attention = DeepSeek_MLA(config)
# Post-attention normalization
self.rms_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
# DeepSeek_MoE
self.moe = DeepSeek_MoE(config)
# Final normalization
self.rms_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
# Linear Output
self.linear_output = nn.Linear(self.hidden_size, self.hidden_size)
def forward(self, hidden_states,input_tokens=None, mode="train", use_cache=False, past_key_values=None, attention_mask=False):
assert hidden_states.dim() == 3, (f"hidden_states must have shape [batch, seq_len, hidden_size], got {hidden_states.shape}.")
assert hidden_states.size(-1) == self.hidden_size, (f"Last dim mismatch: expected {self.hidden_size}, got {hidden_states.size(-1)}.")
# Input normalization
normed_states = self.rms_norm1(hidden_states)
# Multi-Head Latent Attention
attn_output, kv_cache = self.attention(
hidden_states = normed_states,
input_tokens = input_tokens,
mode= mode,
use_cache= use_cache,
past_key_values=past_key_values,
attention_mask=attention_mask
)
assert attn_output.shape == hidden_states.shape, (f"attn_output shape {attn_output.shape} != hidden_states {hidden_states.shape}.")
# Residual connection
hidden_states = hidden_states + attn_output
# Post-attention normalization
normed_states = self.rms_norm2(hidden_states)
# DeepSeekMoE
moe_output, aux_loss = self.moe(normed_states)
# Residual connection
hidden_states = hidden_states + moe_output
# Final normalization
hidden_states = self.rms_norm3(hidden_states)
# Final Output
hidden_states = self.linear_output(hidden_states)
return hidden_states, kv_cache, aux_loss
class DeepSeek_V3_Encoder(nn.Module):
"""
Multi-Block Transformer.
"""
def __init__(self, config):
super().__init__()
self.num_blocks = config.num_blocks # Number of transformer blocks to stack in the model
self.hidden_size = config.hidden_size
self.layer_norm_eps = config.layer_norm_eps
self.vocab_size = config.vocab_size
# Stack of transformer blocks
self.blocks = nn.ModuleList([
DeepSeek_V3_Block(config)
for _ in range(self.num_blocks)
])
# Final normalization
self.final_norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
# ---- Final output ----
self.output = nn.Linear(self.hidden_size, self.vocab_size, bias=False)
# MTP head (Multi-Token Prediction)
self.mtp = DeepSeek_MTP(config)
def forward(self,hidden_states, input_tokens=None, mode="train", past_key_values=None, use_cache=False, attention_mask=False):
assert hidden_states.dim() == 3, (f"hidden_states must have shape [batch, seq_len, hidden_size], got {hidden_states.shape}.")
if past_key_values is None:
past_key_values = [None] * self.num_blocks
new_past_key_values = [] if use_cache else None
# Forward through stacked transformer blocks
for i, block in enumerate(self.blocks):
hidden_states, kv_cache, aux_loss = block(
hidden_states=hidden_states,
input_tokens=input_tokens,
mode=mode,
use_cache=use_cache,
past_key_values=past_key_values[i],
attention_mask=attention_mask,
)
if use_cache:
new_past_key_values.append(kv_cache)
# Final normalization
hidden_states = self.final_norm(hidden_states)
# Output
logits = self.output(hidden_states) # [B, S, V]
# MTP output handling
if mode == "train" and input_tokens is not None:
mtp_loss, mtp_logits = self.mtp(
hidden_states=hidden_states,
input_tokens=input_tokens,
mode = "train",
attention_mask=attention_mask
)
return logits, mtp_loss, mtp_logits, aux_loss
else: # mode == "inference"
predicted_ids, mtp_logits = self.mtp(
hidden_states,
input_tokens=input_tokens,
mode = "inference",
attention_mask=attention_mask
)
return predicted_ids, logits, new_past_key_values
class DeepSeek_V3_Model(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# Embedding Layer
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
# Core model
self.model = DeepSeek_V3_Encoder(config)
# Loss functions
self.ce_loss = nn.CrossEntropyLoss(ignore_index=config.pad_token_id)
def forward(self, input_tokens = None, mode="train", use_cache=False, past_key_values=None,attention_mask=False):
# Generate embeddings from input tokens
hidden_states = self.embedding(input_tokens)
batch_size, seq_len = input_tokens.shape
# Core model forward
outputs = self.model(
hidden_states,
mode=mode,
input_tokens=input_tokens,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache
)
if mode == "train":
logits, mtp_loss, mtp_logits, aux_loss = outputs
# Shift for next-token prediction
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = input_tokens[..., 1:].contiguous()
main_loss = self.ce_loss(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1)
)
# Combine losses
total_loss = main_loss + mtp_loss
return total_loss, main_loss, mtp_loss, aux_loss, logits
else: # mode == "inference"
predicted_ids, logits, new_cache = outputs
return predicted_ids, logits, new_cache
# Code adapted from Sebastian Raschka
def generate_text(model, tokenizer, prompt, max_length=50, temperature=1.0, top_k=20, eos_id=None, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
# Encode prompt
input_ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device)
generated = input_ids.clone()
past_key_values = None # cache for inference
for _ in range(max_length):
if past_key_values is None:
idx_cond = generated # full prompt (first step)
else:
idx_cond = generated[:, -1:] # only last token
with torch.no_grad():
# Use mode="inference" and cache past keys/values
predicted_ids, logits, past_key_values = model(
input_tokens=idx_cond.to(device),
mode="inference",
use_cache=True,
past_key_values=past_key_values,
attention_mask=True
)
logits = logits[:, -1, :] # last token logits
# Top-k filtering
if top_k is not None:
top_logits, _ = torch.topk(logits, top_k)
min_val = top_logits[:, -1].unsqueeze(-1)
logits = torch.where(logits < min_val, torch.tensor(float("-inf"), device=device), logits)
# Temperature + sampling or greedy
if temperature > 0:
logits = logits / temperature
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(logits, dim=-1, keepdim=True)
# Stop if EOS token generated
if eos_id is not None and next_token.item() == eos_id:
break
# Append generated token
generated = torch.cat((generated, next_token.to(device)), dim=1)
# Decode full sequence back to text
return tokenizer.decode(generated[0].tolist())
def clean_response(generated_text):
if not generated_text:
return "Sorry, I couldn't generate a response."
text = str(generated_text)
# Print the prompt part
if "Response:" in text:
prompt_part = text.split("Response:", 1)[0] + "Response:"
else:
prompt_part = ""
print("=======================================")
print(f"{prompt_part.strip()}")
# Extract response
if "Response:" in text:
text = text.split("Response:", 1)[1]
# Truncate at <|endoftext|>
if "<|endoftext|>" in text:
text = text.split("<|endoftext|>", 1)[0]
# Remove non-printable characters
text = ''.join(c for c in text if c.isprintable() or c.isspace())
# If text is empty after cleaning, return a default message
if not text.strip():
return "I'm not sure how to answer that. Could you ask in a different way?"
return text.strip()
if __name__ == "__main__":
raise RuntimeError("This module is not intended to be executed directly.")