""" Complex Reasoning Module for MiniMind Max2 Chain-of-Thought distillation from larger models (DeepSeek-R1, OpenAI o1). """ from dataclasses import dataclass, field from typing import List, Optional, Dict, Any, Tuple import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader import json import re @dataclass class ReasoningConfig: """Configuration for reasoning capabilities.""" # Special tokens for reasoning think_start_token: str = "" think_end_token: str = "" step_token: str = "" # Training settings max_reasoning_steps: int = 10 reasoning_temperature: float = 0.7 distillation_temperature: float = 2.0 alpha_reasoning: float = 0.5 # Weight for reasoning loss vs answer loss # Reasoning patterns enable_self_reflection: bool = True enable_step_verification: bool = True min_reasoning_tokens: int = 50 max_reasoning_tokens: int = 512 class ReasoningTokenizer: """Handles special reasoning tokens.""" SPECIAL_TOKENS = { "think_start": "", "think_end": "", "step": "", "verify": "", "reflect": "", "conclude": "", } @classmethod def wrap_reasoning(cls, reasoning_text: str) -> str: """Wrap reasoning in think tokens.""" return f"{cls.SPECIAL_TOKENS['think_start']}{reasoning_text}{cls.SPECIAL_TOKENS['think_end']}" @classmethod def extract_reasoning(cls, text: str) -> Tuple[str, str]: """Extract reasoning and answer from model output.""" pattern = rf"{re.escape(cls.SPECIAL_TOKENS['think_start'])}(.*?){re.escape(cls.SPECIAL_TOKENS['think_end'])}" match = re.search(pattern, text, re.DOTALL) if match: reasoning = match.group(1).strip() answer = text[match.end():].strip() return reasoning, answer return "", text @classmethod def format_cot_prompt(cls, question: str, reasoning_steps: List[str], answer: str) -> str: """Format a Chain-of-Thought training example.""" steps_text = f"\n{cls.SPECIAL_TOKENS['step']}".join(reasoning_steps) reasoning = f"{cls.SPECIAL_TOKENS['step']}{steps_text}" return f"{question}\n{cls.wrap_reasoning(reasoning)}\n{answer}" class ReasoningModule(nn.Module): """ Reasoning enhancement module for MiniMind Max2. Adds internal monologue capability for complex reasoning tasks. """ def __init__(self, config: ReasoningConfig, hidden_size: int): super().__init__() self.config = config self.hidden_size = hidden_size # Reasoning state classifier self.reasoning_gate = nn.Sequential( nn.Linear(hidden_size, hidden_size // 2), nn.GELU(), nn.Linear(hidden_size // 2, 3), # [continue_reasoning, stop_reasoning, output_answer] ) # Step quality predictor (for self-verification) self.step_verifier = nn.Sequential( nn.Linear(hidden_size, hidden_size // 4), nn.GELU(), nn.Linear(hidden_size // 4, 1), nn.Sigmoid(), ) # Reasoning depth adapter self.depth_adapter = nn.Linear(hidden_size, hidden_size) def forward( self, hidden_states: torch.Tensor, reasoning_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Process hidden states with reasoning awareness. Args: hidden_states: [batch, seq_len, hidden_size] reasoning_mask: Binary mask indicating reasoning tokens Returns: Enhanced hidden states and reasoning metrics """ batch_size, seq_len, _ = hidden_states.shape # Compute reasoning gate decisions gate_logits = self.reasoning_gate(hidden_states) gate_probs = F.softmax(gate_logits, dim=-1) # Verify step quality step_quality = self.step_verifier(hidden_states).squeeze(-1) # Apply depth adaptation for reasoning tokens if reasoning_mask is not None: adapted = self.depth_adapter(hidden_states) reasoning_mask_expanded = reasoning_mask.unsqueeze(-1).float() hidden_states = hidden_states + adapted * reasoning_mask_expanded metrics = { "gate_probs": gate_probs, "step_quality": step_quality, "reasoning_ratio": reasoning_mask.float().mean() if reasoning_mask is not None else torch.tensor(0.0), } return hidden_states, metrics def compute_reasoning_loss( self, hidden_states: torch.Tensor, reasoning_labels: torch.Tensor, step_boundaries: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Compute auxiliary loss for reasoning quality.""" # Gate prediction loss gate_logits = self.reasoning_gate(hidden_states) gate_loss = F.cross_entropy( gate_logits.view(-1, 3), reasoning_labels.view(-1), ignore_index=-100, ) # Step verification loss (if boundaries provided) if step_boundaries is not None: step_quality = self.step_verifier(hidden_states).squeeze(-1) verification_loss = F.binary_cross_entropy( step_quality, step_boundaries.float(), ) gate_loss = gate_loss + 0.1 * verification_loss return gate_loss class ChainOfThoughtDataset(Dataset): """Dataset for Chain-of-Thought training.""" def __init__( self, data_path: str, tokenizer, max_length: int = 2048, config: Optional[ReasoningConfig] = None, ): self.tokenizer = tokenizer self.max_length = max_length self.config = config or ReasoningConfig() self.examples = [] # Load data with open(data_path, 'r', encoding='utf-8') as f: for line in f: if line.strip(): example = json.loads(line) self.examples.append(example) def __len__(self) -> int: return len(self.examples) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: example = self.examples[idx] # Format: question, reasoning trace, answer question = example.get("question", example.get("prompt", "")) reasoning = example.get("reasoning", example.get("thinking", "")) answer = example.get("answer", example.get("response", "")) # Build full text with reasoning tokens full_text = ReasoningTokenizer.format_cot_prompt( question, reasoning.split("\n") if isinstance(reasoning, str) else reasoning, answer, ) # Tokenize encodings = self.tokenizer( full_text, max_length=self.max_length, truncation=True, padding="max_length", return_tensors="pt", ) input_ids = encodings["input_ids"].squeeze(0) attention_mask = encodings["attention_mask"].squeeze(0) # Create reasoning mask (tokens between and ) reasoning_mask = self._create_reasoning_mask(input_ids) return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": input_ids.clone(), "reasoning_mask": reasoning_mask, } def _create_reasoning_mask(self, input_ids: torch.Tensor) -> torch.Tensor: """Create binary mask for reasoning tokens.""" # This is a simplified version - actual implementation would use token IDs mask = torch.zeros_like(input_ids) # In practice, find think_start and think_end token positions return mask class ChainOfThoughtTrainer: """ Trainer for Chain-of-Thought distillation. Distills reasoning capabilities from larger models. """ def __init__( self, student_model: nn.Module, teacher_model: Optional[nn.Module] = None, config: Optional[ReasoningConfig] = None, learning_rate: float = 1e-5, device: str = "cuda", ): self.student = student_model self.teacher = teacher_model self.config = config or ReasoningConfig() self.device = device # Add reasoning module to student if hasattr(student_model, 'config'): hidden_size = student_model.config.hidden_size else: hidden_size = 1024 # Default self.reasoning_module = ReasoningModule(self.config, hidden_size).to(device) # Optimizer params = list(student_model.parameters()) + list(self.reasoning_module.parameters()) self.optimizer = torch.optim.AdamW(params, lr=learning_rate) # Freeze teacher if provided if self.teacher is not None: self.teacher.eval() for param in self.teacher.parameters(): param.requires_grad = False def distillation_loss( self, student_logits: torch.Tensor, teacher_logits: torch.Tensor, temperature: float = 2.0, ) -> torch.Tensor: """Compute KL divergence distillation loss.""" student_probs = F.log_softmax(student_logits / temperature, dim=-1) teacher_probs = F.softmax(teacher_logits / temperature, dim=-1) loss = F.kl_div(student_probs, teacher_probs, reduction="batchmean") return loss * (temperature ** 2) def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: """Single training step.""" self.student.train() self.reasoning_module.train() input_ids = batch["input_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) labels = batch["labels"].to(self.device) reasoning_mask = batch.get("reasoning_mask", None) if reasoning_mask is not None: reasoning_mask = reasoning_mask.to(self.device) # Student forward loss, student_logits, _, aux_loss = self.student( input_ids=input_ids, attention_mask=attention_mask, labels=labels, ) total_loss = loss metrics = {"ce_loss": loss.item(), "aux_loss": aux_loss.item()} # Distillation from teacher if self.teacher is not None: with torch.no_grad(): _, teacher_logits, _, _ = self.teacher( input_ids=input_ids, attention_mask=attention_mask, ) distill_loss = self.distillation_loss( student_logits, teacher_logits, self.config.distillation_temperature, ) total_loss = (1 - self.config.alpha_reasoning) * loss + self.config.alpha_reasoning * distill_loss metrics["distill_loss"] = distill_loss.item() # Backward self.optimizer.zero_grad() total_loss.backward() torch.nn.utils.clip_grad_norm_(self.student.parameters(), 1.0) self.optimizer.step() metrics["total_loss"] = total_loss.item() return metrics def train( self, train_dataloader: DataLoader, num_epochs: int = 3, eval_dataloader: Optional[DataLoader] = None, ) -> Dict[str, List[float]]: """Full training loop.""" history = {"train_loss": [], "eval_loss": []} for epoch in range(num_epochs): epoch_losses = [] for batch in train_dataloader: metrics = self.train_step(batch) epoch_losses.append(metrics["total_loss"]) avg_loss = sum(epoch_losses) / len(epoch_losses) history["train_loss"].append(avg_loss) print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_loss:.4f}") # Evaluation if eval_dataloader is not None: eval_loss = self.evaluate(eval_dataloader) history["eval_loss"].append(eval_loss) print(f" Eval Loss: {eval_loss:.4f}") return history def evaluate(self, dataloader: DataLoader) -> float: """Evaluate on validation set.""" self.student.eval() total_loss = 0.0 num_batches = 0 with torch.no_grad(): for batch in dataloader: input_ids = batch["input_ids"].to(self.device) attention_mask = batch["attention_mask"].to(self.device) labels = batch["labels"].to(self.device) loss, _, _, _ = self.student( input_ids=input_ids, attention_mask=attention_mask, labels=labels, ) total_loss += loss.item() num_batches += 1 return total_loss / num_batches if num_batches > 0 else 0.0 def prepare_openr1_dataset( raw_data_path: str, output_path: str, config: Optional[ReasoningConfig] = None, ) -> int: """ Prepare OpenR1 or DeepSeek-R1 distillation data. Converts raw reasoning traces to training format. """ config = config or ReasoningConfig() processed = 0 with open(raw_data_path, 'r', encoding='utf-8') as fin, \ open(output_path, 'w', encoding='utf-8') as fout: for line in fin: if not line.strip(): continue data = json.loads(line) # Extract components (format varies by source) question = data.get("question", data.get("prompt", data.get("input", ""))) # Handle different reasoning formats if "thinking" in data: reasoning = data["thinking"] elif "reasoning" in data: reasoning = data["reasoning"] elif "chain_of_thought" in data: reasoning = data["chain_of_thought"] else: continue # Skip if no reasoning trace answer = data.get("answer", data.get("response", data.get("output", ""))) # Format for training processed_example = { "question": question, "reasoning": reasoning, "answer": answer, } fout.write(json.dumps(processed_example, ensure_ascii=False) + "\n") processed += 1 return processed