"""Full dense decoder-only transformer model for SAGE.""" from __future__ import annotations import math from typing import Optional import torch from torch import nn from model.block import TransformerBlock from model.config import ModelConfig from model.rope import build_rope_cache from model.rmsnorm import RMSNorm class SageTransformer(nn.Module): """A dense Llama-style decoder-only transformer.""" def __init__(self, config: ModelConfig): super().__init__() self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model) self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)]) self.norm = RMSNorm(config.d_model, eps=config.rms_norm_eps) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) if config.tie_word_embeddings: self.lm_head.weight = self.embed_tokens.weight cos, sin = build_rope_cache( seq_len=config.context_length, head_dim=config.head_dim, base_frequency=config.rope_base_frequency, scaling_factor=config.rope_scaling_factor, ) self.register_buffer("rope_cos", cos, persistent=False) self.register_buffer("rope_sin", sin, persistent=False) self._reset_parameters() def _reset_parameters(self) -> None: """Apply scaled initialization to the model.""" embed_std = 1.0 / math.sqrt(self.config.d_model) nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=embed_std) for module in self.modules(): if not isinstance(module, nn.Linear): continue std = self.config.initializer_range if module is self.lm_head and self.config.tie_word_embeddings: continue if module.out_features == self.config.d_model: std = std / math.sqrt(2 * self.config.num_layers) nn.init.normal_(module.weight, mean=0.0, std=std) def forward( self, input_ids: torch.Tensor, past_key_values: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None, ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]: """Return logits and the updated KV cache.""" batch_size, seq_len = input_ids.shape hidden_states = self.embed_tokens(input_ids) past_key_values = past_key_values or [None] * self.config.num_layers start = 0 if past_key_values[0] is not None: start = past_key_values[0][0].size(-2) cos = self.rope_cos[start : start + seq_len].to(hidden_states.device) sin = self.rope_sin[start : start + seq_len].to(hidden_states.device) presents: list[tuple[torch.Tensor, torch.Tensor]] = [] for layer, past in zip(self.layers, past_key_values): hidden_states, present = layer(hidden_states, cos, sin, past_key_value=past) presents.append(present) hidden_states = self.norm(hidden_states) logits = self.lm_head(hidden_states) return logits, presents