|
|
""" |
|
|
Model Wrapper |
|
|
============= |
|
|
|
|
|
This module provides a wrapper for neural network models to integrate |
|
|
with the gradient descent training system, including support for LoRA |
|
|
adapters and the MangoMAS agent system. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Dict, List, Optional, Any |
|
|
from pathlib import Path |
|
|
import json |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ModelWrapper: |
|
|
""" |
|
|
Wrapper for neural network models to integrate with gradient descent training |
|
|
|
|
|
Provides a unified interface for different model types and handles |
|
|
LoRA adapter integration for the MangoMAS system. |
|
|
""" |
|
|
|
|
|
def __init__(self, model: nn.Module, model_type: str = 'transformer', |
|
|
lora_config: Optional[Dict[str, Any]] = None): |
|
|
self.model = model |
|
|
self.model_type = model_type |
|
|
self.lora_config = lora_config or {} |
|
|
self.lora_params = [] |
|
|
|
|
|
|
|
|
if lora_config: |
|
|
self._setup_lora() |
|
|
|
|
|
logger.info(f"Initialized ModelWrapper for {model_type} model") |
|
|
|
|
|
def _setup_lora(self): |
|
|
"""Setup LoRA adapters for the model""" |
|
|
if not self.lora_config: |
|
|
return |
|
|
|
|
|
|
|
|
for name, param in self.model.named_parameters(): |
|
|
if 'lora' in name.lower() or 'adapter' in name.lower(): |
|
|
self.lora_params.append(name) |
|
|
param.requires_grad = True |
|
|
else: |
|
|
param.requires_grad = False |
|
|
|
|
|
logger.info(f"Setup LoRA with {len(self.lora_params)} adapter parameters") |
|
|
|
|
|
def forward(self, inputs: torch.Tensor, **kwargs) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass through the model |
|
|
|
|
|
Args: |
|
|
inputs: Input tensor |
|
|
**kwargs: Additional arguments |
|
|
|
|
|
Returns: |
|
|
Model output tensor |
|
|
""" |
|
|
return self.model(inputs, **kwargs) |
|
|
|
|
|
def get_trainable_parameters(self) -> List[torch.Tensor]: |
|
|
""" |
|
|
Get list of trainable parameters |
|
|
|
|
|
Returns: |
|
|
List of trainable parameter tensors |
|
|
""" |
|
|
if self.lora_params: |
|
|
|
|
|
return [param for name, param in self.model.named_parameters() |
|
|
if name in self.lora_params and param.requires_grad] |
|
|
else: |
|
|
|
|
|
return [param for param in self.model.parameters() if param.requires_grad] |
|
|
|
|
|
def get_parameter_info(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Get information about model parameters |
|
|
|
|
|
Returns: |
|
|
Dictionary of parameter information |
|
|
""" |
|
|
info = { |
|
|
'total_parameters': sum(p.numel() for p in self.model.parameters()), |
|
|
'trainable_parameters': sum(p.numel() for p in self.get_trainable_parameters()), |
|
|
'lora_parameters': len(self.lora_params), |
|
|
'parameter_details': {} |
|
|
} |
|
|
|
|
|
for name, param in self.model.named_parameters(): |
|
|
info['parameter_details'][name] = { |
|
|
'shape': list(param.shape), |
|
|
'numel': param.numel(), |
|
|
'requires_grad': param.requires_grad, |
|
|
'is_lora': name in self.lora_params |
|
|
} |
|
|
|
|
|
return info |
|
|
|
|
|
def save_model(self, save_path: str, metadata: Optional[Dict[str, Any]] = None): |
|
|
""" |
|
|
Save the model and metadata |
|
|
|
|
|
Args: |
|
|
save_path: Path to save the model |
|
|
metadata: Additional metadata to save |
|
|
""" |
|
|
save_path = Path(save_path) |
|
|
save_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
model_path = save_path / 'model.pt' |
|
|
torch.save(self.model.state_dict(), model_path) |
|
|
|
|
|
|
|
|
if metadata is None: |
|
|
metadata = {} |
|
|
|
|
|
metadata.update({ |
|
|
'model_type': self.model_type, |
|
|
'lora_config': self.lora_config, |
|
|
'lora_params': self.lora_params, |
|
|
'parameter_info': self.get_parameter_info() |
|
|
}) |
|
|
|
|
|
metadata_path = save_path / 'metadata.json' |
|
|
with open(metadata_path, 'w') as f: |
|
|
json.dump(metadata, f, indent=2) |
|
|
|
|
|
logger.info(f"Model saved to {save_path}") |
|
|
|
|
|
def load_model(self, load_path: str): |
|
|
""" |
|
|
Load the model from saved files |
|
|
|
|
|
Args: |
|
|
load_path: Path to load the model from |
|
|
""" |
|
|
load_path = Path(load_path) |
|
|
|
|
|
|
|
|
model_path = load_path / 'model.pt' |
|
|
if model_path.exists(): |
|
|
state_dict = torch.load(model_path, map_location='cpu') |
|
|
self.model.load_state_dict(state_dict) |
|
|
logger.info(f"Model loaded from {model_path}") |
|
|
|
|
|
|
|
|
metadata_path = load_path / 'metadata.json' |
|
|
if metadata_path.exists(): |
|
|
with open(metadata_path, 'r') as f: |
|
|
metadata = json.load(f) |
|
|
|
|
|
self.model_type = metadata.get('model_type', self.model_type) |
|
|
self.lora_config = metadata.get('lora_config', self.lora_config) |
|
|
self.lora_params = metadata.get('lora_params', self.lora_params) |
|
|
|
|
|
logger.info(f"Metadata loaded from {metadata_path}") |
|
|
|
|
|
def to(self, device: torch.device): |
|
|
"""Move model to device""" |
|
|
self.model.to(device) |
|
|
return self |
|
|
|
|
|
def train(self): |
|
|
"""Set model to training mode""" |
|
|
self.model.train() |
|
|
return self |
|
|
|
|
|
def eval(self): |
|
|
"""Set model to evaluation mode""" |
|
|
self.model.eval() |
|
|
return self |
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
"""Call the model""" |
|
|
return self.forward(*args, **kwargs) |
|
|
|
|
|
|
|
|
class LoRAModelWrapper(ModelWrapper): |
|
|
""" |
|
|
Specialized wrapper for LoRA (Low-Rank Adaptation) models |
|
|
|
|
|
Provides enhanced functionality for LoRA adapter management |
|
|
and integration with the MangoMAS system. |
|
|
""" |
|
|
|
|
|
def __init__(self, base_model: nn.Module, lora_config: Dict[str, Any]): |
|
|
super().__init__(base_model, 'lora_transformer', lora_config) |
|
|
self.base_model = base_model |
|
|
self.adapters = {} |
|
|
|
|
|
|
|
|
self._initialize_lora_adapters() |
|
|
|
|
|
def _initialize_lora_adapters(self): |
|
|
"""Initialize LoRA adapters based on configuration""" |
|
|
rank = self.lora_config.get('rank', 16) |
|
|
alpha = self.lora_config.get('alpha', 32) |
|
|
dropout = self.lora_config.get('dropout', 0.1) |
|
|
target_modules = self.lora_config.get('target_modules', ['c_attn', 'c_proj']) |
|
|
|
|
|
|
|
|
for name, module in self.base_model.named_modules(): |
|
|
if any(target in name for target in target_modules): |
|
|
if isinstance(module, (nn.Linear, nn.Conv2d)): |
|
|
|
|
|
adapter = LoRAAdapter(module, rank, alpha, dropout) |
|
|
self.adapters[name] = adapter |
|
|
|
|
|
|
|
|
self._replace_module(name, adapter) |
|
|
|
|
|
logger.info(f"Initialized {len(self.adapters)} LoRA adapters") |
|
|
|
|
|
def _replace_module(self, module_name: str, new_module: nn.Module): |
|
|
"""Replace a module in the model""" |
|
|
parts = module_name.split('.') |
|
|
parent = self.base_model |
|
|
|
|
|
for part in parts[:-1]: |
|
|
parent = getattr(parent, part) |
|
|
|
|
|
setattr(parent, parts[-1], new_module) |
|
|
|
|
|
def get_lora_parameters(self) -> List[torch.Tensor]: |
|
|
"""Get LoRA adapter parameters""" |
|
|
lora_params = [] |
|
|
for adapter in self.adapters.values(): |
|
|
lora_params.extend(adapter.parameters()) |
|
|
return lora_params |
|
|
|
|
|
def merge_adapters(self): |
|
|
"""Merge LoRA adapters into base model""" |
|
|
for adapter in self.adapters.values(): |
|
|
adapter.merge() |
|
|
logger.info("LoRA adapters merged into base model") |
|
|
|
|
|
def unmerge_adapters(self): |
|
|
"""Unmerge LoRA adapters from base model""" |
|
|
for adapter in self.adapters.values(): |
|
|
adapter.unmerge() |
|
|
logger.info("LoRA adapters unmerged from base model") |
|
|
|
|
|
|
|
|
class LoRAAdapter(nn.Module): |
|
|
""" |
|
|
LoRA (Low-Rank Adaptation) adapter module |
|
|
|
|
|
Implements the LoRA technique for efficient fine-tuning of large models. |
|
|
""" |
|
|
|
|
|
def __init__(self, original_module: nn.Module, rank: int = 16, |
|
|
alpha: float = 32, dropout: float = 0.1): |
|
|
super().__init__() |
|
|
self.original_module = original_module |
|
|
self.rank = rank |
|
|
self.alpha = alpha |
|
|
self.dropout = dropout |
|
|
|
|
|
|
|
|
if isinstance(original_module, nn.Linear): |
|
|
in_features = original_module.in_features |
|
|
out_features = original_module.out_features |
|
|
elif isinstance(original_module, nn.Conv2d): |
|
|
in_features = original_module.in_channels |
|
|
out_features = original_module.out_channels |
|
|
else: |
|
|
raise ValueError(f"Unsupported module type: {type(original_module)}") |
|
|
|
|
|
|
|
|
self.lora_A = nn.Linear(in_features, rank, bias=False) |
|
|
self.lora_B = nn.Linear(rank, out_features, bias=False) |
|
|
self.dropout_layer = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
nn.init.kaiming_uniform_(self.lora_A.weight) |
|
|
nn.init.zeros_(self.lora_B.weight) |
|
|
|
|
|
|
|
|
self.original_weight = original_module.weight.data.clone() |
|
|
self.merged = False |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""Forward pass through LoRA adapter""" |
|
|
if self.merged: |
|
|
|
|
|
return F.linear(x, self.original_weight, self.original_module.bias) |
|
|
else: |
|
|
|
|
|
lora_output = self.lora_B(self.dropout_layer(self.lora_A(x))) |
|
|
original_output = F.linear(x, self.original_weight, self.original_module.bias) |
|
|
return original_output + (self.alpha / self.rank) * lora_output |
|
|
|
|
|
def merge(self): |
|
|
"""Merge LoRA weights into original weights""" |
|
|
if not self.merged: |
|
|
lora_weight = (self.alpha / self.rank) * torch.mm( |
|
|
self.lora_B.weight, self.lora_A.weight |
|
|
) |
|
|
self.original_weight += lora_weight |
|
|
self.merged = True |
|
|
|
|
|
def unmerge(self): |
|
|
"""Unmerge LoRA weights from original weights""" |
|
|
if self.merged: |
|
|
lora_weight = (self.alpha / self.rank) * torch.mm( |
|
|
self.lora_B.weight, self.lora_A.weight |
|
|
) |
|
|
self.original_weight -= lora_weight |
|
|
self.merged = False |
|
|
|
|
|
|
|
|
class ModelFactory: |
|
|
"""Factory class for creating model wrappers""" |
|
|
|
|
|
@staticmethod |
|
|
def create_model_wrapper(model_type: str, model: nn.Module, |
|
|
**kwargs) -> ModelWrapper: |
|
|
"""Create a model wrapper instance""" |
|
|
if model_type.lower() == 'lora': |
|
|
return LoRAModelWrapper(model, kwargs.get('lora_config', {})) |
|
|
else: |
|
|
return ModelWrapper(model, model_type, kwargs.get('lora_config')) |
|
|
|
|
|
@staticmethod |
|
|
def get_default_lora_config() -> Dict[str, Any]: |
|
|
"""Get default LoRA configuration""" |
|
|
return { |
|
|
'rank': 16, |
|
|
'alpha': 32, |
|
|
'dropout': 0.1, |
|
|
'target_modules': ['c_attn', 'c_proj'] |
|
|
} |
|
|
|