|
|
""" |
|
|
Curiosity Training Module for MangoMAS Local |
|
|
|
|
|
This module implements specialized training for curiosity and exploration capabilities, |
|
|
adapted from the AWS backup system for local training. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import random |
|
|
from typing import Any, Dict, List |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
from ..core_framework import SpecializedTrainingModule, TrainingModuleConfig |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class CuriosityDataset(Dataset): |
|
|
"""Dataset for training curiosity and exploration capabilities.""" |
|
|
|
|
|
def __init__(self, data_path: str, tokenizer, max_length: int = 768): |
|
|
""" |
|
|
Initialize the curiosity dataset. |
|
|
|
|
|
Args: |
|
|
data_path: Path to the curiosity data file |
|
|
tokenizer: Tokenizer for text processing |
|
|
max_length: Maximum sequence length |
|
|
""" |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.data = self._load_data(data_path) |
|
|
|
|
|
logger.info(f"Loaded curiosity dataset with {len(self.data)} examples") |
|
|
|
|
|
def _load_data(self, data_path: str) -> List[Dict]: |
|
|
"""Load curiosity training data.""" |
|
|
data = [] |
|
|
with open(data_path, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
try: |
|
|
item = json.loads(line.strip()) |
|
|
|
|
|
if ( |
|
|
"scenario" in item |
|
|
and "curiosity_questions" in item |
|
|
and "exploration_directions" in item |
|
|
): |
|
|
data.append(item) |
|
|
except (json.JSONDecodeError, KeyError) as e: |
|
|
logger.warning(f"Skipping invalid curiosity data: {e}") |
|
|
return data |
|
|
|
|
|
def __len__(self) -> int: |
|
|
"""Return the number of examples in the dataset.""" |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
|
"""Get a training example.""" |
|
|
item = self.data[idx] |
|
|
|
|
|
|
|
|
scenario = item["scenario"] |
|
|
curiosity_questions = item["curiosity_questions"] |
|
|
exploration_directions = item["exploration_directions"] |
|
|
|
|
|
|
|
|
text = f"Scenario: {scenario}\n\n" |
|
|
|
|
|
text += "Curiosity Questions:\n" |
|
|
for i, question in enumerate(curiosity_questions): |
|
|
text += f"{i+1}. {question}\n" |
|
|
text += "\n" |
|
|
|
|
|
text += "Exploration Directions:\n" |
|
|
for i, direction in enumerate(exploration_directions): |
|
|
text += f"{i+1}. {direction}\n" |
|
|
|
|
|
|
|
|
encoding = self.tokenizer( |
|
|
text, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=self.max_length, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
return { |
|
|
"input_ids": encoding["input_ids"].squeeze(), |
|
|
"attention_mask": encoding["attention_mask"].squeeze(), |
|
|
"labels": encoding["input_ids"].squeeze().clone(), |
|
|
"scenario": scenario, |
|
|
"curiosity_questions": curiosity_questions, |
|
|
"exploration_directions": exploration_directions, |
|
|
} |
|
|
|
|
|
|
|
|
class CuriosityEvaluator: |
|
|
"""Evaluator for curiosity and exploration capabilities.""" |
|
|
|
|
|
def __init__(self, tokenizer): |
|
|
""" |
|
|
Initialize the curiosity evaluator. |
|
|
|
|
|
Args: |
|
|
tokenizer: Tokenizer for text processing |
|
|
""" |
|
|
self.tokenizer = tokenizer |
|
|
self.metrics = { |
|
|
"question_diversity": 0.0, |
|
|
"exploration_breadth": 0.0, |
|
|
"uncertainty_identification": 0.0, |
|
|
"assumption_challenging": 0.0, |
|
|
} |
|
|
|
|
|
|
|
|
self.question_starters = [ |
|
|
"what", |
|
|
"how", |
|
|
"why", |
|
|
"when", |
|
|
"where", |
|
|
"who", |
|
|
"which", |
|
|
"could", |
|
|
"would", |
|
|
"is", |
|
|
"are", |
|
|
"do", |
|
|
"does", |
|
|
"have", |
|
|
"has", |
|
|
] |
|
|
|
|
|
|
|
|
self.exploration_markers = [ |
|
|
"alternative", |
|
|
"perspective", |
|
|
"consider", |
|
|
"explore", |
|
|
"investigate", |
|
|
"possibility", |
|
|
"approach", |
|
|
"angle", |
|
|
"viewpoint", |
|
|
"scenario", |
|
|
] |
|
|
|
|
|
|
|
|
self.uncertainty_phrases = [ |
|
|
"unclear", |
|
|
"unknown", |
|
|
"uncertain", |
|
|
"not sure", |
|
|
"ambiguous", |
|
|
"might be", |
|
|
"could be", |
|
|
"possibly", |
|
|
"perhaps", |
|
|
"may", |
|
|
] |
|
|
|
|
|
|
|
|
self.assumption_phrases = [ |
|
|
"assuming", |
|
|
"assumption", |
|
|
"presuppose", |
|
|
"presupposition", |
|
|
"take for granted", |
|
|
"implicit", |
|
|
"unstated", |
|
|
"underlying", |
|
|
] |
|
|
|
|
|
def evaluate(self, model, eval_dataset: CuriosityDataset) -> Dict[str, float]: |
|
|
""" |
|
|
Evaluate curiosity capabilities on the provided dataset. |
|
|
|
|
|
Args: |
|
|
model: The model to evaluate |
|
|
eval_dataset: Dataset of curiosity examples |
|
|
|
|
|
Returns: |
|
|
Dictionary of evaluation metrics |
|
|
""" |
|
|
model.eval() |
|
|
device = next(model.parameters()).device |
|
|
|
|
|
|
|
|
for key in self.metrics: |
|
|
self.metrics[key] = 0.0 |
|
|
|
|
|
total_examples = min( |
|
|
len(eval_dataset), 50 |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
for idx in range(total_examples): |
|
|
example = eval_dataset[idx] |
|
|
scenario = example["scenario"] |
|
|
|
|
|
|
|
|
prompt = f"Scenario: {scenario}\n\nGenerate curious questions to explore this further:" |
|
|
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( |
|
|
device |
|
|
) |
|
|
|
|
|
generated_ids = model.generate( |
|
|
input_ids, |
|
|
max_length=256, |
|
|
temperature=0.8, |
|
|
num_return_sequences=1, |
|
|
) |
|
|
|
|
|
generated_text = self.tokenizer.decode( |
|
|
generated_ids[0], skip_special_tokens=True |
|
|
) |
|
|
|
|
|
|
|
|
self._evaluate_curiosity( |
|
|
scenario=scenario, |
|
|
expected_questions=example["curiosity_questions"], |
|
|
expected_directions=example["exploration_directions"], |
|
|
generated_text=generated_text, |
|
|
) |
|
|
|
|
|
|
|
|
for key in self.metrics: |
|
|
self.metrics[key] /= total_examples |
|
|
|
|
|
return self.metrics |
|
|
|
|
|
def _evaluate_curiosity( |
|
|
self, |
|
|
scenario: str, |
|
|
expected_questions: List[str], |
|
|
expected_directions: List[str], |
|
|
generated_text: str, |
|
|
) -> None: |
|
|
""" |
|
|
Evaluate curiosity quality for a specific example. |
|
|
|
|
|
Args: |
|
|
scenario: The scenario to explore |
|
|
expected_questions: Expected curiosity questions |
|
|
expected_directions: Expected exploration directions |
|
|
generated_text: The text generated by the model |
|
|
""" |
|
|
|
|
|
generated_questions = [ |
|
|
line.strip() |
|
|
for line in generated_text.split("\n") |
|
|
if line.strip().endswith("?") |
|
|
] |
|
|
|
|
|
if not generated_questions: |
|
|
|
|
|
for line in generated_text.split("\n"): |
|
|
if any(f"{i}." in line for i in range(1, 10)) and "?" in line: |
|
|
generated_questions.append(line.strip()) |
|
|
|
|
|
|
|
|
starter_counts = {starter: 0 for starter in self.question_starters} |
|
|
for question in generated_questions: |
|
|
for starter in self.question_starters: |
|
|
if ( |
|
|
question.lower().startswith(starter) |
|
|
or f" {starter} " in question.lower() |
|
|
): |
|
|
starter_counts[starter] += 1 |
|
|
|
|
|
unique_starters = sum(1 for count in starter_counts.values() if count > 0) |
|
|
self.metrics["question_diversity"] += min( |
|
|
1.0, unique_starters / 5 |
|
|
) |
|
|
|
|
|
|
|
|
exploration_marker_count = sum( |
|
|
1 for marker in self.exploration_markers if marker in generated_text.lower() |
|
|
) |
|
|
self.metrics["exploration_breadth"] += min(1.0, exploration_marker_count / 3) |
|
|
|
|
|
|
|
|
uncertainty_phrase_count = sum( |
|
|
1 for phrase in self.uncertainty_phrases if phrase in generated_text.lower() |
|
|
) |
|
|
self.metrics["uncertainty_identification"] += min( |
|
|
1.0, uncertainty_phrase_count / 2 |
|
|
) |
|
|
|
|
|
|
|
|
assumption_phrase_count = sum( |
|
|
1 for phrase in self.assumption_phrases if phrase in generated_text.lower() |
|
|
) |
|
|
self.metrics["assumption_challenging"] += min(1.0, assumption_phrase_count / 1) |
|
|
|
|
|
|
|
|
class CuriosityTrainingModule(SpecializedTrainingModule): |
|
|
"""Specialized training module for curiosity and exploration capabilities.""" |
|
|
|
|
|
def __init__(self, config: TrainingModuleConfig, tokenizer): |
|
|
""" |
|
|
Initialize the curiosity training module. |
|
|
|
|
|
Args: |
|
|
config: Module configuration |
|
|
tokenizer: Tokenizer for text processing |
|
|
""" |
|
|
super().__init__(config, tokenizer) |
|
|
|
|
|
|
|
|
self.data_path = config.data_path or "data/processed/curiosity_train.jsonl" |
|
|
self.evaluator = CuriosityEvaluator(tokenizer) |
|
|
|
|
|
|
|
|
self.curiosity_temp = config.module_config.get("temperature", 1.5) |
|
|
self.curiosity_loss = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
|
|
|
|
|
|
self.metrics = { |
|
|
"curiosity_loss": 0.0, |
|
|
"question_generation_score": 0.0, |
|
|
"exploration_score": 0.0, |
|
|
} |
|
|
|
|
|
logger.info( |
|
|
f"Initialized curiosity training module with temperature: {self.curiosity_temp}" |
|
|
) |
|
|
|
|
|
def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Prepare a batch of data for curiosity training. |
|
|
|
|
|
Args: |
|
|
batch: The input batch from the dataloader |
|
|
|
|
|
Returns: |
|
|
Processed batch ready for curiosity training |
|
|
""" |
|
|
|
|
|
if all( |
|
|
key in batch |
|
|
for key in ["scenario", "curiosity_questions", "exploration_directions"] |
|
|
): |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
|
|
|
return batch |
|
|
|
|
|
def compute_loss( |
|
|
self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor] |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compute the curiosity-specific loss. |
|
|
|
|
|
Args: |
|
|
student_outputs: Outputs from the student model |
|
|
teacher_outputs: Outputs from the teacher model |
|
|
batch: The processed input batch |
|
|
|
|
|
Returns: |
|
|
Curiosity-specific loss tensor |
|
|
""" |
|
|
|
|
|
student_logits = ( |
|
|
student_outputs.logits |
|
|
if hasattr(student_outputs, "logits") |
|
|
else student_outputs |
|
|
) |
|
|
teacher_logits = ( |
|
|
teacher_outputs.logits |
|
|
if hasattr(teacher_outputs, "logits") |
|
|
else teacher_outputs |
|
|
) |
|
|
|
|
|
|
|
|
student_logits = student_logits[:, :-1, :].contiguous() |
|
|
teacher_logits = teacher_logits[:, :-1, :].contiguous() |
|
|
target_ids = batch["labels"][:, 1:].contiguous() |
|
|
|
|
|
|
|
|
temperature = self.curiosity_temp |
|
|
kl_loss = F.kl_div( |
|
|
F.log_softmax(student_logits / temperature, dim=-1), |
|
|
F.softmax(teacher_logits / temperature, dim=-1), |
|
|
reduction="batchmean", |
|
|
) * (temperature**2) |
|
|
|
|
|
|
|
|
ce_loss = self.curiosity_loss( |
|
|
student_logits.view(-1, student_logits.size(-1)), target_ids.view(-1) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
loss = 0.4 * ce_loss + 0.6 * kl_loss |
|
|
|
|
|
|
|
|
self.metrics["curiosity_loss"] = loss.item() |
|
|
|
|
|
return loss |
|
|
|
|
|
def get_metrics(self) -> Dict[str, float]: |
|
|
""" |
|
|
Get metrics specific to curiosity training. |
|
|
|
|
|
Returns: |
|
|
Dictionary of metric names and values |
|
|
""" |
|
|
return self.metrics |
|
|
|
|
|
def generate_synthetic_curiosity_data( |
|
|
self, output_path: str, num_samples: int = 1000 |
|
|
) -> None: |
|
|
""" |
|
|
Generate synthetic curiosity training data. |
|
|
|
|
|
Args: |
|
|
output_path: Path to save the generated data |
|
|
num_samples: Number of samples to generate |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
curiosity_templates = [ |
|
|
{ |
|
|
"scenario": "A company is developing a new voice assistant technology.", |
|
|
"curiosity_questions": [ |
|
|
"How might this technology affect people's privacy in their homes?", |
|
|
"What unexpected ways might users interact with this technology?", |
|
|
"How could this technology evolve over the next five years?", |
|
|
"What ethical considerations might arise from widespread adoption?", |
|
|
"How might this technology affect different demographic groups differently?", |
|
|
], |
|
|
"exploration_directions": [ |
|
|
"Consider alternative interaction models beyond voice commands", |
|
|
"Explore potential integration with other smart home systems", |
|
|
"Investigate privacy-preserving design approaches", |
|
|
"Consider accessibility implications for diverse user groups", |
|
|
"Examine potential unintended consequences of ambient listening", |
|
|
], |
|
|
}, |
|
|
{ |
|
|
"scenario": "Scientists have discovered a new species of deep-sea organism that can survive extreme pressure and temperature.", |
|
|
"curiosity_questions": [ |
|
|
"What adaptations allow this organism to survive such extreme conditions?", |
|
|
"Could these adaptations be applied to human technology or medicine?", |
|
|
"What might this discovery tell us about the possibility of life on other planets?", |
|
|
"How might climate change affect deep-sea ecosystems and this organism?", |
|
|
"What other undiscovered species might exist in similar environments?", |
|
|
], |
|
|
"exploration_directions": [ |
|
|
"Examine evolutionary pathways for extreme environment adaptation", |
|
|
"Consider biomimicry applications in engineering and materials science", |
|
|
"Explore implications for astrobiology and extraterrestrial life", |
|
|
"Investigate ecological relationships in extreme environments", |
|
|
"Consider ethical dimensions of deep-sea exploration and bioprospecting", |
|
|
], |
|
|
}, |
|
|
{ |
|
|
"scenario": "A small town is experiencing rapid population growth due to remote workers relocating from urban areas.", |
|
|
"curiosity_questions": [ |
|
|
"How might this demographic shift affect the town's culture and community?", |
|
|
"What infrastructure challenges might arise from rapid population growth?", |
|
|
"How could this trend impact local housing prices and affordability?", |
|
|
"What economic opportunities and challenges might emerge?", |
|
|
"How might long-term residents and newcomers develop different perspectives?", |
|
|
], |
|
|
"exploration_directions": [ |
|
|
"Investigate similar historical population shifts and their outcomes", |
|
|
"Consider varying perspectives from different stakeholder groups", |
|
|
"Explore potential policy approaches to manage growth sustainably", |
|
|
"Examine social integration mechanisms between established and new residents", |
|
|
"Consider environmental impacts of changing land use patterns", |
|
|
], |
|
|
}, |
|
|
] |
|
|
|
|
|
|
|
|
output_data = [] |
|
|
for _ in range(num_samples): |
|
|
template = random.choice(curiosity_templates) |
|
|
|
|
|
|
|
|
variation = template.copy() |
|
|
|
|
|
|
|
|
variation["metadata"] = { |
|
|
"generated": True, |
|
|
"timestamp": ( |
|
|
torch.cuda.get_device_name(0) |
|
|
if torch.cuda.is_available() |
|
|
else "CPU" |
|
|
), |
|
|
"requires_exploration": True, |
|
|
} |
|
|
|
|
|
output_data.append(variation) |
|
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
|
with open(output_path, "w", encoding="utf-8") as f: |
|
|
for item in output_data: |
|
|
f.write(json.dumps(item) + "\n") |
|
|
|
|
|
logger.info( |
|
|
f"Generated {len(output_data)} synthetic curiosity examples at {output_path}" |
|
|
) |
|
|
|