Spaces:
Sleeping
Sleeping
| """ | |
| Qwen3 Model Implementation | |
| This file contains the complete Qwen3 model architecture and helper functions | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| # ============================================================================ | |
| # Helper Functions for Text Generation | |
| # ============================================================================ | |
| def text_to_token_ids(text, tokenizer): | |
| """ | |
| Convert text to token IDs using the tokenizer | |
| Parameters: | |
| ----------- | |
| text : str | |
| Input text to tokenize | |
| tokenizer : tiktoken tokenizer | |
| The tokenizer to use (e.g., tiktoken.get_encoding("gpt2")) | |
| Returns: | |
| -------- | |
| torch.Tensor : Token IDs as a tensor with shape [1, num_tokens] | |
| """ | |
| encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'}) | |
| encoded_tensor = torch.tensor(encoded).unsqueeze(0) # Add batch dimension | |
| return encoded_tensor | |
| def token_ids_to_text(token_ids, tokenizer): | |
| """ | |
| Convert token IDs back to text | |
| Parameters: | |
| ----------- | |
| token_ids : torch.Tensor | |
| Token IDs with shape [batch_size, num_tokens] | |
| tokenizer : tiktoken tokenizer | |
| The tokenizer to use | |
| Returns: | |
| -------- | |
| str : Decoded text | |
| """ | |
| flat = token_ids.squeeze(0) # Remove batch dimension | |
| return tokenizer.decode(flat.tolist()) | |
| def generate_text_simple(model, idx, max_new_tokens, context_size, temperature=1.0, top_k=None): | |
| """ | |
| Generate text using the model | |
| This function generates text one token at a time by: | |
| 1. Getting the model's predictions for the next token | |
| 2. Applying temperature to control randomness | |
| 3. Optionally using top-k sampling to limit choices | |
| 4. Sampling the next token and adding it to the sequence | |
| Parameters: | |
| ----------- | |
| model : Qwen3Model | |
| The trained Qwen3 model | |
| idx : torch.Tensor | |
| Starting token IDs with shape [batch_size, sequence_length] | |
| max_new_tokens : int | |
| How many new tokens to generate | |
| context_size : int | |
| Maximum context length the model can handle | |
| temperature : float | |
| Controls randomness (lower = more predictable, higher = more random) | |
| - temperature < 1.0: More focused/deterministic | |
| - temperature = 1.0: Normal sampling | |
| - temperature > 1.0: More random/creative | |
| top_k : int or None | |
| If set, only sample from the top k most likely tokens | |
| Returns: | |
| -------- | |
| torch.Tensor : Token IDs including both input and generated tokens | |
| """ | |
| model.eval() # Set model to evaluation mode | |
| # Generate tokens one at a time | |
| for _ in range(max_new_tokens): | |
| # Crop context if it exceeds the model's maximum context size | |
| idx_cond = idx if idx.size(1) <= context_size else idx[:, -context_size:] | |
| # Get model predictions | |
| with torch.no_grad(): | |
| logits, _ = model(idx_cond) | |
| # Focus only on the last time step (the next token prediction) | |
| logits = logits[:, -1, :] # Shape: [batch_size, vocab_size] | |
| # Apply temperature scaling | |
| # Lower temperature makes the model more confident in top choices | |
| # Higher temperature makes the distribution more uniform (more random) | |
| logits = logits / temperature | |
| # Optional: Apply top-k filtering | |
| # This limits sampling to only the k most likely tokens | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = float('-inf') | |
| # Convert logits to probabilities | |
| probs = F.softmax(logits, dim=-1) | |
| # Sample the next token | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| # Append sampled token to the sequence | |
| idx = torch.cat((idx, idx_next), dim=1) | |
| return idx | |
| # ============================================================================ | |
| # Model Architecture Components | |
| # ============================================================================ | |
| class RMSNorm(nn.Module): | |
| """ | |
| Root Mean Square Layer Normalization | |
| RMSNorm is simpler and more efficient than LayerNorm. | |
| Instead of normalizing using mean and variance, it only uses the root mean square. | |
| """ | |
| def __init__(self, emb_dim, eps=1e-6, bias=False, qwen3_compatible=True): | |
| super().__init__() | |
| self.eps = eps | |
| self.qwen3_compatible = qwen3_compatible | |
| self.scale = nn.Parameter(torch.ones(emb_dim)) | |
| self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None | |
| def forward(self, x): | |
| input_dtype = x.dtype | |
| if self.qwen3_compatible: | |
| x = x.to(torch.float32) | |
| # Calculate variance using mean of squares | |
| variance = x.pow(2).mean(dim=-1, keepdim=True) | |
| # Normalize | |
| norm_x = x * torch.rsqrt(variance + self.eps) | |
| norm_x = norm_x * self.scale | |
| if self.shift is not None: | |
| norm_x = norm_x + self.shift | |
| return norm_x.to(input_dtype) | |
| def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, dtype=torch.float32): | |
| """ | |
| Compute Rotary Position Embedding (RoPE) parameters | |
| RoPE encodes position by rotating token embeddings. | |
| This allows the model to understand relative positions between tokens. | |
| """ | |
| assert head_dim % 2 == 0, "Embedding dimension must be even" | |
| # Compute the inverse frequencies | |
| inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim)) | |
| # Generate position indices | |
| positions = torch.arange(context_length, dtype=dtype) | |
| # Compute the angles | |
| angles = positions[:, None] * inv_freq[None, :] | |
| # Expand angles to match the head_dim | |
| angles = torch.cat([angles, angles], dim=1) | |
| # Precompute sine and cosine | |
| cos = torch.cos(angles) | |
| sin = torch.sin(angles) | |
| return cos, sin | |
| def apply_rope(x, cos, sin): | |
| """ | |
| Apply Rotary Position Embedding to input tensor | |
| This rotates the embeddings based on their position in the sequence. | |
| """ | |
| batch_size, num_heads, seq_len, head_dim = x.shape | |
| assert head_dim % 2 == 0, "Head dimension must be even" | |
| # Split x into first half and second half | |
| x1 = x[..., : head_dim // 2] | |
| x2 = x[..., head_dim // 2 :] | |
| # Adjust sin and cos shapes | |
| cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) | |
| sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0) | |
| # Apply the rotary transformation | |
| rotated = torch.cat((-x2, x1), dim=-1) | |
| x_rotated = (x * cos) + (rotated * sin) | |
| return x_rotated.to(dtype=x.dtype) | |
| class GroupedQueryAttention(nn.Module): | |
| """ | |
| Grouped Query Attention (GQA) | |
| GQA is more efficient than standard multi-head attention. | |
| It shares Key and Value projections across multiple Query heads, | |
| reducing the number of parameters while maintaining performance. | |
| """ | |
| def __init__(self, d_in, num_heads, num_kv_groups, head_dim=None, qk_norm=False, dtype=None): | |
| super().__init__() | |
| assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups" | |
| self.num_heads = num_heads | |
| self.num_kv_groups = num_kv_groups | |
| self.group_size = num_heads // num_kv_groups | |
| if head_dim is None: | |
| assert d_in % num_heads == 0, "`d_in` must be divisible by `num_heads` if `head_dim` is not set" | |
| head_dim = d_in // num_heads | |
| self.head_dim = head_dim | |
| self.d_out = num_heads * head_dim | |
| self.W_query = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype) | |
| self.W_key = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype) | |
| self.W_value = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype) | |
| self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype) | |
| if qk_norm: | |
| self.q_norm = RMSNorm(head_dim, eps=1e-6) | |
| self.k_norm = RMSNorm(head_dim, eps=1e-6) | |
| else: | |
| self.q_norm = self.k_norm = None | |
| def forward(self, x, mask, cos, sin): | |
| b, num_tokens, _ = x.shape | |
| # Apply projections | |
| queries = self.W_query(x) | |
| keys = self.W_key(x) | |
| values = self.W_value(x) | |
| # Reshape | |
| queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) | |
| keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2) | |
| values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2) | |
| # Optional normalization | |
| if self.q_norm: | |
| queries = self.q_norm(queries) | |
| if self.k_norm: | |
| keys = self.k_norm(keys) | |
| # Apply RoPE | |
| queries = apply_rope(queries, cos, sin) | |
| keys = apply_rope(keys, cos, sin) | |
| # Expand K and V to match number of heads | |
| keys = keys.repeat_interleave(self.group_size, dim=1) | |
| values = values.repeat_interleave(self.group_size, dim=1) | |
| # Attention | |
| attn_scores = queries @ keys.transpose(2, 3) | |
| attn_scores = attn_scores.masked_fill(mask, -torch.inf) | |
| attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1) | |
| context = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out) | |
| return self.out_proj(context) | |
| class FeedForward(nn.Module): | |
| """ | |
| Feed-Forward Network used in transformer blocks | |
| This applies two linear transformations with a SiLU activation in between. | |
| The hidden dimension is typically larger than the embedding dimension, | |
| allowing the model to learn complex patterns. | |
| """ | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.fc1 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False) | |
| self.fc2 = nn.Linear(cfg["emb_dim"], cfg["hidden_dim"], dtype=cfg["dtype"], bias=False) | |
| self.fc3 = nn.Linear(cfg["hidden_dim"], cfg["emb_dim"], dtype=cfg["dtype"], bias=False) | |
| def forward(self, x): | |
| x_fc1 = self.fc1(x) | |
| x_fc2 = self.fc2(x) | |
| x = nn.functional.silu(x_fc1) * x_fc2 | |
| return self.fc3(x) | |
| class TransformerBlock(nn.Module): | |
| """ | |
| A single Transformer Block | |
| Each block consists of: | |
| 1. Grouped Query Attention for processing relationships between tokens | |
| 2. Feed-Forward Network for processing each token independently | |
| 3. Residual connections and normalization for stable training | |
| """ | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.att = GroupedQueryAttention( | |
| d_in=cfg["emb_dim"], | |
| num_heads=cfg["n_heads"], | |
| head_dim=cfg["head_dim"], | |
| num_kv_groups=cfg["n_kv_groups"], | |
| qk_norm=cfg["qk_norm"], | |
| dtype=cfg["dtype"] | |
| ) | |
| self.ff = FeedForward(cfg) | |
| self.norm1 = RMSNorm(cfg["emb_dim"], eps=1e-6) | |
| self.norm2 = RMSNorm(cfg["emb_dim"], eps=1e-6) | |
| def forward(self, x, mask, cos, sin): | |
| # Attention block with residual connection | |
| shortcut = x | |
| x = self.norm1(x) | |
| x = self.att(x, mask, cos, sin) | |
| x = x + shortcut | |
| # Feed-forward block with residual connection | |
| shortcut = x | |
| x = self.norm2(x) | |
| x = self.ff(x) | |
| x = x + shortcut | |
| return x | |
| class Qwen3Model(nn.Module): | |
| """ | |
| Complete Qwen3 Language Model | |
| This model can: | |
| 1. Take token IDs as input | |
| 2. Process them through multiple transformer layers | |
| 3. Output predictions for the next token | |
| 4. Generate new text autoregressively | |
| """ | |
| def __init__(self, cfg): | |
| super().__init__() | |
| # Token embedding layer | |
| self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"]) | |
| # Stack of transformer blocks | |
| self.trf_blocks = nn.ModuleList( | |
| [TransformerBlock(cfg) for _ in range(cfg["n_layers"])] | |
| ) | |
| # Final normalization and output projection | |
| self.final_norm = RMSNorm(cfg["emb_dim"]) | |
| self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) | |
| # Precompute RoPE parameters | |
| if cfg["head_dim"] is None: | |
| head_dim = cfg["emb_dim"] // cfg["n_heads"] | |
| else: | |
| head_dim = cfg["head_dim"] | |
| cos, sin = compute_rope_params( | |
| head_dim=head_dim, | |
| theta_base=cfg["rope_base"], | |
| context_length=cfg["context_length"] | |
| ) | |
| self.register_buffer("cos", cos, persistent=False) | |
| self.register_buffer("sin", sin, persistent=False) | |
| self.cfg = cfg | |
| def forward(self, in_idx, targets=None): | |
| """ | |
| Forward pass through the model | |
| Parameters: | |
| ----------- | |
| in_idx : torch.Tensor | |
| Input token IDs with shape [batch_size, sequence_length] | |
| targets : torch.Tensor or None | |
| Target token IDs for computing loss (used during training) | |
| Returns: | |
| -------- | |
| logits : torch.Tensor | |
| Predictions for next tokens with shape [batch_size, sequence_length, vocab_size] | |
| loss : torch.Tensor or None | |
| Cross-entropy loss if targets are provided, otherwise None | |
| """ | |
| # Get token embeddings | |
| tok_embeds = self.tok_emb(in_idx) | |
| x = tok_embeds | |
| # Create causal mask (prevents looking at future tokens) | |
| num_tokens = x.shape[1] | |
| mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1) | |
| # Pass through all transformer blocks | |
| for block in self.trf_blocks: | |
| x = block(x, mask, self.cos, self.sin) | |
| # Final normalization and projection to vocabulary | |
| x = self.final_norm(x) | |
| logits = self.out_head(x.to(self.cfg["dtype"])) | |
| # Compute loss if targets are provided | |
| loss = None | |
| if targets is not None: | |
| loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)), targets.reshape(-1)) | |
| return logits, loss | |
| def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): | |
| """ | |
| Generate new tokens autoregressively | |
| This is a convenience method that wraps the generation logic. | |
| For more details, see the generate_text_simple function. | |
| """ | |
| for _ in range(max_new_tokens): | |
| ctx_len = self.cfg["context_length"] | |
| idx_cond = idx if idx.size(1) <= ctx_len else idx[:, -ctx_len:] | |
| logits, _ = self(idx_cond) | |
| logits = logits[:, -1, :] / temperature | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = float("-inf") | |
| probs = F.softmax(logits, dim=-1) | |
| idx_next = torch.multinomial(probs, num_samples=1) | |
| idx = torch.cat((idx, idx_next), dim=1) | |
| return idx | |