PebbleLM-117M / src /model /transformer.py
nameissakthi's picture
Add model architecture code
27871e7
"""
Full Transformer model for SLM.
Implements the mandatory prefill/decode API for Qualcomm deployment.
"""
import torch
import torch.nn as nn
from typing import Optional, Tuple, Union
from dataclasses import dataclass
from .config import SLMConfig
from .normalization import RMSNorm
from .decoder import DecoderBlock
from .attention import create_causal_mask
from .kv_cache import KVCache
@dataclass
class SLMOutput:
"""Output from SLM forward pass."""
logits: torch.Tensor # [batch, seq, vocab_size]
kv_cache: Optional[KVCache] = None
hidden_states: Optional[torch.Tensor] = None
class SLMModel(nn.Module):
"""Core transformer model (without LM head).
This is the decoder stack:
- Token embedding
- N decoder blocks
- Final RMSNorm
"""
def __init__(self, config: SLMConfig):
"""Initialize transformer model.
Args:
config: Model configuration
"""
super().__init__()
self.config = config
# Token embeddings
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
# Decoder layers
self.layers = nn.ModuleList([
DecoderBlock(config, layer_idx=i)
for i in range(config.num_layers)
])
# Final normalization
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[KVCache]]:
"""Forward pass through transformer.
Args:
input_ids: Token IDs [batch, seq]
position_ids: Position indices [batch, seq]
attention_mask: Causal mask
kv_cache: Optional KV cache
use_cache: Whether to use/update cache
Returns:
Tuple of (hidden_states, kv_cache)
"""
batch_size, seq_len = input_ids.shape
# Create position IDs if not provided
if position_ids is None:
if kv_cache is not None and kv_cache.seq_len > 0:
# For decode: position is the current cache length
position_ids = torch.arange(
kv_cache.seq_len, kv_cache.seq_len + seq_len,
device=input_ids.device
).unsqueeze(0).expand(batch_size, -1)
else:
# For prefill: positions are 0..seq_len-1
position_ids = torch.arange(
seq_len, device=input_ids.device
).unsqueeze(0).expand(batch_size, -1)
# Create attention mask if not provided
if attention_mask is None:
kv_seq_len = seq_len
if kv_cache is not None and kv_cache.seq_len > 0:
kv_seq_len = kv_cache.seq_len + seq_len
attention_mask = create_causal_mask(
seq_len=seq_len,
kv_seq_len=kv_seq_len,
dtype=self.embed_tokens.weight.dtype,
device=input_ids.device,
)
# Token embeddings
hidden_states = self.embed_tokens(input_ids)
# Pass through decoder layers
for layer in self.layers:
hidden_states, kv_cache = layer(
hidden_states=hidden_states,
position_ids=position_ids,
attention_mask=attention_mask,
kv_cache=kv_cache,
use_cache=use_cache,
)
# Final normalization
hidden_states = self.norm(hidden_states)
return hidden_states, kv_cache
class SLMForCausalLM(nn.Module):
"""SLM with language modeling head.
This is the full model with:
- Transformer backbone
- LM head (tied with embeddings)
- Prefill/Decode API for Qualcomm deployment
"""
def __init__(self, config: SLMConfig):
"""Initialize causal LM.
Args:
config: Model configuration
"""
super().__init__()
self.config = config
# Transformer backbone
self.model = SLMModel(config)
# LM head
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Tie weights if configured
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module: nn.Module):
"""Initialize model weights."""
std = 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
def forward(
self,
input_ids: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
kv_cache: Optional[KVCache] = None,
use_cache: bool = False,
labels: Optional[torch.Tensor] = None,
) -> SLMOutput:
"""Forward pass for causal LM.
Args:
input_ids: Token IDs [batch, seq]
position_ids: Position indices [batch, seq]
attention_mask: Causal mask
kv_cache: Optional KV cache
use_cache: Whether to use/update cache
labels: Optional labels for loss computation
Returns:
SLMOutput with logits and optional loss
"""
# Get hidden states from transformer
hidden_states, kv_cache = self.model(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
kv_cache=kv_cache,
use_cache=use_cache,
)
# Compute logits
logits = self.lm_head(hidden_states)
return SLMOutput(
logits=logits,
kv_cache=kv_cache,
hidden_states=hidden_states,
)
# =========================================================================
# MANDATORY KV CACHE API (from architecture.txt)
# =========================================================================
def prefill(
self,
input_ids: torch.Tensor,
kv_cache: Optional[KVCache] = None,
) -> Tuple[torch.Tensor, KVCache]:
"""Prefill: Process full prompt and populate KV cache.
This is Graph 1 for Qualcomm deployment.
Args:
input_ids: Token IDs [batch, seq]
kv_cache: Empty or existing KV cache
Returns:
Tuple of (logits [batch, seq, vocab], populated_kv_cache)
"""
batch_size = input_ids.shape[0]
# Create empty cache if not provided
if kv_cache is None:
kv_cache = KVCache.create(
num_layers=self.config.num_layers,
batch_size=batch_size,
num_heads=self.config.num_heads,
max_seq_len=self.config.max_position_embeddings,
head_dim=self.config.head_dim,
dtype=self.model.embed_tokens.weight.dtype,
device=input_ids.device,
)
# Forward pass with cache
output = self.forward(
input_ids=input_ids,
kv_cache=kv_cache,
use_cache=True,
)
return output.logits, output.kv_cache
def decode(
self,
input_id: torch.Tensor,
kv_cache: KVCache,
position: Optional[int] = None,
) -> Tuple[torch.Tensor, KVCache]:
"""Decode: Generate single token using KV cache.
This is Graph 2 for Qualcomm deployment.
Args:
input_id: Single token ID [batch, 1]
kv_cache: Populated KV cache from prefill or previous decode
position: Position index (defaults to cache.seq_len)
Returns:
Tuple of (logits [batch, 1, vocab], updated_kv_cache)
"""
batch_size = input_id.shape[0]
# Get position from cache if not provided
if position is None:
position = kv_cache.seq_len
# Create position IDs
position_ids = torch.tensor(
[[position]], device=input_id.device
).expand(batch_size, -1)
# Forward pass with cache
output = self.forward(
input_ids=input_id,
position_ids=position_ids,
kv_cache=kv_cache,
use_cache=True,
)
return output.logits, output.kv_cache
def create_empty_cache(
self,
batch_size: int = 1,
device: torch.device = None,
) -> KVCache:
"""Create an empty KV cache for inference.
Args:
batch_size: Batch size
device: Device for cache tensors
Returns:
Empty KVCache ready for prefill
"""
if device is None:
device = self.model.embed_tokens.weight.device
return KVCache.create(
num_layers=self.config.num_layers,
batch_size=batch_size,
num_heads=self.config.num_heads,
max_seq_len=self.config.max_position_embeddings,
head_dim=self.config.head_dim,
dtype=self.model.embed_tokens.weight.dtype,
device=device,
)
@property
def num_parameters(self) -> int:
"""Count total trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
@property
def device(self) -> torch.device:
"""Get model device."""
return self.model.embed_tokens.weight.device