sage / model /model.py
sage002's picture
feat: rewrite SAGE 1B architecture and replace legacy repo contents
ef18673 verified
"""Full dense decoder-only transformer model for SAGE."""
from __future__ import annotations
import math
from typing import Optional
import torch
from torch import nn
from model.block import TransformerBlock
from model.config import ModelConfig
from model.rope import build_rope_cache
from model.rmsnorm import RMSNorm
class SageTransformer(nn.Module):
"""A dense Llama-style decoder-only transformer."""
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
self.norm = RMSNorm(config.d_model, eps=config.rms_norm_eps)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.lm_head.weight = self.embed_tokens.weight
cos, sin = build_rope_cache(
seq_len=config.context_length,
head_dim=config.head_dim,
base_frequency=config.rope_base_frequency,
scaling_factor=config.rope_scaling_factor,
)
self.register_buffer("rope_cos", cos, persistent=False)
self.register_buffer("rope_sin", sin, persistent=False)
self._reset_parameters()
def _reset_parameters(self) -> None:
"""Apply scaled initialization to the model."""
embed_std = 1.0 / math.sqrt(self.config.d_model)
nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=embed_std)
for module in self.modules():
if not isinstance(module, nn.Linear):
continue
std = self.config.initializer_range
if module is self.lm_head and self.config.tie_word_embeddings:
continue
if module.out_features == self.config.d_model:
std = std / math.sqrt(2 * self.config.num_layers)
nn.init.normal_(module.weight, mean=0.0, std=std)
def forward(
self,
input_ids: torch.Tensor,
past_key_values: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None,
) -> tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
"""Return logits and the updated KV cache."""
batch_size, seq_len = input_ids.shape
hidden_states = self.embed_tokens(input_ids)
past_key_values = past_key_values or [None] * self.config.num_layers
start = 0
if past_key_values[0] is not None:
start = past_key_values[0][0].size(-2)
cos = self.rope_cos[start : start + seq_len].to(hidden_states.device)
sin = self.rope_sin[start : start + seq_len].to(hidden_states.device)
presents: list[tuple[torch.Tensor, torch.Tensor]] = []
for layer, past in zip(self.layers, past_key_values):
hidden_states, present = layer(hidden_states, cos, sin, past_key_value=past)
presents.append(present)
hidden_states = self.norm(hidden_states)
logits = self.lm_head(hidden_states)
return logits, presents