| | """ |
| | DeepSeek Model Architecture for Children's Stories |
| | Implements advanced features: |
| | - Multihead Latent Attention (MLA) |
| | - Mixture of Experts (MoE) |
| | - Multi-token prediction |
| | - Quantization support |
| | - Rotary Positional Encodings (RoPE) |
| | - Optimized for children's story generation |
| | """ |
| |
|
| | import math |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import Optional, Tuple, List |
| | from dataclasses import dataclass |
| |
|
| |
|
| | @dataclass |
| | class DeepSeekConfig: |
| | """Configuration for DeepSeek model optimized for children's stories""" |
| | vocab_size: int = 50257 |
| | n_layer: int = 6 |
| | n_head: int = 8 |
| | n_embd: int = 512 |
| | block_size: int = 1024 |
| | dropout: float = 0.1 |
| | bias: bool = True |
| | |
| | |
| | use_mla: bool = True |
| | mla_kv_heads: int = 4 |
| | mla_q_lora_rank: int = 32 |
| | mla_kv_lora_rank: int = 16 |
| | |
| | |
| | moe_num_experts: int = 4 |
| | moe_top_k: int = 2 |
| | moe_expert_capacity: float = 1.25 |
| | moe_aux_loss_coeff: float = 0.01 |
| | |
| | |
| | multi_token_predict: int = 2 |
| | |
| | |
| | use_quantization: bool = False |
| | quantization_bits: int = 8 |
| |
|
| |
|
| | class RoPEPositionalEncoding(nn.Module): |
| | """Rotary Positional Encoding (RoPE) for better position understanding""" |
| | |
| | def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000.0): |
| | super().__init__() |
| | self.dim = dim |
| | self.max_seq_len = max_seq_len |
| | self.base = base |
| | |
| | |
| | inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| | self.register_buffer('inv_freq', inv_freq) |
| | |
| | |
| | self._cached_cos = None |
| | self._cached_sin = None |
| | self._cached_seq_len = 0 |
| | |
| | def _compute_cos_sin(self, seq_len: int, device: torch.device): |
| | """Compute cosine and sine values for given sequence length""" |
| | if seq_len > self._cached_seq_len or self._cached_cos is None: |
| | |
| | t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) |
| | |
| | |
| | freqs = torch.outer(t, self.inv_freq) |
| | |
| | |
| | cos_vals = torch.cos(freqs) |
| | sin_vals = torch.sin(freqs) |
| | |
| | |
| | self._cached_cos = cos_vals |
| | self._cached_sin = sin_vals |
| | self._cached_seq_len = seq_len |
| | |
| | return self._cached_cos[:seq_len], self._cached_sin[:seq_len] |
| | |
| | def apply_rope(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None): |
| | """Apply RoPE to input tensor""" |
| | batch_size, seq_len, n_heads, head_dim = x.shape |
| | |
| | |
| | cos, sin = self._compute_cos_sin(seq_len, x.device) |
| | |
| | |
| | if position_ids is not None: |
| | cos = cos[position_ids] |
| | sin = sin[position_ids] |
| | |
| | |
| | cos = cos.unsqueeze(0).unsqueeze(2) |
| | sin = sin.unsqueeze(0).unsqueeze(2) |
| | |
| | |
| | x1 = x[..., ::2] |
| | x2 = x[..., 1::2] |
| | |
| | |
| | rotated_x1 = x1 * cos - x2 * sin |
| | rotated_x2 = x1 * sin + x2 * cos |
| | |
| | |
| | rotated_x = torch.stack([rotated_x1, rotated_x2], dim=-1).flatten(-2) |
| | |
| | return rotated_x |
| |
|
| |
|
| | class MultiheadLatentAttention(nn.Module): |
| | """ |
| | Multihead Latent Attention (MLA) - DeepSeek's efficient attention mechanism |
| | Uses shared key-value heads with LoRA-style projections for efficiency |
| | """ |
| | |
| | def __init__(self, config: DeepSeekConfig): |
| | super().__init__() |
| | self.config = config |
| | self.n_head = config.n_head |
| | self.n_embd = config.n_embd |
| | self.head_dim = config.n_embd // config.n_head |
| | self.kv_heads = config.mla_kv_heads |
| | self.kv_head_dim = self.head_dim |
| | |
| | |
| | self.q_a_proj = nn.Linear(config.n_embd, config.mla_q_lora_rank, bias=False) |
| | self.q_b_proj = nn.Linear(config.mla_q_lora_rank, config.n_embd, bias=False) |
| | |
| | |
| | self.kv_a_proj = nn.Linear(config.n_embd, config.mla_kv_lora_rank, bias=False) |
| | self.kv_b_proj = nn.Linear(config.mla_kv_lora_rank, self.kv_heads * self.head_dim * 2, bias=False) |
| | |
| | |
| | self.out_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) |
| | |
| | |
| | self.rope = RoPEPositionalEncoding(self.head_dim) |
| | |
| | |
| | self.dropout = nn.Dropout(config.dropout) |
| | |
| | |
| | self.scale = self.head_dim ** -0.5 |
| | |
| | def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): |
| | batch_size, seq_len, _ = x.shape |
| | |
| | |
| | q_latent = self.q_a_proj(x) |
| | q = self.q_b_proj(q_latent) |
| | q = q.view(batch_size, seq_len, self.n_head, self.head_dim) |
| | |
| | |
| | kv_latent = self.kv_a_proj(x) |
| | kv = self.kv_b_proj(kv_latent) |
| | kv = kv.view(batch_size, seq_len, self.kv_heads, self.head_dim, 2) |
| | k, v = kv.unbind(dim=-1) |
| | |
| | |
| | q = self.rope.apply_rope(q) |
| | k = self.rope.apply_rope(k) |
| | |
| | |
| | k = k.repeat_interleave(self.n_head // self.kv_heads, dim=2) |
| | v = v.repeat_interleave(self.n_head // self.kv_heads, dim=2) |
| | |
| | |
| | q = q.transpose(1, 2) |
| | k = k.transpose(1, 2) |
| | v = v.transpose(1, 2) |
| | |
| | |
| | attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
| | |
| | |
| | if attention_mask is None: |
| | causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool() |
| | attn_scores.masked_fill_(causal_mask, float('-inf')) |
| | else: |
| | attn_scores = attn_scores + attention_mask |
| | |
| | |
| | attn_weights = F.softmax(attn_scores, dim=-1) |
| | attn_weights = self.dropout(attn_weights) |
| | |
| | |
| | out = torch.matmul(attn_weights, v) |
| | out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.n_embd) |
| | |
| | |
| | out = self.out_proj(out) |
| | |
| | return out |
| |
|
| |
|
| | class MoEExpert(nn.Module): |
| | """Expert network for Mixture of Experts""" |
| | |
| | def __init__(self, config: DeepSeekConfig): |
| | super().__init__() |
| | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) |
| | self.gelu = nn.GELU() |
| | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) |
| | self.dropout = nn.Dropout(config.dropout) |
| | |
| | def forward(self, x: torch.Tensor): |
| | return self.dropout(self.c_proj(self.gelu(self.c_fc(x)))) |
| |
|
| |
|
| | class MixtureOfExperts(nn.Module): |
| | """Mixture of Experts (MoE) for increased model capacity""" |
| | |
| | def __init__(self, config: DeepSeekConfig): |
| | super().__init__() |
| | self.config = config |
| | self.num_experts = config.moe_num_experts |
| | self.top_k = config.moe_top_k |
| | self.expert_capacity = config.moe_expert_capacity |
| | |
| | |
| | self.router = nn.Linear(config.n_embd, config.moe_num_experts, bias=False) |
| | |
| | |
| | self.experts = nn.ModuleList([MoEExpert(config) for _ in range(config.moe_num_experts)]) |
| | |
| | |
| | self.ln = nn.LayerNorm(config.n_embd, bias=config.bias) |
| | |
| | def forward(self, x: torch.Tensor): |
| | batch_size, seq_len, hidden_dim = x.shape |
| | |
| | |
| | router_logits = self.router(x) |
| | |
| | |
| | top_k_logits, top_k_indices = torch.topk(router_logits, self.top_k, dim=-1) |
| | top_k_probs = F.softmax(top_k_logits, dim=-1) |
| | |
| | |
| | output = torch.zeros_like(x) |
| | |
| | |
| | for expert_idx in range(self.num_experts): |
| | |
| | expert_mask = (top_k_indices == expert_idx).any(dim=-1) |
| | |
| | if expert_mask.any(): |
| | |
| | expert_tokens = x[expert_mask] |
| | |
| | |
| | expert_weights = top_k_probs[expert_mask] |
| | expert_weights = expert_weights[top_k_indices[expert_mask] == expert_idx] |
| | |
| | |
| | expert_output = self.experts[expert_idx](expert_tokens) |
| | |
| | |
| | weighted_output = expert_output * expert_weights.unsqueeze(-1) |
| | |
| | |
| | output[expert_mask] += weighted_output |
| | |
| | |
| | output = self.ln(output) |
| | |
| | return output, router_logits |
| | |
| | def _compute_aux_loss(self, router_logits: torch.Tensor): |
| | """Compute auxiliary loss for load balancing""" |
| | router_probs = F.softmax(router_logits, dim=-1) |
| | mean_expert_usage = router_probs.mean(dim=[0, 1]) |
| | target_usage = 1.0 / self.num_experts |
| | |
| | aux_loss = torch.sum((mean_expert_usage - target_usage) ** 2) |
| | return aux_loss |
| |
|
| |
|
| | class DeepSeekBlock(nn.Module): |
| | """DeepSeek transformer block with MLA and MoE""" |
| | |
| | def __init__(self, config: DeepSeekConfig): |
| | super().__init__() |
| | self.config = config |
| | |
| | |
| | self.ln1 = nn.LayerNorm(config.n_embd, bias=config.bias) |
| | self.ln2 = nn.LayerNorm(config.n_embd, bias=config.bias) |
| | |
| | |
| | if config.use_mla: |
| | self.attn = MultiheadLatentAttention(config) |
| | else: |
| | |
| | self.attn = nn.MultiheadAttention( |
| | config.n_embd, |
| | config.n_head, |
| | dropout=config.dropout, |
| | bias=config.bias, |
| | batch_first=True |
| | ) |
| | |
| | |
| | self.moe = MixtureOfExperts(config) |
| | |
| | def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None): |
| | |
| | if self.config.use_mla: |
| | x = x + self.attn(self.ln1(x), attention_mask) |
| | else: |
| | attn_out, _ = self.attn(self.ln1(x), self.ln1(x), self.ln1(x), attn_mask=attention_mask) |
| | x = x + attn_out |
| | |
| | |
| | moe_output, router_logits = self.moe(self.ln2(x)) |
| | x = x + moe_output |
| | |
| | return x, router_logits |
| |
|
| |
|
| | class MultiTokenPredictor(nn.Module): |
| | """Multi-token prediction head for improved training efficiency""" |
| | |
| | def __init__(self, config: DeepSeekConfig): |
| | super().__init__() |
| | self.config = config |
| | self.num_tokens = config.multi_token_predict |
| | |
| | |
| | self.predictors = nn.ModuleList([ |
| | nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| | for _ in range(config.multi_token_predict) |
| | ]) |
| | |
| | def forward(self, hidden_states: torch.Tensor): |
| | """Forward pass for multi-token prediction""" |
| | batch_size, seq_len, hidden_dim = hidden_states.shape |
| | |
| | |
| | logits = [] |
| | for i, predictor in enumerate(self.predictors): |
| | |
| | if i + 1 < seq_len: |
| | token_logits = predictor(hidden_states[:, i+1:i+2, :]) |
| | logits.append(token_logits) |
| | else: |
| | |
| | token_logits = torch.zeros(batch_size, 1, self.config.vocab_size, |
| | device=hidden_states.device) |
| | logits.append(token_logits) |
| | |
| | return torch.cat(logits, dim=1) |
| |
|
| |
|
| | class DeepSeek(nn.Module): |
| | """DeepSeek model for children's story generation""" |
| | |
| | def __init__(self, config: DeepSeekConfig): |
| | super().__init__() |
| | assert isinstance(config, DeepSeekConfig), "config must be an instance of DeepSeekConfig" |
| | self.config = config |
| | |
| | |
| | self.transformer = nn.ModuleDict(dict( |
| | wte=nn.Embedding(config.vocab_size, config.n_embd), |
| | wpe=nn.Embedding(config.block_size, config.n_embd), |
| | drop=nn.Dropout(config.dropout), |
| | h=nn.ModuleList([DeepSeekBlock(config) for _ in range(config.n_layer)]), |
| | ln_f=nn.LayerNorm(config.n_embd, bias=config.bias), |
| | )) |
| | |
| | |
| | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| | |
| | |
| | if config.multi_token_predict > 0: |
| | self.multi_token_predictor = MultiTokenPredictor(config) |
| | else: |
| | self.multi_token_predictor = None |
| | |
| | |
| | self.transformer.wte.weight = self.lm_head.weight |
| | |
| | |
| | self.apply(self._init_weights) |
| | |
| | |
| | if config.use_quantization: |
| | self._setup_quantization() |
| | |
| | def _init_weights(self, module): |
| | """Initialize model weights""" |
| | if isinstance(module, nn.Linear): |
| | nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | if module.bias is not None: |
| | nn.init.zeros_(module.bias) |
| | elif isinstance(module, nn.Embedding): |
| | nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| | elif isinstance(module, nn.LayerNorm): |
| | nn.init.ones_(module.weight) |
| | if module.bias is not None: |
| | nn.init.zeros_(module.bias) |
| | |
| | def _setup_quantization(self): |
| | """Setup quantization for the model""" |
| | |
| | |
| | pass |
| | |
| | def forward(self, input_ids: torch.Tensor, targets: Optional[torch.Tensor] = None): |
| | """Forward pass""" |
| | device = input_ids.device |
| | batch_size, seq_len = input_ids.size() |
| | assert seq_len <= self.config.block_size |
| | |
| | |
| | pos = torch.arange(0, seq_len, dtype=torch.long, device=device) |
| | |
| | |
| | tok_emb = self.transformer.wte(input_ids) |
| | pos_emb = self.transformer.wpe(pos) |
| | |
| | x = self.transformer.drop(tok_emb + pos_emb) |
| | |
| | |
| | router_logits_list = [] |
| | for block in self.transformer.h: |
| | x, router_logits = block(x) |
| | router_logits_list.append(router_logits) |
| | |
| | |
| | x = self.transformer.ln_f(x) |
| | |
| | if targets is not None: |
| | |
| | if self.multi_token_predictor is not None: |
| | |
| | multi_logits = self.multi_token_predictor(x) |
| | loss = self._compute_multi_token_loss(multi_logits, targets) |
| | else: |
| | |
| | logits = self.lm_head(x) |
| | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), |
| | targets.view(-1), ignore_index=-1) |
| | |
| | |
| | if router_logits_list: |
| | aux_loss = sum(self.transformer.h[i].moe._compute_aux_loss(router_logits_list[i]) |
| | for i in range(len(router_logits_list))) |
| | loss += self.config.moe_aux_loss_coeff * aux_loss |
| | |
| | return logits if self.multi_token_predictor is None else multi_logits, loss |
| | else: |
| | |
| | logits = self.lm_head(x[:, [-1], :]) |
| | return logits, None |
| | |
| | def _compute_multi_token_loss(self, logits: torch.Tensor, targets: torch.Tensor): |
| | """Compute loss for multi-token prediction""" |
| | batch_size, num_tokens, vocab_size = logits.shape |
| | |
| | |
| | logits_flat = logits.view(-1, vocab_size) |
| | targets_flat = targets.view(-1) |
| | |
| | |
| | loss = F.cross_entropy(logits_flat, targets_flat, ignore_index=-1) |
| | |
| | return loss |
| | |
| | @torch.no_grad() |
| | def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100, |
| | temperature: float = 1.0, top_k: Optional[int] = None): |
| | """Generate text using the model""" |
| | for _ in range(max_new_tokens): |
| | |
| | idx_cond = input_ids if input_ids.size(1) <= self.config.block_size else input_ids[:, -self.config.block_size:] |
| | |
| | |
| | logits, _ = self(idx_cond) |
| | logits = logits[:, -1, :] / temperature |
| | |
| | |
| | if top_k is not None: |
| | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| | logits[logits < v[:, [-1]]] = -float('Inf') |
| | |
| | |
| | probs = F.softmax(logits, dim=-1) |
| | idx_next = torch.multinomial(probs, num_samples=1) |
| | input_ids = torch.cat((input_ids, idx_next), dim=1) |
| | |
| | return input_ids |
| | |
| | @classmethod |
| | def from_pretrained(cls, model_type: str, override_args: Optional[dict] = None): |
| | """Load a pretrained model""" |
| | |
| | |
| | config = DeepSeekConfig() |
| | if override_args: |
| | for key, value in override_args.items(): |
| | setattr(config, key, value) |
| | return cls(config) |