""" Full Transformer model for SLM. Implements the mandatory prefill/decode API for Qualcomm deployment. """ import torch import torch.nn as nn from typing import Optional, Tuple, Union from dataclasses import dataclass from .config import SLMConfig from .normalization import RMSNorm from .decoder import DecoderBlock from .attention import create_causal_mask from .kv_cache import KVCache @dataclass class SLMOutput: """Output from SLM forward pass.""" logits: torch.Tensor # [batch, seq, vocab_size] kv_cache: Optional[KVCache] = None hidden_states: Optional[torch.Tensor] = None class SLMModel(nn.Module): """Core transformer model (without LM head). This is the decoder stack: - Token embedding - N decoder blocks - Final RMSNorm """ def __init__(self, config: SLMConfig): """Initialize transformer model. Args: config: Model configuration """ super().__init__() self.config = config # Token embeddings self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) # Decoder layers self.layers = nn.ModuleList([ DecoderBlock(config, layer_idx=i) for i in range(config.num_layers) ]) # Final normalization self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, input_ids: torch.Tensor, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, kv_cache: Optional[KVCache] = None, use_cache: bool = False, ) -> Tuple[torch.Tensor, Optional[KVCache]]: """Forward pass through transformer. Args: input_ids: Token IDs [batch, seq] position_ids: Position indices [batch, seq] attention_mask: Causal mask kv_cache: Optional KV cache use_cache: Whether to use/update cache Returns: Tuple of (hidden_states, kv_cache) """ batch_size, seq_len = input_ids.shape # Create position IDs if not provided if position_ids is None: if kv_cache is not None and kv_cache.seq_len > 0: # For decode: position is the current cache length position_ids = torch.arange( kv_cache.seq_len, kv_cache.seq_len + seq_len, device=input_ids.device ).unsqueeze(0).expand(batch_size, -1) else: # For prefill: positions are 0..seq_len-1 position_ids = torch.arange( seq_len, device=input_ids.device ).unsqueeze(0).expand(batch_size, -1) # Create attention mask if not provided if attention_mask is None: kv_seq_len = seq_len if kv_cache is not None and kv_cache.seq_len > 0: kv_seq_len = kv_cache.seq_len + seq_len attention_mask = create_causal_mask( seq_len=seq_len, kv_seq_len=kv_seq_len, dtype=self.embed_tokens.weight.dtype, device=input_ids.device, ) # Token embeddings hidden_states = self.embed_tokens(input_ids) # Pass through decoder layers for layer in self.layers: hidden_states, kv_cache = layer( hidden_states=hidden_states, position_ids=position_ids, attention_mask=attention_mask, kv_cache=kv_cache, use_cache=use_cache, ) # Final normalization hidden_states = self.norm(hidden_states) return hidden_states, kv_cache class SLMForCausalLM(nn.Module): """SLM with language modeling head. This is the full model with: - Transformer backbone - LM head (tied with embeddings) - Prefill/Decode API for Qualcomm deployment """ def __init__(self, config: SLMConfig): """Initialize causal LM. Args: config: Model configuration """ super().__init__() self.config = config # Transformer backbone self.model = SLMModel(config) # LM head self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # Tie weights if configured if config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight # Initialize weights self.apply(self._init_weights) def _init_weights(self, module: nn.Module): """Initialize model weights.""" std = 0.02 if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) def forward( self, input_ids: torch.Tensor, position_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, kv_cache: Optional[KVCache] = None, use_cache: bool = False, labels: Optional[torch.Tensor] = None, ) -> SLMOutput: """Forward pass for causal LM. Args: input_ids: Token IDs [batch, seq] position_ids: Position indices [batch, seq] attention_mask: Causal mask kv_cache: Optional KV cache use_cache: Whether to use/update cache labels: Optional labels for loss computation Returns: SLMOutput with logits and optional loss """ # Get hidden states from transformer hidden_states, kv_cache = self.model( input_ids=input_ids, position_ids=position_ids, attention_mask=attention_mask, kv_cache=kv_cache, use_cache=use_cache, ) # Compute logits logits = self.lm_head(hidden_states) return SLMOutput( logits=logits, kv_cache=kv_cache, hidden_states=hidden_states, ) # ========================================================================= # MANDATORY KV CACHE API (from architecture.txt) # ========================================================================= def prefill( self, input_ids: torch.Tensor, kv_cache: Optional[KVCache] = None, ) -> Tuple[torch.Tensor, KVCache]: """Prefill: Process full prompt and populate KV cache. This is Graph 1 for Qualcomm deployment. Args: input_ids: Token IDs [batch, seq] kv_cache: Empty or existing KV cache Returns: Tuple of (logits [batch, seq, vocab], populated_kv_cache) """ batch_size = input_ids.shape[0] # Create empty cache if not provided if kv_cache is None: kv_cache = KVCache.create( num_layers=self.config.num_layers, batch_size=batch_size, num_heads=self.config.num_heads, max_seq_len=self.config.max_position_embeddings, head_dim=self.config.head_dim, dtype=self.model.embed_tokens.weight.dtype, device=input_ids.device, ) # Forward pass with cache output = self.forward( input_ids=input_ids, kv_cache=kv_cache, use_cache=True, ) return output.logits, output.kv_cache def decode( self, input_id: torch.Tensor, kv_cache: KVCache, position: Optional[int] = None, ) -> Tuple[torch.Tensor, KVCache]: """Decode: Generate single token using KV cache. This is Graph 2 for Qualcomm deployment. Args: input_id: Single token ID [batch, 1] kv_cache: Populated KV cache from prefill or previous decode position: Position index (defaults to cache.seq_len) Returns: Tuple of (logits [batch, 1, vocab], updated_kv_cache) """ batch_size = input_id.shape[0] # Get position from cache if not provided if position is None: position = kv_cache.seq_len # Create position IDs position_ids = torch.tensor( [[position]], device=input_id.device ).expand(batch_size, -1) # Forward pass with cache output = self.forward( input_ids=input_id, position_ids=position_ids, kv_cache=kv_cache, use_cache=True, ) return output.logits, output.kv_cache def create_empty_cache( self, batch_size: int = 1, device: torch.device = None, ) -> KVCache: """Create an empty KV cache for inference. Args: batch_size: Batch size device: Device for cache tensors Returns: Empty KVCache ready for prefill """ if device is None: device = self.model.embed_tokens.weight.device return KVCache.create( num_layers=self.config.num_layers, batch_size=batch_size, num_heads=self.config.num_heads, max_seq_len=self.config.max_position_embeddings, head_dim=self.config.head_dim, dtype=self.model.embed_tokens.weight.dtype, device=device, ) @property def num_parameters(self) -> int: """Count total trainable parameters.""" return sum(p.numel() for p in self.parameters() if p.requires_grad) @property def device(self) -> torch.device: """Get model device.""" return self.model.embed_tokens.weight.device