Spaces:
Running
Running
| """ | |
| Llama-Style Transformer Model | |
| ============================= | |
| Modern transformer architecture with all Tier 1 and Tier 2 optimizations: | |
| Architecture (Tier 1): | |
| - RMSNorm (faster than LayerNorm, no mean calculation) | |
| - RoPE (Rotary Position Embedding, better length generalization) | |
| - SwiGLU activation (gated FFN, consistently outperforms GELU) | |
| - Pre-norm (apply norm before attention/FFN, more stable training) | |
| Optimizations (Tier 2): | |
| - GQA (Grouped Query Attention, fewer KV heads = faster + less memory) | |
| - Weight tying (share embedding and output projection) | |
| - Flash Attention via F.scaled_dot_product_attention | |
| - Gradient checkpointing support (trade compute for memory) | |
| Compatible with: | |
| - liger-kernel (fused RMSNorm, SwiGLU, RoPE, cross-entropy) | |
| - bf16/fp16 mixed precision training | |
| - torch.compile for additional speedups | |
| Model Sizes: | |
| - tiny: ~15M params (for testing) | |
| - small: ~125M params | |
| - medium: ~350M params | |
| - large: ~760M params | |
| - 1B: ~1.1B params (Llama 3.2 1B style) | |
| """ | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # ============================================================================ | |
| # Model Configuration | |
| # ============================================================================ | |
| class ModelConfig: | |
| """Configuration for Llama-style transformer model.""" | |
| # Model architecture | |
| vocab_size: int = 32000 | |
| d_model: int = 2048 # Hidden dimension | |
| n_layers: int = 16 # Number of transformer blocks | |
| n_heads: int = 32 # Number of attention heads | |
| n_kv_heads: int = 8 # Number of KV heads (for GQA) | |
| d_ff: int = None # FFN intermediate dim (default: 8/3 * d_model) | |
| # Sequence | |
| max_seq_len: int = 2048 # Maximum sequence length | |
| # RoPE | |
| rope_theta: float = 500000.0 # RoPE base frequency | |
| # Regularization | |
| dropout: float = 0.0 # Dropout (0 for pretraining) | |
| # Options | |
| tie_weights: bool = True # Tie embedding and output weights | |
| use_flash_attn: bool = True # Use Flash Attention (SDPA) | |
| def __post_init__(self): | |
| # SwiGLU uses 8/3 * d_model for FFN, rounded to multiple of 256 | |
| if self.d_ff is None: | |
| self.d_ff = int(8 / 3 * self.d_model) | |
| self.d_ff = ((self.d_ff + 255) // 256) * 256 | |
| # Validate GQA configuration | |
| assert self.n_heads % self.n_kv_heads == 0, \ | |
| f"n_heads ({self.n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})" | |
| self.n_kv_groups = self.n_heads // self.n_kv_heads | |
| self.head_dim = self.d_model // self.n_heads | |
| # Predefined model configurations | |
| MODEL_CONFIGS = { | |
| "tiny": ModelConfig( | |
| d_model=256, | |
| n_layers=6, | |
| n_heads=8, | |
| n_kv_heads=4, | |
| max_seq_len=1024, | |
| ), | |
| "small": ModelConfig( | |
| d_model=768, | |
| n_layers=12, | |
| n_heads=12, | |
| n_kv_heads=4, | |
| max_seq_len=2048, | |
| ), | |
| "medium": ModelConfig( | |
| d_model=1024, | |
| n_layers=16, | |
| n_heads=16, | |
| n_kv_heads=4, | |
| max_seq_len=2048, | |
| ), | |
| "large": ModelConfig( | |
| d_model=1536, | |
| n_layers=20, | |
| n_heads=24, | |
| n_kv_heads=8, | |
| max_seq_len=2048, | |
| ), | |
| "1B": ModelConfig( | |
| d_model=2048, | |
| n_layers=16, | |
| n_heads=32, | |
| n_kv_heads=8, | |
| d_ff=8192, # Llama 3.2 1B uses 4x hidden, not 8/3x | |
| max_seq_len=2048, | |
| ), | |
| } | |
| def get_model_config(size: str, **overrides) -> ModelConfig: | |
| """Get a predefined model configuration with optional overrides.""" | |
| if size not in MODEL_CONFIGS: | |
| raise ValueError(f"Unknown model size: {size}. Choose from: {list(MODEL_CONFIGS.keys())}") | |
| config = MODEL_CONFIGS[size] | |
| # Apply overrides | |
| for key, value in overrides.items(): | |
| if hasattr(config, key): | |
| setattr(config, key, value) | |
| else: | |
| raise ValueError(f"Unknown config parameter: {key}") | |
| # Recompute derived values | |
| config.__post_init__() | |
| return config | |
| # ============================================================================ | |
| # RMSNorm (Tier 1) | |
| # ============================================================================ | |
| class RMSNorm(nn.Module): | |
| """ | |
| Root Mean Square Layer Normalization. | |
| Simpler and faster than LayerNorm - skips the mean calculation. | |
| Used in Llama, Mistral, and other modern LLMs. | |
| Can be replaced with liger_kernel.transformers.LigerRMSNorm for | |
| additional speedup via kernel fusion. | |
| """ | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def _norm(self, x: torch.Tensor) -> torch.Tensor: | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight | |
| # ============================================================================ | |
| # Rotary Position Embedding (RoPE) (Tier 1) | |
| # ============================================================================ | |
| def precompute_rope_freqs( | |
| dim: int, | |
| max_seq_len: int, | |
| theta: float = 10000.0, | |
| device: torch.device = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """ | |
| Precompute the cos and sin frequencies for RoPE. | |
| Args: | |
| dim: Head dimension (d_model // n_heads) | |
| max_seq_len: Maximum sequence length | |
| theta: Base frequency (Llama 3 uses 500000) | |
| device: Target device | |
| Returns: | |
| cos, sin tensors of shape (max_seq_len, dim) | |
| """ | |
| # Compute inverse frequencies | |
| freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device).float() / dim)) | |
| # Create position indices | |
| t = torch.arange(max_seq_len, device=device) | |
| # Outer product: (seq_len,) x (dim/2,) -> (seq_len, dim/2) | |
| freqs = torch.outer(t, freqs) | |
| # Compute cos and sin, then interleave to get (seq_len, dim) | |
| cos = torch.cos(freqs).repeat_interleave(2, dim=-1) | |
| sin = torch.sin(freqs).repeat_interleave(2, dim=-1) | |
| return cos, sin | |
| def apply_rotary_emb( | |
| x: torch.Tensor, | |
| cos: torch.Tensor, | |
| sin: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Apply rotary position embedding to input tensor. | |
| Args: | |
| x: Input tensor of shape (batch, n_heads, seq_len, head_dim) | |
| cos: Cosine frequencies of shape (seq_len, head_dim) | |
| sin: Sine frequencies of shape (seq_len, head_dim) | |
| Returns: | |
| Tensor with rotary embedding applied | |
| """ | |
| # Get sequence length from input | |
| seq_len = x.size(2) | |
| cos = cos[:seq_len] | |
| sin = sin[:seq_len] | |
| # Reshape for broadcasting: (1, 1, seq_len, head_dim) | |
| cos = cos.unsqueeze(0).unsqueeze(0) | |
| sin = sin.unsqueeze(0).unsqueeze(0) | |
| # Rotate pairs: [x0, x1, x2, x3, ...] -> [-x1, x0, -x3, x2, ...] | |
| x_rot = torch.stack([-x[..., 1::2], x[..., ::2]], dim=-1) | |
| x_rot = x_rot.reshape(x.shape) | |
| # Apply rotation | |
| return x * cos + x_rot * sin | |
| # ============================================================================ | |
| # Grouped Query Attention (GQA) with Flash Attention (Tier 1 + Tier 2) | |
| # ============================================================================ | |
| class Attention(nn.Module): | |
| """ | |
| Multi-head attention with Grouped Query Attention (GQA) and Flash Attention. | |
| GQA uses fewer key-value heads than query heads, reducing memory and | |
| compute while maintaining quality. For example, with 32 query heads and | |
| 8 KV heads, each KV head is shared by 4 query heads. | |
| Flash Attention is used via PyTorch's scaled_dot_product_attention, | |
| which provides O(N) memory complexity instead of O(N^2). | |
| """ | |
| def __init__(self, config: ModelConfig): | |
| super().__init__() | |
| self.config = config | |
| self.n_heads = config.n_heads | |
| self.n_kv_heads = config.n_kv_heads | |
| self.n_kv_groups = config.n_kv_groups | |
| self.head_dim = config.head_dim | |
| # Query projection: full heads | |
| self.wq = nn.Linear(config.d_model, config.n_heads * config.head_dim, bias=False) | |
| # Key and Value projections: fewer heads for GQA | |
| self.wk = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=False) | |
| self.wv = nn.Linear(config.d_model, config.n_kv_heads * config.head_dim, bias=False) | |
| # Output projection | |
| self.wo = nn.Linear(config.n_heads * config.head_dim, config.d_model, bias=False) | |
| self.dropout = nn.Dropout(config.dropout) | |
| self.use_flash_attn = config.use_flash_attn | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| cos: torch.Tensor, | |
| sin: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| batch_size, seq_len, _ = x.shape | |
| # Project to Q, K, V | |
| q = self.wq(x) # (B, T, n_heads * head_dim) | |
| k = self.wk(x) # (B, T, n_kv_heads * head_dim) | |
| v = self.wv(x) # (B, T, n_kv_heads * head_dim) | |
| # Reshape to (B, n_heads, T, head_dim) | |
| q = q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) | |
| k = k.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) | |
| v = v.view(batch_size, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2) | |
| # Apply RoPE to Q and K | |
| q = apply_rotary_emb(q, cos, sin) | |
| k = apply_rotary_emb(k, cos, sin) | |
| # Expand KV heads for GQA: (B, n_kv_heads, T, head_dim) -> (B, n_heads, T, head_dim) | |
| if self.n_kv_groups > 1: | |
| k = k.repeat_interleave(self.n_kv_groups, dim=1) | |
| v = v.repeat_interleave(self.n_kv_groups, dim=1) | |
| # Attention | |
| if self.use_flash_attn: | |
| # Use PyTorch's optimized SDPA (Flash Attention when available) | |
| attn_out = F.scaled_dot_product_attention( | |
| q, k, v, | |
| attn_mask=mask, | |
| dropout_p=self.dropout.p if self.training else 0.0, | |
| is_causal=mask is None, # Use causal mask if no explicit mask | |
| ) | |
| else: | |
| # Manual attention (for debugging or when SDPA unavailable) | |
| scale = 1.0 / math.sqrt(self.head_dim) | |
| attn_weights = torch.matmul(q, k.transpose(-2, -1)) * scale | |
| if mask is not None: | |
| attn_weights = attn_weights + mask | |
| else: | |
| # Causal mask | |
| causal_mask = torch.triu( | |
| torch.full((seq_len, seq_len), float('-inf'), device=x.device), | |
| diagonal=1 | |
| ) | |
| attn_weights = attn_weights + causal_mask | |
| attn_weights = F.softmax(attn_weights, dim=-1) | |
| attn_weights = self.dropout(attn_weights) | |
| attn_out = torch.matmul(attn_weights, v) | |
| # Reshape back: (B, n_heads, T, head_dim) -> (B, T, d_model) | |
| attn_out = attn_out.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) | |
| return self.wo(attn_out) | |
| # ============================================================================ | |
| # SwiGLU Feed-Forward Network (Tier 1) | |
| # ============================================================================ | |
| class FeedForward(nn.Module): | |
| """ | |
| SwiGLU Feed-Forward Network. | |
| Replaces the standard GELU FFN with a gated linear unit using SiLU activation. | |
| Uses 3 weight matrices (gate, up, down) instead of 2. | |
| SwiGLU(x) = (x * W_gate * SiLU) * (x * W_up) * W_down | |
| Consistently outperforms GELU at the same compute budget. | |
| Can be replaced with liger_kernel.transformers.LigerSwiGLUMLP for fusion. | |
| """ | |
| def __init__(self, config: ModelConfig): | |
| super().__init__() | |
| hidden_dim = config.d_ff | |
| # Gate and up projections (can be fused) | |
| self.w_gate = nn.Linear(config.d_model, hidden_dim, bias=False) | |
| self.w_up = nn.Linear(config.d_model, hidden_dim, bias=False) | |
| # Down projection | |
| self.w_down = nn.Linear(hidden_dim, config.d_model, bias=False) | |
| self.dropout = nn.Dropout(config.dropout) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # SwiGLU: SiLU(gate) * up, then project down | |
| return self.dropout(self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))) | |
| # ============================================================================ | |
| # Transformer Block (Pre-norm) | |
| # ============================================================================ | |
| class TransformerBlock(nn.Module): | |
| """ | |
| Single transformer block with pre-norm architecture. | |
| Pre-norm applies normalization BEFORE attention/FFN (not after), | |
| which provides more stable gradients at scale. | |
| Structure: | |
| x = x + Attention(RMSNorm(x)) | |
| x = x + FFN(RMSNorm(x)) | |
| """ | |
| def __init__(self, config: ModelConfig, layer_idx: int): | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| # Pre-norm layers | |
| self.attn_norm = RMSNorm(config.d_model) | |
| self.ffn_norm = RMSNorm(config.d_model) | |
| # Attention and FFN | |
| self.attn = Attention(config) | |
| self.ffn = FeedForward(config) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| cos: torch.Tensor, | |
| sin: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| # Pre-norm attention with residual | |
| x = x + self.attn(self.attn_norm(x), cos, sin, mask) | |
| # Pre-norm FFN with residual | |
| x = x + self.ffn(self.ffn_norm(x)) | |
| return x | |
| # ============================================================================ | |
| # Complete Llama Model | |
| # ============================================================================ | |
| class LlamaModel(nn.Module): | |
| """ | |
| Complete Llama-style transformer model for language modeling. | |
| Features: | |
| - RMSNorm, RoPE, SwiGLU, GQA (Tier 1) | |
| - Weight tying, Flash Attention (Tier 2) | |
| - Gradient checkpointing support | |
| - Compatible with liger-kernel fused ops | |
| Usage: | |
| config = get_model_config("1B", vocab_size=32000) | |
| model = LlamaModel(config) | |
| # Enable gradient checkpointing for memory savings | |
| model.gradient_checkpointing_enable() | |
| # Forward pass | |
| logits = model(input_ids) | |
| loss = model(input_ids, targets=targets) | |
| """ | |
| def __init__(self, config: ModelConfig): | |
| super().__init__() | |
| self.config = config | |
| # Token embedding | |
| self.tok_emb = nn.Embedding(config.vocab_size, config.d_model) | |
| # Transformer blocks | |
| self.layers = nn.ModuleList([ | |
| TransformerBlock(config, layer_idx=i) | |
| for i in range(config.n_layers) | |
| ]) | |
| # Final normalization | |
| self.norm = RMSNorm(config.d_model) | |
| # Output projection (language model head) | |
| self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
| # Weight tying: share embedding and output weights | |
| if config.tie_weights: | |
| self.lm_head.weight = self.tok_emb.weight | |
| # Precompute RoPE frequencies | |
| self.register_buffer( | |
| "rope_cos", | |
| torch.zeros(config.max_seq_len, config.head_dim), | |
| persistent=False | |
| ) | |
| self.register_buffer( | |
| "rope_sin", | |
| torch.zeros(config.max_seq_len, config.head_dim), | |
| persistent=False | |
| ) | |
| # Gradient checkpointing flag | |
| self._gradient_checkpointing = False | |
| # Initialize weights | |
| self.apply(self._init_weights) | |
| # Apply special initialization for output projection | |
| self._init_output_weights() | |
| def _init_weights(self, module: nn.Module): | |
| """Initialize weights using Llama-style initialization.""" | |
| if isinstance(module, nn.Linear): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| if module.bias is not None: | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.Embedding): | |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) | |
| def _init_output_weights(self): | |
| """Apply scaled initialization to output projections for stability.""" | |
| # Scale down residual projections by 1/sqrt(2*n_layers) | |
| scale = (2 * self.config.n_layers) ** -0.5 | |
| for layer in self.layers: | |
| torch.nn.init.normal_(layer.attn.wo.weight, mean=0.0, std=0.02 * scale) | |
| torch.nn.init.normal_(layer.ffn.w_down.weight, mean=0.0, std=0.02 * scale) | |
| def _init_rope(self, device: torch.device): | |
| """Initialize RoPE frequencies on the correct device.""" | |
| cos, sin = precompute_rope_freqs( | |
| dim=self.config.head_dim, | |
| max_seq_len=self.config.max_seq_len, | |
| theta=self.config.rope_theta, | |
| device=device, | |
| ) | |
| self.rope_cos = cos | |
| self.rope_sin = sin | |
| def gradient_checkpointing_enable(self): | |
| """Enable gradient checkpointing for memory-efficient training.""" | |
| self._gradient_checkpointing = True | |
| def gradient_checkpointing_disable(self): | |
| """Disable gradient checkpointing.""" | |
| self._gradient_checkpointing = False | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| targets: Optional[torch.Tensor] = None, | |
| mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Forward pass. | |
| Args: | |
| input_ids: Token IDs of shape (batch_size, seq_len) | |
| targets: Optional target IDs for loss computation | |
| mask: Optional attention mask | |
| Returns: | |
| If targets provided: scalar loss | |
| Otherwise: logits of shape (batch_size, seq_len, vocab_size) | |
| """ | |
| batch_size, seq_len = input_ids.shape | |
| device = input_ids.device | |
| # Initialize RoPE on first forward pass (ensures correct device) | |
| if self.rope_cos.device != device or self.rope_cos.sum() == 0: | |
| self._init_rope(device) | |
| # Token embeddings | |
| x = self.tok_emb(input_ids) | |
| # Get RoPE frequencies for this sequence length | |
| cos = self.rope_cos[:seq_len] | |
| sin = self.rope_sin[:seq_len] | |
| # Transformer blocks | |
| for layer in self.layers: | |
| if self._gradient_checkpointing and self.training: | |
| x = torch.utils.checkpoint.checkpoint( | |
| layer, x, cos, sin, mask, | |
| use_reentrant=False | |
| ) | |
| else: | |
| x = layer(x, cos, sin, mask) | |
| # Final norm | |
| x = self.norm(x) | |
| # Compute logits | |
| logits = self.lm_head(x) | |
| # Compute loss if targets provided | |
| if targets is not None: | |
| # NOTE: No shift here — the DataLoader already provides | |
| # pre-shifted targets (x = tokens[:-1], y = tokens[1:]), | |
| # so logits[k] should predict targets[k] directly. | |
| loss = F.cross_entropy( | |
| logits.view(-1, self.config.vocab_size), | |
| targets.view(-1), | |
| ignore_index=-100, # Ignore padding | |
| ) | |
| return loss | |
| return logits | |
| def generate( | |
| self, | |
| input_ids: torch.Tensor, | |
| max_new_tokens: int = 100, | |
| temperature: float = 1.0, | |
| top_k: Optional[int] = None, | |
| top_p: Optional[float] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Generate tokens autoregressively. | |
| Args: | |
| input_ids: Starting token IDs (batch_size, seq_len) | |
| max_new_tokens: Maximum number of tokens to generate | |
| temperature: Sampling temperature (1.0 = neutral) | |
| top_k: If set, only sample from top k tokens | |
| top_p: If set, use nucleus sampling with this probability mass | |
| Returns: | |
| Generated token IDs (batch_size, seq_len + max_new_tokens) | |
| """ | |
| self.eval() | |
| for _ in range(max_new_tokens): | |
| # Crop to max_seq_len if needed | |
| idx_cond = input_ids if input_ids.size(1) <= self.config.max_seq_len else \ | |
| input_ids[:, -self.config.max_seq_len:] | |
| # Forward pass | |
| logits = self(idx_cond) | |
| # Get logits for last position | |
| logits = logits[:, -1, :] / temperature | |
| # Apply top-k filtering | |
| if top_k is not None: | |
| v, _ = torch.topk(logits, min(top_k, logits.size(-1))) | |
| logits[logits < v[:, [-1]]] = float('-inf') | |
| # Apply top-p (nucleus) filtering | |
| if top_p is not None: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| # Remove tokens with cumulative probability above threshold | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices_to_remove.scatter( | |
| 1, sorted_indices, sorted_indices_to_remove | |
| ) | |
| logits[indices_to_remove] = float('-inf') | |
| # Sample | |
| probs = F.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| # Append | |
| input_ids = torch.cat([input_ids, next_token], dim=1) | |
| return input_ids | |
| def count_parameters(self, trainable_only: bool = True) -> int: | |
| """Count model parameters.""" | |
| if trainable_only: | |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| return sum(p.numel() for p in self.parameters()) | |
| def estimate_flops(self, seq_len: int, batch_size: int = 1) -> int: | |
| """ | |
| Estimate FLOPs for a forward pass. | |
| Uses the approximation: FLOPs ≈ 2 * params * tokens | |
| (multiply-add counts as 2 ops) | |
| """ | |
| params = self.count_parameters(trainable_only=False) | |
| tokens = batch_size * seq_len | |
| return 2 * params * tokens | |
| # ============================================================================ | |
| # Utility Functions | |
| # ============================================================================ | |
| def create_model( | |
| size: str = "1B", | |
| vocab_size: int = 32000, | |
| max_seq_len: int = 2048, | |
| **kwargs | |
| ) -> LlamaModel: | |
| """ | |
| Create a Llama model with the specified configuration. | |
| Args: | |
| size: Model size ("tiny", "small", "medium", "large", "1B") | |
| vocab_size: Vocabulary size | |
| max_seq_len: Maximum sequence length | |
| **kwargs: Additional config overrides | |
| Returns: | |
| Initialized LlamaModel | |
| """ | |
| config = get_model_config( | |
| size, | |
| vocab_size=vocab_size, | |
| max_seq_len=max_seq_len, | |
| **kwargs | |
| ) | |
| return LlamaModel(config) | |
| def print_model_summary(model: LlamaModel): | |
| """Print a summary of the model architecture.""" | |
| config = model.config | |
| params = model.count_parameters() | |
| print("\n" + "=" * 60) | |
| print("LLAMA MODEL SUMMARY") | |
| print("=" * 60) | |
| print(f"\nArchitecture:") | |
| print(f" Hidden dim: {config.d_model}") | |
| print(f" Layers: {config.n_layers}") | |
| print(f" Attention heads: {config.n_heads}") | |
| print(f" KV heads (GQA): {config.n_kv_heads}") | |
| print(f" Head dim: {config.head_dim}") | |
| print(f" FFN dim: {config.d_ff}") | |
| print(f" Vocab size: {config.vocab_size}") | |
| print(f" Max seq len: {config.max_seq_len}") | |
| print(f"\nOptimizations:") | |
| print(f" RMSNorm: Yes") | |
| print(f" RoPE: Yes (theta={config.rope_theta})") | |
| print(f" SwiGLU: Yes") | |
| print(f" GQA: Yes ({config.n_heads}/{config.n_kv_heads} = {config.n_kv_groups}x)") | |
| print(f" Weight tying: {config.tie_weights}") | |
| print(f" Flash Attention: {config.use_flash_attn}") | |
| print(f"\nParameters:") | |
| print(f" Total: {params:,}") | |
| print(f" Size: ~{params / 1e9:.2f}B" if params > 1e9 else f" Size: ~{params / 1e6:.0f}M") | |
| # Estimate memory | |
| param_bytes = params * 4 # fp32 | |
| print(f" FP32 memory: ~{param_bytes / 1e9:.2f} GB") | |
| print(f" BF16 memory: ~{param_bytes / 2 / 1e9:.2f} GB") | |
| print("=" * 60 + "\n") | |
| # ============================================================================ | |
| # Main (for testing) | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| # Test model creation | |
| print("Testing Llama model creation...\n") | |
| for size in ["tiny", "small", "medium", "large", "1B"]: | |
| model = create_model(size) | |
| params = model.count_parameters() | |
| print(f"{size:8s}: {params:>12,} parameters ({params/1e6:>7.1f}M)") | |
| print("\n" + "-" * 60) | |
| # Detailed summary for 1B | |
| model = create_model("1B") | |
| print_model_summary(model) | |
| # Test forward pass | |
| print("Testing forward pass...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = model.to(device) | |
| batch_size = 2 | |
| seq_len = 128 | |
| input_ids = torch.randint(0, 32000, (batch_size, seq_len), device=device) | |
| # Forward without targets (returns logits) | |
| logits = model(input_ids) | |
| print(f"Logits shape: {logits.shape}") | |
| # Forward with targets (returns loss) | |
| targets = torch.randint(0, 32000, (batch_size, seq_len), device=device) | |
| loss = model(input_ids, targets=targets) | |
| print(f"Loss: {loss.item():.4f}") | |
| # Test gradient checkpointing | |
| print("\nTesting gradient checkpointing...") | |
| model.gradient_checkpointing_enable() | |
| loss = model(input_ids, targets=targets) | |
| loss.backward() | |
| print(f"Gradient checkpointing loss: {loss.item():.4f}") | |
| print("\nAll tests passed!") | |