|
|
""" |
|
|
Full Transformer model for SLM. |
|
|
Implements the mandatory prefill/decode API for Qualcomm deployment. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Optional, Tuple, Union |
|
|
from dataclasses import dataclass |
|
|
|
|
|
from .config import SLMConfig |
|
|
from .normalization import RMSNorm |
|
|
from .decoder import DecoderBlock |
|
|
from .attention import create_causal_mask |
|
|
from .kv_cache import KVCache |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class SLMOutput: |
|
|
"""Output from SLM forward pass.""" |
|
|
|
|
|
logits: torch.Tensor |
|
|
kv_cache: Optional[KVCache] = None |
|
|
hidden_states: Optional[torch.Tensor] = None |
|
|
|
|
|
|
|
|
class SLMModel(nn.Module): |
|
|
"""Core transformer model (without LM head). |
|
|
|
|
|
This is the decoder stack: |
|
|
- Token embedding |
|
|
- N decoder blocks |
|
|
- Final RMSNorm |
|
|
""" |
|
|
|
|
|
def __init__(self, config: SLMConfig): |
|
|
"""Initialize transformer model. |
|
|
|
|
|
Args: |
|
|
config: Model configuration |
|
|
""" |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
DecoderBlock(config, layer_idx=i) |
|
|
for i in range(config.num_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
kv_cache: Optional[KVCache] = None, |
|
|
use_cache: bool = False, |
|
|
) -> Tuple[torch.Tensor, Optional[KVCache]]: |
|
|
"""Forward pass through transformer. |
|
|
|
|
|
Args: |
|
|
input_ids: Token IDs [batch, seq] |
|
|
position_ids: Position indices [batch, seq] |
|
|
attention_mask: Causal mask |
|
|
kv_cache: Optional KV cache |
|
|
use_cache: Whether to use/update cache |
|
|
|
|
|
Returns: |
|
|
Tuple of (hidden_states, kv_cache) |
|
|
""" |
|
|
batch_size, seq_len = input_ids.shape |
|
|
|
|
|
|
|
|
if position_ids is None: |
|
|
if kv_cache is not None and kv_cache.seq_len > 0: |
|
|
|
|
|
position_ids = torch.arange( |
|
|
kv_cache.seq_len, kv_cache.seq_len + seq_len, |
|
|
device=input_ids.device |
|
|
).unsqueeze(0).expand(batch_size, -1) |
|
|
else: |
|
|
|
|
|
position_ids = torch.arange( |
|
|
seq_len, device=input_ids.device |
|
|
).unsqueeze(0).expand(batch_size, -1) |
|
|
|
|
|
|
|
|
if attention_mask is None: |
|
|
kv_seq_len = seq_len |
|
|
if kv_cache is not None and kv_cache.seq_len > 0: |
|
|
kv_seq_len = kv_cache.seq_len + seq_len |
|
|
|
|
|
attention_mask = create_causal_mask( |
|
|
seq_len=seq_len, |
|
|
kv_seq_len=kv_seq_len, |
|
|
dtype=self.embed_tokens.weight.dtype, |
|
|
device=input_ids.device, |
|
|
) |
|
|
|
|
|
|
|
|
hidden_states = self.embed_tokens(input_ids) |
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
|
hidden_states, kv_cache = layer( |
|
|
hidden_states=hidden_states, |
|
|
position_ids=position_ids, |
|
|
attention_mask=attention_mask, |
|
|
kv_cache=kv_cache, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
|
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
return hidden_states, kv_cache |
|
|
|
|
|
|
|
|
class SLMForCausalLM(nn.Module): |
|
|
"""SLM with language modeling head. |
|
|
|
|
|
This is the full model with: |
|
|
- Transformer backbone |
|
|
- LM head (tied with embeddings) |
|
|
- Prefill/Decode API for Qualcomm deployment |
|
|
""" |
|
|
|
|
|
def __init__(self, config: SLMConfig): |
|
|
"""Initialize causal LM. |
|
|
|
|
|
Args: |
|
|
config: Model configuration |
|
|
""" |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.model = SLMModel(config) |
|
|
|
|
|
|
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
if config.tie_word_embeddings: |
|
|
self.lm_head.weight = self.model.embed_tokens.weight |
|
|
|
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, module: nn.Module): |
|
|
"""Initialize model weights.""" |
|
|
std = 0.02 |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
kv_cache: Optional[KVCache] = None, |
|
|
use_cache: bool = False, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
) -> SLMOutput: |
|
|
"""Forward pass for causal LM. |
|
|
|
|
|
Args: |
|
|
input_ids: Token IDs [batch, seq] |
|
|
position_ids: Position indices [batch, seq] |
|
|
attention_mask: Causal mask |
|
|
kv_cache: Optional KV cache |
|
|
use_cache: Whether to use/update cache |
|
|
labels: Optional labels for loss computation |
|
|
|
|
|
Returns: |
|
|
SLMOutput with logits and optional loss |
|
|
""" |
|
|
|
|
|
hidden_states, kv_cache = self.model( |
|
|
input_ids=input_ids, |
|
|
position_ids=position_ids, |
|
|
attention_mask=attention_mask, |
|
|
kv_cache=kv_cache, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
|
|
|
|
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
return SLMOutput( |
|
|
logits=logits, |
|
|
kv_cache=kv_cache, |
|
|
hidden_states=hidden_states, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prefill( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
kv_cache: Optional[KVCache] = None, |
|
|
) -> Tuple[torch.Tensor, KVCache]: |
|
|
"""Prefill: Process full prompt and populate KV cache. |
|
|
|
|
|
This is Graph 1 for Qualcomm deployment. |
|
|
|
|
|
Args: |
|
|
input_ids: Token IDs [batch, seq] |
|
|
kv_cache: Empty or existing KV cache |
|
|
|
|
|
Returns: |
|
|
Tuple of (logits [batch, seq, vocab], populated_kv_cache) |
|
|
""" |
|
|
batch_size = input_ids.shape[0] |
|
|
|
|
|
|
|
|
if kv_cache is None: |
|
|
kv_cache = KVCache.create( |
|
|
num_layers=self.config.num_layers, |
|
|
batch_size=batch_size, |
|
|
num_heads=self.config.num_heads, |
|
|
max_seq_len=self.config.max_position_embeddings, |
|
|
head_dim=self.config.head_dim, |
|
|
dtype=self.model.embed_tokens.weight.dtype, |
|
|
device=input_ids.device, |
|
|
) |
|
|
|
|
|
|
|
|
output = self.forward( |
|
|
input_ids=input_ids, |
|
|
kv_cache=kv_cache, |
|
|
use_cache=True, |
|
|
) |
|
|
|
|
|
return output.logits, output.kv_cache |
|
|
|
|
|
def decode( |
|
|
self, |
|
|
input_id: torch.Tensor, |
|
|
kv_cache: KVCache, |
|
|
position: Optional[int] = None, |
|
|
) -> Tuple[torch.Tensor, KVCache]: |
|
|
"""Decode: Generate single token using KV cache. |
|
|
|
|
|
This is Graph 2 for Qualcomm deployment. |
|
|
|
|
|
Args: |
|
|
input_id: Single token ID [batch, 1] |
|
|
kv_cache: Populated KV cache from prefill or previous decode |
|
|
position: Position index (defaults to cache.seq_len) |
|
|
|
|
|
Returns: |
|
|
Tuple of (logits [batch, 1, vocab], updated_kv_cache) |
|
|
""" |
|
|
batch_size = input_id.shape[0] |
|
|
|
|
|
|
|
|
if position is None: |
|
|
position = kv_cache.seq_len |
|
|
|
|
|
|
|
|
position_ids = torch.tensor( |
|
|
[[position]], device=input_id.device |
|
|
).expand(batch_size, -1) |
|
|
|
|
|
|
|
|
output = self.forward( |
|
|
input_ids=input_id, |
|
|
position_ids=position_ids, |
|
|
kv_cache=kv_cache, |
|
|
use_cache=True, |
|
|
) |
|
|
|
|
|
return output.logits, output.kv_cache |
|
|
|
|
|
def create_empty_cache( |
|
|
self, |
|
|
batch_size: int = 1, |
|
|
device: torch.device = None, |
|
|
) -> KVCache: |
|
|
"""Create an empty KV cache for inference. |
|
|
|
|
|
Args: |
|
|
batch_size: Batch size |
|
|
device: Device for cache tensors |
|
|
|
|
|
Returns: |
|
|
Empty KVCache ready for prefill |
|
|
""" |
|
|
if device is None: |
|
|
device = self.model.embed_tokens.weight.device |
|
|
|
|
|
return KVCache.create( |
|
|
num_layers=self.config.num_layers, |
|
|
batch_size=batch_size, |
|
|
num_heads=self.config.num_heads, |
|
|
max_seq_len=self.config.max_position_embeddings, |
|
|
head_dim=self.config.head_dim, |
|
|
dtype=self.model.embed_tokens.weight.dtype, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
@property |
|
|
def num_parameters(self) -> int: |
|
|
"""Count total trainable parameters.""" |
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
|
@property |
|
|
def device(self) -> torch.device: |
|
|
"""Get model device.""" |
|
|
return self.model.embed_tokens.weight.device |
|
|
|