|
|
""" |
|
|
Core Training Framework for MangoMAS Local |
|
|
|
|
|
This module provides the foundation for specialized training modules, |
|
|
allowing for modular training of different cognitive capabilities. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from abc import ABC, abstractmethod |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Any, Dict, List, Optional |
|
|
|
|
|
import torch |
|
|
import yaml |
|
|
|
|
|
from .lora_trainer import LoRADistillationTrainer |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainingModuleConfig: |
|
|
"""Configuration for a specialized training module.""" |
|
|
|
|
|
name: str |
|
|
module_type: str |
|
|
enabled: bool = True |
|
|
loss_weight: float = 1.0 |
|
|
learning_rate: Optional[float] = None |
|
|
batch_size: Optional[int] = None |
|
|
data_path: Optional[str] = None |
|
|
module_config: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
|
|
|
class SpecializedTrainingModule(ABC): |
|
|
""" |
|
|
Abstract base class for specialized training modules. |
|
|
Each cognitive capability (reasoning, memory, etc.) should implement this interface. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: TrainingModuleConfig, tokenizer): |
|
|
""" |
|
|
Initialize the specialized training module. |
|
|
|
|
|
Args: |
|
|
config: Module configuration |
|
|
tokenizer: Tokenizer for text processing |
|
|
""" |
|
|
self.config = config |
|
|
self.tokenizer = tokenizer |
|
|
self.name = config.name |
|
|
self.enabled = config.enabled |
|
|
self.loss_weight = config.loss_weight |
|
|
self.device = torch.device( |
|
|
"cuda" |
|
|
if torch.cuda.is_available() |
|
|
else "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
) |
|
|
|
|
|
logger.info(f"Initialized {self.name} training module") |
|
|
logger.info(f"Module config: {self.config}") |
|
|
|
|
|
@abstractmethod |
|
|
def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Prepare a batch of data for this specific training module. |
|
|
|
|
|
Args: |
|
|
batch: The input batch from the dataloader |
|
|
|
|
|
Returns: |
|
|
Processed batch ready for the module |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def compute_loss( |
|
|
self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor] |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compute the specialized loss for this module. |
|
|
|
|
|
Args: |
|
|
student_outputs: Outputs from the student model |
|
|
teacher_outputs: Outputs from the teacher model |
|
|
batch: The processed input batch |
|
|
|
|
|
Returns: |
|
|
Loss tensor for this module |
|
|
""" |
|
|
pass |
|
|
|
|
|
@abstractmethod |
|
|
def get_metrics(self) -> Dict[str, float]: |
|
|
""" |
|
|
Get metrics specific to this training module. |
|
|
|
|
|
Returns: |
|
|
Dictionary of metric names and values |
|
|
""" |
|
|
pass |
|
|
|
|
|
|
|
|
class ModularTrainingManager: |
|
|
""" |
|
|
Training manager that orchestrates multiple specialized training modules. |
|
|
""" |
|
|
|
|
|
def __init__(self, config_path: str): |
|
|
""" |
|
|
Initialize the modular training manager. |
|
|
|
|
|
Args: |
|
|
config_path: Path to the training configuration file |
|
|
""" |
|
|
with open(config_path, "r") as f: |
|
|
self.config = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
self.base_trainer = LoRADistillationTrainer(config_path) |
|
|
self.tokenizer = self.base_trainer.tokenizer |
|
|
self.student_model = self.base_trainer.student_model |
|
|
self.teacher_model = self.base_trainer.teacher_manager.model |
|
|
|
|
|
|
|
|
self.modules = self._initialize_modules() |
|
|
|
|
|
logger.info( |
|
|
f"Initialized ModularTrainingManager with {len(self.modules)} modules" |
|
|
) |
|
|
|
|
|
def _initialize_modules(self) -> List[SpecializedTrainingModule]: |
|
|
""" |
|
|
Initialize all specialized training modules based on configuration. |
|
|
|
|
|
Returns: |
|
|
List of initialized training modules |
|
|
""" |
|
|
modules = [] |
|
|
module_configs = self.config.get("specialized_modules", []) |
|
|
|
|
|
for module_config in module_configs: |
|
|
if not module_config.get("enabled", True): |
|
|
logger.info(f"Skipping disabled module: {module_config.get('name')}") |
|
|
continue |
|
|
|
|
|
try: |
|
|
|
|
|
config_obj = TrainingModuleConfig(**module_config) |
|
|
|
|
|
|
|
|
module_type = config_obj.module_type |
|
|
module_class = self._import_module_class(module_type) |
|
|
|
|
|
|
|
|
module = module_class(config_obj, self.tokenizer) |
|
|
modules.append(module) |
|
|
|
|
|
logger.info(f"Successfully loaded module: {config_obj.name}") |
|
|
except Exception as e: |
|
|
logger.error( |
|
|
f"Failed to load module {module_config.get('name')}: {str(e)}" |
|
|
) |
|
|
|
|
|
return modules |
|
|
|
|
|
def _import_module_class(self, module_type: str) -> type: |
|
|
""" |
|
|
Dynamically import a module class based on its type. |
|
|
|
|
|
Args: |
|
|
module_type: The module type identifier |
|
|
|
|
|
Returns: |
|
|
The module class |
|
|
""" |
|
|
if module_type == "reasoning": |
|
|
from .specialized.reasoning_module import ReasoningTrainingModule |
|
|
|
|
|
return ReasoningTrainingModule |
|
|
elif module_type == "memory": |
|
|
from .specialized.memory_module import MemoryTrainingModule |
|
|
|
|
|
return MemoryTrainingModule |
|
|
elif module_type == "ethics": |
|
|
from .specialized.ethics_module import EthicsTrainingModule |
|
|
|
|
|
return EthicsTrainingModule |
|
|
elif module_type == "empathy": |
|
|
from .specialized.empathy_module import EmpathyTrainingModule |
|
|
|
|
|
return EmpathyTrainingModule |
|
|
elif module_type == "curiosity": |
|
|
from .specialized.curiosity_module import CuriosityTrainingModule |
|
|
|
|
|
return CuriosityTrainingModule |
|
|
else: |
|
|
raise ValueError(f"Unknown module type: {module_type}") |
|
|
|
|
|
def train(self, agent_type: str = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Train the model using all enabled specialized modules. |
|
|
|
|
|
Args: |
|
|
agent_type: Optional agent type for specialized training |
|
|
|
|
|
Returns: |
|
|
Training metrics and results |
|
|
""" |
|
|
|
|
|
|
|
|
logger.info(f"Starting modular training for agent: {agent_type or 'all'}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return self.base_trainer.train(agent_type) |
|
|
|
|
|
def evaluate(self, agent_type: str = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Evaluate the model using all enabled specialized modules. |
|
|
|
|
|
Args: |
|
|
agent_type: Optional agent type for specialized evaluation |
|
|
|
|
|
Returns: |
|
|
Evaluation metrics and results |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
return self.base_trainer.evaluate(agent_type) |
|
|
|