MiniMind / capabilities /reasoning.py
fariasultana's picture
feat: Add capabilities/reasoning.py
0108694 verified
"""
Complex Reasoning Module for MiniMind Max2
Chain-of-Thought distillation from larger models (DeepSeek-R1, OpenAI o1).
"""
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Any, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import json
import re
@dataclass
class ReasoningConfig:
"""Configuration for reasoning capabilities."""
# Special tokens for reasoning
think_start_token: str = "<think>"
think_end_token: str = "</think>"
step_token: str = "<step>"
# Training settings
max_reasoning_steps: int = 10
reasoning_temperature: float = 0.7
distillation_temperature: float = 2.0
alpha_reasoning: float = 0.5 # Weight for reasoning loss vs answer loss
# Reasoning patterns
enable_self_reflection: bool = True
enable_step_verification: bool = True
min_reasoning_tokens: int = 50
max_reasoning_tokens: int = 512
class ReasoningTokenizer:
"""Handles special reasoning tokens."""
SPECIAL_TOKENS = {
"think_start": "<think>",
"think_end": "</think>",
"step": "<step>",
"verify": "<verify>",
"reflect": "<reflect>",
"conclude": "<conclude>",
}
@classmethod
def wrap_reasoning(cls, reasoning_text: str) -> str:
"""Wrap reasoning in think tokens."""
return f"{cls.SPECIAL_TOKENS['think_start']}{reasoning_text}{cls.SPECIAL_TOKENS['think_end']}"
@classmethod
def extract_reasoning(cls, text: str) -> Tuple[str, str]:
"""Extract reasoning and answer from model output."""
pattern = rf"{re.escape(cls.SPECIAL_TOKENS['think_start'])}(.*?){re.escape(cls.SPECIAL_TOKENS['think_end'])}"
match = re.search(pattern, text, re.DOTALL)
if match:
reasoning = match.group(1).strip()
answer = text[match.end():].strip()
return reasoning, answer
return "", text
@classmethod
def format_cot_prompt(cls, question: str, reasoning_steps: List[str], answer: str) -> str:
"""Format a Chain-of-Thought training example."""
steps_text = f"\n{cls.SPECIAL_TOKENS['step']}".join(reasoning_steps)
reasoning = f"{cls.SPECIAL_TOKENS['step']}{steps_text}"
return f"{question}\n{cls.wrap_reasoning(reasoning)}\n{answer}"
class ReasoningModule(nn.Module):
"""
Reasoning enhancement module for MiniMind Max2.
Adds internal monologue capability for complex reasoning tasks.
"""
def __init__(self, config: ReasoningConfig, hidden_size: int):
super().__init__()
self.config = config
self.hidden_size = hidden_size
# Reasoning state classifier
self.reasoning_gate = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 2),
nn.GELU(),
nn.Linear(hidden_size // 2, 3), # [continue_reasoning, stop_reasoning, output_answer]
)
# Step quality predictor (for self-verification)
self.step_verifier = nn.Sequential(
nn.Linear(hidden_size, hidden_size // 4),
nn.GELU(),
nn.Linear(hidden_size // 4, 1),
nn.Sigmoid(),
)
# Reasoning depth adapter
self.depth_adapter = nn.Linear(hidden_size, hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
reasoning_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Process hidden states with reasoning awareness.
Args:
hidden_states: [batch, seq_len, hidden_size]
reasoning_mask: Binary mask indicating reasoning tokens
Returns:
Enhanced hidden states and reasoning metrics
"""
batch_size, seq_len, _ = hidden_states.shape
# Compute reasoning gate decisions
gate_logits = self.reasoning_gate(hidden_states)
gate_probs = F.softmax(gate_logits, dim=-1)
# Verify step quality
step_quality = self.step_verifier(hidden_states).squeeze(-1)
# Apply depth adaptation for reasoning tokens
if reasoning_mask is not None:
adapted = self.depth_adapter(hidden_states)
reasoning_mask_expanded = reasoning_mask.unsqueeze(-1).float()
hidden_states = hidden_states + adapted * reasoning_mask_expanded
metrics = {
"gate_probs": gate_probs,
"step_quality": step_quality,
"reasoning_ratio": reasoning_mask.float().mean() if reasoning_mask is not None else torch.tensor(0.0),
}
return hidden_states, metrics
def compute_reasoning_loss(
self,
hidden_states: torch.Tensor,
reasoning_labels: torch.Tensor,
step_boundaries: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute auxiliary loss for reasoning quality."""
# Gate prediction loss
gate_logits = self.reasoning_gate(hidden_states)
gate_loss = F.cross_entropy(
gate_logits.view(-1, 3),
reasoning_labels.view(-1),
ignore_index=-100,
)
# Step verification loss (if boundaries provided)
if step_boundaries is not None:
step_quality = self.step_verifier(hidden_states).squeeze(-1)
verification_loss = F.binary_cross_entropy(
step_quality,
step_boundaries.float(),
)
gate_loss = gate_loss + 0.1 * verification_loss
return gate_loss
class ChainOfThoughtDataset(Dataset):
"""Dataset for Chain-of-Thought training."""
def __init__(
self,
data_path: str,
tokenizer,
max_length: int = 2048,
config: Optional[ReasoningConfig] = None,
):
self.tokenizer = tokenizer
self.max_length = max_length
self.config = config or ReasoningConfig()
self.examples = []
# Load data
with open(data_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
example = json.loads(line)
self.examples.append(example)
def __len__(self) -> int:
return len(self.examples)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
example = self.examples[idx]
# Format: question, reasoning trace, answer
question = example.get("question", example.get("prompt", ""))
reasoning = example.get("reasoning", example.get("thinking", ""))
answer = example.get("answer", example.get("response", ""))
# Build full text with reasoning tokens
full_text = ReasoningTokenizer.format_cot_prompt(
question,
reasoning.split("\n") if isinstance(reasoning, str) else reasoning,
answer,
)
# Tokenize
encodings = self.tokenizer(
full_text,
max_length=self.max_length,
truncation=True,
padding="max_length",
return_tensors="pt",
)
input_ids = encodings["input_ids"].squeeze(0)
attention_mask = encodings["attention_mask"].squeeze(0)
# Create reasoning mask (tokens between <think> and </think>)
reasoning_mask = self._create_reasoning_mask(input_ids)
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": input_ids.clone(),
"reasoning_mask": reasoning_mask,
}
def _create_reasoning_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Create binary mask for reasoning tokens."""
# This is a simplified version - actual implementation would use token IDs
mask = torch.zeros_like(input_ids)
# In practice, find think_start and think_end token positions
return mask
class ChainOfThoughtTrainer:
"""
Trainer for Chain-of-Thought distillation.
Distills reasoning capabilities from larger models.
"""
def __init__(
self,
student_model: nn.Module,
teacher_model: Optional[nn.Module] = None,
config: Optional[ReasoningConfig] = None,
learning_rate: float = 1e-5,
device: str = "cuda",
):
self.student = student_model
self.teacher = teacher_model
self.config = config or ReasoningConfig()
self.device = device
# Add reasoning module to student
if hasattr(student_model, 'config'):
hidden_size = student_model.config.hidden_size
else:
hidden_size = 1024 # Default
self.reasoning_module = ReasoningModule(self.config, hidden_size).to(device)
# Optimizer
params = list(student_model.parameters()) + list(self.reasoning_module.parameters())
self.optimizer = torch.optim.AdamW(params, lr=learning_rate)
# Freeze teacher if provided
if self.teacher is not None:
self.teacher.eval()
for param in self.teacher.parameters():
param.requires_grad = False
def distillation_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
temperature: float = 2.0,
) -> torch.Tensor:
"""Compute KL divergence distillation loss."""
student_probs = F.log_softmax(student_logits / temperature, dim=-1)
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1)
loss = F.kl_div(student_probs, teacher_probs, reduction="batchmean")
return loss * (temperature ** 2)
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Single training step."""
self.student.train()
self.reasoning_module.train()
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
labels = batch["labels"].to(self.device)
reasoning_mask = batch.get("reasoning_mask", None)
if reasoning_mask is not None:
reasoning_mask = reasoning_mask.to(self.device)
# Student forward
loss, student_logits, _, aux_loss = self.student(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
total_loss = loss
metrics = {"ce_loss": loss.item(), "aux_loss": aux_loss.item()}
# Distillation from teacher
if self.teacher is not None:
with torch.no_grad():
_, teacher_logits, _, _ = self.teacher(
input_ids=input_ids,
attention_mask=attention_mask,
)
distill_loss = self.distillation_loss(
student_logits,
teacher_logits,
self.config.distillation_temperature,
)
total_loss = (1 - self.config.alpha_reasoning) * loss + self.config.alpha_reasoning * distill_loss
metrics["distill_loss"] = distill_loss.item()
# Backward
self.optimizer.zero_grad()
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.student.parameters(), 1.0)
self.optimizer.step()
metrics["total_loss"] = total_loss.item()
return metrics
def train(
self,
train_dataloader: DataLoader,
num_epochs: int = 3,
eval_dataloader: Optional[DataLoader] = None,
) -> Dict[str, List[float]]:
"""Full training loop."""
history = {"train_loss": [], "eval_loss": []}
for epoch in range(num_epochs):
epoch_losses = []
for batch in train_dataloader:
metrics = self.train_step(batch)
epoch_losses.append(metrics["total_loss"])
avg_loss = sum(epoch_losses) / len(epoch_losses)
history["train_loss"].append(avg_loss)
print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_loss:.4f}")
# Evaluation
if eval_dataloader is not None:
eval_loss = self.evaluate(eval_dataloader)
history["eval_loss"].append(eval_loss)
print(f" Eval Loss: {eval_loss:.4f}")
return history
def evaluate(self, dataloader: DataLoader) -> float:
"""Evaluate on validation set."""
self.student.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch in dataloader:
input_ids = batch["input_ids"].to(self.device)
attention_mask = batch["attention_mask"].to(self.device)
labels = batch["labels"].to(self.device)
loss, _, _, _ = self.student(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
total_loss += loss.item()
num_batches += 1
return total_loss / num_batches if num_batches > 0 else 0.0
def prepare_openr1_dataset(
raw_data_path: str,
output_path: str,
config: Optional[ReasoningConfig] = None,
) -> int:
"""
Prepare OpenR1 or DeepSeek-R1 distillation data.
Converts raw reasoning traces to training format.
"""
config = config or ReasoningConfig()
processed = 0
with open(raw_data_path, 'r', encoding='utf-8') as fin, \
open(output_path, 'w', encoding='utf-8') as fout:
for line in fin:
if not line.strip():
continue
data = json.loads(line)
# Extract components (format varies by source)
question = data.get("question", data.get("prompt", data.get("input", "")))
# Handle different reasoning formats
if "thinking" in data:
reasoning = data["thinking"]
elif "reasoning" in data:
reasoning = data["reasoning"]
elif "chain_of_thought" in data:
reasoning = data["chain_of_thought"]
else:
continue # Skip if no reasoning trace
answer = data.get("answer", data.get("response", data.get("output", "")))
# Format for training
processed_example = {
"question": question,
"reasoning": reasoning,
"answer": answer,
}
fout.write(json.dumps(processed_example, ensure_ascii=False) + "\n")
processed += 1
return processed