|
|
""" |
|
|
Reasoning Training Module for MangoMAS Local |
|
|
|
|
|
This module implements specialized training for reasoning capabilities, |
|
|
adapted from the AWS backup system for local training. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import random |
|
|
import re |
|
|
from typing import Any, Dict, List |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
from ..core_framework import SpecializedTrainingModule, TrainingModuleConfig |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ReasoningDataset(Dataset): |
|
|
"""Dataset for training reasoning capabilities.""" |
|
|
|
|
|
def __init__(self, data_path: str, tokenizer, max_length: int = 512): |
|
|
""" |
|
|
Initialize the reasoning dataset. |
|
|
|
|
|
Args: |
|
|
data_path: Path to the reasoning data file |
|
|
tokenizer: Tokenizer for text processing |
|
|
max_length: Maximum sequence length |
|
|
""" |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.data = self._load_data(data_path) |
|
|
|
|
|
logger.info(f"Loaded reasoning dataset with {len(self.data)} examples") |
|
|
|
|
|
def _load_data(self, data_path: str) -> List[Dict]: |
|
|
"""Load reasoning training data.""" |
|
|
data = [] |
|
|
with open(data_path, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
try: |
|
|
item = json.loads(line.strip()) |
|
|
|
|
|
if "question" in item and "reasoning" in item and "answer" in item: |
|
|
data.append(item) |
|
|
except json.JSONDecodeError: |
|
|
continue |
|
|
return data |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
item = self.data[idx] |
|
|
|
|
|
|
|
|
prompt = f"Question: {item['question']}\nReasoning: {item['reasoning']}\nAnswer: {item['answer']}" |
|
|
|
|
|
|
|
|
encoding = self.tokenizer( |
|
|
prompt, |
|
|
max_length=self.max_length, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
return { |
|
|
"input_ids": encoding["input_ids"].squeeze(), |
|
|
"attention_mask": encoding["attention_mask"].squeeze(), |
|
|
"labels": encoding["input_ids"].squeeze(), |
|
|
} |
|
|
|
|
|
|
|
|
class ReasoningEvaluator: |
|
|
"""Evaluator for reasoning capabilities.""" |
|
|
|
|
|
def __init__(self, tokenizer): |
|
|
""" |
|
|
Initialize the reasoning evaluator. |
|
|
|
|
|
Args: |
|
|
tokenizer: Tokenizer for text processing |
|
|
""" |
|
|
self.tokenizer = tokenizer |
|
|
self.metrics = { |
|
|
"logical_consistency": 0.0, |
|
|
"premise_relevance": 0.0, |
|
|
"conclusion_validity": 0.0, |
|
|
"steps_coherence": 0.0, |
|
|
} |
|
|
|
|
|
def evaluate(self, model, eval_dataset: ReasoningDataset) -> Dict[str, float]: |
|
|
""" |
|
|
Evaluate reasoning capabilities on the provided dataset. |
|
|
|
|
|
Args: |
|
|
model: The model to evaluate |
|
|
eval_dataset: Dataset of reasoning examples |
|
|
|
|
|
Returns: |
|
|
Dictionary of evaluation metrics |
|
|
""" |
|
|
model.eval() |
|
|
device = next(model.parameters()).device |
|
|
|
|
|
|
|
|
for key in self.metrics: |
|
|
self.metrics[key] = 0.0 |
|
|
|
|
|
total_examples = min( |
|
|
len(eval_dataset), 100 |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for idx in range(total_examples): |
|
|
example = eval_dataset[idx] |
|
|
premise = example["premise"] |
|
|
|
|
|
|
|
|
prompt = f"Premise: {premise}\nReasoning:" |
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( |
|
|
device |
|
|
) |
|
|
|
|
|
generated_ids = model.generate( |
|
|
input_ids, max_length=512, temperature=0.7, num_return_sequences=1 |
|
|
) |
|
|
|
|
|
generated_text = self.tokenizer.decode( |
|
|
generated_ids[0], skip_special_tokens=True |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
generated_reasoning = re.search( |
|
|
r"Reasoning:(.*?)(?:Conclusion:|$)", generated_text, re.DOTALL |
|
|
) |
|
|
generated_conclusion = re.search( |
|
|
r"Conclusion:(.*?)$", generated_text, re.DOTALL |
|
|
) |
|
|
|
|
|
if generated_reasoning: |
|
|
gen_reasoning = generated_reasoning.group(1).strip() |
|
|
else: |
|
|
gen_reasoning = "" |
|
|
|
|
|
if generated_conclusion: |
|
|
gen_conclusion = generated_conclusion.group(1).strip() |
|
|
else: |
|
|
gen_conclusion = "" |
|
|
|
|
|
|
|
|
self._update_metrics( |
|
|
premise=premise, |
|
|
expected_reasoning=example["reasoning"], |
|
|
expected_conclusion=example["conclusion"], |
|
|
generated_reasoning=gen_reasoning, |
|
|
generated_conclusion=gen_conclusion, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Error evaluating reasoning: {e}") |
|
|
|
|
|
|
|
|
for key in self.metrics: |
|
|
self.metrics[key] /= total_examples |
|
|
|
|
|
return self.metrics |
|
|
|
|
|
def _update_metrics( |
|
|
self, |
|
|
premise: str, |
|
|
expected_reasoning: str, |
|
|
expected_conclusion: str, |
|
|
generated_reasoning: str, |
|
|
generated_conclusion: str, |
|
|
) -> None: |
|
|
""" |
|
|
Update reasoning metrics based on a single example. |
|
|
|
|
|
Args: |
|
|
premise: Input premise |
|
|
expected_reasoning: Expected reasoning steps |
|
|
expected_conclusion: Expected conclusion |
|
|
generated_reasoning: Generated reasoning steps |
|
|
generated_conclusion: Generated conclusion |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.metrics["logical_consistency"] += 0.5 |
|
|
|
|
|
|
|
|
premise_terms = set(premise.lower().split()) |
|
|
reasoning_terms = set(generated_reasoning.lower().split()) |
|
|
term_overlap = len(premise_terms.intersection(reasoning_terms)) / max( |
|
|
len(premise_terms), 1 |
|
|
) |
|
|
self.metrics["premise_relevance"] += term_overlap |
|
|
|
|
|
|
|
|
if generated_conclusion and "therefore" in generated_conclusion.lower(): |
|
|
self.metrics["conclusion_validity"] += 0.7 |
|
|
else: |
|
|
self.metrics["conclusion_validity"] += 0.3 |
|
|
|
|
|
|
|
|
flow_markers = [ |
|
|
"first", |
|
|
"second", |
|
|
"third", |
|
|
"then", |
|
|
"next", |
|
|
"finally", |
|
|
"because", |
|
|
"thus", |
|
|
"hence", |
|
|
] |
|
|
marker_count = sum( |
|
|
1 for marker in flow_markers if marker in generated_reasoning.lower() |
|
|
) |
|
|
self.metrics["steps_coherence"] += min(1.0, marker_count / 3) |
|
|
|
|
|
|
|
|
class ReasoningTrainingModule(SpecializedTrainingModule): |
|
|
"""Specialized training module for reasoning capabilities.""" |
|
|
|
|
|
def __init__(self, config: TrainingModuleConfig, tokenizer): |
|
|
""" |
|
|
Initialize the reasoning training module. |
|
|
|
|
|
Args: |
|
|
config: Module configuration |
|
|
tokenizer: Tokenizer for text processing |
|
|
""" |
|
|
super().__init__(config, tokenizer) |
|
|
|
|
|
|
|
|
self.reasoning_loss = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
self.metrics = { |
|
|
"reasoning_loss": 0.0, |
|
|
"reasoning_accuracy": 0.0, |
|
|
"reasoning_perplexity": 0.0, |
|
|
} |
|
|
|
|
|
logger.info("Initialized ReasoningTrainingModule") |
|
|
|
|
|
def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Prepare a batch of data for reasoning training. |
|
|
|
|
|
Args: |
|
|
batch: The input batch from the dataloader |
|
|
|
|
|
Returns: |
|
|
Processed batch ready for reasoning training |
|
|
""" |
|
|
|
|
|
prepared_batch = {} |
|
|
for key, value in batch.items(): |
|
|
if isinstance(value, torch.Tensor): |
|
|
prepared_batch[key] = value.to(self.device) |
|
|
else: |
|
|
prepared_batch[key] = value |
|
|
|
|
|
return prepared_batch |
|
|
|
|
|
def compute_loss( |
|
|
self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor] |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compute the reasoning-specific loss. |
|
|
|
|
|
Args: |
|
|
student_outputs: Outputs from the student model |
|
|
teacher_outputs: Outputs from the teacher model |
|
|
batch: The processed input batch |
|
|
|
|
|
Returns: |
|
|
Loss tensor for reasoning training |
|
|
""" |
|
|
try: |
|
|
|
|
|
if hasattr(student_outputs, "logits"): |
|
|
student_logits = student_outputs.logits |
|
|
else: |
|
|
student_logits = student_outputs |
|
|
|
|
|
if hasattr(teacher_outputs, "logits"): |
|
|
teacher_logits = teacher_outputs.logits |
|
|
else: |
|
|
teacher_logits = teacher_outputs |
|
|
|
|
|
|
|
|
labels = batch.get("labels", batch.get("input_ids")) |
|
|
|
|
|
|
|
|
shift_logits = student_logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
reasoning_loss = self.reasoning_loss( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) |
|
|
) |
|
|
|
|
|
|
|
|
if teacher_logits is not None: |
|
|
kl_loss = F.kl_div( |
|
|
F.log_softmax(student_logits, dim=-1), |
|
|
F.softmax(teacher_logits, dim=-1), |
|
|
reduction="batchmean", |
|
|
) |
|
|
total_loss = reasoning_loss + 0.1 * kl_loss |
|
|
else: |
|
|
total_loss = reasoning_loss |
|
|
|
|
|
|
|
|
self.metrics["reasoning_loss"] = reasoning_loss.item() |
|
|
|
|
|
return total_loss * self.loss_weight |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error computing reasoning loss: {e}") |
|
|
|
|
|
return torch.tensor(0.01, requires_grad=True) |
|
|
|
|
|
def get_metrics(self) -> Dict[str, float]: |
|
|
""" |
|
|
Get metrics specific to reasoning training. |
|
|
|
|
|
Returns: |
|
|
Dictionary of reasoning metrics |
|
|
""" |
|
|
return self.metrics.copy() |
|
|
|
|
|
def generate_synthetic_reasoning_data( |
|
|
self, output_path: str, num_samples: int = 1000 |
|
|
) -> None: |
|
|
""" |
|
|
Generate synthetic reasoning data for training. |
|
|
|
|
|
Args: |
|
|
output_path: Path to save the generated data |
|
|
num_samples: Number of samples to generate |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
templates = [ |
|
|
{ |
|
|
"premise": "If it rains, the ground gets wet. It is raining now.", |
|
|
"reasoning": "Since it is raining, and rain makes the ground wet, we can conclude that the ground is getting wet.", |
|
|
"conclusion": "Therefore, the ground is wet.", |
|
|
}, |
|
|
{ |
|
|
"premise": "All mammals are warm-blooded. Whales are mammals.", |
|
|
"reasoning": "Whales are classified as mammals. All mammals are warm-blooded animals. Therefore, as a mammal, a whale must be warm-blooded.", |
|
|
"conclusion": "Therefore, whales are warm-blooded.", |
|
|
}, |
|
|
{ |
|
|
"premise": "If you study hard, you will pass the exam. You studied hard.", |
|
|
"reasoning": "The premise states a conditional relationship between studying hard and passing the exam. Since you studied hard, the condition is met.", |
|
|
"conclusion": "Therefore, you will pass the exam.", |
|
|
}, |
|
|
] |
|
|
|
|
|
|
|
|
output_data = [] |
|
|
for _ in range(num_samples): |
|
|
template = random.choice(templates) |
|
|
|
|
|
|
|
|
variation = { |
|
|
"premise": template["premise"], |
|
|
"reasoning": template["reasoning"], |
|
|
"conclusion": template["conclusion"], |
|
|
"metadata": { |
|
|
"generated": True, |
|
|
"timestamp": str( |
|
|
torch.cuda.get_device_name(0) |
|
|
if torch.cuda.is_available() |
|
|
else "CPU" |
|
|
), |
|
|
}, |
|
|
} |
|
|
|
|
|
output_data.append(variation) |
|
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
|
with open(output_path, "w", encoding="utf-8") as f: |
|
|
for item in output_data: |
|
|
f.write(json.dumps(item) + "\n") |
|
|
|
|
|
logger.info( |
|
|
f"Generated {len(output_data)} synthetic reasoning examples at {output_path}" |
|
|
) |
|
|
|