""" Core Training Framework for MangoMAS Local This module provides the foundation for specialized training modules, allowing for modular training of different cognitive capabilities. """ import logging from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Dict, List, Optional import torch import yaml from .lora_trainer import LoRADistillationTrainer logger = logging.getLogger(__name__) @dataclass class TrainingModuleConfig: """Configuration for a specialized training module.""" name: str module_type: str enabled: bool = True loss_weight: float = 1.0 learning_rate: Optional[float] = None batch_size: Optional[int] = None data_path: Optional[str] = None module_config: Dict[str, Any] = field(default_factory=dict) class SpecializedTrainingModule(ABC): """ Abstract base class for specialized training modules. Each cognitive capability (reasoning, memory, etc.) should implement this interface. """ def __init__(self, config: TrainingModuleConfig, tokenizer): """ Initialize the specialized training module. Args: config: Module configuration tokenizer: Tokenizer for text processing """ self.config = config self.tokenizer = tokenizer self.name = config.name self.enabled = config.enabled self.loss_weight = config.loss_weight self.device = torch.device( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) logger.info(f"Initialized {self.name} training module") logger.info(f"Module config: {self.config}") @abstractmethod def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Prepare a batch of data for this specific training module. Args: batch: The input batch from the dataloader Returns: Processed batch ready for the module """ pass @abstractmethod def compute_loss( self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor] ) -> torch.Tensor: """ Compute the specialized loss for this module. Args: student_outputs: Outputs from the student model teacher_outputs: Outputs from the teacher model batch: The processed input batch Returns: Loss tensor for this module """ pass @abstractmethod def get_metrics(self) -> Dict[str, float]: """ Get metrics specific to this training module. Returns: Dictionary of metric names and values """ pass class ModularTrainingManager: """ Training manager that orchestrates multiple specialized training modules. """ def __init__(self, config_path: str): """ Initialize the modular training manager. Args: config_path: Path to the training configuration file """ with open(config_path, "r") as f: self.config = yaml.safe_load(f) # Set up core components self.base_trainer = LoRADistillationTrainer(config_path) self.tokenizer = self.base_trainer.tokenizer self.student_model = self.base_trainer.student_model self.teacher_model = self.base_trainer.teacher_manager.model # Initialize modules self.modules = self._initialize_modules() logger.info( f"Initialized ModularTrainingManager with {len(self.modules)} modules" ) def _initialize_modules(self) -> List[SpecializedTrainingModule]: """ Initialize all specialized training modules based on configuration. Returns: List of initialized training modules """ modules = [] module_configs = self.config.get("specialized_modules", []) for module_config in module_configs: if not module_config.get("enabled", True): logger.info(f"Skipping disabled module: {module_config.get('name')}") continue try: # Convert to proper config object config_obj = TrainingModuleConfig(**module_config) # Import the module dynamically module_type = config_obj.module_type module_class = self._import_module_class(module_type) # Initialize the module module = module_class(config_obj, self.tokenizer) modules.append(module) logger.info(f"Successfully loaded module: {config_obj.name}") except Exception as e: logger.error( f"Failed to load module {module_config.get('name')}: {str(e)}" ) return modules def _import_module_class(self, module_type: str) -> type: """ Dynamically import a module class based on its type. Args: module_type: The module type identifier Returns: The module class """ if module_type == "reasoning": from .specialized.reasoning_module import ReasoningTrainingModule return ReasoningTrainingModule elif module_type == "memory": from .specialized.memory_module import MemoryTrainingModule return MemoryTrainingModule elif module_type == "ethics": from .specialized.ethics_module import EthicsTrainingModule return EthicsTrainingModule elif module_type == "empathy": from .specialized.empathy_module import EmpathyTrainingModule return EmpathyTrainingModule elif module_type == "curiosity": from .specialized.curiosity_module import CuriosityTrainingModule return CuriosityTrainingModule else: raise ValueError(f"Unknown module type: {module_type}") def train(self, agent_type: str = None) -> Dict[str, Any]: """ Train the model using all enabled specialized modules. Args: agent_type: Optional agent type for specialized training Returns: Training metrics and results """ # Delegate to base trainer for core training functionality # but integrate specialized module losses logger.info(f"Starting modular training for agent: {agent_type or 'all'}") # TODO: Implement the full training loop integrating all modules # This is a placeholder until we implement the full integration return self.base_trainer.train(agent_type) def evaluate(self, agent_type: str = None) -> Dict[str, Any]: """ Evaluate the model using all enabled specialized modules. Args: agent_type: Optional agent type for specialized evaluation Returns: Evaluation metrics and results """ # TODO: Implement evaluation using specialized modules # This is a placeholder until we implement the full integration return self.base_trainer.evaluate(agent_type)