File size: 8,830 Bytes
7f974df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""
model/model.py

SLLM — Small Language Model (decoder-only Transformer).

Full architecture:
    tokens  (B, T)
      -> Embedding       (vocab_size -> d_model)
      -> N x TransformerBlock   (attention + FFN)
      -> Final RMSNorm
      -> LM Head (Linear d_model -> vocab_size)   <- weight-TIED to embedding

Weight tying:
    The embedding matrix and the LM head output matrix share the same weights.
    - Halves memory for the embedding/output layers.
    - A standard practice since GPT-2 (Press & Wolf, 2016).

Weight initialization:
    - Embeddings: std=0.02  (GPT-2 convention)
    - Linear layers: std=0.02
    - Output projections (attn.o_proj, mlp.down): std = 0.02/sqrt(2*n_layers)
      - Scaled down per GPT-2/NanoGPT: at initialization, the residual
        stream grows as sqrt(n_layers), so we scale residual contributions down.

Forward:
    Returns logits (B, T, vocab_size).
    Loss is computed externally in the training loop for flexibility.
"""

import math
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint

from model.config import ModelConfig
from model.norm   import RMSNorm
from model.block  import TransformerBlock


class SLLM(nn.Module):

    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        # ---- Token embedding --------------------------------------- #
        self.token_emb = nn.Embedding(config.vocab_size, config.d_model)

        # ---- Transformer blocks ------------------------------------ #
        self.blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layers)
        ])

        # ---- Final norm -------------------------------------------- #
        self.norm = RMSNorm(config.d_model)

        # ---- LM Head ----------------------------------------------- #
        # Linear: d_model -> vocab_size, no bias
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)

        # ---- Weight tying ------------------------------------------ #
        # Share embedding weights with lm_head
        self.lm_head.weight = self.token_emb.weight

        # ---- Gradient checkpointing flag --------------------------- #
        # Enabled via enable_gradient_checkpointing() to save VRAM
        self._gradient_checkpointing = False

        # ---- Initialize weights ------------------------------------ #
        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module):
        """
        Custom weight initialization.
        - Normal(0, 0.02) for Linear and Embedding
        - Scaled residual projections: std *= 1/sqrt(2 * n_layers)
        """
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)

        # Scale down residual projections (attn output + mlp down)
        # Accessed by name: o_proj and down
        if isinstance(module, nn.Linear):
            if getattr(module, '_is_residual', False):
                scale = 0.02 / math.sqrt(2 * self.config.n_layers)
                nn.init.normal_(module.weight, mean=0.0, std=scale)

    def _mark_residual_projections(self):
        """
        Mark output projections so _init_weights can scale them.
        Called after __init__ to tag the specific layers.
        """
        for block in self.blocks:
            block.attn.o_proj._is_residual = True
            block.mlp.down._is_residual    = True
        self.apply(self._init_weights)

    def forward(
        self,
        input_ids: torch.Tensor,
        targets: torch.Tensor = None,
    ):
        """
        Args:
            input_ids : (B, T)  — integer token IDs
            targets   : (B, T)  — optional, for loss computation

        Returns:
            logits : (B, T, vocab_size)
            loss   : scalar CrossEntropy loss if targets given, else None
        """
        B, T = input_ids.shape
        assert T <= self.config.context_length, (
            f"Sequence length {T} exceeds context_length {self.config.context_length}"
        )

        # ---- Embedding --------------------------------------------- #
        x = self.token_emb(input_ids)          # (B, T, d_model)

        # ---- Transformer blocks ------------------------------------ #
        for block in self.blocks:
            if self._gradient_checkpointing and self.training:
                # Recompute activations during backward to save VRAM
                # use_reentrant=False is the modern recommended API
                x = checkpoint(block, x, use_reentrant=False)
            else:
                x = block(x)

        # ---- Final norm -------------------------------------------- #
        x = self.norm(x)                       # (B, T, d_model)

        # ---- LM Head ----------------------------------------------- #
        logits = self.lm_head(x)               # (B, T, vocab_size)

        # ---- Loss -------------------------------------------------- #
        loss = None
        if targets is not None:
            # Flatten for cross-entropy: (B*T, vocab_size) vs (B*T,)
            loss = nn.functional.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
            )

        return logits, loss

    @torch.no_grad()
    def generate(
        self,
        input_ids: torch.Tensor,
        max_new_tokens: int,
        temperature: float = 1.0,
        top_k: int = None,
    ) -> torch.Tensor:
        """
        Autoregressive text generation (greedy or top-k sampling).

        Args:
            input_ids      : (B, T) prompt tokens
            max_new_tokens : number of tokens to generate
            temperature    : softmax temperature (1.0 = neutral, <1 = sharper)
            top_k          : if set, sample from top-k tokens only

        Returns:
            (B, T + max_new_tokens) token IDs
        """
        self.eval()
        for _ in range(max_new_tokens):

            # Crop context if longer than max
            ctx = input_ids
            if ctx.shape[1] > self.config.context_length:
                ctx = ctx[:, -self.config.context_length:]

            # Forward pass — only need last logit
            logits, _ = self(ctx)
            logits = logits[:, -1, :] / temperature      # (B, vocab_size)

            # Optional top-k filtering
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            # Sample from distribution
            probs     = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)  # (B, 1)

            input_ids = torch.cat([input_ids, next_token], dim=1)

        return input_ids

    def enable_gradient_checkpointing(self):
        """
        Enables gradient checkpointing to reduce VRAM usage.
        Recomputes activations during the backward pass instead of
        storing them — trades ~30% more compute for ~40% less memory.
        Essential for fitting 100M+ models on 4GB VRAM.
        """
        self._gradient_checkpointing = True

    def count_params(self, non_embedding: bool = False) -> int:
        """
        Returns parameter count.

        Args:
            non_embedding: if True, exclude embedding parameters
                           (common in LLM reporting since embeddings scale
                           with vocab size and not model capacity)
        """
        total = sum(p.numel() for p in self.parameters())
        if non_embedding:
            total -= self.token_emb.weight.numel()
        return total


# ------------------------------------------------------------------ #
#  QUICK CHECK
# ------------------------------------------------------------------ #

if __name__ == "__main__":
    from model.config import SLLM_100M, SLLM_150M

    for name, cfg in [("SLLM-100M", SLLM_100M), ("SLLM-150M", SLLM_150M)]:
        model = SLLM(cfg)

        total = model.count_params()
        non_emb = model.count_params(non_embedding=True)
        print(f"{name}")
        print(f"  total params           : {total/1e6:.1f}M")
        print(f"  non-embedding params   : {non_emb/1e6:.1f}M")
        print(f"  embedding params       : {(total-non_emb)/1e6:.1f}M")

        # Forward pass check
        B, T = 2, 64
        ids     = torch.randint(0, cfg.vocab_size, (B, T))
        targets = torch.randint(0, cfg.vocab_size, (B, T))

        logits, loss = model(ids, targets)
        print(f"  logits shape : {logits.shape}")
        print(f"  loss         : {loss.item():.4f}")
        print()