abi96062 commited on
Commit
49d2fa1
·
verified ·
1 Parent(s): 9b57eb5

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +354 -0
model.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ model.py
3
+ ========
4
+ Complete SmolLM2-135M model implementation
5
+
6
+ Architecture:
7
+ - 30 transformer blocks
8
+ - 576 hidden dimensions
9
+ - 9 query heads, 3 KV heads (Grouped Query Attention)
10
+ - SwiGLU feed-forward network
11
+ - RoPE position embeddings
12
+ - RMSNorm layer normalization
13
+ - Weight tying (embeddings = lm_head)
14
+
15
+ Total parameters: 134,515,008 (~135M)
16
+ """
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import math
22
+ from components import RMSNorm, TransformerBlock
23
+ from transformers import AutoConfig
24
+
25
+
26
+ class SmolLM2Model(nn.Module):
27
+ """
28
+ SmolLM2-135M Language Model
29
+
30
+ A decoder-only transformer based on Llama architecture with:
31
+ - Grouped Query Attention (memory efficient)
32
+ - SwiGLU FFN (improved expressiveness)
33
+ - RoPE position embeddings (length extrapolation)
34
+ - RMSNorm (faster than LayerNorm)
35
+
36
+ Model configuration:
37
+ - Layers: 30
38
+ - Hidden size: 576
39
+ - Attention heads: 9 (Q) / 3 (KV)
40
+ - FFN size: 1536
41
+ - Vocab size: 49,152
42
+ - Context length: 2048
43
+ """
44
+
45
+ def __init__(self, config):
46
+ """
47
+ Initialize SmolLM2 model
48
+
49
+ Args:
50
+ config: Model configuration object with attributes:
51
+ - vocab_size: Size of vocabulary (49152)
52
+ - hidden_size: Model dimension (576)
53
+ - num_hidden_layers: Number of transformer blocks (30)
54
+ - tie_word_embeddings: Whether to tie input/output embeddings
55
+ - rms_norm_eps: Epsilon for RMSNorm
56
+ """
57
+ super().__init__()
58
+ self.config = config
59
+
60
+ # Token embeddings
61
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
62
+
63
+ # Transformer blocks (30 layers)
64
+ self.layers = nn.ModuleList([
65
+ TransformerBlock(config) for _ in range(config.num_hidden_layers)
66
+ ])
67
+
68
+ # Final layer normalization
69
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
70
+
71
+ # Language modeling head (output projection)
72
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
73
+
74
+ # Weight tying: share embeddings with output projection
75
+ if config.tie_word_embeddings:
76
+ self.lm_head.weight = self.embed_tokens.weight
77
+
78
+ print(f"✅ Model initialized with {config.num_hidden_layers} transformer blocks")
79
+ print(f"✅ Weight tying: {config.tie_word_embeddings}")
80
+
81
+ def forward(self, input_ids, attention_mask=None, position_ids=None):
82
+ """
83
+ Forward pass through the model
84
+
85
+ Args:
86
+ input_ids (torch.Tensor): Input token IDs [batch, seq_len]
87
+ attention_mask (torch.Tensor, optional): Attention mask
88
+ position_ids (torch.Tensor, optional): Position indices
89
+
90
+ Returns:
91
+ torch.Tensor: Logits over vocabulary [batch, seq_len, vocab_size]
92
+ """
93
+ batch_size, seq_len = input_ids.shape
94
+
95
+ # Create position IDs if not provided
96
+ if position_ids is None:
97
+ position_ids = torch.arange(seq_len, device=input_ids.device)
98
+
99
+ # Embed tokens
100
+ hidden_states = self.embed_tokens(input_ids)
101
+
102
+ # Pass through all transformer blocks
103
+ for layer in self.layers:
104
+ hidden_states = layer(hidden_states, attention_mask, position_ids)
105
+
106
+ # Final normalization
107
+ hidden_states = self.norm(hidden_states)
108
+
109
+ # Project to vocabulary
110
+ logits = self.lm_head(hidden_states)
111
+
112
+ return logits
113
+
114
+ def generate(
115
+ self,
116
+ input_ids,
117
+ max_new_tokens=50,
118
+ temperature=1.0,
119
+ top_p=0.9,
120
+ top_k=None,
121
+ do_sample=True
122
+ ):
123
+ """
124
+ Generate text autoregressively
125
+
126
+ Supports multiple sampling strategies:
127
+ - Greedy decoding (temperature=0)
128
+ - Temperature sampling
129
+ - Nucleus (top-p) sampling
130
+ - Top-k sampling
131
+
132
+ Args:
133
+ input_ids (torch.Tensor): Input token IDs [batch, seq_len]
134
+ max_new_tokens (int): Number of tokens to generate
135
+ temperature (float): Sampling temperature (0 = greedy, >1 = more random)
136
+ top_p (float): Nucleus sampling threshold (0-1)
137
+ top_k (int, optional): Top-k sampling threshold
138
+ do_sample (bool): Whether to sample or use greedy decoding
139
+
140
+ Returns:
141
+ torch.Tensor: Generated token IDs [batch, seq_len + max_new_tokens]
142
+ """
143
+ self.eval()
144
+
145
+ for _ in range(max_new_tokens):
146
+ with torch.no_grad():
147
+ # Forward pass
148
+ logits = self(input_ids)
149
+
150
+ # Get next token logits
151
+ next_token_logits = logits[:, -1, :]
152
+
153
+ # Apply temperature
154
+ if temperature > 0:
155
+ next_token_logits = next_token_logits / temperature
156
+
157
+ # Greedy decoding
158
+ if not do_sample or temperature == 0:
159
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
160
+ else:
161
+ # Top-k sampling
162
+ if top_k is not None:
163
+ top_k = min(top_k, next_token_logits.size(-1))
164
+ indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
165
+ next_token_logits[indices_to_remove] = float('-inf')
166
+
167
+ # Nucleus (top-p) sampling
168
+ if top_p < 1.0:
169
+ sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
170
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
171
+
172
+ # Remove tokens with cumulative probability above threshold
173
+ sorted_indices_to_remove = cumulative_probs > top_p
174
+ # Keep at least one token
175
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
176
+ sorted_indices_to_remove[..., 0] = False
177
+
178
+ # Scatter to original indexing
179
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
180
+ next_token_logits[indices_to_remove] = float('-inf')
181
+
182
+ # Sample from distribution
183
+ probs = F.softmax(next_token_logits, dim=-1)
184
+ next_token = torch.multinomial(probs, num_samples=1)
185
+
186
+ # Append to sequence
187
+ input_ids = torch.cat([input_ids, next_token], dim=1)
188
+
189
+ return input_ids
190
+
191
+ def get_num_params(self, non_embedding=False):
192
+ """
193
+ Count model parameters
194
+
195
+ Args:
196
+ non_embedding (bool): If True, exclude embedding parameters
197
+
198
+ Returns:
199
+ int: Number of parameters
200
+ """
201
+ n_params = sum(p.numel() for p in self.parameters())
202
+
203
+ if non_embedding:
204
+ n_params -= self.embed_tokens.weight.numel()
205
+ # If weights are tied, don't double-count
206
+ if not self.config.tie_word_embeddings:
207
+ n_params -= self.lm_head.weight.numel()
208
+
209
+ return n_params
210
+
211
+
212
+ def initialize_weights(model, config):
213
+ """
214
+ Initialize model weights using GPT-style initialization
215
+
216
+ Strategy:
217
+ - All weights: Normal(0, 0.02)
218
+ - Residual projections: Scaled by 1/sqrt(2 * num_layers)
219
+ - RMSNorm: Initialized to 1.0 (PyTorch default)
220
+
221
+ The residual scaling prevents variance explosion in deep networks.
222
+
223
+ Args:
224
+ model (SmolLM2Model): Model to initialize
225
+ config: Model configuration
226
+ """
227
+ std = 0.02
228
+ num_layers = config.num_hidden_layers
229
+ # Residual scaling factor: 1/sqrt(2 * num_layers)
230
+ residual_scaling = 1.0 / math.sqrt(2 * num_layers)
231
+
232
+ print(f"Initializing weights with std={std}, residual_scaling={residual_scaling:.6f}")
233
+
234
+ # Initialize embeddings
235
+ nn.init.normal_(model.embed_tokens.weight, mean=0.0, std=std)
236
+
237
+ # Initialize each transformer block
238
+ for layer in model.layers:
239
+ # Attention projections
240
+ nn.init.normal_(layer.self_attn.q_proj.weight, mean=0.0, std=std)
241
+ nn.init.normal_(layer.self_attn.k_proj.weight, mean=0.0, std=std)
242
+ nn.init.normal_(layer.self_attn.v_proj.weight, mean=0.0, std=std)
243
+ # Output projection with residual scaling
244
+ nn.init.normal_(layer.self_attn.o_proj.weight, mean=0.0, std=std * residual_scaling)
245
+
246
+ # FFN projections
247
+ nn.init.normal_(layer.mlp.gate_proj.weight, mean=0.0, std=std)
248
+ nn.init.normal_(layer.mlp.up_proj.weight, mean=0.0, std=std)
249
+ # Output projection with residual scaling
250
+ nn.init.normal_(layer.mlp.down_proj.weight, mean=0.0, std=std * residual_scaling)
251
+
252
+ # RMSNorm weights are initialized to 1.0 by default (PyTorch)
253
+
254
+ print(f"✅ Initialized {sum(1 for _ in model.parameters())} weight tensors")
255
+
256
+
257
+ def load_pretrained_weights(our_model, official_model, device='cuda'):
258
+ """
259
+ Load weights from HuggingFace official model
260
+
261
+ Maps weight names from official model to our implementation:
262
+ - model.embed_tokens.weight -> embed_tokens.weight
263
+ - model.layers.{i}.* -> layers[i].*
264
+ - model.norm.weight -> norm.weight
265
+ - lm_head.weight (tied with embeddings)
266
+
267
+ Args:
268
+ our_model (SmolLM2Model): Our model to load weights into
269
+ official_model: HuggingFace official model
270
+ device (str): Device to load weights to
271
+
272
+ Returns:
273
+ int: Number of weight tensors loaded
274
+ """
275
+ print("=" * 70)
276
+ print("LOADING PRETRAINED WEIGHTS")
277
+ print("=" * 70)
278
+
279
+ official_state = official_model.state_dict()
280
+ loaded_count = 0
281
+
282
+ # 1. Load token embeddings
283
+ our_model.embed_tokens.weight.data = official_state['model.embed_tokens.weight'].clone().to(device)
284
+ loaded_count += 1
285
+
286
+ # 2. Load all transformer blocks
287
+ num_layers = our_model.config.num_hidden_layers
288
+ for layer_idx in range(num_layers):
289
+ prefix = f'model.layers.{layer_idx}'
290
+
291
+ # Layer norms
292
+ our_model.layers[layer_idx].input_layernorm.weight.data = \
293
+ official_state[f'{prefix}.input_layernorm.weight'].clone().to(device)
294
+ our_model.layers[layer_idx].post_attention_layernorm.weight.data = \
295
+ official_state[f'{prefix}.post_attention_layernorm.weight'].clone().to(device)
296
+
297
+ # Attention projections
298
+ our_model.layers[layer_idx].self_attn.q_proj.weight.data = \
299
+ official_state[f'{prefix}.self_attn.q_proj.weight'].clone().to(device)
300
+ our_model.layers[layer_idx].self_attn.k_proj.weight.data = \
301
+ official_state[f'{prefix}.self_attn.k_proj.weight'].clone().to(device)
302
+ our_model.layers[layer_idx].self_attn.v_proj.weight.data = \
303
+ official_state[f'{prefix}.self_attn.v_proj.weight'].clone().to(device)
304
+ our_model.layers[layer_idx].self_attn.o_proj.weight.data = \
305
+ official_state[f'{prefix}.self_attn.o_proj.weight'].clone().to(device)
306
+
307
+ # FFN projections
308
+ our_model.layers[layer_idx].mlp.gate_proj.weight.data = \
309
+ official_state[f'{prefix}.mlp.gate_proj.weight'].clone().to(device)
310
+ our_model.layers[layer_idx].mlp.up_proj.weight.data = \
311
+ official_state[f'{prefix}.mlp.up_proj.weight'].clone().to(device)
312
+ our_model.layers[layer_idx].mlp.down_proj.weight.data = \
313
+ official_state[f'{prefix}.mlp.down_proj.weight'].clone().to(device)
314
+
315
+ loaded_count += 9 # 2 norms + 4 attn + 3 ffn
316
+
317
+ # 3. Load final norm
318
+ our_model.norm.weight.data = official_state['model.norm.weight'].clone().to(device)
319
+ loaded_count += 1
320
+
321
+ print(f"\n✅ Loaded {num_layers} transformer blocks")
322
+ print(f"✅ Total loaded: {loaded_count} weight tensors")
323
+ print("=" * 70)
324
+
325
+ return loaded_count
326
+
327
+
328
+ if __name__ == "__main__":
329
+ """Test model creation and parameter count"""
330
+ # Load config
331
+ config = AutoConfig.from_pretrained("HuggingFaceTB/SmolLM2-135M")
332
+
333
+ # Create model
334
+ model = SmolLM2Model(config)
335
+
336
+ # Count parameters
337
+ total_params = model.get_num_params()
338
+ print(f"\nTotal parameters: {total_params:,}")
339
+ print(f"Expected: 134,515,008")
340
+ print(f"Match: {total_params == 134_515_008}")
341
+
342
+ # Test forward pass
343
+ test_input = torch.randint(0, config.vocab_size, (1, 10))
344
+ output = model(test_input)
345
+ print(f"\nForward pass test:")
346
+ print(f" Input shape: {test_input.shape}")
347
+ print(f" Output shape: {output.shape}")
348
+ print(f" Expected: torch.Size([1, 10, 49152])")
349
+
350
+ # Test generation
351
+ generated = model.generate(test_input, max_new_tokens=5)
352
+ print(f"\nGeneration test:")
353
+ print(f" Generated shape: {generated.shape}")
354
+ print(f" Expected: torch.Size([1, 15])")