import math import torch import torch.nn as nn import torch.nn.functional as F import tiktoken __all__ = [ 'Rope', 'DeepSeek_MLA', 'DeepSeek_MoE', 'DeepSeek_MTP', 'DeepSeek_V3_Block', 'DeepSeek_V3_Encoder', 'DeepSeek_V3_Model', 'generate_text', 'clean_response' ] device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define model configurations class Config: hidden_size = 128 # Embedding dimension (D) latent_dim = hidden_size // 2 # Latent dimension, half of D (a random choice) num_heads = 16 # Number of attention heads (should divide hidden_size) pos_dim = 24 # Positional encoding dimension pad_token_id = 50256 # Padding token ID (matches <|endoftext|> in GPT-2 vocab) num_shared_experts = 4 num_routed_experts = 8 top_k = 8 # Kr, number of experts selected per token bias_update_speed = 0.01 balance_alpha = 0.01 lambda_mtp = 0.5 # λ, weighting num_depths = 3 # D, number of prediction depths vocab_size= tiktoken.get_encoding("gpt2").n_vocab # Vocab size of tiktoken’s GPT-2 vocab (50257) layer_norm_eps = 1e-5 # Small epsilon value for numerical stability in layer normalization num_blocks = 12 # Number of transformer blocks to stack in the model batch_size = 64 # Number of sequences per batch context_length = 60 # Number of tokens per sequence class Rope(nn.Module): """ Rotary Position Embedding (RoPE) module. Applies rotary position encoding to an input tensor of shape (B, H, S, D), """ def __init__(self, dim, max_seq_len = 4096): super().__init__() # Safety check: RoPE requires even dimensionality (for splitting into pairs) assert dim % 2 == 0, f"RoPE dim must be even, got {dim}" self.dim = dim self.max_seq_len = max_seq_len # Step 2: Compute rotation frequencies for sinusoidal positions # inv_freq[i] = 1 / (10000^(2i/dim)), where i = 0, 1, ..., dim/2 - 1 inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) # Store as non-trainable buffer self.register_buffer("inv_freq", inv_freq, persistent=False) # Precompute and cache cos/sin values up to max_seq_len self._build_cache(max_seq_len) def _build_cache(self, seq_len): """ Precompute cosine and sine embeddings for all positions up to seq_len. This avoids recomputing trig functions during every forward pass. """ # Positions: [0, 1, 2, ..., seq_len-1] t = torch.arange(seq_len, dtype=self.inv_freq.dtype, device=self.inv_freq.device) # Step 3: Compute rotation angles (per position and dimension pair) # Each row is t * inv_freq[i], giving angular frequency per dimension freqs = torch.outer(t, self.inv_freq) # Duplicate for concatenation of sin and cos values, shape: (seq_len, dim) emb = torch.cat((freqs, freqs), dim=-1) # Step 4: Construct rotation matrix elements (cos and sin) # Register as buffers with shape (1, 1, seq_len, dim) self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False) # Track how many positions we have cached self.max_seq_len = seq_len def forward(self, x, seq_len, position_offset = 0): """ Apply RoPE to input tensor. Args: x: Input tensor of shape (B, H, S, D) seq_len: Actual sequence length to encode position_offset: Offset for decoding continuation (default = 0) Returns: Tensor with RoPE applied, same shape as x. """ device = x.device # Ensure input matches expected dimensionality assert x.shape[-1] == self.dim, ( f"RoPE input dim mismatch: expected {self.dim}, got {x.shape[-1]}" ) seq_len_x = x.size(-2) # sequence length from input tensor if (position_offset + seq_len) > self.max_seq_len: # Rebuild cache with doubled size for efficiency self._build_cache(max(position_offset + seq_len, self.max_seq_len * 2)) # Select only the needed positions cos = self.cos_cached[:, :, position_offset:position_offset + seq_len, :].to(device) sin = self.sin_cached[:, :, position_offset:position_offset + seq_len, :].to(device) # Ensure cache slice matches actual input sequence length assert cos.shape[2] == seq_len_x, ( f"RoPE seq_len mismatch: expected {seq_len_x}, got {cos.shape[2]}" ) # Step 1: Split Q/K into 2D subspaces (pairs of dimensions) # Split last dimension into pairs: (x1, x2) x1, x2 = x.chunk(2, dim=-1) # Step 5 and Step 6: Apply rotation to each 2D subspace # Rotate pairs: (x1, x2) → (-x2, x1) rotated = torch.cat((-x2, x1), dim=-1) # Apply rotary transformation: elementwise (x*cos + rotated*sin) result = x*cos + rotated*sin return result class DeepSeek_MLA(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size # Embedding dimension self.num_heads = config.num_heads self.head_dim = self.hidden_size // self.num_heads self.latent_dim = config.latent_dim self.pos_dim = config.pos_dim self.max_seq_len = getattr(config, 'max_seq_len', 512) # Add max sequence length self.pad_token_id = getattr(config, 'pad_token_id', 50256) # Default pad token ID to 50256 of tiktoken’s GPT-2 vocab, same as <|endoftext|> in the tiktoken’s GPT-2 vocab assert self.hidden_size % self.num_heads == 0, f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({self.num_heads})" # Ensure pos_dim is even for RoPE assert self.pos_dim % 2 == 0, f"pos_dim ({self.pos_dim}) must be even for RoPE" # Latent compression projections self.W_DKV = nn.Linear(self.hidden_size, self.latent_dim, bias=False) # KV compression self.W_DQ = nn.Linear(self.hidden_size, self.latent_dim, bias=False) # Q compression # Content projection from latent to multi-head space self.W_UK = nn.Linear(self.latent_dim, self.hidden_size, bias=False) # K content self.W_UV = nn.Linear(self.latent_dim, self.hidden_size, bias=False) # V content self.W_UQ = nn.Linear(self.latent_dim, self.hidden_size, bias=False) # Q content # Positional projections (RoPE pathway) self.W_KR = nn.Linear(self.hidden_size, self.pos_dim, bias=False) # K positional self.W_QR = nn.Linear(self.latent_dim, self.num_heads * self.pos_dim, bias=False) # Q positional # Output projection self.W_O = nn.Linear(self.hidden_size, self.hidden_size, bias=False) # RoPE initialization self.rope_k = Rope(self.pos_dim) self.rope_q = Rope(self.pos_dim) # ---- Precomputed causal mask ---- # Create upper triangular mask with ones above diagonal and convert to boolean self.register_buffer("causal_mask", torch.triu(torch.ones(self.max_seq_len, self.max_seq_len), diagonal=1).bool()) def forward(self, hidden_states, input_tokens=None, mode="train", use_cache=False, past_key_values=None, attention_mask=False): batch_size, seq_len, hidden_size = hidden_states.shape assert hidden_size == self.hidden_size, f"hidden_size mismatch: got {hidden_size}, expected {self.hidden_size}" # ---- Latent compressions ---- c_KV = self.W_DKV(hidden_states) # (batch_size, seq_len, latent_dim) c_Q = self.W_DQ(hidden_states) # (batch_size, seq_len, latent_dim) # ---- Content projections (per-head) ---- k_C = self.W_UK(c_KV).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, H, seq_len, head_dim) v_C = self.W_UV(c_KV).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, H, seq_len, head_dim) q_C = self.W_UQ(c_Q).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (batch_size, H, seq_len, head_dim) # ---- Positional projections ---- k_R = self.W_KR(hidden_states) # (batch_size, seq_len, pos_dim) q_R = self.W_QR(c_Q).view(batch_size, seq_len, self.num_heads, self.pos_dim).transpose(1, 2) # (batch_size, H, seq_len, pos_dim) # ---- Determine past length for RoPE position_offset ---- past_len = 0 if past_key_values is None else past_key_values[0].size(2) # ---- Apply RoPE (position offset = past_len) ---- k_R = self.rope_k(k_R.unsqueeze(1).expand(-1, self.num_heads, -1, -1), seq_len=seq_len, position_offset=past_len) # (batch_size, H, seq_len, pos_dim) q_R = self.rope_q(q_R, seq_len=seq_len, position_offset=past_len) # (batch_size, H, seq_len, pos_dim) ######### TRAINING MODE ######### if mode == "train": k = torch.cat([k_C, k_R], dim=-1) # (batch_size, H, seq_len, head_dim + pos_dim) q = torch.cat([q_C, q_R], dim=-1) # (batch_size, H, seq_len, head_dim + pos_dim) scale = 1.0 / math.sqrt(q.shape[-1]) # same as scale = 1.0 / math.sqrt(head_dim + pos_dim) attn_scores = torch.matmul(q, k.transpose(-2, -1)) * scale # (batch_size, H, seq_len, seq_len) # ---- Apply mask (causal + padding) ---- if attention_mask: # Mask truncated to the number of tokens and converted to boolean mask_bool = self.causal_mask[:seq_len, :seq_len] # Convert boolean mask to -inf format for attention causal_mask = mask_bool.float().masked_fill(mask_bool, float('-inf')) # Create padding mask from hidden states padding_mask = (input_tokens == self.pad_token_id) #.all(dim=-1) # (B, S) - True where all features are 50256 # Expand padding mask to match attention scores shape padding_mask = padding_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1, S) padding_mask = padding_mask.expand(-1, self.num_heads, seq_len, -1) # (B, H, S, S) padding_mask = padding_mask.float().masked_fill(padding_mask, float('-inf')) # Combine causal and padding masks full_mask = causal_mask.unsqueeze(0).unsqueeze(0) + padding_mask attn_scores = attn_scores + full_mask attn_probs = F.softmax(attn_scores, dim=-1) o_heads = torch.matmul(attn_probs, v_C) # (batch_size, H, seq_len, head_dim) kv_cache = None # training returns no cache ######### INFERENCE MODE ######### elif mode == "inference": # Concatenate past and current per-head keys/values/pos if provided if past_key_values is None: k_C_total = k_C # (batch_size, H, seq_len, head_dim) v_C_total = v_C k_R_total = k_R q_R_total = q_R c_KV_total = c_KV.unsqueeze(1).expand(-1, self.num_heads, -1, -1) # (batch_size, H, seq_len, latent_dim) total_len = seq_len else: # past_key_values: (past_k_cache, past_v_cache, past_kR_cache, past_qR_cache) past_k_cache, past_v_cache, past_k_R_cache, past_q_R_cache, past_c_KV_total = past_key_values # Append along sequence dim (dim=2 for per-head) k_C_total = torch.cat([past_k_cache, k_C], dim=2) # (batch_size, H, past_len+seq_len, head_dim) v_C_total = torch.cat([past_v_cache, v_C], dim=2) k_R_total = torch.cat([past_k_R_cache, k_R], dim=2) # (batch_size, H, past_len+seq_len, pos_dim) q_R_total = torch.cat([past_q_R_cache, q_R], dim=2) c_KV_total = torch.cat([past_c_KV_total, c_KV.unsqueeze(1).expand(-1, self.num_heads, -1, -1)], dim=2) # (batch_size, H, total_len, latent_dim) total_len = k_C_total.size(2) # q_latent computation W_UK_heads = self.W_UK.weight.view(self.num_heads, self.head_dim, self.latent_dim) q_latent = torch.matmul(q_C, W_UK_heads) # (batch, heads, seq_len, latent_dim) k_hat = torch.cat([c_KV_total, k_R_total], dim=-1) q_hat = torch.cat([q_latent, q_R], dim=-1) # (batch_size, H, seq_len, head_dim+pos_dim) # Attention scale = 1.0 / math.sqrt(k_hat.shape[-1]) attn_scores = torch.matmul(q_hat, k_hat.transpose(-2, -1)) * scale # (batch_size, H, seq_len, total_len) # ---- Apply mask (causal + padding, cache-aware) ---- if attention_mask: mask_bool = self.causal_mask[:total_len, :total_len] causal_mask_base = mask_bool.float().masked_fill(mask_bool, float('-inf')) offset = total_len - seq_len causal_mask = causal_mask_base[offset:offset+seq_len, :total_len].unsqueeze(0).unsqueeze(0) causal_mask = causal_mask.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, total_len) # Create padding mask from hidden states padding_mask = (input_tokens == self.pad_token_id) #.all(dim=-1) # (B, S) - True where all features are 50256 # For inference with cache, we need to handle the full sequence length padding_mask_full = torch.zeros(batch_size, total_len, device=hidden_states.device, dtype=torch.bool) padding_mask_full[:, -seq_len:] = padding_mask # Only the current tokens have padding padding_mask_expanded = padding_mask_full.unsqueeze(1).unsqueeze(2).expand(-1, self.num_heads, seq_len, -1) padding_mask_expanded = padding_mask_expanded.float().masked_fill(padding_mask_expanded, float('-inf')) full_mask = causal_mask + padding_mask_expanded attn_scores = attn_scores + full_mask attn_probs = F.softmax(attn_scores, dim=-1) o_hat = torch.matmul(attn_probs, c_KV_total) # (batch_size, H, seq_len, latent_dim) # 2. Apply per-head W_UV projection (Absorb step) W_UV_heads = self.W_UV.weight.view(self.num_heads, self.head_dim, self.latent_dim) # [H, head_dim, latent_dim] o_heads = torch.matmul(o_hat, W_UV_heads.transpose(1, 2)) # [batch_size, H, seq_len, head] # Prepare kv_cache tuple to return (present caches covering full sequence) if use_cache: kv_cache = ( k_C_total.detach(), v_C_total.detach(), k_R_total.detach(), q_R_total.detach(), c_KV_total.detach() ) else: kv_cache = None else: raise ValueError("mode must be 'train' or 'inference'") # ---- Final projection ---- o = o_heads.transpose(1, 2).reshape(batch_size, seq_len, self.num_heads * self.head_dim) # (batch_size, seq_len, hidden_size) attn_output = self.W_O(o) # (batch_size, seq_len, hidden_size) return attn_output, kv_cache class DeepSeek_MoE(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size # Embedding dimension self.latent_dim = config.latent_dim self.num_shared_experts = config.num_shared_experts self.num_routed_experts = config.num_routed_experts self.top_k = config.top_k # Kr self.bias_update_speed = config.bias_update_speed self.balance_alpha = config.balance_alpha assert self.top_k <= self.num_routed_experts, f"top_k: ({self.top_k}) exceeds available experts: ({self.num_routed_experts})" # Expert centroids for affinity scores self.expert_centroids = nn.Parameter( torch.empty(self.num_routed_experts, self.hidden_size) ) # Bias terms for load balancing self.register_buffer("expert_biases", torch.zeros(self.num_routed_experts)) # Shared experts self.shared_experts = nn.ModuleList([ nn.Sequential( nn.Linear(self.hidden_size, self.latent_dim), nn.SiLU(), nn.Linear(self.latent_dim, self.hidden_size) ) for _ in range(self.num_shared_experts) ]) # Routed experts self.routed_experts = nn.ModuleList([ nn.Sequential( nn.Linear(self.hidden_size, self.latent_dim), nn.SiLU(), nn.Linear(self.latent_dim, self.hidden_size) ) for _ in range(self.num_routed_experts) ]) # Initialize centroids nn.init.xavier_uniform_(self.expert_centroids) def forward(self, hidden_states, training=True): batch_size, seq_len, hidden_dim = hidden_states.shape assert hidden_dim == self.hidden_size, f"Input hidden size mismatch: got {hidden_dim}, expected {self.hidden_size}." total_tokens = batch_size * seq_len # ========== Compute affinity scores ========== # Equation: s_i,t = Sigmoid(u_t^T e_i) flat_input = hidden_states.view(-1, hidden_dim) affinity_scores = torch.sigmoid( F.linear(flat_input, self.expert_centroids) # u_t^T e_i ).view(batch_size, seq_len, self.num_routed_experts) # ========== Top-K routing with bias ========== # Equation: Use biased scores s_i,t + b_i for routing selection biased_scores = affinity_scores + self.expert_biases # Get top-K experts using biased scores topk_values, topk_indices = torch.topk(biased_scores, self.top_k, dim=-1) # Create mask for selected experts expert_mask = torch.zeros_like(affinity_scores) expert_mask.scatter_(-1, topk_indices, 1.0) # ========== Compute gating values ========== # Equation: g'_i,t = s_i,t if selected, 0 otherwise selected_scores = affinity_scores * expert_mask # Equation: g_i,t = g'_i,t / sum_j(g'_j,t) - normalization gating_values = selected_scores / (selected_scores.sum(dim=-1, keepdim=True) + 1e-8) # ========== Shared experts computation ========== # Equation: ∑_{i=1}^{N_s} FFN_i^{(s)}(u_t) shared_output = sum(expert(hidden_states) for expert in self.shared_experts) # ========== Routed experts computation ========== # Equation: ∑_{i=1}^{N_r} g_i,t FFN_i^{(r)}(u_t) flat_gating = gating_values.view(-1, self.num_routed_experts) flat_indices = topk_indices.view(-1, self.top_k) # Precompute all expert outputs: FFN_i^{(r)}(u_t) for all experts all_expert_outputs = torch.stack([ expert(flat_input) for expert in self.routed_experts ], dim=1) # [total_tokens, num_routed_experts, hidden_size] # Gather outputs for selected experts and apply gating expanded_indices = flat_indices.unsqueeze(-1).expand(-1, -1, hidden_dim) selected_outputs = all_expert_outputs.gather(1, expanded_indices) # Get FFN outputs for top-k experts gating_weights = flat_gating.gather(1, flat_indices).unsqueeze(-1) # Get g_i,t for selected experts routed_output_flat = (selected_outputs * gating_weights).sum(dim=1) # ∑ g_i,t * FFN_i^{(r)}(u_t) routed_output = routed_output_flat.view(batch_size, seq_len, hidden_dim) # ========== Load balancing updates ========== aux_loss = torch.tensor(0.0, device=hidden_states.device) if training: # ========== Bias Update ========== # Count how many times each expert is selected (or the number of tokens routed to that expert) expert_counts = torch.bincount( topk_indices.view(-1), minlength=self.num_routed_experts ).float() expert_loads = expert_counts / total_tokens # Load proportion for each expert target_load = torch.ones_like(expert_loads) / self.num_routed_experts # Ideal balanced load load_diff = expert_loads - target_load # Positive = overloaded, Negative = underloaded # Update: decrease bias for overloaded experts, increase for underloaded self.expert_biases -= self.bias_update_speed * load_diff # ========== Sequence-wise Auxiliary Loss ========== # Equation: f_i = (N_r / (K_r * T)) * ∑_t 𝟙(s_i,t ∈ TopK) f_i = expert_mask.view(-1, self.num_routed_experts).sum(dim=0) # Count selections per expert f_i = f_i * (self.num_routed_experts / (self.top_k * seq_len)) # Normalize by sequence length f_i = f_i / batch_size # Average over batch # Equation: P_i = (1/T) ∑_t s'_i,t where s'_i,t = s_i,t / ∑_j s_j,t s_prime = affinity_scores / (affinity_scores.sum(dim=-1, keepdim=True) + 1e-8) # Normalized affinities P_i = s_prime.view(-1, self.num_routed_experts).mean(dim=0) # Average over all tokens # Equation: ℒ_Bal = α * ∑_{i=1}^{N_r} f_i * P_i aux_loss = self.balance_alpha * (f_i * P_i).sum() # ========== Final output ========== # Equation: O_t = X_t + shared_experts + routed_experts output = hidden_states + shared_output + routed_output return output, aux_loss class DeepSeek_MTP(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size # Embedding dimension self.vocab_size = config.vocab_size self.num_depths = config.num_depths # D (Please note that this D is different from Embedding dimension D; feel free to replace it with another notation) self.lambda_mtp = config.lambda_mtp # λ self.max_seq_len = getattr(config, 'max_seq_len', 512) # Add max sequence length self.pad_token_id = getattr(config, 'pad_token_id', 50256) # Default pad token ID to 50256 of tiktoken’s GPT-2 vocab, same as <|endoftext|> in the tiktoken’s GPT-2 vocab assert self.hidden_size % config.num_heads == 0,f"hidden_size ({self.hidden_size}) must be divisible by num_heads ({config.num_heads})" # ===== Shared layers ===== self.embedding = nn.Embedding(self.vocab_size, self.hidden_size) # shared Emb(·) self.output_head = nn.Linear(self.hidden_size, self.vocab_size) # shared OutHead(·) # ---- Create D Transformer blocks TRM_k ---- self.trm_blocks = nn.ModuleList([ nn.TransformerEncoderLayer( d_model=self.hidden_size, nhead=config.num_heads, dim_feedforward=config.latent_dim, activation="gelu", batch_first=True, ) for _ in range(self.num_depths) ]) # ---- Projection matrices M_k ∈ ℝ^{d×2d} ---- self.proj_matrices = nn.ParameterList([ nn.Parameter(torch.randn(self.hidden_size, 2 * self.hidden_size)) for _ in range(self.num_depths) ]) # ---- RMSNorm layers ---- self.rmsnorm_h = nn.RMSNorm(self.hidden_size) self.rmsnorm_e = nn.RMSNorm(self.hidden_size) # ---- Precomputed causal mask ---- # Create upper triangular mask with ones above diagonal and convert to boolean #self.register_buffer("causal_mask", torch.triu(torch.ones(self.max_seq_len, self.max_seq_len), diagonal=1).bool()) self.register_buffer("causal_mask", torch.triu(torch.ones(self.max_seq_len, self.max_seq_len, device=device)).bool()) def forward(self, hidden_states, input_tokens=None, mode="train", attention_mask=True): batch_size, seq_len, hidden_size = hidden_states.shape assert hidden_size == self.hidden_size, f"hidden_states last dim {hidden_size} != expected hidden_size {self.hidden_size}" if mode == "train": assert input_tokens is not None, "input_tokens required in training mode" assert input_tokens.shape == (batch_size, seq_len), f"input_tokens {(input_tokens.shape)} must match batch & seq length of hidden_states= {[batch_size, seq_len]}" mtp_losses = [] # Use separate variable to prevent in-place overwriting h_current = hidden_states # ===== MTP depths loop ===== for k in range(1, self.num_depths + 1): current_seq_len = h_current.shape[1] # Use current sequence length if current_seq_len - k <= 0: break # nothing left to predict # ---- h'_i^k = M_k [RMSNorm(h_i^{k−1}); RMSNorm(Emb(t_{i+k}))] ---- h_prev = h_current[:, :current_seq_len - k, :] # h_i^{k−1} emb_shifted = self.embedding(input_tokens[:, k:]) # Emb(t_{i+k}) h_prev_norm = self.rmsnorm_h(h_prev) emb_norm = self.rmsnorm_e(emb_shifted) concat = torch.cat([h_prev_norm, emb_norm], dim=-1) # concat [h; e] h_prime_k = torch.matmul(concat, self.proj_matrices[k - 1].T) # ---- causal + padding attention mask ---- causal_mask = None padding_mask = None if attention_mask: # Get the actual sequence length for this depth L = current_seq_len - k # Original mask truncated to the number of tokens and converted to boolean causal_mask = self.causal_mask[:L, :L] # Create padding mask from input tokens (also boolean) padding_mask = (input_tokens[:, k:current_seq_len] == self.pad_token_id) # (B, L) # ---- Transformer block TRM_k(h'_i^k) ---- h_k = self.trm_blocks[k - 1](h_prime_k, src_mask=causal_mask, src_key_padding_mask=padding_mask) # ---- logits = OutHead(h_i^k) ---- mtp_logits = self.output_head(h_k) # ---- Cross-entropy loss ---- target_k = input_tokens[:, k:current_seq_len] # shift targets by +k, match current length loss_k = F.cross_entropy( mtp_logits.reshape(-1, self.vocab_size), target_k.reshape(-1), reduction="mean", ignore_index=self.pad_token_id ) mtp_losses.append(loss_k) # Update h_current for next depth (maintain causal chain) h_current = torch.cat([h_k, h_current[:, current_seq_len - k:, :]], dim=1) assert mtp_losses, "No valid MTP losses computed" mtp_loss = self.lambda_mtp * torch.stack(mtp_losses).mean() return mtp_loss, mtp_logits elif mode == "inference": # completely skip MTP path — just run the shared output head logits = self.output_head(hidden_states) # [B, S, V] predicted_ids = torch.argmax(logits, dim=-1) # [B, S] return predicted_ids, logits else: raise ValueError(f"Invalid mode '{mode}', must be 'train' or 'inference'") class DeepSeek_V3_Block(nn.Module): """ Single-Block Transformer. """ def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.layer_norm_eps = config.layer_norm_eps # Small epsilon value for numerical stability in layer normalization # --- Layers --- # Input normalization self.rms_norm1 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) # DeepSeek_MLA self.attention = DeepSeek_MLA(config) # Post-attention normalization self.rms_norm2 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) # DeepSeek_MoE self.moe = DeepSeek_MoE(config) # Final normalization self.rms_norm3 = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) # Linear Output self.linear_output = nn.Linear(self.hidden_size, self.hidden_size) def forward(self, hidden_states,input_tokens=None, mode="train", use_cache=False, past_key_values=None, attention_mask=False): assert hidden_states.dim() == 3, (f"hidden_states must have shape [batch, seq_len, hidden_size], got {hidden_states.shape}.") assert hidden_states.size(-1) == self.hidden_size, (f"Last dim mismatch: expected {self.hidden_size}, got {hidden_states.size(-1)}.") # Input normalization normed_states = self.rms_norm1(hidden_states) # Multi-Head Latent Attention attn_output, kv_cache = self.attention( hidden_states = normed_states, input_tokens = input_tokens, mode= mode, use_cache= use_cache, past_key_values=past_key_values, attention_mask=attention_mask ) assert attn_output.shape == hidden_states.shape, (f"attn_output shape {attn_output.shape} != hidden_states {hidden_states.shape}.") # Residual connection hidden_states = hidden_states + attn_output # Post-attention normalization normed_states = self.rms_norm2(hidden_states) # DeepSeekMoE moe_output, aux_loss = self.moe(normed_states) # Residual connection hidden_states = hidden_states + moe_output # Final normalization hidden_states = self.rms_norm3(hidden_states) # Final Output hidden_states = self.linear_output(hidden_states) return hidden_states, kv_cache, aux_loss class DeepSeek_V3_Encoder(nn.Module): """ Multi-Block Transformer. """ def __init__(self, config): super().__init__() self.num_blocks = config.num_blocks # Number of transformer blocks to stack in the model self.hidden_size = config.hidden_size self.layer_norm_eps = config.layer_norm_eps self.vocab_size = config.vocab_size # Stack of transformer blocks self.blocks = nn.ModuleList([ DeepSeek_V3_Block(config) for _ in range(self.num_blocks) ]) # Final normalization self.final_norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) # ---- Final output ---- self.output = nn.Linear(self.hidden_size, self.vocab_size, bias=False) # MTP head (Multi-Token Prediction) self.mtp = DeepSeek_MTP(config) def forward(self,hidden_states, input_tokens=None, mode="train", past_key_values=None, use_cache=False, attention_mask=False): assert hidden_states.dim() == 3, (f"hidden_states must have shape [batch, seq_len, hidden_size], got {hidden_states.shape}.") if past_key_values is None: past_key_values = [None] * self.num_blocks new_past_key_values = [] if use_cache else None # Forward through stacked transformer blocks for i, block in enumerate(self.blocks): hidden_states, kv_cache, aux_loss = block( hidden_states=hidden_states, input_tokens=input_tokens, mode=mode, use_cache=use_cache, past_key_values=past_key_values[i], attention_mask=attention_mask, ) if use_cache: new_past_key_values.append(kv_cache) # Final normalization hidden_states = self.final_norm(hidden_states) # Output logits = self.output(hidden_states) # [B, S, V] # MTP output handling if mode == "train" and input_tokens is not None: mtp_loss, mtp_logits = self.mtp( hidden_states=hidden_states, input_tokens=input_tokens, mode = "train", attention_mask=attention_mask ) return logits, mtp_loss, mtp_logits, aux_loss else: # mode == "inference" predicted_ids, mtp_logits = self.mtp( hidden_states, input_tokens=input_tokens, mode = "inference", attention_mask=attention_mask ) return predicted_ids, logits, new_past_key_values class DeepSeek_V3_Model(nn.Module): def __init__(self, config): super().__init__() self.config = config # Embedding Layer self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) # Core model self.model = DeepSeek_V3_Encoder(config) # Loss functions self.ce_loss = nn.CrossEntropyLoss(ignore_index=config.pad_token_id) def forward(self, input_tokens = None, mode="train", use_cache=False, past_key_values=None,attention_mask=False): # Generate embeddings from input tokens hidden_states = self.embedding(input_tokens) batch_size, seq_len = input_tokens.shape # Core model forward outputs = self.model( hidden_states, mode=mode, input_tokens=input_tokens, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=use_cache ) if mode == "train": logits, mtp_loss, mtp_logits, aux_loss = outputs # Shift for next-token prediction shift_logits = logits[..., :-1, :].contiguous() shift_labels = input_tokens[..., 1:].contiguous() main_loss = self.ce_loss( shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1) ) # Combine losses total_loss = main_loss + mtp_loss return total_loss, main_loss, mtp_loss, aux_loss, logits else: # mode == "inference" predicted_ids, logits, new_cache = outputs return predicted_ids, logits, new_cache # Code adapted from Sebastian Raschka def generate_text(model, tokenizer, prompt, max_length=50, temperature=1.0, top_k=20, eos_id=None, device=None): if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() # Encode prompt input_ids = torch.tensor([tokenizer.encode(prompt)], dtype=torch.long, device=device) generated = input_ids.clone() past_key_values = None # cache for inference for _ in range(max_length): if past_key_values is None: idx_cond = generated # full prompt (first step) else: idx_cond = generated[:, -1:] # only last token with torch.no_grad(): # Use mode="inference" and cache past keys/values predicted_ids, logits, past_key_values = model( input_tokens=idx_cond.to(device), mode="inference", use_cache=True, past_key_values=past_key_values, attention_mask=True ) logits = logits[:, -1, :] # last token logits # Top-k filtering if top_k is not None: top_logits, _ = torch.topk(logits, top_k) min_val = top_logits[:, -1].unsqueeze(-1) logits = torch.where(logits < min_val, torch.tensor(float("-inf"), device=device), logits) # Temperature + sampling or greedy if temperature > 0: logits = logits / temperature probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) else: next_token = torch.argmax(logits, dim=-1, keepdim=True) # Stop if EOS token generated if eos_id is not None and next_token.item() == eos_id: break # Append generated token generated = torch.cat((generated, next_token.to(device)), dim=1) # Decode full sequence back to text return tokenizer.decode(generated[0].tolist()) def clean_response(generated_text): if not generated_text: return "Sorry, I couldn't generate a response." text = str(generated_text) # Print the prompt part if "Response:" in text: prompt_part = text.split("Response:", 1)[0] + "Response:" else: prompt_part = "" print("=======================================") print(f"{prompt_part.strip()}") # Extract response if "Response:" in text: text = text.split("Response:", 1)[1] # Truncate at <|endoftext|> if "<|endoftext|>" in text: text = text.split("<|endoftext|>", 1)[0] # Remove non-printable characters text = ''.join(c for c in text if c.isprintable() or c.isspace()) # If text is empty after cleaning, return a default message if not text.strip(): return "I'm not sure how to answer that. Could you ask in a different way?" return text.strip() if __name__ == "__main__": raise RuntimeError("This module is not intended to be executed directly.")