Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple, List | |
| class DeepSeekConfig: | |
| """ | |
| Configuration for DeepSeek model (Scaled down to ~135M params). | |
| """ | |
| block_size: int = 2048 | |
| vocab_size: int = 49152 | |
| n_layer: int = 12 # 12 layers (Pruned from 30) | |
| n_head: int = 9 # Trained with 9 | |
| n_embd: int = 576 # Trained with 576 | |
| n_kv_head: int = 3 # Trained with 3 | |
| intermediate_size: int = 1536 # Trained with 1536 | |
| rms_norm_eps: float = 1e-5 | |
| rope_theta: float = 100000.0 | |
| # MLHA params | |
| q_lora_rank: int = 192 # Trained with 192 | |
| kv_lora_rank: int = 128 # Trained with 128 | |
| # MoE params | |
| n_routed_experts: int = 8 # Trained with 8 | |
| n_shared_experts: int = 2 # Trained with 2 | |
| n_activated_experts: int = 2 | |
| moe_intermediate_size: int = 1536 # Trained with 1536 (Fixed mismatch) | |
| class RMSNorm(nn.Module): | |
| """Root Mean Square Layer Normalization.""" | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def _norm(self, x: torch.Tensor) -> torch.Tensor: | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight | |
| def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor: | |
| """Precompute complex exponentials for Rotary Positional Embeddings (RoPE).""" | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) | |
| t = torch.arange(end, device=freqs.device, dtype=torch.float32) | |
| freqs = torch.outer(t, freqs) | |
| freqs_cis = torch.polar(torch.ones_like(freqs), freqs) | |
| return freqs_cis | |
| def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: | |
| """Reshape frequency tensor for broadcasting with input tensor.""" | |
| ndim = x.ndim | |
| assert 0 <= 1 < ndim | |
| assert freqs_cis.shape == (x.shape[1], x.shape[-1]) | |
| shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] | |
| return freqs_cis.view(*shape) | |
| def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Apply Rotary Positional Embeddings to query and key tensors.""" | |
| xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
| xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
| freqs_cis = reshape_for_broadcast(freqs_cis, xq_) | |
| xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
| xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
| return xq_out.type_as(xq), xk_out.type_as(xk) | |
| class MultiHeadLatentAttention(nn.Module): | |
| """ | |
| Multi-Head Latent Attention (MLHA) from DeepSeek. | |
| Key innovation: Low-rank compression of KV cache to reduce memory. | |
| - Queries: Compressed via low-rank projection (q_lora_rank) | |
| - Keys/Values: Compressed via low-rank projection (kv_lora_rank) | |
| - Significantly reduces KV cache size during inference | |
| """ | |
| def __init__(self, config: DeepSeekConfig): | |
| super().__init__() | |
| self.n_head = config.n_head | |
| self.n_kv_head = config.n_kv_head | |
| self.n_embd = config.n_embd | |
| self.head_dim = config.n_embd // config.n_head | |
| self.n_rep = self.n_head // self.n_kv_head | |
| # MLHA: Low-rank compression | |
| self.kv_lora_rank = config.kv_lora_rank | |
| self.q_lora_rank = config.q_lora_rank | |
| # Query projection with low-rank compression | |
| # q = W_q_down @ W_q_up @ x | |
| self.q_down_proj = nn.Linear(config.n_embd, config.q_lora_rank, bias=False) | |
| self.q_up_proj = nn.Linear(config.q_lora_rank, config.n_head * self.head_dim, bias=False) | |
| # KV projection with low-rank compression | |
| self.kv_down_proj = nn.Linear(config.n_embd, config.kv_lora_rank, bias=False) | |
| self.kv_up_proj = nn.Linear(config.kv_lora_rank, 2 * config.n_kv_head * self.head_dim, bias=False) | |
| # Output projection | |
| self.o_proj = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=False) | |
| def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: | |
| B, T, C = x.size() | |
| # MLHA: Low-rank query projection | |
| q_compressed = self.q_down_proj(x) # (B, T, q_lora_rank) | |
| xq = self.q_up_proj(q_compressed) # (B, T, n_head * head_dim) | |
| # MLHA: Low-rank KV projection | |
| kv_compressed = self.kv_down_proj(x) # (B, T, kv_lora_rank) | |
| kv = self.kv_up_proj(kv_compressed) # (B, T, 2 * n_kv_head * head_dim) | |
| # Split KV | |
| xk, xv = kv.chunk(2, dim=-1) | |
| # Reshape for multi-head attention | |
| xq = xq.view(B, T, self.n_head, self.head_dim) | |
| xk = xk.view(B, T, self.n_kv_head, self.head_dim) | |
| xv = xv.view(B, T, self.n_kv_head, self.head_dim) | |
| # Apply RoPE | |
| xq, xk = apply_rotary_emb(xq, xk, freqs_cis) | |
| # GQA: Repeat KV heads to match query heads | |
| if self.n_rep > 1: | |
| xk = torch.repeat_interleave(xk, self.n_rep, dim=2) | |
| xv = torch.repeat_interleave(xv, self.n_rep, dim=2) | |
| # Transpose for attention: (B, n_head, T, head_dim) | |
| xq = xq.transpose(1, 2) | |
| xk = xk.transpose(1, 2) | |
| xv = xv.transpose(1, 2) | |
| # Flash Attention | |
| output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True) | |
| # Reshape and project | |
| output = output.transpose(1, 2).contiguous().view(B, T, C) | |
| return self.o_proj(output) | |
| class Expert(nn.Module): | |
| """Single expert in the MoE layer.""" | |
| def __init__(self, config: DeepSeekConfig): | |
| super().__init__() | |
| self.gate_proj = nn.Linear(config.n_embd, config.intermediate_size, bias=False) | |
| self.up_proj = nn.Linear(config.n_embd, config.intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(config.intermediate_size, config.n_embd, bias=False) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # SwiGLU activation | |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) | |
| class MoELayer(nn.Module): | |
| """ | |
| Mixture of Experts (MoE) with lossless load balancing. | |
| Key features: | |
| - Shared experts: Always activated for all tokens | |
| - Routed experts: Top-k selection per token | |
| - Lossless load balancing: No auxiliary loss, uses expert capacity | |
| """ | |
| def __init__(self, config: DeepSeekConfig): | |
| super().__init__() | |
| self.n_routed_experts = config.n_routed_experts | |
| self.n_shared_experts = config.n_shared_experts | |
| self.n_activated_experts = config.n_activated_experts | |
| self.n_embd = config.n_embd | |
| # Router: Maps input to expert scores | |
| self.router = nn.Linear(config.n_embd, config.n_routed_experts, bias=False) | |
| # Routed experts | |
| self.routed_experts = nn.ModuleList([ | |
| Expert(config) for _ in range(config.n_routed_experts) | |
| ]) | |
| # Shared experts (always active) | |
| self.shared_experts = nn.ModuleList([ | |
| Expert(config) for _ in range(config.n_shared_experts) | |
| ]) | |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Memory-optimized forward pass with lossless load balancing. | |
| Returns: | |
| output: MoE output | |
| router_logits: For monitoring load balance (not used in loss) | |
| """ | |
| B, T, C = x.size() | |
| x_flat = x.view(-1, C) # (B*T, C) | |
| # 1. Shared experts (always active) - memory efficient | |
| if self.n_shared_experts > 0: | |
| shared_output = torch.zeros_like(x_flat) | |
| for expert in self.shared_experts: | |
| shared_output.add_(expert(x_flat)) | |
| shared_output.div_(self.n_shared_experts) | |
| else: | |
| shared_output = torch.zeros_like(x_flat) | |
| # 2. Routed experts (top-k selection) - optimized routing | |
| router_logits = self.router(x_flat) # (B*T, n_routed_experts) | |
| routing_weights = F.softmax(router_logits, dim=-1) | |
| # Select top-k experts | |
| top_k_weights, top_k_indices = torch.topk( | |
| routing_weights, | |
| k=self.n_activated_experts, | |
| dim=-1 | |
| ) # (B*T, k) | |
| # Normalize top-k weights | |
| top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-8) | |
| # Memory-efficient expert routing | |
| routed_output = torch.zeros_like(x_flat) | |
| # Process each expert efficiently | |
| for expert_id in range(self.n_routed_experts): | |
| # Find tokens routed to this expert | |
| expert_mask = (top_k_indices == expert_id).any(dim=1) | |
| if expert_mask.any(): | |
| # Get indices and weights for this expert | |
| token_indices = expert_mask.nonzero(as_tuple=True)[0] | |
| expert_input = x_flat[token_indices] | |
| # Compute expert output | |
| expert_out = self.routed_experts[expert_id](expert_input) | |
| # Get weights for these tokens | |
| weights = torch.zeros(token_indices.size(0), 1, device=x_flat.device) | |
| for k in range(self.n_activated_experts): | |
| mask = (top_k_indices[token_indices, k] == expert_id) | |
| weights[mask] = top_k_weights[token_indices[mask], k:k+1] | |
| # Add weighted output | |
| routed_output[token_indices] += weights * expert_out | |
| # Combine shared and routed outputs | |
| output = shared_output + routed_output | |
| output = output.view(B, T, C) | |
| return output, router_logits | |
| class DeepSeekBlock(nn.Module): | |
| """ | |
| DeepSeek Transformer Block with MLHA and MoE. | |
| """ | |
| def __init__(self, config: DeepSeekConfig): | |
| super().__init__() | |
| self.attention = MultiHeadLatentAttention(config) | |
| self.moe = MoELayer(config) | |
| self.input_layernorm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) | |
| def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
| # Self-attention with residual | |
| h = x + self.attention(self.input_layernorm(x), freqs_cis) | |
| # MoE with residual | |
| moe_output, router_logits = self.moe(self.post_attention_layernorm(h)) | |
| output = h + moe_output | |
| return output, router_logits | |
| class DeepSeek(nn.Module): | |
| """ | |
| DeepSeek Model with MLHA and MoE for Causal Language Modeling. | |
| """ | |
| def __init__(self, config: DeepSeekConfig): | |
| super().__init__() | |
| self.config = config | |
| # Token embeddings | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.n_embd) | |
| # Transformer blocks | |
| self.layers = nn.ModuleList([ | |
| DeepSeekBlock(config) for _ in range(config.n_layer) | |
| ]) | |
| # Final layer norm | |
| self.norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps) | |
| # Language modeling head | |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) | |
| # Weight tying | |
| self.embed_tokens.weight = self.lm_head.weight | |
| # Precompute RoPE frequencies | |
| self.freqs_cis = precompute_freqs_cis( | |
| config.n_embd // config.n_head, | |
| config.block_size * 2, | |
| config.rope_theta | |
| ) | |
| print(f"DeepSeek Model initialized with {self.count_parameters():,} parameters") | |
| def count_parameters(self) -> int: | |
| """Count total trainable parameters.""" | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None): | |
| B, T = idx.size() | |
| assert T <= self.config.block_size, f"Sequence length {T} exceeds block size {self.config.block_size}" | |
| # Move freqs_cis to device if needed | |
| if self.freqs_cis.device != idx.device: | |
| self.freqs_cis = self.freqs_cis.to(idx.device) | |
| freqs_cis = self.freqs_cis[:T] | |
| # Embeddings | |
| x = self.embed_tokens(idx) | |
| # Transformer blocks | |
| all_router_logits = [] | |
| for layer in self.layers: | |
| x, router_logits = layer(x, freqs_cis) | |
| all_router_logits.append(router_logits) | |
| # Final norm | |
| x = self.norm(x) | |
| # Language modeling head | |
| if targets is not None: | |
| logits = self.lm_head(x) | |
| loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) | |
| return logits, loss, all_router_logits | |
| else: | |
| # Generation mode: only compute last token | |
| logits = self.lm_head(x[:, [-1], :]) | |
| return logits | |
| def from_pretrained(cls, model_name: str = "HuggingFaceTB/SmolLM2-135M"): | |
| """ | |
| Initialize from SmolLM2 pretrained weights (where possible). | |
| Note: MLHA and MoE layers will be randomly initialized. | |
| """ | |
| from transformers import AutoModelForCausalLM | |
| print(f"Loading base weights from {model_name}") | |
| hf_model = AutoModelForCausalLM.from_pretrained(model_name) | |
| hf_sd = hf_model.state_dict() | |
| config = DeepSeekConfig() | |
| model = cls(config) | |
| sd = model.state_dict() | |
| # Only load embeddings and LM head (architecture is different) | |
| keys_to_load = ["embed_tokens.weight", "lm_head.weight"] | |
| for k in keys_to_load: | |
| hf_key = f"model.{k}" if "embed" in k else k | |
| if hf_key in hf_sd: | |
| with torch.no_grad(): | |
| sd[k].copy_(hf_sd[hf_key]) | |
| print(f"Loaded: {k}") | |
| print("Note: MLHA and MoE layers initialized randomly (architecture mismatch)") | |
| return model | |