Mango-Metrics-NLM
feat: Phi-3.5-MoE multi-agent model repository
c8b77b5
"""
Curiosity Training Module for MangoMAS Local
This module implements specialized training for curiosity and exploration capabilities,
adapted from the AWS backup system for local training.
"""
import json
import logging
import os
import random
from typing import Any, Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from ..core_framework import SpecializedTrainingModule, TrainingModuleConfig
logger = logging.getLogger(__name__)
class CuriosityDataset(Dataset):
"""Dataset for training curiosity and exploration capabilities."""
def __init__(self, data_path: str, tokenizer, max_length: int = 768):
"""
Initialize the curiosity dataset.
Args:
data_path: Path to the curiosity data file
tokenizer: Tokenizer for text processing
max_length: Maximum sequence length
"""
self.tokenizer = tokenizer
self.max_length = max_length
self.data = self._load_data(data_path)
logger.info(f"Loaded curiosity dataset with {len(self.data)} examples")
def _load_data(self, data_path: str) -> List[Dict]:
"""Load curiosity training data."""
data = []
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
try:
item = json.loads(line.strip())
# Validate required fields for curiosity data
if (
"scenario" in item
and "curiosity_questions" in item
and "exploration_directions" in item
):
data.append(item)
except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"Skipping invalid curiosity data: {e}")
return data
def __len__(self) -> int:
"""Return the number of examples in the dataset."""
return len(self.data)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Get a training example."""
item = self.data[idx]
# Format the curiosity example
scenario = item["scenario"]
curiosity_questions = item["curiosity_questions"]
exploration_directions = item["exploration_directions"]
# Format as text
text = f"Scenario: {scenario}\n\n"
text += "Curiosity Questions:\n"
for i, question in enumerate(curiosity_questions):
text += f"{i+1}. {question}\n"
text += "\n"
text += "Exploration Directions:\n"
for i, direction in enumerate(exploration_directions):
text += f"{i+1}. {direction}\n"
# Tokenize
encoding = self.tokenizer(
text,
truncation=True,
padding="max_length",
max_length=self.max_length,
return_tensors="pt",
)
return {
"input_ids": encoding["input_ids"].squeeze(),
"attention_mask": encoding["attention_mask"].squeeze(),
"labels": encoding["input_ids"].squeeze().clone(),
"scenario": scenario,
"curiosity_questions": curiosity_questions,
"exploration_directions": exploration_directions,
}
class CuriosityEvaluator:
"""Evaluator for curiosity and exploration capabilities."""
def __init__(self, tokenizer):
"""
Initialize the curiosity evaluator.
Args:
tokenizer: Tokenizer for text processing
"""
self.tokenizer = tokenizer
self.metrics = {
"question_diversity": 0.0,
"exploration_breadth": 0.0,
"uncertainty_identification": 0.0,
"assumption_challenging": 0.0,
}
# Question starters for evaluating diversity
self.question_starters = [
"what",
"how",
"why",
"when",
"where",
"who",
"which",
"could",
"would",
"is",
"are",
"do",
"does",
"have",
"has",
]
# Exploration markers
self.exploration_markers = [
"alternative",
"perspective",
"consider",
"explore",
"investigate",
"possibility",
"approach",
"angle",
"viewpoint",
"scenario",
]
# Uncertainty phrases
self.uncertainty_phrases = [
"unclear",
"unknown",
"uncertain",
"not sure",
"ambiguous",
"might be",
"could be",
"possibly",
"perhaps",
"may",
]
# Assumption challenging phrases
self.assumption_phrases = [
"assuming",
"assumption",
"presuppose",
"presupposition",
"take for granted",
"implicit",
"unstated",
"underlying",
]
def evaluate(self, model, eval_dataset: CuriosityDataset) -> Dict[str, float]:
"""
Evaluate curiosity capabilities on the provided dataset.
Args:
model: The model to evaluate
eval_dataset: Dataset of curiosity examples
Returns:
Dictionary of evaluation metrics
"""
model.eval()
device = next(model.parameters()).device
# Reset metrics
for key in self.metrics:
self.metrics[key] = 0.0
total_examples = min(
len(eval_dataset), 50
) # Limit to 50 examples for efficiency
with torch.no_grad():
for idx in range(total_examples):
example = eval_dataset[idx]
scenario = example["scenario"]
# Generate questions for the scenario
prompt = f"Scenario: {scenario}\n\nGenerate curious questions to explore this further:"
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(
device
)
generated_ids = model.generate(
input_ids,
max_length=256,
temperature=0.8, # Slightly higher temperature for creativity
num_return_sequences=1,
)
generated_text = self.tokenizer.decode(
generated_ids[0], skip_special_tokens=True
)
# Evaluate curiosity quality
self._evaluate_curiosity(
scenario=scenario,
expected_questions=example["curiosity_questions"],
expected_directions=example["exploration_directions"],
generated_text=generated_text,
)
# Calculate averages
for key in self.metrics:
self.metrics[key] /= total_examples
return self.metrics
def _evaluate_curiosity(
self,
scenario: str,
expected_questions: List[str],
expected_directions: List[str],
generated_text: str,
) -> None:
"""
Evaluate curiosity quality for a specific example.
Args:
scenario: The scenario to explore
expected_questions: Expected curiosity questions
expected_directions: Expected exploration directions
generated_text: The text generated by the model
"""
# Extract questions from generated text (simple approach)
generated_questions = [
line.strip()
for line in generated_text.split("\n")
if line.strip().endswith("?")
]
if not generated_questions:
# Try to extract numbered questions
for line in generated_text.split("\n"):
if any(f"{i}." in line for i in range(1, 10)) and "?" in line:
generated_questions.append(line.strip())
# 1. Question diversity - variety of question types
starter_counts = {starter: 0 for starter in self.question_starters}
for question in generated_questions:
for starter in self.question_starters:
if (
question.lower().startswith(starter)
or f" {starter} " in question.lower()
):
starter_counts[starter] += 1
unique_starters = sum(1 for count in starter_counts.values() if count > 0)
self.metrics["question_diversity"] += min(
1.0, unique_starters / 5
) # Normalize to 5 unique types
# 2. Exploration breadth - check for exploration markers
exploration_marker_count = sum(
1 for marker in self.exploration_markers if marker in generated_text.lower()
)
self.metrics["exploration_breadth"] += min(1.0, exploration_marker_count / 3)
# 3. Uncertainty identification - check for uncertainty phrases
uncertainty_phrase_count = sum(
1 for phrase in self.uncertainty_phrases if phrase in generated_text.lower()
)
self.metrics["uncertainty_identification"] += min(
1.0, uncertainty_phrase_count / 2
)
# 4. Assumption challenging - check for phrases that challenge assumptions
assumption_phrase_count = sum(
1 for phrase in self.assumption_phrases if phrase in generated_text.lower()
)
self.metrics["assumption_challenging"] += min(1.0, assumption_phrase_count / 1)
class CuriosityTrainingModule(SpecializedTrainingModule):
"""Specialized training module for curiosity and exploration capabilities."""
def __init__(self, config: TrainingModuleConfig, tokenizer):
"""
Initialize the curiosity training module.
Args:
config: Module configuration
tokenizer: Tokenizer for text processing
"""
super().__init__(config, tokenizer)
# Initialize curiosity-specific components
self.data_path = config.data_path or "data/processed/curiosity_train.jsonl"
self.evaluator = CuriosityEvaluator(tokenizer)
# Curiosity-specific loss with higher temperature
self.curiosity_temp = config.module_config.get("temperature", 1.5)
self.curiosity_loss = nn.CrossEntropyLoss(ignore_index=-100)
# Training metrics
self.metrics = {
"curiosity_loss": 0.0,
"question_generation_score": 0.0,
"exploration_score": 0.0,
}
logger.info(
f"Initialized curiosity training module with temperature: {self.curiosity_temp}"
)
def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Prepare a batch of data for curiosity training.
Args:
batch: The input batch from the dataloader
Returns:
Processed batch ready for curiosity training
"""
# Extract curiosity-specific elements if they exist
if all(
key in batch
for key in ["scenario", "curiosity_questions", "exploration_directions"]
):
# This is already a curiosity-specific batch
return batch
# For general conversation batches, we could extract potential exploration scenarios
# This is a simplified placeholder implementation
return batch
def compute_loss(
self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor]
) -> torch.Tensor:
"""
Compute the curiosity-specific loss.
Args:
student_outputs: Outputs from the student model
teacher_outputs: Outputs from the teacher model
batch: The processed input batch
Returns:
Curiosity-specific loss tensor
"""
# Get logits from outputs
student_logits = (
student_outputs.logits
if hasattr(student_outputs, "logits")
else student_outputs
)
teacher_logits = (
teacher_outputs.logits
if hasattr(teacher_outputs, "logits")
else teacher_outputs
)
# Standard distillation loss calculation
student_logits = student_logits[:, :-1, :].contiguous()
teacher_logits = teacher_logits[:, :-1, :].contiguous()
target_ids = batch["labels"][:, 1:].contiguous()
# For curiosity, we use a higher temperature to encourage more diverse outputs
temperature = self.curiosity_temp
kl_loss = F.kl_div(
F.log_softmax(student_logits / temperature, dim=-1),
F.softmax(teacher_logits / temperature, dim=-1),
reduction="batchmean",
) * (temperature**2)
# Cross-entropy loss against labels
ce_loss = self.curiosity_loss(
student_logits.view(-1, student_logits.size(-1)), target_ids.view(-1)
)
# Combined loss with curiosity focus
# We weight KL divergence higher to encourage exploration
loss = 0.4 * ce_loss + 0.6 * kl_loss
# Update metrics
self.metrics["curiosity_loss"] = loss.item()
return loss
def get_metrics(self) -> Dict[str, float]:
"""
Get metrics specific to curiosity training.
Returns:
Dictionary of metric names and values
"""
return self.metrics
def generate_synthetic_curiosity_data(
self, output_path: str, num_samples: int = 1000
) -> None:
"""
Generate synthetic curiosity training data.
Args:
output_path: Path to save the generated data
num_samples: Number of samples to generate
"""
# This is a simplified implementation based on the AWS backup
# In a full implementation, this would be much more sophisticated
curiosity_templates = [
{
"scenario": "A company is developing a new voice assistant technology.",
"curiosity_questions": [
"How might this technology affect people's privacy in their homes?",
"What unexpected ways might users interact with this technology?",
"How could this technology evolve over the next five years?",
"What ethical considerations might arise from widespread adoption?",
"How might this technology affect different demographic groups differently?",
],
"exploration_directions": [
"Consider alternative interaction models beyond voice commands",
"Explore potential integration with other smart home systems",
"Investigate privacy-preserving design approaches",
"Consider accessibility implications for diverse user groups",
"Examine potential unintended consequences of ambient listening",
],
},
{
"scenario": "Scientists have discovered a new species of deep-sea organism that can survive extreme pressure and temperature.",
"curiosity_questions": [
"What adaptations allow this organism to survive such extreme conditions?",
"Could these adaptations be applied to human technology or medicine?",
"What might this discovery tell us about the possibility of life on other planets?",
"How might climate change affect deep-sea ecosystems and this organism?",
"What other undiscovered species might exist in similar environments?",
],
"exploration_directions": [
"Examine evolutionary pathways for extreme environment adaptation",
"Consider biomimicry applications in engineering and materials science",
"Explore implications for astrobiology and extraterrestrial life",
"Investigate ecological relationships in extreme environments",
"Consider ethical dimensions of deep-sea exploration and bioprospecting",
],
},
{
"scenario": "A small town is experiencing rapid population growth due to remote workers relocating from urban areas.",
"curiosity_questions": [
"How might this demographic shift affect the town's culture and community?",
"What infrastructure challenges might arise from rapid population growth?",
"How could this trend impact local housing prices and affordability?",
"What economic opportunities and challenges might emerge?",
"How might long-term residents and newcomers develop different perspectives?",
],
"exploration_directions": [
"Investigate similar historical population shifts and their outcomes",
"Consider varying perspectives from different stakeholder groups",
"Explore potential policy approaches to manage growth sustainably",
"Examine social integration mechanisms between established and new residents",
"Consider environmental impacts of changing land use patterns",
],
},
]
# Generate variations
output_data = []
for _ in range(num_samples):
template = random.choice(curiosity_templates)
# Create a variation to avoid exact duplicates
variation = template.copy()
# Add metadata
variation["metadata"] = {
"generated": True,
"timestamp": (
torch.cuda.get_device_name(0)
if torch.cuda.is_available()
else "CPU"
),
"requires_exploration": True,
}
output_data.append(variation)
# Save to file
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
for item in output_data:
f.write(json.dumps(item) + "\n")
logger.info(
f"Generated {len(output_data)} synthetic curiosity examples at {output_path}"
)