File size: 3,109 Bytes
ef18673
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
"""Full dense decoder-only transformer model for SAGE."""

from __future__ import annotations

import math
from typing import Optional

import torch
from torch import nn

from model.block import TransformerBlock
from model.config import ModelConfig
from model.rope import build_rope_cache
from model.rmsnorm import RMSNorm


class SageTransformer(nn.Module):
    """A dense Llama-style decoder-only transformer."""

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config
        self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
        self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
        self.norm = RMSNorm(config.d_model, eps=config.rms_norm_eps)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        if config.tie_word_embeddings:
            self.lm_head.weight = self.embed_tokens.weight
        cos, sin = build_rope_cache(
            seq_len=config.context_length,
            head_dim=config.head_dim,
            base_frequency=config.rope_base_frequency,
            scaling_factor=config.rope_scaling_factor,
        )
        self.register_buffer("rope_cos", cos, persistent=False)
        self.register_buffer("rope_sin", sin, persistent=False)
        self._reset_parameters()

    def _reset_parameters(self) -> None:
        """Apply scaled initialization to the model."""
        embed_std = 1.0 / math.sqrt(self.config.d_model)
        nn.init.normal_(self.embed_tokens.weight, mean=0.0, std=embed_std)
        for module in self.modules():
            if not isinstance(module, nn.Linear):
                continue
            std = self.config.initializer_range
            if module is self.lm_head and self.config.tie_word_embeddings:
                continue
            if module.out_features == self.config.d_model:
                std = std / math.sqrt(2 * self.config.num_layers)
            nn.init.normal_(module.weight, mean=0.0, std=std)

    def forward(
        self,
        input_ids: torch.Tensor,
        past_key_values: Optional[list[tuple[torch.Tensor, torch.Tensor]]] = None,
    ) -> tuple[torch.Tensor, list[tuple[torch.Tensor, torch.Tensor]]]:
        """Return logits and the updated KV cache."""
        batch_size, seq_len = input_ids.shape
        hidden_states = self.embed_tokens(input_ids)
        past_key_values = past_key_values or [None] * self.config.num_layers
        start = 0
        if past_key_values[0] is not None:
            start = past_key_values[0][0].size(-2)
        cos = self.rope_cos[start : start + seq_len].to(hidden_states.device)
        sin = self.rope_sin[start : start + seq_len].to(hidden_states.device)
        presents: list[tuple[torch.Tensor, torch.Tensor]] = []
        for layer, past in zip(self.layers, past_key_values):
            hidden_states, present = layer(hidden_states, cos, sin, past_key_value=past)
            presents.append(present)
        hidden_states = self.norm(hidden_states)
        logits = self.lm_head(hidden_states)
        return logits, presents