""" Reasoning Training Module for MangoMAS Local This module implements specialized training for reasoning capabilities, adapted from the AWS backup system for local training. """ import json import logging import os import random import re from typing import Any, Dict, List import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset from ..core_framework import SpecializedTrainingModule, TrainingModuleConfig logger = logging.getLogger(__name__) class ReasoningDataset(Dataset): """Dataset for training reasoning capabilities.""" def __init__(self, data_path: str, tokenizer, max_length: int = 512): """ Initialize the reasoning dataset. Args: data_path: Path to the reasoning data file tokenizer: Tokenizer for text processing max_length: Maximum sequence length """ self.tokenizer = tokenizer self.max_length = max_length self.data = self._load_data(data_path) logger.info(f"Loaded reasoning dataset with {len(self.data)} examples") def _load_data(self, data_path: str) -> List[Dict]: """Load reasoning training data.""" data = [] with open(data_path, "r", encoding="utf-8") as f: for line in f: try: item = json.loads(line.strip()) # Validate required fields if "question" in item and "reasoning" in item and "answer" in item: data.append(item) except json.JSONDecodeError: continue return data def __len__(self): return len(self.data) def __getitem__(self, idx): item = self.data[idx] # Format the reasoning prompt prompt = f"Question: {item['question']}\nReasoning: {item['reasoning']}\nAnswer: {item['answer']}" # Tokenize encoding = self.tokenizer( prompt, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt", ) return { "input_ids": encoding["input_ids"].squeeze(), "attention_mask": encoding["attention_mask"].squeeze(), "labels": encoding["input_ids"].squeeze(), } class ReasoningEvaluator: """Evaluator for reasoning capabilities.""" def __init__(self, tokenizer): """ Initialize the reasoning evaluator. Args: tokenizer: Tokenizer for text processing """ self.tokenizer = tokenizer self.metrics = { "logical_consistency": 0.0, "premise_relevance": 0.0, "conclusion_validity": 0.0, "steps_coherence": 0.0, } def evaluate(self, model, eval_dataset: ReasoningDataset) -> Dict[str, float]: """ Evaluate reasoning capabilities on the provided dataset. Args: model: The model to evaluate eval_dataset: Dataset of reasoning examples Returns: Dictionary of evaluation metrics """ model.eval() device = next(model.parameters()).device # Reset metrics for key in self.metrics: self.metrics[key] = 0.0 total_examples = min( len(eval_dataset), 100 ) # Limit to 100 examples for efficiency with torch.no_grad(): for idx in range(total_examples): example = eval_dataset[idx] premise = example["premise"] # Generate reasoning and conclusion from premise prompt = f"Premise: {premise}\nReasoning:" input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( device ) generated_ids = model.generate( input_ids, max_length=512, temperature=0.7, num_return_sequences=1 ) generated_text = self.tokenizer.decode( generated_ids[0], skip_special_tokens=True ) # Extract reasoning and conclusion from generated text try: generated_reasoning = re.search( r"Reasoning:(.*?)(?:Conclusion:|$)", generated_text, re.DOTALL ) generated_conclusion = re.search( r"Conclusion:(.*?)$", generated_text, re.DOTALL ) if generated_reasoning: gen_reasoning = generated_reasoning.group(1).strip() else: gen_reasoning = "" if generated_conclusion: gen_conclusion = generated_conclusion.group(1).strip() else: gen_conclusion = "" # Evaluate reasoning quality self._update_metrics( premise=premise, expected_reasoning=example["reasoning"], expected_conclusion=example["conclusion"], generated_reasoning=gen_reasoning, generated_conclusion=gen_conclusion, ) except Exception as e: logger.error(f"Error evaluating reasoning: {e}") # Calculate averages for key in self.metrics: self.metrics[key] /= total_examples return self.metrics def _update_metrics( self, premise: str, expected_reasoning: str, expected_conclusion: str, generated_reasoning: str, generated_conclusion: str, ) -> None: """ Update reasoning metrics based on a single example. Args: premise: Input premise expected_reasoning: Expected reasoning steps expected_conclusion: Expected conclusion generated_reasoning: Generated reasoning steps generated_conclusion: Generated conclusion """ # Very simplified evaluation - in a real system, this would use more sophisticated # semantic similarity and logical consistency checking # Logical consistency - check if reasoning follows from premise self.metrics["logical_consistency"] += 0.5 # Simplified placeholder # Premise relevance - check if reasoning references key terms from premise premise_terms = set(premise.lower().split()) reasoning_terms = set(generated_reasoning.lower().split()) term_overlap = len(premise_terms.intersection(reasoning_terms)) / max( len(premise_terms), 1 ) self.metrics["premise_relevance"] += term_overlap # Conclusion validity - check if conclusion follows from reasoning if generated_conclusion and "therefore" in generated_conclusion.lower(): self.metrics["conclusion_validity"] += 0.7 # Simplified placeholder else: self.metrics["conclusion_validity"] += 0.3 # Steps coherence - check for logical flow markers flow_markers = [ "first", "second", "third", "then", "next", "finally", "because", "thus", "hence", ] marker_count = sum( 1 for marker in flow_markers if marker in generated_reasoning.lower() ) self.metrics["steps_coherence"] += min(1.0, marker_count / 3) class ReasoningTrainingModule(SpecializedTrainingModule): """Specialized training module for reasoning capabilities.""" def __init__(self, config: TrainingModuleConfig, tokenizer): """ Initialize the reasoning training module. Args: config: Module configuration tokenizer: Tokenizer for text processing """ super().__init__(config, tokenizer) # Initialize reasoning-specific components self.reasoning_loss = nn.CrossEntropyLoss(ignore_index=-100) self.metrics = { "reasoning_loss": 0.0, "reasoning_accuracy": 0.0, "reasoning_perplexity": 0.0, } logger.info("Initialized ReasoningTrainingModule") def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Prepare a batch of data for reasoning training. Args: batch: The input batch from the dataloader Returns: Processed batch ready for reasoning training """ # Move batch to device prepared_batch = {} for key, value in batch.items(): if isinstance(value, torch.Tensor): prepared_batch[key] = value.to(self.device) else: prepared_batch[key] = value return prepared_batch def compute_loss( self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor] ) -> torch.Tensor: """ Compute the reasoning-specific loss. Args: student_outputs: Outputs from the student model teacher_outputs: Outputs from the teacher model batch: The processed input batch Returns: Loss tensor for reasoning training """ try: # Extract logits from model outputs if hasattr(student_outputs, "logits"): student_logits = student_outputs.logits else: student_logits = student_outputs if hasattr(teacher_outputs, "logits"): teacher_logits = teacher_outputs.logits else: teacher_logits = teacher_outputs # Get labels from batch labels = batch.get("labels", batch.get("input_ids")) # Compute cross entropy loss for reasoning shift_logits = student_logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() reasoning_loss = self.reasoning_loss( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) # Add KL divergence loss between student and teacher if teacher_logits is not None: kl_loss = F.kl_div( F.log_softmax(student_logits, dim=-1), F.softmax(teacher_logits, dim=-1), reduction="batchmean", ) total_loss = reasoning_loss + 0.1 * kl_loss else: total_loss = reasoning_loss # Update metrics self.metrics["reasoning_loss"] = reasoning_loss.item() return total_loss * self.loss_weight except Exception as e: logger.error(f"Error computing reasoning loss: {e}") # Return a small loss to avoid training failure return torch.tensor(0.01, requires_grad=True) def get_metrics(self) -> Dict[str, float]: """ Get metrics specific to reasoning training. Returns: Dictionary of reasoning metrics """ return self.metrics.copy() def generate_synthetic_reasoning_data( self, output_path: str, num_samples: int = 1000 ) -> None: """ Generate synthetic reasoning data for training. Args: output_path: Path to save the generated data num_samples: Number of samples to generate """ # This is a simplified implementation based on the AWS backup's synthetic_generator # In a full implementation, this would be much more sophisticated templates = [ { "premise": "If it rains, the ground gets wet. It is raining now.", "reasoning": "Since it is raining, and rain makes the ground wet, we can conclude that the ground is getting wet.", "conclusion": "Therefore, the ground is wet.", }, { "premise": "All mammals are warm-blooded. Whales are mammals.", "reasoning": "Whales are classified as mammals. All mammals are warm-blooded animals. Therefore, as a mammal, a whale must be warm-blooded.", "conclusion": "Therefore, whales are warm-blooded.", }, { "premise": "If you study hard, you will pass the exam. You studied hard.", "reasoning": "The premise states a conditional relationship between studying hard and passing the exam. Since you studied hard, the condition is met.", "conclusion": "Therefore, you will pass the exam.", }, ] # Generate variations of the templates output_data = [] for _ in range(num_samples): template = random.choice(templates) # Create a variation (very simplified) variation = { "premise": template["premise"], "reasoning": template["reasoning"], "conclusion": template["conclusion"], "metadata": { "generated": True, "timestamp": str( torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU" ), }, } output_data.append(variation) # Save to file os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: for item in output_data: f.write(json.dumps(item) + "\n") logger.info( f"Generated {len(output_data)} synthetic reasoning examples at {output_path}" )