fariasultana commited on
Commit
0108694
·
verified ·
1 Parent(s): 51efa41

feat: Add capabilities/reasoning.py

Browse files
Files changed (1) hide show
  1. capabilities/reasoning.py +432 -0
capabilities/reasoning.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Complex Reasoning Module for MiniMind Max2
3
+ Chain-of-Thought distillation from larger models (DeepSeek-R1, OpenAI o1).
4
+ """
5
+
6
+ from dataclasses import dataclass, field
7
+ from typing import List, Optional, Dict, Any, Tuple
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.utils.data import Dataset, DataLoader
12
+ import json
13
+ import re
14
+
15
+
16
+ @dataclass
17
+ class ReasoningConfig:
18
+ """Configuration for reasoning capabilities."""
19
+ # Special tokens for reasoning
20
+ think_start_token: str = "<think>"
21
+ think_end_token: str = "</think>"
22
+ step_token: str = "<step>"
23
+
24
+ # Training settings
25
+ max_reasoning_steps: int = 10
26
+ reasoning_temperature: float = 0.7
27
+ distillation_temperature: float = 2.0
28
+ alpha_reasoning: float = 0.5 # Weight for reasoning loss vs answer loss
29
+
30
+ # Reasoning patterns
31
+ enable_self_reflection: bool = True
32
+ enable_step_verification: bool = True
33
+ min_reasoning_tokens: int = 50
34
+ max_reasoning_tokens: int = 512
35
+
36
+
37
+ class ReasoningTokenizer:
38
+ """Handles special reasoning tokens."""
39
+
40
+ SPECIAL_TOKENS = {
41
+ "think_start": "<think>",
42
+ "think_end": "</think>",
43
+ "step": "<step>",
44
+ "verify": "<verify>",
45
+ "reflect": "<reflect>",
46
+ "conclude": "<conclude>",
47
+ }
48
+
49
+ @classmethod
50
+ def wrap_reasoning(cls, reasoning_text: str) -> str:
51
+ """Wrap reasoning in think tokens."""
52
+ return f"{cls.SPECIAL_TOKENS['think_start']}{reasoning_text}{cls.SPECIAL_TOKENS['think_end']}"
53
+
54
+ @classmethod
55
+ def extract_reasoning(cls, text: str) -> Tuple[str, str]:
56
+ """Extract reasoning and answer from model output."""
57
+ pattern = rf"{re.escape(cls.SPECIAL_TOKENS['think_start'])}(.*?){re.escape(cls.SPECIAL_TOKENS['think_end'])}"
58
+ match = re.search(pattern, text, re.DOTALL)
59
+
60
+ if match:
61
+ reasoning = match.group(1).strip()
62
+ answer = text[match.end():].strip()
63
+ return reasoning, answer
64
+ return "", text
65
+
66
+ @classmethod
67
+ def format_cot_prompt(cls, question: str, reasoning_steps: List[str], answer: str) -> str:
68
+ """Format a Chain-of-Thought training example."""
69
+ steps_text = f"\n{cls.SPECIAL_TOKENS['step']}".join(reasoning_steps)
70
+ reasoning = f"{cls.SPECIAL_TOKENS['step']}{steps_text}"
71
+ return f"{question}\n{cls.wrap_reasoning(reasoning)}\n{answer}"
72
+
73
+
74
+ class ReasoningModule(nn.Module):
75
+ """
76
+ Reasoning enhancement module for MiniMind Max2.
77
+ Adds internal monologue capability for complex reasoning tasks.
78
+ """
79
+
80
+ def __init__(self, config: ReasoningConfig, hidden_size: int):
81
+ super().__init__()
82
+ self.config = config
83
+ self.hidden_size = hidden_size
84
+
85
+ # Reasoning state classifier
86
+ self.reasoning_gate = nn.Sequential(
87
+ nn.Linear(hidden_size, hidden_size // 2),
88
+ nn.GELU(),
89
+ nn.Linear(hidden_size // 2, 3), # [continue_reasoning, stop_reasoning, output_answer]
90
+ )
91
+
92
+ # Step quality predictor (for self-verification)
93
+ self.step_verifier = nn.Sequential(
94
+ nn.Linear(hidden_size, hidden_size // 4),
95
+ nn.GELU(),
96
+ nn.Linear(hidden_size // 4, 1),
97
+ nn.Sigmoid(),
98
+ )
99
+
100
+ # Reasoning depth adapter
101
+ self.depth_adapter = nn.Linear(hidden_size, hidden_size)
102
+
103
+ def forward(
104
+ self,
105
+ hidden_states: torch.Tensor,
106
+ reasoning_mask: Optional[torch.Tensor] = None,
107
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
108
+ """
109
+ Process hidden states with reasoning awareness.
110
+
111
+ Args:
112
+ hidden_states: [batch, seq_len, hidden_size]
113
+ reasoning_mask: Binary mask indicating reasoning tokens
114
+
115
+ Returns:
116
+ Enhanced hidden states and reasoning metrics
117
+ """
118
+ batch_size, seq_len, _ = hidden_states.shape
119
+
120
+ # Compute reasoning gate decisions
121
+ gate_logits = self.reasoning_gate(hidden_states)
122
+ gate_probs = F.softmax(gate_logits, dim=-1)
123
+
124
+ # Verify step quality
125
+ step_quality = self.step_verifier(hidden_states).squeeze(-1)
126
+
127
+ # Apply depth adaptation for reasoning tokens
128
+ if reasoning_mask is not None:
129
+ adapted = self.depth_adapter(hidden_states)
130
+ reasoning_mask_expanded = reasoning_mask.unsqueeze(-1).float()
131
+ hidden_states = hidden_states + adapted * reasoning_mask_expanded
132
+
133
+ metrics = {
134
+ "gate_probs": gate_probs,
135
+ "step_quality": step_quality,
136
+ "reasoning_ratio": reasoning_mask.float().mean() if reasoning_mask is not None else torch.tensor(0.0),
137
+ }
138
+
139
+ return hidden_states, metrics
140
+
141
+ def compute_reasoning_loss(
142
+ self,
143
+ hidden_states: torch.Tensor,
144
+ reasoning_labels: torch.Tensor,
145
+ step_boundaries: Optional[torch.Tensor] = None,
146
+ ) -> torch.Tensor:
147
+ """Compute auxiliary loss for reasoning quality."""
148
+ # Gate prediction loss
149
+ gate_logits = self.reasoning_gate(hidden_states)
150
+ gate_loss = F.cross_entropy(
151
+ gate_logits.view(-1, 3),
152
+ reasoning_labels.view(-1),
153
+ ignore_index=-100,
154
+ )
155
+
156
+ # Step verification loss (if boundaries provided)
157
+ if step_boundaries is not None:
158
+ step_quality = self.step_verifier(hidden_states).squeeze(-1)
159
+ verification_loss = F.binary_cross_entropy(
160
+ step_quality,
161
+ step_boundaries.float(),
162
+ )
163
+ gate_loss = gate_loss + 0.1 * verification_loss
164
+
165
+ return gate_loss
166
+
167
+
168
+ class ChainOfThoughtDataset(Dataset):
169
+ """Dataset for Chain-of-Thought training."""
170
+
171
+ def __init__(
172
+ self,
173
+ data_path: str,
174
+ tokenizer,
175
+ max_length: int = 2048,
176
+ config: Optional[ReasoningConfig] = None,
177
+ ):
178
+ self.tokenizer = tokenizer
179
+ self.max_length = max_length
180
+ self.config = config or ReasoningConfig()
181
+ self.examples = []
182
+
183
+ # Load data
184
+ with open(data_path, 'r', encoding='utf-8') as f:
185
+ for line in f:
186
+ if line.strip():
187
+ example = json.loads(line)
188
+ self.examples.append(example)
189
+
190
+ def __len__(self) -> int:
191
+ return len(self.examples)
192
+
193
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
194
+ example = self.examples[idx]
195
+
196
+ # Format: question, reasoning trace, answer
197
+ question = example.get("question", example.get("prompt", ""))
198
+ reasoning = example.get("reasoning", example.get("thinking", ""))
199
+ answer = example.get("answer", example.get("response", ""))
200
+
201
+ # Build full text with reasoning tokens
202
+ full_text = ReasoningTokenizer.format_cot_prompt(
203
+ question,
204
+ reasoning.split("\n") if isinstance(reasoning, str) else reasoning,
205
+ answer,
206
+ )
207
+
208
+ # Tokenize
209
+ encodings = self.tokenizer(
210
+ full_text,
211
+ max_length=self.max_length,
212
+ truncation=True,
213
+ padding="max_length",
214
+ return_tensors="pt",
215
+ )
216
+
217
+ input_ids = encodings["input_ids"].squeeze(0)
218
+ attention_mask = encodings["attention_mask"].squeeze(0)
219
+
220
+ # Create reasoning mask (tokens between <think> and </think>)
221
+ reasoning_mask = self._create_reasoning_mask(input_ids)
222
+
223
+ return {
224
+ "input_ids": input_ids,
225
+ "attention_mask": attention_mask,
226
+ "labels": input_ids.clone(),
227
+ "reasoning_mask": reasoning_mask,
228
+ }
229
+
230
+ def _create_reasoning_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
231
+ """Create binary mask for reasoning tokens."""
232
+ # This is a simplified version - actual implementation would use token IDs
233
+ mask = torch.zeros_like(input_ids)
234
+ # In practice, find think_start and think_end token positions
235
+ return mask
236
+
237
+
238
+ class ChainOfThoughtTrainer:
239
+ """
240
+ Trainer for Chain-of-Thought distillation.
241
+ Distills reasoning capabilities from larger models.
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ student_model: nn.Module,
247
+ teacher_model: Optional[nn.Module] = None,
248
+ config: Optional[ReasoningConfig] = None,
249
+ learning_rate: float = 1e-5,
250
+ device: str = "cuda",
251
+ ):
252
+ self.student = student_model
253
+ self.teacher = teacher_model
254
+ self.config = config or ReasoningConfig()
255
+ self.device = device
256
+
257
+ # Add reasoning module to student
258
+ if hasattr(student_model, 'config'):
259
+ hidden_size = student_model.config.hidden_size
260
+ else:
261
+ hidden_size = 1024 # Default
262
+
263
+ self.reasoning_module = ReasoningModule(self.config, hidden_size).to(device)
264
+
265
+ # Optimizer
266
+ params = list(student_model.parameters()) + list(self.reasoning_module.parameters())
267
+ self.optimizer = torch.optim.AdamW(params, lr=learning_rate)
268
+
269
+ # Freeze teacher if provided
270
+ if self.teacher is not None:
271
+ self.teacher.eval()
272
+ for param in self.teacher.parameters():
273
+ param.requires_grad = False
274
+
275
+ def distillation_loss(
276
+ self,
277
+ student_logits: torch.Tensor,
278
+ teacher_logits: torch.Tensor,
279
+ temperature: float = 2.0,
280
+ ) -> torch.Tensor:
281
+ """Compute KL divergence distillation loss."""
282
+ student_probs = F.log_softmax(student_logits / temperature, dim=-1)
283
+ teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
284
+
285
+ loss = F.kl_div(student_probs, teacher_probs, reduction="batchmean")
286
+ return loss * (temperature ** 2)
287
+
288
+ def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
289
+ """Single training step."""
290
+ self.student.train()
291
+ self.reasoning_module.train()
292
+
293
+ input_ids = batch["input_ids"].to(self.device)
294
+ attention_mask = batch["attention_mask"].to(self.device)
295
+ labels = batch["labels"].to(self.device)
296
+ reasoning_mask = batch.get("reasoning_mask", None)
297
+ if reasoning_mask is not None:
298
+ reasoning_mask = reasoning_mask.to(self.device)
299
+
300
+ # Student forward
301
+ loss, student_logits, _, aux_loss = self.student(
302
+ input_ids=input_ids,
303
+ attention_mask=attention_mask,
304
+ labels=labels,
305
+ )
306
+
307
+ total_loss = loss
308
+ metrics = {"ce_loss": loss.item(), "aux_loss": aux_loss.item()}
309
+
310
+ # Distillation from teacher
311
+ if self.teacher is not None:
312
+ with torch.no_grad():
313
+ _, teacher_logits, _, _ = self.teacher(
314
+ input_ids=input_ids,
315
+ attention_mask=attention_mask,
316
+ )
317
+
318
+ distill_loss = self.distillation_loss(
319
+ student_logits,
320
+ teacher_logits,
321
+ self.config.distillation_temperature,
322
+ )
323
+ total_loss = (1 - self.config.alpha_reasoning) * loss + self.config.alpha_reasoning * distill_loss
324
+ metrics["distill_loss"] = distill_loss.item()
325
+
326
+ # Backward
327
+ self.optimizer.zero_grad()
328
+ total_loss.backward()
329
+ torch.nn.utils.clip_grad_norm_(self.student.parameters(), 1.0)
330
+ self.optimizer.step()
331
+
332
+ metrics["total_loss"] = total_loss.item()
333
+ return metrics
334
+
335
+ def train(
336
+ self,
337
+ train_dataloader: DataLoader,
338
+ num_epochs: int = 3,
339
+ eval_dataloader: Optional[DataLoader] = None,
340
+ ) -> Dict[str, List[float]]:
341
+ """Full training loop."""
342
+ history = {"train_loss": [], "eval_loss": []}
343
+
344
+ for epoch in range(num_epochs):
345
+ epoch_losses = []
346
+
347
+ for batch in train_dataloader:
348
+ metrics = self.train_step(batch)
349
+ epoch_losses.append(metrics["total_loss"])
350
+
351
+ avg_loss = sum(epoch_losses) / len(epoch_losses)
352
+ history["train_loss"].append(avg_loss)
353
+ print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_loss:.4f}")
354
+
355
+ # Evaluation
356
+ if eval_dataloader is not None:
357
+ eval_loss = self.evaluate(eval_dataloader)
358
+ history["eval_loss"].append(eval_loss)
359
+ print(f" Eval Loss: {eval_loss:.4f}")
360
+
361
+ return history
362
+
363
+ def evaluate(self, dataloader: DataLoader) -> float:
364
+ """Evaluate on validation set."""
365
+ self.student.eval()
366
+ total_loss = 0.0
367
+ num_batches = 0
368
+
369
+ with torch.no_grad():
370
+ for batch in dataloader:
371
+ input_ids = batch["input_ids"].to(self.device)
372
+ attention_mask = batch["attention_mask"].to(self.device)
373
+ labels = batch["labels"].to(self.device)
374
+
375
+ loss, _, _, _ = self.student(
376
+ input_ids=input_ids,
377
+ attention_mask=attention_mask,
378
+ labels=labels,
379
+ )
380
+ total_loss += loss.item()
381
+ num_batches += 1
382
+
383
+ return total_loss / num_batches if num_batches > 0 else 0.0
384
+
385
+
386
+ def prepare_openr1_dataset(
387
+ raw_data_path: str,
388
+ output_path: str,
389
+ config: Optional[ReasoningConfig] = None,
390
+ ) -> int:
391
+ """
392
+ Prepare OpenR1 or DeepSeek-R1 distillation data.
393
+ Converts raw reasoning traces to training format.
394
+ """
395
+ config = config or ReasoningConfig()
396
+ processed = 0
397
+
398
+ with open(raw_data_path, 'r', encoding='utf-8') as fin, \
399
+ open(output_path, 'w', encoding='utf-8') as fout:
400
+
401
+ for line in fin:
402
+ if not line.strip():
403
+ continue
404
+
405
+ data = json.loads(line)
406
+
407
+ # Extract components (format varies by source)
408
+ question = data.get("question", data.get("prompt", data.get("input", "")))
409
+
410
+ # Handle different reasoning formats
411
+ if "thinking" in data:
412
+ reasoning = data["thinking"]
413
+ elif "reasoning" in data:
414
+ reasoning = data["reasoning"]
415
+ elif "chain_of_thought" in data:
416
+ reasoning = data["chain_of_thought"]
417
+ else:
418
+ continue # Skip if no reasoning trace
419
+
420
+ answer = data.get("answer", data.get("response", data.get("output", "")))
421
+
422
+ # Format for training
423
+ processed_example = {
424
+ "question": question,
425
+ "reasoning": reasoning,
426
+ "answer": answer,
427
+ }
428
+
429
+ fout.write(json.dumps(processed_example, ensure_ascii=False) + "\n")
430
+ processed += 1
431
+
432
+ return processed