AxionLab-official commited on
Commit
6e14a07
·
verified ·
1 Parent(s): efbcbb7

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +273 -0
model.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Nano Reasoning Model (NRM) - Main Architecture
3
+
4
+ ARCHITECTURE DESIGN PHILOSOPHY:
5
+ ================================
6
+ This model maximizes reasoning ability per parameter through several key innovations:
7
+
8
+ 1. SHARED LAYERS: The middle layers are shared (looped through multiple times).
9
+ This creates a form of "iterative refinement" - the model processes information
10
+ multiple passes, similar to how recurrent networks process sequences but applied
11
+ to depth instead. This is inspired by Universal Transformers and ALBERT.
12
+
13
+ WHY IT HELPS REASONING: Reasoning often requires iterative refinement of
14
+ intermediate representations. Shared layers let the model "think more" without
15
+ more parameters.
16
+
17
+ 2. THINKING TOKENS: Special <THINK> and </THINK> tokens create a "scratchpad"
18
+ where the model can show intermediate reasoning steps. The model is trained to
19
+ use <STEP> tokens for each logical step.
20
+
21
+ WHY IT HELPS: Decomposing complex problems into steps is THE key capability
22
+ for reasoning. Even large models benefit from chain-of-thought prompting.
23
+
24
+ 3. WEIGHT TYING: Input and output embeddings share the same weight matrix.
25
+ This halves the embedding parameter count and creates a natural link between
26
+ token understanding and token generation.
27
+
28
+ WHY IT HELPS CPU: Fewer parameters = faster forward/backward passes.
29
+
30
+ 4. LOW-RANK PROJECTIONS: All attention and MLP projections use LoRA-style
31
+ factored matrices, cutting parameter count by ~8x in linear layers.
32
+
33
+ 5. GROUPED QUERY ATTENTION: KV heads are shared across query heads,
34
+ reducing KV projection parameters and memory.
35
+
36
+ PARAMETER BUDGET (~10M):
37
+ Embedding: 2048 * 256 = 524K (shared with output head)
38
+ Per unique layer: ~200K
39
+ 4 unique + 2 shared (run 2x) = 6 effective layers
40
+ Total: ~2.1M (layers) + 524K (embed) ≈ 2.6M unique params
41
+ Effective computation: ~3.1M param equivalent
42
+ """
43
+
44
+ import torch
45
+ import torch.nn as nn
46
+ import torch.nn.functional as F
47
+ from typing import Optional, Dict
48
+ from .components import TransformerBlock, RMSNorm
49
+
50
+
51
+ class NanoReasoningModel(nn.Module):
52
+ def __init__(self, config: dict):
53
+ super().__init__()
54
+ self.config = config
55
+
56
+ d_model = config['d_model']
57
+ n_heads = config['n_heads']
58
+ n_layers = config['n_layers']
59
+ n_shared = config.get('n_shared_layers', 2)
60
+ d_ff = config['d_ff']
61
+ vocab_size = config['vocab_size']
62
+ max_seq_len = config['max_seq_len']
63
+ dropout = config.get('dropout', 0.05)
64
+ rank = config.get('lora_rank', 16)
65
+ self.use_thinking = config.get('use_thinking_tokens', True)
66
+ self.n_thinking_steps = config.get('n_thinking_steps', 2)
67
+ n_kv_heads = config.get('n_kv_heads', n_heads // 2)
68
+
69
+ # Token embeddings (will be tied with output head)
70
+ self.token_embedding = nn.Embedding(vocab_size, d_model)
71
+ self.embedding_dropout = nn.Dropout(dropout)
72
+
73
+ # Entry layers (unique)
74
+ n_unique = n_layers - n_shared
75
+ self.entry_layers = nn.ModuleList([
76
+ TransformerBlock(d_model, n_heads, d_ff, rank, dropout, max_seq_len, n_kv_heads)
77
+ for _ in range(n_unique // 2)
78
+ ])
79
+
80
+ # Shared layers (looped)
81
+ self.shared_layers = nn.ModuleList([
82
+ TransformerBlock(d_model, n_heads, d_ff, rank, dropout, max_seq_len, n_kv_heads)
83
+ for _ in range(n_shared)
84
+ ])
85
+
86
+ # Exit layers (unique)
87
+ self.exit_layers = nn.ModuleList([
88
+ TransformerBlock(d_model, n_heads, d_ff, rank, dropout, max_seq_len, n_kv_heads)
89
+ for _ in range(n_unique - n_unique // 2)
90
+ ])
91
+
92
+ # Final norm
93
+ self.final_norm = RMSNorm(d_model)
94
+
95
+ # Output head (tied with embeddings)
96
+ self.output_head = nn.Linear(d_model, vocab_size, bias=False)
97
+
98
+ if config.get('weight_tying', True):
99
+ self.output_head.weight = self.token_embedding.weight
100
+
101
+ # Thinking step gate: learned scalar for blending thinking iterations
102
+ if self.use_thinking:
103
+ self.think_gate = nn.Parameter(torch.tensor(0.5))
104
+
105
+ # Initialize weights
106
+ self.apply(self._init_weights)
107
+
108
+ # Count parameters
109
+ self._count_parameters()
110
+
111
+ def _init_weights(self, module: nn.Module):
112
+ """Initialize weights with scaled initialization for stability."""
113
+ if isinstance(module, nn.Linear):
114
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
115
+ if module.bias is not None:
116
+ torch.nn.init.zeros_(module.bias)
117
+ elif isinstance(module, nn.Embedding):
118
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
119
+
120
+ def _count_parameters(self):
121
+ """Count and report parameters."""
122
+ total = sum(p.numel() for p in self.parameters())
123
+ trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
124
+
125
+ # Count unique parameters (shared layers counted once)
126
+ unique = sum(p.numel() for p in self.parameters())
127
+
128
+ self.total_params = total
129
+ self.trainable_params = trainable
130
+ print(f"\n{'='*50}")
131
+ print(f"NRM Model Configuration:")
132
+ print(f" d_model: {self.config['d_model']}")
133
+ print(f" n_heads: {self.config['n_heads']}")
134
+ print(f" n_layers: {self.config['n_layers']} "
135
+ f"({len(self.entry_layers)} entry + {len(self.shared_layers)} shared + {len(self.exit_layers)} exit)")
136
+ print(f" d_ff: {self.config['d_ff']}")
137
+ print(f" vocab_size: {self.config['vocab_size']}")
138
+ print(f" LoRA rank: {self.config.get('lora_rank', 16)}")
139
+ print(f" Thinking: {'enabled' if self.use_thinking else 'disabled'}")
140
+ print(f" Total parameters: {total:,}")
141
+ print(f" Trainable parameters: {trainable:,}")
142
+ print(f"{'='*50}\n")
143
+
144
+ def forward(self, input_ids: torch.Tensor,
145
+ attention_mask: Optional[torch.Tensor] = None,
146
+ labels: Optional[torch.Tensor] = None,
147
+ n_think_loops: int = 1) -> Dict[str, torch.Tensor]:
148
+ """
149
+ Forward pass with optional thinking loops.
150
+
151
+ n_think_loops: How many times to loop through shared layers.
152
+ During reasoning, we increase this to give the model more "thinking time".
153
+ """
154
+ B, T = input_ids.shape
155
+
156
+ # Embeddings
157
+ x = self.token_embedding(input_ids)
158
+ x = self.embedding_dropout(x)
159
+
160
+ # Padding mask
161
+ pad_mask = None
162
+ if attention_mask is not None:
163
+ pad_mask = (attention_mask == 0) # True where padded
164
+
165
+ # Entry layers
166
+ for layer in self.entry_layers:
167
+ x = layer(x, pad_mask)
168
+
169
+ # Shared layers with thinking loops
170
+ actual_loops = max(1, n_think_loops)
171
+ if self.use_thinking and actual_loops > 1:
172
+ # Store the "pre-think" state
173
+ x_original = x
174
+ for loop in range(actual_loops):
175
+ for layer in self.shared_layers:
176
+ x = layer(x, pad_mask)
177
+ if loop < actual_loops - 1:
178
+ # Blend with original (residual thinking)
179
+ gate = torch.sigmoid(self.think_gate)
180
+ x = gate * x + (1 - gate) * x_original
181
+ else:
182
+ for layer in self.shared_layers:
183
+ x = layer(x, pad_mask)
184
+
185
+ # Exit layers
186
+ for layer in self.exit_layers:
187
+ x = layer(x, pad_mask)
188
+
189
+ # Output
190
+ x = self.final_norm(x)
191
+ logits = self.output_head(x)
192
+
193
+ result = {"logits": logits}
194
+
195
+ if labels is not None:
196
+ # Shift for autoregressive loss
197
+ shift_logits = logits[:, :-1, :].contiguous()
198
+ shift_labels = labels[:, 1:].contiguous()
199
+
200
+ loss = F.cross_entropy(
201
+ shift_logits.view(-1, shift_logits.size(-1)),
202
+ shift_labels.view(-1),
203
+ ignore_index=0, # PAD token
204
+ label_smoothing=0.05 # Slight smoothing for better generalization
205
+ )
206
+ result["loss"] = loss
207
+
208
+ return result
209
+
210
+ @torch.no_grad()
211
+ def generate(self, input_ids: torch.Tensor, max_new_tokens: int = 100,
212
+ temperature: float = 0.7, top_k: int = 50, top_p: float = 0.9,
213
+ n_think_loops: int = 1, eos_token_id: int = 2) -> torch.Tensor:
214
+ """
215
+ Autoregressive generation with temperature, top-k, and top-p sampling.
216
+
217
+ Uses nucleus (top-p) sampling for diverse but coherent generation.
218
+ """
219
+ self.eval()
220
+ generated = input_ids.clone()
221
+
222
+ for _ in range(max_new_tokens):
223
+ # Truncate to max_seq_len
224
+ context = generated[:, -self.config['max_seq_len']:]
225
+
226
+ outputs = self.forward(context, n_think_loops=n_think_loops)
227
+ logits = outputs["logits"][:, -1, :] / max(temperature, 1e-5)
228
+
229
+ # Top-k filtering
230
+ if top_k > 0:
231
+ top_k_val = min(top_k, logits.size(-1))
232
+ indices_to_remove = logits < torch.topk(logits, top_k_val)[0][..., -1, None]
233
+ logits[indices_to_remove] = float('-inf')
234
+
235
+ # Top-p (nucleus) filtering
236
+ if top_p < 1.0:
237
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
238
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
239
+ sorted_indices_to_remove = cumulative_probs > top_p
240
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
241
+ sorted_indices_to_remove[..., 0] = 0
242
+ indices_to_remove = sorted_indices_to_remove.scatter(
243
+ 1, sorted_indices, sorted_indices_to_remove)
244
+ logits[indices_to_remove] = float('-inf')
245
+
246
+ probs = F.softmax(logits, dim=-1)
247
+ next_token = torch.multinomial(probs, num_samples=1)
248
+ generated = torch.cat([generated, next_token], dim=1)
249
+
250
+ if next_token.item() == eos_token_id:
251
+ break
252
+
253
+ return generated
254
+
255
+ def save(self, path: str):
256
+ """Save model state dict and config."""
257
+ import os, json
258
+ os.makedirs(path, exist_ok=True)
259
+ torch.save(self.state_dict(), os.path.join(path, "model.pt"))
260
+ with open(os.path.join(path, "config.json"), 'w') as f:
261
+ json.dump(self.config, f, indent=2)
262
+ print(f"Model saved to {path}")
263
+
264
+ @classmethod
265
+ def load(cls, path: str, device: str = 'cpu') -> 'NanoReasoningModel':
266
+ """Load model from saved state."""
267
+ import os, json
268
+ with open(os.path.join(path, "config.json"), 'r') as f:
269
+ config = json.load(f)
270
+ model = cls(config)
271
+ model.load_state_dict(torch.load(os.path.join(path, "model.pt"),
272
+ map_location=device))
273
+ return model