# MAP-NEO Mini Model Architecture # Scaled-down version of MAP-NEO (300M parameters) with RMSNorm, RoPE, and Flash Attention import math import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple import json class RMSNorm(nn.Module): """Root Mean Square Layer Normalization (same as MAP-NEO)""" def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): # RMS normalization norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return x * norm * self.weight class RotaryPositionalEmbedding(nn.Module): """Rotary Position Embedding (RoPE) - same as MAP-NEO""" def __init__(self, dim: int, max_len: int = 8192, theta: float = 10000.0): super().__init__() self.dim = dim self.max_len = max_len self.theta = theta # Precompute frequencies freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("freqs", freqs, persistent=False) def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]: # x shape: [batch, seq_len, n_heads, head_dim] device = x.device positions = torch.arange(seq_len, device=device).float() # Compute angles angles = positions.unsqueeze(1) * self.freqs.unsqueeze(0) # [seq_len, dim//2] cos = torch.cos(angles) # [seq_len, dim//2] sin = torch.sin(angles) # [seq_len, dim//2] return cos, sin def apply_rotary_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: """Apply rotary embedding to query/key tensors""" # x: [batch, seq_len, n_heads, head_dim] # Split into real and imaginary parts x1, x2 = x[..., ::2], x[..., 1::2] # Even and odd indices # Apply rotation rotated = torch.cat([ x1 * cos.unsqueeze(0).unsqueeze(-2) - x2 * sin.unsqueeze(0).unsqueeze(-2), x1 * sin.unsqueeze(0).unsqueeze(-2) + x2 * cos.unsqueeze(0).unsqueeze(-2) ], dim=-1) return rotated class MultiHeadAttention(nn.Module): """Multi-head attention with RoPE and optional Flash Attention""" def __init__(self, dim: int, n_heads: int, dropout: float = 0.0): super().__init__() assert dim % n_heads == 0 self.dim = dim self.n_heads = n_heads self.head_dim = dim // n_heads self.scale = self.head_dim ** -0.5 # Linear projections self.q_proj = nn.Linear(dim, dim, bias=False) self.k_proj = nn.Linear(dim, dim, bias=False) self.v_proj = nn.Linear(dim, dim, bias=False) self.o_proj = nn.Linear(dim, dim, bias=False) self.dropout = nn.Dropout(dropout) # RoPE self.rotary_emb = RotaryPositionalEmbedding(self.head_dim) def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size, seq_len, dim = x.shape # Project to Q, K, V q = self.q_proj(x) # [batch, seq_len, dim] k = self.k_proj(x) # [batch, seq_len, dim] v = self.v_proj(x) # [batch, seq_len, dim] # Reshape for multi-head attention q = q.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) k = k.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) v = v.view(batch_size, seq_len, self.n_heads, self.head_dim).transpose(1, 2) # Apply RoPE to Q and K cos, sin = self.rotary_emb(q, seq_len) q = apply_rotary_emb(q.transpose(1, 2), cos, sin).transpose(1, 2) k = apply_rotary_emb(k.transpose(1, 2), cos, sin).transpose(1, 2) # Try Flash Attention 2 if available try: # Flash Attention 2 format: [batch, seq_len, n_heads, head_dim] q_flash = q.transpose(1, 2) # [batch, seq_len, n_heads, head_dim] k_flash = k.transpose(1, 2) # [batch, seq_len, n_heads, head_dim] v_flash = v.transpose(1, 2) # [batch, seq_len, n_heads, head_dim] # Use Flash Attention (causal mask built-in) out = F.scaled_dot_product_attention( q_flash.transpose(1, 2), k_flash.transpose(1, 2), v_flash.transpose(1, 2), attn_mask=None, # Causal mask applied automatically dropout_p=self.dropout.p if self.training else 0.0, is_causal=True ) out = out.transpose(1, 2) # Back to [batch, seq_len, n_heads, head_dim] except: # Fallback to manual attention scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale # Apply causal mask causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=x.device), diagonal=1).bool() scores = scores.masked_fill(causal_mask, float('-inf')) # Apply attention mask if provided if attention_mask is not None: scores = scores.masked_fill(~attention_mask.unsqueeze(1).unsqueeze(1), float('-inf')) attn_weights = F.softmax(scores, dim=-1) attn_weights = self.dropout(attn_weights) out = torch.matmul(attn_weights, v) # [batch, n_heads, seq_len, head_dim] out = out.transpose(1, 2) # [batch, seq_len, n_heads, head_dim] # Concat heads and project out = out.contiguous().view(batch_size, seq_len, dim) out = self.o_proj(out) return out class FeedForward(nn.Module): """SwiGLU Feed-Forward Network (same as MAP-NEO)""" def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.0): super().__init__() self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) self.up_proj = nn.Linear(dim, hidden_dim, bias=False) self.down_proj = nn.Linear(hidden_dim, dim, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: # SwiGLU activation: swish(gate) * up gate = F.silu(self.gate_proj(x)) # SiLU = Swish up = self.up_proj(x) hidden = gate * up hidden = self.dropout(hidden) return self.down_proj(hidden) class TransformerBlock(nn.Module): """Transformer block with pre-norm (RMSNorm)""" def __init__(self, dim: int, n_heads: int, hidden_dim: int, dropout: float = 0.0): super().__init__() self.attention_norm = RMSNorm(dim) self.attention = MultiHeadAttention(dim, n_heads, dropout) self.ffn_norm = RMSNorm(dim) self.ffn = FeedForward(dim, hidden_dim, dropout) def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: # Pre-norm attention h = x + self.attention(self.attention_norm(x), attention_mask) # Pre-norm FFN h = h + self.ffn(self.ffn_norm(h)) return h class NeoMiniConfig: """Configuration for MAP-NEO Mini (300M parameters)""" def __init__(self): # Model architecture self.vocab_size = 50257 # GPT-2 tokenizer vocab size (will update for MAP-NEO tokenizer) self.max_seq_len = 2048 self.dim = 1024 # Hidden dimension self.n_layers = 16 # Number of transformer layers self.n_heads = 16 # Number of attention heads self.hidden_dim = 2736 # FFN hidden dimension (2.67x of dim) # Training self.dropout = 0.0 # No dropout for pretraining # Approximated parameter count: ~300M def to_dict(self): return {k: v for k, v in self.__dict__.items() if not k.startswith('_')} @classmethod def from_dict(cls, config_dict): config = cls() for k, v in config_dict.items(): setattr(config, k, v) return config class NeoMini(nn.Module): """MAP-NEO Mini Language Model (300M parameters)""" def __init__(self, config: NeoMiniConfig): super().__init__() self.config = config # Embeddings self.token_embedding = nn.Embedding(config.vocab_size, config.dim) # Transformer blocks self.blocks = nn.ModuleList([ TransformerBlock( dim=config.dim, n_heads=config.n_heads, hidden_dim=config.hidden_dim, dropout=config.dropout ) for _ in range(config.n_layers) ]) # Output self.ln_f = RMSNorm(config.dim) self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False) # Tie weights (common in modern LLMs) self.lm_head.weight = self.token_embedding.weight # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: # Token embeddings x = self.token_embedding(input_ids) # Apply transformer blocks for block in self.blocks: x = block(x, attention_mask) # Final layer norm and projection x = self.ln_f(x) logits = self.lm_head(x) return logits def get_num_params(self): """Count model parameters""" return sum(p.numel() for p in self.parameters()) def save_config(self, path: str): """Save model configuration""" with open(path, 'w') as f: json.dump(self.config.to_dict(), f, indent=2) @classmethod def from_config(cls, config_path: str): """Load model from configuration""" with open(config_path, 'r') as f: config_dict = json.load(f) config = NeoMiniConfig.from_dict(config_dict) return cls(config) def create_model(): """Create a MAP-NEO Mini model""" config = NeoMiniConfig() model = NeoMini(config) print(f"Created MAP-NEO Mini with {model.get_num_params():,} parameters") print(f"Config: {config.n_layers} layers, {config.dim} dim, {config.n_heads} heads") return model, config if __name__ == "__main__": # Test model creation model, config = create_model() # Test forward pass batch_size, seq_len = 2, 512 input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len)) with torch.no_grad(): logits = model(input_ids) print(f"Input shape: {input_ids.shape}") print(f"Output shape: {logits.shape}") print("Model test passed!")