""" Model Wrapper ============= This module provides a wrapper for neural network models to integrate with the gradient descent training system, including support for LoRA adapters and the MangoMAS agent system. """ import logging import torch import torch.nn as nn from typing import Dict, List, Optional, Any from pathlib import Path import json logger = logging.getLogger(__name__) class ModelWrapper: """ Wrapper for neural network models to integrate with gradient descent training Provides a unified interface for different model types and handles LoRA adapter integration for the MangoMAS system. """ def __init__(self, model: nn.Module, model_type: str = 'transformer', lora_config: Optional[Dict[str, Any]] = None): self.model = model self.model_type = model_type self.lora_config = lora_config or {} self.lora_params = [] # Initialize LoRA if configured if lora_config: self._setup_lora() logger.info(f"Initialized ModelWrapper for {model_type} model") def _setup_lora(self): """Setup LoRA adapters for the model""" if not self.lora_config: return # Extract LoRA parameters for name, param in self.model.named_parameters(): if 'lora' in name.lower() or 'adapter' in name.lower(): self.lora_params.append(name) param.requires_grad = True else: param.requires_grad = False logger.info(f"Setup LoRA with {len(self.lora_params)} adapter parameters") def forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor: """ Forward pass through the model Args: inputs: Input tensor **kwargs: Additional arguments Returns: Model output tensor """ return self.model(inputs, **kwargs) def get_trainable_parameters(self) -> List[torch.Tensor]: """ Get list of trainable parameters Returns: List of trainable parameter tensors """ if self.lora_params: # Return only LoRA parameters return [param for name, param in self.model.named_parameters() if name in self.lora_params and param.requires_grad] else: # Return all trainable parameters return [param for param in self.model.parameters() if param.requires_grad] def get_parameter_info(self) -> Dict[str, Any]: """ Get information about model parameters Returns: Dictionary of parameter information """ info = { 'total_parameters': sum(p.numel() for p in self.model.parameters()), 'trainable_parameters': sum(p.numel() for p in self.get_trainable_parameters()), 'lora_parameters': len(self.lora_params), 'parameter_details': {} } for name, param in self.model.named_parameters(): info['parameter_details'][name] = { 'shape': list(param.shape), 'numel': param.numel(), 'requires_grad': param.requires_grad, 'is_lora': name in self.lora_params } return info def save_model(self, save_path: str, metadata: Optional[Dict[str, Any]] = None): """ Save the model and metadata Args: save_path: Path to save the model metadata: Additional metadata to save """ save_path = Path(save_path) save_path.mkdir(parents=True, exist_ok=True) # Save model state model_path = save_path / 'model.pt' torch.save(self.model.state_dict(), model_path) # Save metadata if metadata is None: metadata = {} metadata.update({ 'model_type': self.model_type, 'lora_config': self.lora_config, 'lora_params': self.lora_params, 'parameter_info': self.get_parameter_info() }) metadata_path = save_path / 'metadata.json' with open(metadata_path, 'w') as f: json.dump(metadata, f, indent=2) logger.info(f"Model saved to {save_path}") def load_model(self, load_path: str): """ Load the model from saved files Args: load_path: Path to load the model from """ load_path = Path(load_path) # Load model state model_path = load_path / 'model.pt' if model_path.exists(): state_dict = torch.load(model_path, map_location='cpu') self.model.load_state_dict(state_dict) logger.info(f"Model loaded from {model_path}") # Load metadata metadata_path = load_path / 'metadata.json' if metadata_path.exists(): with open(metadata_path, 'r') as f: metadata = json.load(f) self.model_type = metadata.get('model_type', self.model_type) self.lora_config = metadata.get('lora_config', self.lora_config) self.lora_params = metadata.get('lora_params', self.lora_params) logger.info(f"Metadata loaded from {metadata_path}") def to(self, device: torch.device): """Move model to device""" self.model.to(device) return self def train(self): """Set model to training mode""" self.model.train() return self def eval(self): """Set model to evaluation mode""" self.model.eval() return self def __call__(self, *args, **kwargs): """Call the model""" return self.forward(*args, **kwargs) class LoRAModelWrapper(ModelWrapper): """ Specialized wrapper for LoRA (Low-Rank Adaptation) models Provides enhanced functionality for LoRA adapter management and integration with the MangoMAS system. """ def __init__(self, base_model: nn.Module, lora_config: Dict[str, Any]): super().__init__(base_model, 'lora_transformer', lora_config) self.base_model = base_model self.adapters = {} # Initialize LoRA adapters self._initialize_lora_adapters() def _initialize_lora_adapters(self): """Initialize LoRA adapters based on configuration""" rank = self.lora_config.get('rank', 16) alpha = self.lora_config.get('alpha', 32) dropout = self.lora_config.get('dropout', 0.1) target_modules = self.lora_config.get('target_modules', ['c_attn', 'c_proj']) # Add LoRA adapters to target modules for name, module in self.base_model.named_modules(): if any(target in name for target in target_modules): if isinstance(module, (nn.Linear, nn.Conv2d)): # Add LoRA adapter adapter = LoRAAdapter(module, rank, alpha, dropout) self.adapters[name] = adapter # Replace original module self._replace_module(name, adapter) logger.info(f"Initialized {len(self.adapters)} LoRA adapters") def _replace_module(self, module_name: str, new_module: nn.Module): """Replace a module in the model""" parts = module_name.split('.') parent = self.base_model for part in parts[:-1]: parent = getattr(parent, part) setattr(parent, parts[-1], new_module) def get_lora_parameters(self) -> List[torch.Tensor]: """Get LoRA adapter parameters""" lora_params = [] for adapter in self.adapters.values(): lora_params.extend(adapter.parameters()) return lora_params def merge_adapters(self): """Merge LoRA adapters into base model""" for adapter in self.adapters.values(): adapter.merge() logger.info("LoRA adapters merged into base model") def unmerge_adapters(self): """Unmerge LoRA adapters from base model""" for adapter in self.adapters.values(): adapter.unmerge() logger.info("LoRA adapters unmerged from base model") class LoRAAdapter(nn.Module): """ LoRA (Low-Rank Adaptation) adapter module Implements the LoRA technique for efficient fine-tuning of large models. """ def __init__(self, original_module: nn.Module, rank: int = 16, alpha: float = 32, dropout: float = 0.1): super().__init__() self.original_module = original_module self.rank = rank self.alpha = alpha self.dropout = dropout # Get original module dimensions if isinstance(original_module, nn.Linear): in_features = original_module.in_features out_features = original_module.out_features elif isinstance(original_module, nn.Conv2d): in_features = original_module.in_channels out_features = original_module.out_channels else: raise ValueError(f"Unsupported module type: {type(original_module)}") # Initialize LoRA matrices self.lora_A = nn.Linear(in_features, rank, bias=False) self.lora_B = nn.Linear(rank, out_features, bias=False) self.dropout_layer = nn.Dropout(dropout) # Initialize weights nn.init.kaiming_uniform_(self.lora_A.weight) nn.init.zeros_(self.lora_B.weight) # Store original weights self.original_weight = original_module.weight.data.clone() self.merged = False def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through LoRA adapter""" if self.merged: # Use merged weights return F.linear(x, self.original_weight, self.original_module.bias) else: # Use LoRA adaptation lora_output = self.lora_B(self.dropout_layer(self.lora_A(x))) original_output = F.linear(x, self.original_weight, self.original_module.bias) return original_output + (self.alpha / self.rank) * lora_output def merge(self): """Merge LoRA weights into original weights""" if not self.merged: lora_weight = (self.alpha / self.rank) * torch.mm( self.lora_B.weight, self.lora_A.weight ) self.original_weight += lora_weight self.merged = True def unmerge(self): """Unmerge LoRA weights from original weights""" if self.merged: lora_weight = (self.alpha / self.rank) * torch.mm( self.lora_B.weight, self.lora_A.weight ) self.original_weight -= lora_weight self.merged = False class ModelFactory: """Factory class for creating model wrappers""" @staticmethod def create_model_wrapper(model_type: str, model: nn.Module, **kwargs) -> ModelWrapper: """Create a model wrapper instance""" if model_type.lower() == 'lora': return LoRAModelWrapper(model, kwargs.get('lora_config', {})) else: return ModelWrapper(model, model_type, kwargs.get('lora_config')) @staticmethod def get_default_lora_config() -> Dict[str, Any]: """Get default LoRA configuration""" return { 'rank': 16, 'alpha': 32, 'dropout': 0.1, 'target_modules': ['c_attn', 'c_proj'] }