""" Curiosity Training Module for MangoMAS Local This module implements specialized training for curiosity and exploration capabilities, adapted from the AWS backup system for local training. """ import json import logging import os import random from typing import Any, Dict, List import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset from ..core_framework import SpecializedTrainingModule, TrainingModuleConfig logger = logging.getLogger(__name__) class CuriosityDataset(Dataset): """Dataset for training curiosity and exploration capabilities.""" def __init__(self, data_path: str, tokenizer, max_length: int = 768): """ Initialize the curiosity dataset. Args: data_path: Path to the curiosity data file tokenizer: Tokenizer for text processing max_length: Maximum sequence length """ self.tokenizer = tokenizer self.max_length = max_length self.data = self._load_data(data_path) logger.info(f"Loaded curiosity dataset with {len(self.data)} examples") def _load_data(self, data_path: str) -> List[Dict]: """Load curiosity training data.""" data = [] with open(data_path, "r", encoding="utf-8") as f: for line in f: try: item = json.loads(line.strip()) # Validate required fields for curiosity data if ( "scenario" in item and "curiosity_questions" in item and "exploration_directions" in item ): data.append(item) except (json.JSONDecodeError, KeyError) as e: logger.warning(f"Skipping invalid curiosity data: {e}") return data def __len__(self) -> int: """Return the number of examples in the dataset.""" return len(self.data) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """Get a training example.""" item = self.data[idx] # Format the curiosity example scenario = item["scenario"] curiosity_questions = item["curiosity_questions"] exploration_directions = item["exploration_directions"] # Format as text text = f"Scenario: {scenario}\n\n" text += "Curiosity Questions:\n" for i, question in enumerate(curiosity_questions): text += f"{i+1}. {question}\n" text += "\n" text += "Exploration Directions:\n" for i, direction in enumerate(exploration_directions): text += f"{i+1}. {direction}\n" # Tokenize encoding = self.tokenizer( text, truncation=True, padding="max_length", max_length=self.max_length, return_tensors="pt", ) return { "input_ids": encoding["input_ids"].squeeze(), "attention_mask": encoding["attention_mask"].squeeze(), "labels": encoding["input_ids"].squeeze().clone(), "scenario": scenario, "curiosity_questions": curiosity_questions, "exploration_directions": exploration_directions, } class CuriosityEvaluator: """Evaluator for curiosity and exploration capabilities.""" def __init__(self, tokenizer): """ Initialize the curiosity evaluator. Args: tokenizer: Tokenizer for text processing """ self.tokenizer = tokenizer self.metrics = { "question_diversity": 0.0, "exploration_breadth": 0.0, "uncertainty_identification": 0.0, "assumption_challenging": 0.0, } # Question starters for evaluating diversity self.question_starters = [ "what", "how", "why", "when", "where", "who", "which", "could", "would", "is", "are", "do", "does", "have", "has", ] # Exploration markers self.exploration_markers = [ "alternative", "perspective", "consider", "explore", "investigate", "possibility", "approach", "angle", "viewpoint", "scenario", ] # Uncertainty phrases self.uncertainty_phrases = [ "unclear", "unknown", "uncertain", "not sure", "ambiguous", "might be", "could be", "possibly", "perhaps", "may", ] # Assumption challenging phrases self.assumption_phrases = [ "assuming", "assumption", "presuppose", "presupposition", "take for granted", "implicit", "unstated", "underlying", ] def evaluate(self, model, eval_dataset: CuriosityDataset) -> Dict[str, float]: """ Evaluate curiosity capabilities on the provided dataset. Args: model: The model to evaluate eval_dataset: Dataset of curiosity examples Returns: Dictionary of evaluation metrics """ model.eval() device = next(model.parameters()).device # Reset metrics for key in self.metrics: self.metrics[key] = 0.0 total_examples = min( len(eval_dataset), 50 ) # Limit to 50 examples for efficiency with torch.no_grad(): for idx in range(total_examples): example = eval_dataset[idx] scenario = example["scenario"] # Generate questions for the scenario prompt = f"Scenario: {scenario}\n\nGenerate curious questions to explore this further:" input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( device ) generated_ids = model.generate( input_ids, max_length=256, temperature=0.8, # Slightly higher temperature for creativity num_return_sequences=1, ) generated_text = self.tokenizer.decode( generated_ids[0], skip_special_tokens=True ) # Evaluate curiosity quality self._evaluate_curiosity( scenario=scenario, expected_questions=example["curiosity_questions"], expected_directions=example["exploration_directions"], generated_text=generated_text, ) # Calculate averages for key in self.metrics: self.metrics[key] /= total_examples return self.metrics def _evaluate_curiosity( self, scenario: str, expected_questions: List[str], expected_directions: List[str], generated_text: str, ) -> None: """ Evaluate curiosity quality for a specific example. Args: scenario: The scenario to explore expected_questions: Expected curiosity questions expected_directions: Expected exploration directions generated_text: The text generated by the model """ # Extract questions from generated text (simple approach) generated_questions = [ line.strip() for line in generated_text.split("\n") if line.strip().endswith("?") ] if not generated_questions: # Try to extract numbered questions for line in generated_text.split("\n"): if any(f"{i}." in line for i in range(1, 10)) and "?" in line: generated_questions.append(line.strip()) # 1. Question diversity - variety of question types starter_counts = {starter: 0 for starter in self.question_starters} for question in generated_questions: for starter in self.question_starters: if ( question.lower().startswith(starter) or f" {starter} " in question.lower() ): starter_counts[starter] += 1 unique_starters = sum(1 for count in starter_counts.values() if count > 0) self.metrics["question_diversity"] += min( 1.0, unique_starters / 5 ) # Normalize to 5 unique types # 2. Exploration breadth - check for exploration markers exploration_marker_count = sum( 1 for marker in self.exploration_markers if marker in generated_text.lower() ) self.metrics["exploration_breadth"] += min(1.0, exploration_marker_count / 3) # 3. Uncertainty identification - check for uncertainty phrases uncertainty_phrase_count = sum( 1 for phrase in self.uncertainty_phrases if phrase in generated_text.lower() ) self.metrics["uncertainty_identification"] += min( 1.0, uncertainty_phrase_count / 2 ) # 4. Assumption challenging - check for phrases that challenge assumptions assumption_phrase_count = sum( 1 for phrase in self.assumption_phrases if phrase in generated_text.lower() ) self.metrics["assumption_challenging"] += min(1.0, assumption_phrase_count / 1) class CuriosityTrainingModule(SpecializedTrainingModule): """Specialized training module for curiosity and exploration capabilities.""" def __init__(self, config: TrainingModuleConfig, tokenizer): """ Initialize the curiosity training module. Args: config: Module configuration tokenizer: Tokenizer for text processing """ super().__init__(config, tokenizer) # Initialize curiosity-specific components self.data_path = config.data_path or "data/processed/curiosity_train.jsonl" self.evaluator = CuriosityEvaluator(tokenizer) # Curiosity-specific loss with higher temperature self.curiosity_temp = config.module_config.get("temperature", 1.5) self.curiosity_loss = nn.CrossEntropyLoss(ignore_index=-100) # Training metrics self.metrics = { "curiosity_loss": 0.0, "question_generation_score": 0.0, "exploration_score": 0.0, } logger.info( f"Initialized curiosity training module with temperature: {self.curiosity_temp}" ) def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Prepare a batch of data for curiosity training. Args: batch: The input batch from the dataloader Returns: Processed batch ready for curiosity training """ # Extract curiosity-specific elements if they exist if all( key in batch for key in ["scenario", "curiosity_questions", "exploration_directions"] ): # This is already a curiosity-specific batch return batch # For general conversation batches, we could extract potential exploration scenarios # This is a simplified placeholder implementation return batch def compute_loss( self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor] ) -> torch.Tensor: """ Compute the curiosity-specific loss. Args: student_outputs: Outputs from the student model teacher_outputs: Outputs from the teacher model batch: The processed input batch Returns: Curiosity-specific loss tensor """ # Get logits from outputs student_logits = ( student_outputs.logits if hasattr(student_outputs, "logits") else student_outputs ) teacher_logits = ( teacher_outputs.logits if hasattr(teacher_outputs, "logits") else teacher_outputs ) # Standard distillation loss calculation student_logits = student_logits[:, :-1, :].contiguous() teacher_logits = teacher_logits[:, :-1, :].contiguous() target_ids = batch["labels"][:, 1:].contiguous() # For curiosity, we use a higher temperature to encourage more diverse outputs temperature = self.curiosity_temp kl_loss = F.kl_div( F.log_softmax(student_logits / temperature, dim=-1), F.softmax(teacher_logits / temperature, dim=-1), reduction="batchmean", ) * (temperature**2) # Cross-entropy loss against labels ce_loss = self.curiosity_loss( student_logits.view(-1, student_logits.size(-1)), target_ids.view(-1) ) # Combined loss with curiosity focus # We weight KL divergence higher to encourage exploration loss = 0.4 * ce_loss + 0.6 * kl_loss # Update metrics self.metrics["curiosity_loss"] = loss.item() return loss def get_metrics(self) -> Dict[str, float]: """ Get metrics specific to curiosity training. Returns: Dictionary of metric names and values """ return self.metrics def generate_synthetic_curiosity_data( self, output_path: str, num_samples: int = 1000 ) -> None: """ Generate synthetic curiosity training data. Args: output_path: Path to save the generated data num_samples: Number of samples to generate """ # This is a simplified implementation based on the AWS backup # In a full implementation, this would be much more sophisticated curiosity_templates = [ { "scenario": "A company is developing a new voice assistant technology.", "curiosity_questions": [ "How might this technology affect people's privacy in their homes?", "What unexpected ways might users interact with this technology?", "How could this technology evolve over the next five years?", "What ethical considerations might arise from widespread adoption?", "How might this technology affect different demographic groups differently?", ], "exploration_directions": [ "Consider alternative interaction models beyond voice commands", "Explore potential integration with other smart home systems", "Investigate privacy-preserving design approaches", "Consider accessibility implications for diverse user groups", "Examine potential unintended consequences of ambient listening", ], }, { "scenario": "Scientists have discovered a new species of deep-sea organism that can survive extreme pressure and temperature.", "curiosity_questions": [ "What adaptations allow this organism to survive such extreme conditions?", "Could these adaptations be applied to human technology or medicine?", "What might this discovery tell us about the possibility of life on other planets?", "How might climate change affect deep-sea ecosystems and this organism?", "What other undiscovered species might exist in similar environments?", ], "exploration_directions": [ "Examine evolutionary pathways for extreme environment adaptation", "Consider biomimicry applications in engineering and materials science", "Explore implications for astrobiology and extraterrestrial life", "Investigate ecological relationships in extreme environments", "Consider ethical dimensions of deep-sea exploration and bioprospecting", ], }, { "scenario": "A small town is experiencing rapid population growth due to remote workers relocating from urban areas.", "curiosity_questions": [ "How might this demographic shift affect the town's culture and community?", "What infrastructure challenges might arise from rapid population growth?", "How could this trend impact local housing prices and affordability?", "What economic opportunities and challenges might emerge?", "How might long-term residents and newcomers develop different perspectives?", ], "exploration_directions": [ "Investigate similar historical population shifts and their outcomes", "Consider varying perspectives from different stakeholder groups", "Explore potential policy approaches to manage growth sustainably", "Examine social integration mechanisms between established and new residents", "Consider environmental impacts of changing land use patterns", ], }, ] # Generate variations output_data = [] for _ in range(num_samples): template = random.choice(curiosity_templates) # Create a variation to avoid exact duplicates variation = template.copy() # Add metadata variation["metadata"] = { "generated": True, "timestamp": ( torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU" ), "requires_exploration": True, } output_data.append(variation) # Save to file os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: for item in output_data: f.write(json.dumps(item) + "\n") logger.info( f"Generated {len(output_data)} synthetic curiosity examples at {output_path}" )