|
|
""" |
|
|
Complex Reasoning Module for MiniMind Max2 |
|
|
Chain-of-Thought distillation from larger models (DeepSeek-R1, OpenAI o1). |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
from typing import List, Optional, Dict, Any, Tuple |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import json |
|
|
import re |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ReasoningConfig: |
|
|
"""Configuration for reasoning capabilities.""" |
|
|
|
|
|
think_start_token: str = "<think>" |
|
|
think_end_token: str = "</think>" |
|
|
step_token: str = "<step>" |
|
|
|
|
|
|
|
|
max_reasoning_steps: int = 10 |
|
|
reasoning_temperature: float = 0.7 |
|
|
distillation_temperature: float = 2.0 |
|
|
alpha_reasoning: float = 0.5 |
|
|
|
|
|
|
|
|
enable_self_reflection: bool = True |
|
|
enable_step_verification: bool = True |
|
|
min_reasoning_tokens: int = 50 |
|
|
max_reasoning_tokens: int = 512 |
|
|
|
|
|
|
|
|
class ReasoningTokenizer: |
|
|
"""Handles special reasoning tokens.""" |
|
|
|
|
|
SPECIAL_TOKENS = { |
|
|
"think_start": "<think>", |
|
|
"think_end": "</think>", |
|
|
"step": "<step>", |
|
|
"verify": "<verify>", |
|
|
"reflect": "<reflect>", |
|
|
"conclude": "<conclude>", |
|
|
} |
|
|
|
|
|
@classmethod |
|
|
def wrap_reasoning(cls, reasoning_text: str) -> str: |
|
|
"""Wrap reasoning in think tokens.""" |
|
|
return f"{cls.SPECIAL_TOKENS['think_start']}{reasoning_text}{cls.SPECIAL_TOKENS['think_end']}" |
|
|
|
|
|
@classmethod |
|
|
def extract_reasoning(cls, text: str) -> Tuple[str, str]: |
|
|
"""Extract reasoning and answer from model output.""" |
|
|
pattern = rf"{re.escape(cls.SPECIAL_TOKENS['think_start'])}(.*?){re.escape(cls.SPECIAL_TOKENS['think_end'])}" |
|
|
match = re.search(pattern, text, re.DOTALL) |
|
|
|
|
|
if match: |
|
|
reasoning = match.group(1).strip() |
|
|
answer = text[match.end():].strip() |
|
|
return reasoning, answer |
|
|
return "", text |
|
|
|
|
|
@classmethod |
|
|
def format_cot_prompt(cls, question: str, reasoning_steps: List[str], answer: str) -> str: |
|
|
"""Format a Chain-of-Thought training example.""" |
|
|
steps_text = f"\n{cls.SPECIAL_TOKENS['step']}".join(reasoning_steps) |
|
|
reasoning = f"{cls.SPECIAL_TOKENS['step']}{steps_text}" |
|
|
return f"{question}\n{cls.wrap_reasoning(reasoning)}\n{answer}" |
|
|
|
|
|
|
|
|
class ReasoningModule(nn.Module): |
|
|
""" |
|
|
Reasoning enhancement module for MiniMind Max2. |
|
|
Adds internal monologue capability for complex reasoning tasks. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: ReasoningConfig, hidden_size: int): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.hidden_size = hidden_size |
|
|
|
|
|
|
|
|
self.reasoning_gate = nn.Sequential( |
|
|
nn.Linear(hidden_size, hidden_size // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_size // 2, 3), |
|
|
) |
|
|
|
|
|
|
|
|
self.step_verifier = nn.Sequential( |
|
|
nn.Linear(hidden_size, hidden_size // 4), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_size // 4, 1), |
|
|
nn.Sigmoid(), |
|
|
) |
|
|
|
|
|
|
|
|
self.depth_adapter = nn.Linear(hidden_size, hidden_size) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
reasoning_mask: Optional[torch.Tensor] = None, |
|
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: |
|
|
""" |
|
|
Process hidden states with reasoning awareness. |
|
|
|
|
|
Args: |
|
|
hidden_states: [batch, seq_len, hidden_size] |
|
|
reasoning_mask: Binary mask indicating reasoning tokens |
|
|
|
|
|
Returns: |
|
|
Enhanced hidden states and reasoning metrics |
|
|
""" |
|
|
batch_size, seq_len, _ = hidden_states.shape |
|
|
|
|
|
|
|
|
gate_logits = self.reasoning_gate(hidden_states) |
|
|
gate_probs = F.softmax(gate_logits, dim=-1) |
|
|
|
|
|
|
|
|
step_quality = self.step_verifier(hidden_states).squeeze(-1) |
|
|
|
|
|
|
|
|
if reasoning_mask is not None: |
|
|
adapted = self.depth_adapter(hidden_states) |
|
|
reasoning_mask_expanded = reasoning_mask.unsqueeze(-1).float() |
|
|
hidden_states = hidden_states + adapted * reasoning_mask_expanded |
|
|
|
|
|
metrics = { |
|
|
"gate_probs": gate_probs, |
|
|
"step_quality": step_quality, |
|
|
"reasoning_ratio": reasoning_mask.float().mean() if reasoning_mask is not None else torch.tensor(0.0), |
|
|
} |
|
|
|
|
|
return hidden_states, metrics |
|
|
|
|
|
def compute_reasoning_loss( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
reasoning_labels: torch.Tensor, |
|
|
step_boundaries: Optional[torch.Tensor] = None, |
|
|
) -> torch.Tensor: |
|
|
"""Compute auxiliary loss for reasoning quality.""" |
|
|
|
|
|
gate_logits = self.reasoning_gate(hidden_states) |
|
|
gate_loss = F.cross_entropy( |
|
|
gate_logits.view(-1, 3), |
|
|
reasoning_labels.view(-1), |
|
|
ignore_index=-100, |
|
|
) |
|
|
|
|
|
|
|
|
if step_boundaries is not None: |
|
|
step_quality = self.step_verifier(hidden_states).squeeze(-1) |
|
|
verification_loss = F.binary_cross_entropy( |
|
|
step_quality, |
|
|
step_boundaries.float(), |
|
|
) |
|
|
gate_loss = gate_loss + 0.1 * verification_loss |
|
|
|
|
|
return gate_loss |
|
|
|
|
|
|
|
|
class ChainOfThoughtDataset(Dataset): |
|
|
"""Dataset for Chain-of-Thought training.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
data_path: str, |
|
|
tokenizer, |
|
|
max_length: int = 2048, |
|
|
config: Optional[ReasoningConfig] = None, |
|
|
): |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.config = config or ReasoningConfig() |
|
|
self.examples = [] |
|
|
|
|
|
|
|
|
with open(data_path, 'r', encoding='utf-8') as f: |
|
|
for line in f: |
|
|
if line.strip(): |
|
|
example = json.loads(line) |
|
|
self.examples.append(example) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.examples) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
|
example = self.examples[idx] |
|
|
|
|
|
|
|
|
question = example.get("question", example.get("prompt", "")) |
|
|
reasoning = example.get("reasoning", example.get("thinking", "")) |
|
|
answer = example.get("answer", example.get("response", "")) |
|
|
|
|
|
|
|
|
full_text = ReasoningTokenizer.format_cot_prompt( |
|
|
question, |
|
|
reasoning.split("\n") if isinstance(reasoning, str) else reasoning, |
|
|
answer, |
|
|
) |
|
|
|
|
|
|
|
|
encodings = self.tokenizer( |
|
|
full_text, |
|
|
max_length=self.max_length, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
input_ids = encodings["input_ids"].squeeze(0) |
|
|
attention_mask = encodings["attention_mask"].squeeze(0) |
|
|
|
|
|
|
|
|
reasoning_mask = self._create_reasoning_mask(input_ids) |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"attention_mask": attention_mask, |
|
|
"labels": input_ids.clone(), |
|
|
"reasoning_mask": reasoning_mask, |
|
|
} |
|
|
|
|
|
def _create_reasoning_mask(self, input_ids: torch.Tensor) -> torch.Tensor: |
|
|
"""Create binary mask for reasoning tokens.""" |
|
|
|
|
|
mask = torch.zeros_like(input_ids) |
|
|
|
|
|
return mask |
|
|
|
|
|
|
|
|
class ChainOfThoughtTrainer: |
|
|
""" |
|
|
Trainer for Chain-of-Thought distillation. |
|
|
Distills reasoning capabilities from larger models. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
student_model: nn.Module, |
|
|
teacher_model: Optional[nn.Module] = None, |
|
|
config: Optional[ReasoningConfig] = None, |
|
|
learning_rate: float = 1e-5, |
|
|
device: str = "cuda", |
|
|
): |
|
|
self.student = student_model |
|
|
self.teacher = teacher_model |
|
|
self.config = config or ReasoningConfig() |
|
|
self.device = device |
|
|
|
|
|
|
|
|
if hasattr(student_model, 'config'): |
|
|
hidden_size = student_model.config.hidden_size |
|
|
else: |
|
|
hidden_size = 1024 |
|
|
|
|
|
self.reasoning_module = ReasoningModule(self.config, hidden_size).to(device) |
|
|
|
|
|
|
|
|
params = list(student_model.parameters()) + list(self.reasoning_module.parameters()) |
|
|
self.optimizer = torch.optim.AdamW(params, lr=learning_rate) |
|
|
|
|
|
|
|
|
if self.teacher is not None: |
|
|
self.teacher.eval() |
|
|
for param in self.teacher.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def distillation_loss( |
|
|
self, |
|
|
student_logits: torch.Tensor, |
|
|
teacher_logits: torch.Tensor, |
|
|
temperature: float = 2.0, |
|
|
) -> torch.Tensor: |
|
|
"""Compute KL divergence distillation loss.""" |
|
|
student_probs = F.log_softmax(student_logits / temperature, dim=-1) |
|
|
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1) |
|
|
|
|
|
loss = F.kl_div(student_probs, teacher_probs, reduction="batchmean") |
|
|
return loss * (temperature ** 2) |
|
|
|
|
|
def train_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: |
|
|
"""Single training step.""" |
|
|
self.student.train() |
|
|
self.reasoning_module.train() |
|
|
|
|
|
input_ids = batch["input_ids"].to(self.device) |
|
|
attention_mask = batch["attention_mask"].to(self.device) |
|
|
labels = batch["labels"].to(self.device) |
|
|
reasoning_mask = batch.get("reasoning_mask", None) |
|
|
if reasoning_mask is not None: |
|
|
reasoning_mask = reasoning_mask.to(self.device) |
|
|
|
|
|
|
|
|
loss, student_logits, _, aux_loss = self.student( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
labels=labels, |
|
|
) |
|
|
|
|
|
total_loss = loss |
|
|
metrics = {"ce_loss": loss.item(), "aux_loss": aux_loss.item()} |
|
|
|
|
|
|
|
|
if self.teacher is not None: |
|
|
with torch.no_grad(): |
|
|
_, teacher_logits, _, _ = self.teacher( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
) |
|
|
|
|
|
distill_loss = self.distillation_loss( |
|
|
student_logits, |
|
|
teacher_logits, |
|
|
self.config.distillation_temperature, |
|
|
) |
|
|
total_loss = (1 - self.config.alpha_reasoning) * loss + self.config.alpha_reasoning * distill_loss |
|
|
metrics["distill_loss"] = distill_loss.item() |
|
|
|
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
total_loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(self.student.parameters(), 1.0) |
|
|
self.optimizer.step() |
|
|
|
|
|
metrics["total_loss"] = total_loss.item() |
|
|
return metrics |
|
|
|
|
|
def train( |
|
|
self, |
|
|
train_dataloader: DataLoader, |
|
|
num_epochs: int = 3, |
|
|
eval_dataloader: Optional[DataLoader] = None, |
|
|
) -> Dict[str, List[float]]: |
|
|
"""Full training loop.""" |
|
|
history = {"train_loss": [], "eval_loss": []} |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
epoch_losses = [] |
|
|
|
|
|
for batch in train_dataloader: |
|
|
metrics = self.train_step(batch) |
|
|
epoch_losses.append(metrics["total_loss"]) |
|
|
|
|
|
avg_loss = sum(epoch_losses) / len(epoch_losses) |
|
|
history["train_loss"].append(avg_loss) |
|
|
print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_loss:.4f}") |
|
|
|
|
|
|
|
|
if eval_dataloader is not None: |
|
|
eval_loss = self.evaluate(eval_dataloader) |
|
|
history["eval_loss"].append(eval_loss) |
|
|
print(f" Eval Loss: {eval_loss:.4f}") |
|
|
|
|
|
return history |
|
|
|
|
|
def evaluate(self, dataloader: DataLoader) -> float: |
|
|
"""Evaluate on validation set.""" |
|
|
self.student.eval() |
|
|
total_loss = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in dataloader: |
|
|
input_ids = batch["input_ids"].to(self.device) |
|
|
attention_mask = batch["attention_mask"].to(self.device) |
|
|
labels = batch["labels"].to(self.device) |
|
|
|
|
|
loss, _, _, _ = self.student( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
labels=labels, |
|
|
) |
|
|
total_loss += loss.item() |
|
|
num_batches += 1 |
|
|
|
|
|
return total_loss / num_batches if num_batches > 0 else 0.0 |
|
|
|
|
|
|
|
|
def prepare_openr1_dataset( |
|
|
raw_data_path: str, |
|
|
output_path: str, |
|
|
config: Optional[ReasoningConfig] = None, |
|
|
) -> int: |
|
|
""" |
|
|
Prepare OpenR1 or DeepSeek-R1 distillation data. |
|
|
Converts raw reasoning traces to training format. |
|
|
""" |
|
|
config = config or ReasoningConfig() |
|
|
processed = 0 |
|
|
|
|
|
with open(raw_data_path, 'r', encoding='utf-8') as fin, \ |
|
|
open(output_path, 'w', encoding='utf-8') as fout: |
|
|
|
|
|
for line in fin: |
|
|
if not line.strip(): |
|
|
continue |
|
|
|
|
|
data = json.loads(line) |
|
|
|
|
|
|
|
|
question = data.get("question", data.get("prompt", data.get("input", ""))) |
|
|
|
|
|
|
|
|
if "thinking" in data: |
|
|
reasoning = data["thinking"] |
|
|
elif "reasoning" in data: |
|
|
reasoning = data["reasoning"] |
|
|
elif "chain_of_thought" in data: |
|
|
reasoning = data["chain_of_thought"] |
|
|
else: |
|
|
continue |
|
|
|
|
|
answer = data.get("answer", data.get("response", data.get("output", ""))) |
|
|
|
|
|
|
|
|
processed_example = { |
|
|
"question": question, |
|
|
"reasoning": reasoning, |
|
|
"answer": answer, |
|
|
} |
|
|
|
|
|
fout.write(json.dumps(processed_example, ensure_ascii=False) + "\n") |
|
|
processed += 1 |
|
|
|
|
|
return processed |
|
|
|