| | """ |
| | Full Transformer model for SLM. |
| | Implements the mandatory prefill/decode API for Qualcomm deployment. |
| | """ |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from typing import Optional, Tuple, Union |
| | from dataclasses import dataclass |
| |
|
| | from .config import SLMConfig |
| | from .normalization import RMSNorm |
| | from .decoder import DecoderBlock |
| | from .attention import create_causal_mask |
| | from .kv_cache import KVCache |
| |
|
| |
|
| | @dataclass |
| | class SLMOutput: |
| | """Output from SLM forward pass.""" |
| |
|
| | logits: torch.Tensor |
| | kv_cache: Optional[KVCache] = None |
| | hidden_states: Optional[torch.Tensor] = None |
| |
|
| |
|
| | class SLMModel(nn.Module): |
| | """Core transformer model (without LM head). |
| | |
| | This is the decoder stack: |
| | - Token embedding |
| | - N decoder blocks |
| | - Final RMSNorm |
| | """ |
| |
|
| | def __init__(self, config: SLMConfig): |
| | """Initialize transformer model. |
| | |
| | Args: |
| | config: Model configuration |
| | """ |
| | super().__init__() |
| | self.config = config |
| |
|
| | |
| | self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
| |
|
| | |
| | self.layers = nn.ModuleList([ |
| | DecoderBlock(config, layer_idx=i) |
| | for i in range(config.num_layers) |
| | ]) |
| |
|
| | |
| | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | position_ids: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | kv_cache: Optional[KVCache] = None, |
| | use_cache: bool = False, |
| | ) -> Tuple[torch.Tensor, Optional[KVCache]]: |
| | """Forward pass through transformer. |
| | |
| | Args: |
| | input_ids: Token IDs [batch, seq] |
| | position_ids: Position indices [batch, seq] |
| | attention_mask: Causal mask |
| | kv_cache: Optional KV cache |
| | use_cache: Whether to use/update cache |
| | |
| | Returns: |
| | Tuple of (hidden_states, kv_cache) |
| | """ |
| | batch_size, seq_len = input_ids.shape |
| |
|
| | |
| | if position_ids is None: |
| | if kv_cache is not None and kv_cache.seq_len > 0: |
| | |
| | position_ids = torch.arange( |
| | kv_cache.seq_len, kv_cache.seq_len + seq_len, |
| | device=input_ids.device |
| | ).unsqueeze(0).expand(batch_size, -1) |
| | else: |
| | |
| | position_ids = torch.arange( |
| | seq_len, device=input_ids.device |
| | ).unsqueeze(0).expand(batch_size, -1) |
| |
|
| | |
| | if attention_mask is None: |
| | kv_seq_len = seq_len |
| | if kv_cache is not None and kv_cache.seq_len > 0: |
| | kv_seq_len = kv_cache.seq_len + seq_len |
| |
|
| | attention_mask = create_causal_mask( |
| | seq_len=seq_len, |
| | kv_seq_len=kv_seq_len, |
| | dtype=self.embed_tokens.weight.dtype, |
| | device=input_ids.device, |
| | ) |
| |
|
| | |
| | hidden_states = self.embed_tokens(input_ids) |
| |
|
| | |
| | for layer in self.layers: |
| | hidden_states, kv_cache = layer( |
| | hidden_states=hidden_states, |
| | position_ids=position_ids, |
| | attention_mask=attention_mask, |
| | kv_cache=kv_cache, |
| | use_cache=use_cache, |
| | ) |
| |
|
| | |
| | hidden_states = self.norm(hidden_states) |
| |
|
| | return hidden_states, kv_cache |
| |
|
| |
|
| | class SLMForCausalLM(nn.Module): |
| | """SLM with language modeling head. |
| | |
| | This is the full model with: |
| | - Transformer backbone |
| | - LM head (tied with embeddings) |
| | - Prefill/Decode API for Qualcomm deployment |
| | """ |
| |
|
| | def __init__(self, config: SLMConfig): |
| | """Initialize causal LM. |
| | |
| | Args: |
| | config: Model configuration |
| | """ |
| | super().__init__() |
| | self.config = config |
| |
|
| | |
| | self.model = SLMModel(config) |
| |
|
| | |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
|
| | |
| | if config.tie_word_embeddings: |
| | self.lm_head.weight = self.model.embed_tokens.weight |
| |
|
| | |
| | self.apply(self._init_weights) |
| |
|
| | def _init_weights(self, module: nn.Module): |
| | """Initialize model weights.""" |
| | std = 0.02 |
| | if isinstance(module, nn.Linear): |
| | module.weight.data.normal_(mean=0.0, std=std) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.Embedding): |
| | module.weight.data.normal_(mean=0.0, std=std) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | position_ids: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | kv_cache: Optional[KVCache] = None, |
| | use_cache: bool = False, |
| | labels: Optional[torch.Tensor] = None, |
| | ) -> SLMOutput: |
| | """Forward pass for causal LM. |
| | |
| | Args: |
| | input_ids: Token IDs [batch, seq] |
| | position_ids: Position indices [batch, seq] |
| | attention_mask: Causal mask |
| | kv_cache: Optional KV cache |
| | use_cache: Whether to use/update cache |
| | labels: Optional labels for loss computation |
| | |
| | Returns: |
| | SLMOutput with logits and optional loss |
| | """ |
| | |
| | hidden_states, kv_cache = self.model( |
| | input_ids=input_ids, |
| | position_ids=position_ids, |
| | attention_mask=attention_mask, |
| | kv_cache=kv_cache, |
| | use_cache=use_cache, |
| | ) |
| |
|
| | |
| | logits = self.lm_head(hidden_states) |
| |
|
| | return SLMOutput( |
| | logits=logits, |
| | kv_cache=kv_cache, |
| | hidden_states=hidden_states, |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | def prefill( |
| | self, |
| | input_ids: torch.Tensor, |
| | kv_cache: Optional[KVCache] = None, |
| | ) -> Tuple[torch.Tensor, KVCache]: |
| | """Prefill: Process full prompt and populate KV cache. |
| | |
| | This is Graph 1 for Qualcomm deployment. |
| | |
| | Args: |
| | input_ids: Token IDs [batch, seq] |
| | kv_cache: Empty or existing KV cache |
| | |
| | Returns: |
| | Tuple of (logits [batch, seq, vocab], populated_kv_cache) |
| | """ |
| | batch_size = input_ids.shape[0] |
| |
|
| | |
| | if kv_cache is None: |
| | kv_cache = KVCache.create( |
| | num_layers=self.config.num_layers, |
| | batch_size=batch_size, |
| | num_heads=self.config.num_heads, |
| | max_seq_len=self.config.max_position_embeddings, |
| | head_dim=self.config.head_dim, |
| | dtype=self.model.embed_tokens.weight.dtype, |
| | device=input_ids.device, |
| | ) |
| |
|
| | |
| | output = self.forward( |
| | input_ids=input_ids, |
| | kv_cache=kv_cache, |
| | use_cache=True, |
| | ) |
| |
|
| | return output.logits, output.kv_cache |
| |
|
| | def decode( |
| | self, |
| | input_id: torch.Tensor, |
| | kv_cache: KVCache, |
| | position: Optional[int] = None, |
| | ) -> Tuple[torch.Tensor, KVCache]: |
| | """Decode: Generate single token using KV cache. |
| | |
| | This is Graph 2 for Qualcomm deployment. |
| | |
| | Args: |
| | input_id: Single token ID [batch, 1] |
| | kv_cache: Populated KV cache from prefill or previous decode |
| | position: Position index (defaults to cache.seq_len) |
| | |
| | Returns: |
| | Tuple of (logits [batch, 1, vocab], updated_kv_cache) |
| | """ |
| | batch_size = input_id.shape[0] |
| |
|
| | |
| | if position is None: |
| | position = kv_cache.seq_len |
| |
|
| | |
| | position_ids = torch.tensor( |
| | [[position]], device=input_id.device |
| | ).expand(batch_size, -1) |
| |
|
| | |
| | output = self.forward( |
| | input_ids=input_id, |
| | position_ids=position_ids, |
| | kv_cache=kv_cache, |
| | use_cache=True, |
| | ) |
| |
|
| | return output.logits, output.kv_cache |
| |
|
| | def create_empty_cache( |
| | self, |
| | batch_size: int = 1, |
| | device: torch.device = None, |
| | ) -> KVCache: |
| | """Create an empty KV cache for inference. |
| | |
| | Args: |
| | batch_size: Batch size |
| | device: Device for cache tensors |
| | |
| | Returns: |
| | Empty KVCache ready for prefill |
| | """ |
| | if device is None: |
| | device = self.model.embed_tokens.weight.device |
| |
|
| | return KVCache.create( |
| | num_layers=self.config.num_layers, |
| | batch_size=batch_size, |
| | num_heads=self.config.num_heads, |
| | max_seq_len=self.config.max_position_embeddings, |
| | head_dim=self.config.head_dim, |
| | dtype=self.model.embed_tokens.weight.dtype, |
| | device=device, |
| | ) |
| |
|
| | @property |
| | def num_parameters(self) -> int: |
| | """Count total trainable parameters.""" |
| | return sum(p.numel() for p in self.parameters() if p.requires_grad) |
| |
|
| | @property |
| | def device(self) -> torch.device: |
| | """Get model device.""" |
| | return self.model.embed_tokens.weight.device |
| |
|