File size: 13,266 Bytes
49d2fa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
"""
model.py
========
Complete SmolLM2-135M model implementation

Architecture:
- 30 transformer blocks
- 576 hidden dimensions
- 9 query heads, 3 KV heads (Grouped Query Attention)
- SwiGLU feed-forward network
- RoPE position embeddings
- RMSNorm layer normalization
- Weight tying (embeddings = lm_head)

Total parameters: 134,515,008 (~135M)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from components import RMSNorm, TransformerBlock
from transformers import AutoConfig


class SmolLM2Model(nn.Module):
    """
    SmolLM2-135M Language Model
    
    A decoder-only transformer based on Llama architecture with:
    - Grouped Query Attention (memory efficient)
    - SwiGLU FFN (improved expressiveness)
    - RoPE position embeddings (length extrapolation)
    - RMSNorm (faster than LayerNorm)
    
    Model configuration:
    - Layers: 30
    - Hidden size: 576
    - Attention heads: 9 (Q) / 3 (KV)
    - FFN size: 1536
    - Vocab size: 49,152
    - Context length: 2048
    """
    
    def __init__(self, config):
        """
        Initialize SmolLM2 model
        
        Args:
            config: Model configuration object with attributes:
                - vocab_size: Size of vocabulary (49152)
                - hidden_size: Model dimension (576)
                - num_hidden_layers: Number of transformer blocks (30)
                - tie_word_embeddings: Whether to tie input/output embeddings
                - rms_norm_eps: Epsilon for RMSNorm
        """
        super().__init__()
        self.config = config
        
        # Token embeddings
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        
        # Transformer blocks (30 layers)
        self.layers = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.num_hidden_layers)
        ])
        
        # Final layer normalization
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        
        # Language modeling head (output projection)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        
        # Weight tying: share embeddings with output projection
        if config.tie_word_embeddings:
            self.lm_head.weight = self.embed_tokens.weight
        
        print(f"βœ… Model initialized with {config.num_hidden_layers} transformer blocks")
        print(f"βœ… Weight tying: {config.tie_word_embeddings}")
    
    def forward(self, input_ids, attention_mask=None, position_ids=None):
        """
        Forward pass through the model
        
        Args:
            input_ids (torch.Tensor): Input token IDs [batch, seq_len]
            attention_mask (torch.Tensor, optional): Attention mask
            position_ids (torch.Tensor, optional): Position indices
            
        Returns:
            torch.Tensor: Logits over vocabulary [batch, seq_len, vocab_size]
        """
        batch_size, seq_len = input_ids.shape
        
        # Create position IDs if not provided
        if position_ids is None:
            position_ids = torch.arange(seq_len, device=input_ids.device)
        
        # Embed tokens
        hidden_states = self.embed_tokens(input_ids)
        
        # Pass through all transformer blocks
        for layer in self.layers:
            hidden_states = layer(hidden_states, attention_mask, position_ids)
        
        # Final normalization
        hidden_states = self.norm(hidden_states)
        
        # Project to vocabulary
        logits = self.lm_head(hidden_states)
        
        return logits
    
    def generate(
        self,
        input_ids,
        max_new_tokens=50,
        temperature=1.0,
        top_p=0.9,
        top_k=None,
        do_sample=True
    ):
        """
        Generate text autoregressively
        
        Supports multiple sampling strategies:
        - Greedy decoding (temperature=0)
        - Temperature sampling
        - Nucleus (top-p) sampling
        - Top-k sampling
        
        Args:
            input_ids (torch.Tensor): Input token IDs [batch, seq_len]
            max_new_tokens (int): Number of tokens to generate
            temperature (float): Sampling temperature (0 = greedy, >1 = more random)
            top_p (float): Nucleus sampling threshold (0-1)
            top_k (int, optional): Top-k sampling threshold
            do_sample (bool): Whether to sample or use greedy decoding
            
        Returns:
            torch.Tensor: Generated token IDs [batch, seq_len + max_new_tokens]
        """
        self.eval()
        
        for _ in range(max_new_tokens):
            with torch.no_grad():
                # Forward pass
                logits = self(input_ids)
                
                # Get next token logits
                next_token_logits = logits[:, -1, :]
                
                # Apply temperature
                if temperature > 0:
                    next_token_logits = next_token_logits / temperature
                
                # Greedy decoding
                if not do_sample or temperature == 0:
                    next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
                else:
                    # Top-k sampling
                    if top_k is not None:
                        top_k = min(top_k, next_token_logits.size(-1))
                        indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
                        next_token_logits[indices_to_remove] = float('-inf')
                    
                    # Nucleus (top-p) sampling
                    if top_p < 1.0:
                        sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                        
                        # Remove tokens with cumulative probability above threshold
                        sorted_indices_to_remove = cumulative_probs > top_p
                        # Keep at least one token
                        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                        sorted_indices_to_remove[..., 0] = False
                        
                        # Scatter to original indexing
                        indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                        next_token_logits[indices_to_remove] = float('-inf')
                    
                    # Sample from distribution
                    probs = F.softmax(next_token_logits, dim=-1)
                    next_token = torch.multinomial(probs, num_samples=1)
                
                # Append to sequence
                input_ids = torch.cat([input_ids, next_token], dim=1)
        
        return input_ids
    
    def get_num_params(self, non_embedding=False):
        """
        Count model parameters
        
        Args:
            non_embedding (bool): If True, exclude embedding parameters
            
        Returns:
            int: Number of parameters
        """
        n_params = sum(p.numel() for p in self.parameters())
        
        if non_embedding:
            n_params -= self.embed_tokens.weight.numel()
            # If weights are tied, don't double-count
            if not self.config.tie_word_embeddings:
                n_params -= self.lm_head.weight.numel()
        
        return n_params


def initialize_weights(model, config):
    """
    Initialize model weights using GPT-style initialization
    
    Strategy:
    - All weights: Normal(0, 0.02)
    - Residual projections: Scaled by 1/sqrt(2 * num_layers)
    - RMSNorm: Initialized to 1.0 (PyTorch default)
    
    The residual scaling prevents variance explosion in deep networks.
    
    Args:
        model (SmolLM2Model): Model to initialize
        config: Model configuration
    """
    std = 0.02
    num_layers = config.num_hidden_layers
    # Residual scaling factor: 1/sqrt(2 * num_layers)
    residual_scaling = 1.0 / math.sqrt(2 * num_layers)
    
    print(f"Initializing weights with std={std}, residual_scaling={residual_scaling:.6f}")
    
    # Initialize embeddings
    nn.init.normal_(model.embed_tokens.weight, mean=0.0, std=std)
    
    # Initialize each transformer block
    for layer in model.layers:
        # Attention projections
        nn.init.normal_(layer.self_attn.q_proj.weight, mean=0.0, std=std)
        nn.init.normal_(layer.self_attn.k_proj.weight, mean=0.0, std=std)
        nn.init.normal_(layer.self_attn.v_proj.weight, mean=0.0, std=std)
        # Output projection with residual scaling
        nn.init.normal_(layer.self_attn.o_proj.weight, mean=0.0, std=std * residual_scaling)
        
        # FFN projections
        nn.init.normal_(layer.mlp.gate_proj.weight, mean=0.0, std=std)
        nn.init.normal_(layer.mlp.up_proj.weight, mean=0.0, std=std)
        # Output projection with residual scaling
        nn.init.normal_(layer.mlp.down_proj.weight, mean=0.0, std=std * residual_scaling)
    
    # RMSNorm weights are initialized to 1.0 by default (PyTorch)
    
    print(f"βœ… Initialized {sum(1 for _ in model.parameters())} weight tensors")


def load_pretrained_weights(our_model, official_model, device='cuda'):
    """
    Load weights from HuggingFace official model
    
    Maps weight names from official model to our implementation:
    - model.embed_tokens.weight -> embed_tokens.weight
    - model.layers.{i}.* -> layers[i].*
    - model.norm.weight -> norm.weight
    - lm_head.weight (tied with embeddings)
    
    Args:
        our_model (SmolLM2Model): Our model to load weights into
        official_model: HuggingFace official model
        device (str): Device to load weights to
        
    Returns:
        int: Number of weight tensors loaded
    """
    print("=" * 70)
    print("LOADING PRETRAINED WEIGHTS")
    print("=" * 70)
    
    official_state = official_model.state_dict()
    loaded_count = 0
    
    # 1. Load token embeddings
    our_model.embed_tokens.weight.data = official_state['model.embed_tokens.weight'].clone().to(device)
    loaded_count += 1
    
    # 2. Load all transformer blocks
    num_layers = our_model.config.num_hidden_layers
    for layer_idx in range(num_layers):
        prefix = f'model.layers.{layer_idx}'
        
        # Layer norms
        our_model.layers[layer_idx].input_layernorm.weight.data = \
            official_state[f'{prefix}.input_layernorm.weight'].clone().to(device)
        our_model.layers[layer_idx].post_attention_layernorm.weight.data = \
            official_state[f'{prefix}.post_attention_layernorm.weight'].clone().to(device)
        
        # Attention projections
        our_model.layers[layer_idx].self_attn.q_proj.weight.data = \
            official_state[f'{prefix}.self_attn.q_proj.weight'].clone().to(device)
        our_model.layers[layer_idx].self_attn.k_proj.weight.data = \
            official_state[f'{prefix}.self_attn.k_proj.weight'].clone().to(device)
        our_model.layers[layer_idx].self_attn.v_proj.weight.data = \
            official_state[f'{prefix}.self_attn.v_proj.weight'].clone().to(device)
        our_model.layers[layer_idx].self_attn.o_proj.weight.data = \
            official_state[f'{prefix}.self_attn.o_proj.weight'].clone().to(device)
        
        # FFN projections
        our_model.layers[layer_idx].mlp.gate_proj.weight.data = \
            official_state[f'{prefix}.mlp.gate_proj.weight'].clone().to(device)
        our_model.layers[layer_idx].mlp.up_proj.weight.data = \
            official_state[f'{prefix}.mlp.up_proj.weight'].clone().to(device)
        our_model.layers[layer_idx].mlp.down_proj.weight.data = \
            official_state[f'{prefix}.mlp.down_proj.weight'].clone().to(device)
        
        loaded_count += 9  # 2 norms + 4 attn + 3 ffn
    
    # 3. Load final norm
    our_model.norm.weight.data = official_state['model.norm.weight'].clone().to(device)
    loaded_count += 1
    
    print(f"\nβœ… Loaded {num_layers} transformer blocks")
    print(f"βœ… Total loaded: {loaded_count} weight tensors")
    print("=" * 70)
    
    return loaded_count


if __name__ == "__main__":
    """Test model creation and parameter count"""    
    # Load config
    config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
    
    # Create model
    model = SmolLM2Model(config)
    
    # Count parameters
    total_params = model.get_num_params()
    print(f"\nTotal parameters: {total_params:,}")
    print(f"Expected: 134,515,008")
    print(f"Match: {total_params == 134_515_008}")
    
    # Test forward pass
    test_input = torch.randint(0, config.vocab_size, (1, 10))
    output = model(test_input)
    print(f"\nForward pass test:")
    print(f"  Input shape: {test_input.shape}")
    print(f"  Output shape: {output.shape}")
    print(f"  Expected: torch.Size([1, 10, 49152])")
    
    # Test generation
    generated = model.generate(test_input, max_new_tokens=5)
    print(f"\nGeneration test:")
    print(f"  Generated shape: {generated.shape}")
    print(f"  Expected: torch.Size([1, 15])")