|
|
|
|
|
""" |
|
|
CPU-Optimized Multi-Agent Trainer |
|
|
|
|
|
This module provides comprehensive multi-agent training capabilities optimized for CPU execution, |
|
|
including LoRA fine-tuning, agent-specific conditioning, and integration with existing training infrastructure. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import math |
|
|
import random |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional, Tuple, Any, Union |
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
import torch |
|
|
import yaml |
|
|
from datasets import DatasetDict, Dataset |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
TrainingArguments, |
|
|
Trainer, |
|
|
DataCollatorForLanguageModeling |
|
|
) |
|
|
from trl import SFTTrainer |
|
|
from peft import LoraConfig, get_peft_model, TaskType |
|
|
from huggingface_hub import HfApi, create_repo |
|
|
|
|
|
from ..multi_agent_datasets.multi_agent_loader import MultiAgentDatasetLoader, MultiAgentDatasetConfig |
|
|
from ..multi_agent_tokenization.agent_tokenizer import AgentTokenManager, AgentTokenConfig, AgentTokenizer |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
|
class MultiAgentTrainingConfig: |
|
|
"""Configuration for multi-agent training""" |
|
|
|
|
|
base_model: str = "microsoft/Phi-3.5-MoE-instruct" |
|
|
model_cache_dir: Optional[str] = None |
|
|
trust_remote_code: bool = True |
|
|
|
|
|
|
|
|
output_dir: str = "./outputs" |
|
|
max_steps: int = 50 |
|
|
num_train_epochs: int = 1 |
|
|
per_device_train_batch_size: int = 1 |
|
|
per_device_eval_batch_size: int = 1 |
|
|
gradient_accumulation_steps: int = 8 |
|
|
learning_rate: float = 2e-5 |
|
|
lr_scheduler_type: str = "cosine" |
|
|
warmup_steps: int = 0 |
|
|
|
|
|
|
|
|
lora_r: int = 8 |
|
|
lora_alpha: int = 16 |
|
|
lora_dropout: float = 0.05 |
|
|
lora_target_modules: str = "all-linear" |
|
|
lora_bias: str = "none" |
|
|
|
|
|
|
|
|
use_cpu: bool = True |
|
|
bf16: bool = False |
|
|
fp16: bool = False |
|
|
gradient_checkpointing: bool = True |
|
|
dataloader_num_workers: int = 0 |
|
|
remove_unused_columns: bool = False |
|
|
|
|
|
|
|
|
agent_prefix: str = "<|agent:" |
|
|
agent_suffix: str = "|>" |
|
|
balance_agents: bool = False |
|
|
balance_cap: Optional[int] = None |
|
|
|
|
|
|
|
|
logging_steps: int = 5 |
|
|
save_steps: int = 50 |
|
|
eval_steps: int = 25 |
|
|
save_total_limit: int = 1 |
|
|
logging_dir: str = "./logs" |
|
|
report_to: str = "none" |
|
|
|
|
|
|
|
|
hub_repo_id: Optional[str] = None |
|
|
push_to_hub: bool = False |
|
|
hub_token: Optional[str] = None |
|
|
|
|
|
|
|
|
dataset_config: Optional[MultiAgentDatasetConfig] = None |
|
|
|
|
|
class CPUOptimizedMultiAgentTrainer: |
|
|
""" |
|
|
CPU-optimized multi-agent trainer with LoRA fine-tuning |
|
|
""" |
|
|
|
|
|
def __init__(self, config: MultiAgentTrainingConfig): |
|
|
self.config = config |
|
|
self.tokenizer: Optional[AutoTokenizer] = None |
|
|
self.model: Optional[torch.nn.Module] = None |
|
|
self.agent_manager: Optional[AgentTokenManager] = None |
|
|
self.dataset_loader: Optional[MultiAgentDatasetLoader] = None |
|
|
self.trainer: Optional[SFTTrainer] = None |
|
|
self.agents: List[str] = [] |
|
|
self.dataset_stats: Dict[str, Any] = {} |
|
|
|
|
|
|
|
|
self._setup_logging() |
|
|
|
|
|
def _setup_logging(self): |
|
|
"""Setup logging configuration""" |
|
|
log_level = logging.INFO |
|
|
logging.basicConfig( |
|
|
level=log_level, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
handlers=[ |
|
|
logging.StreamHandler(), |
|
|
logging.FileHandler(os.path.join(self.config.logging_dir, 'training.log')) |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
os.makedirs(self.config.logging_dir, exist_ok=True) |
|
|
|
|
|
def load_model_and_tokenizer(self) -> Tuple[AutoTokenizer, torch.nn.Module]: |
|
|
"""Load model and tokenizer optimized for CPU""" |
|
|
logger.info(f"Loading model and tokenizer: {self.config.base_model}") |
|
|
|
|
|
|
|
|
tokenizer_kwargs = { |
|
|
"trust_remote_code": self.config.trust_remote_code, |
|
|
"cache_dir": self.config.model_cache_dir |
|
|
} |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
self.config.base_model, |
|
|
**tokenizer_kwargs |
|
|
) |
|
|
|
|
|
|
|
|
self.tokenizer.model_max_length = 2048 |
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.unk_token or self.tokenizer.eos_token |
|
|
self.tokenizer.pad_token_id = self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) |
|
|
self.tokenizer.padding_side = "right" |
|
|
|
|
|
|
|
|
model_kwargs = { |
|
|
"trust_remote_code": self.config.trust_remote_code, |
|
|
"torch_dtype": torch.float32, |
|
|
"device_map": "cpu", |
|
|
"attn_implementation": "eager", |
|
|
"use_cache": False, |
|
|
"cache_dir": self.config.model_cache_dir |
|
|
} |
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
self.config.base_model, |
|
|
**model_kwargs |
|
|
) |
|
|
|
|
|
logger.info(f"Model loaded with {self.model.num_parameters():,} parameters") |
|
|
return self.tokenizer, self.model |
|
|
|
|
|
def setup_agent_tokens(self, agents: List[str]) -> AgentTokenManager: |
|
|
"""Setup agent token management""" |
|
|
logger.info(f"Setting up agent tokens for {len(agents)} agents") |
|
|
|
|
|
agent_config = AgentTokenConfig( |
|
|
agent_prefix=self.config.agent_prefix, |
|
|
agent_suffix=self.config.agent_suffix, |
|
|
resize_embeddings=True |
|
|
) |
|
|
|
|
|
self.agent_manager = AgentTokenManager(agent_config) |
|
|
|
|
|
|
|
|
self.tokenizer, agent_tokens = self.agent_manager.add_agent_tokens_to_tokenizer( |
|
|
self.tokenizer, agents |
|
|
) |
|
|
|
|
|
|
|
|
self.model = self.agent_manager.resize_model_embeddings(self.model, self.tokenizer) |
|
|
|
|
|
logger.info(f"Agent tokens setup complete. Tokens: {agent_tokens}") |
|
|
return self.agent_manager |
|
|
|
|
|
def load_dataset(self, dataset_path: str) -> Tuple[DatasetDict, List[str], Dict[str, Any]]: |
|
|
"""Load and process multi-agent dataset""" |
|
|
logger.info(f"Loading dataset from: {dataset_path}") |
|
|
|
|
|
|
|
|
if self.config.dataset_config is None: |
|
|
dataset_config = MultiAgentDatasetConfig( |
|
|
dataset_path=dataset_path, |
|
|
agent_prefix=self.config.agent_prefix, |
|
|
agent_suffix=self.config.agent_suffix, |
|
|
balance_agents=self.config.balance_agents, |
|
|
balance_cap=self.config.balance_cap |
|
|
) |
|
|
else: |
|
|
dataset_config = self.config.dataset_config |
|
|
dataset_config.dataset_path = dataset_path |
|
|
|
|
|
|
|
|
self.dataset_loader = MultiAgentDatasetLoader(dataset_config) |
|
|
|
|
|
|
|
|
dataset, agents, stats = self.dataset_loader.load_and_process(self.tokenizer) |
|
|
|
|
|
self.agents = agents |
|
|
self.dataset_stats = stats |
|
|
|
|
|
logger.info(f"Dataset loaded: {len(agents)} agents, {stats['total_samples']} samples") |
|
|
return dataset, agents, stats |
|
|
|
|
|
def create_lora_config(self) -> LoraConfig: |
|
|
"""Create LoRA configuration optimized for CPU""" |
|
|
logger.info("Creating LoRA configuration") |
|
|
|
|
|
lora_config = LoraConfig( |
|
|
r=self.config.lora_r, |
|
|
lora_alpha=self.config.lora_alpha, |
|
|
lora_dropout=self.config.lora_dropout, |
|
|
bias=self.config.lora_bias, |
|
|
task_type=TaskType.CAUSAL_LM, |
|
|
target_modules=self.config.lora_target_modules |
|
|
) |
|
|
|
|
|
logger.info(f"LoRA config: r={lora_config.r}, alpha={lora_config.lora_alpha}, dropout={lora_config.lora_dropout}") |
|
|
return lora_config |
|
|
|
|
|
def create_training_arguments(self) -> TrainingArguments: |
|
|
"""Create training arguments optimized for CPU""" |
|
|
logger.info("Creating training arguments") |
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=self.config.output_dir, |
|
|
overwrite_output_dir=True, |
|
|
num_train_epochs=self.config.num_train_epochs, |
|
|
max_steps=self.config.max_steps, |
|
|
per_device_train_batch_size=self.config.per_device_train_batch_size, |
|
|
per_device_eval_batch_size=self.config.per_device_eval_batch_size, |
|
|
gradient_accumulation_steps=self.config.gradient_accumulation_steps, |
|
|
learning_rate=self.config.learning_rate, |
|
|
lr_scheduler_type=self.config.lr_scheduler_type, |
|
|
warmup_steps=self.config.warmup_steps, |
|
|
|
|
|
|
|
|
bf16=self.config.bf16, |
|
|
fp16=self.config.fp16, |
|
|
gradient_checkpointing=self.config.gradient_checkpointing, |
|
|
dataloader_num_workers=self.config.dataloader_num_workers, |
|
|
remove_unused_columns=self.config.remove_unused_columns, |
|
|
|
|
|
|
|
|
logging_steps=self.config.logging_steps, |
|
|
save_steps=self.config.save_steps, |
|
|
eval_steps=self.config.eval_steps, |
|
|
save_total_limit=self.config.save_total_limit, |
|
|
logging_dir=self.config.logging_dir, |
|
|
report_to=self.config.report_to, |
|
|
|
|
|
|
|
|
evaluation_strategy="steps" if self.config.eval_steps > 0 else "no", |
|
|
|
|
|
|
|
|
optim="adamw_torch", |
|
|
weight_decay=0.01, |
|
|
max_grad_norm=1.0, |
|
|
|
|
|
|
|
|
push_to_hub=self.config.push_to_hub, |
|
|
hub_model_id=self.config.hub_repo_id, |
|
|
hub_token=self.config.hub_token, |
|
|
) |
|
|
|
|
|
logger.info(f"Training arguments created: {training_args.output_dir}") |
|
|
return training_args |
|
|
|
|
|
def create_trainer(self, dataset: DatasetDict, lora_config: LoraConfig, training_args: TrainingArguments) -> SFTTrainer: |
|
|
"""Create SFT trainer for multi-agent training""" |
|
|
logger.info("Creating SFT trainer") |
|
|
|
|
|
|
|
|
train_dataset = dataset["train"] |
|
|
eval_dataset = dataset.get("test", None) |
|
|
|
|
|
|
|
|
self.trainer = SFTTrainer( |
|
|
model=self.model, |
|
|
args=training_args, |
|
|
peft_config=lora_config, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=eval_dataset, |
|
|
tokenizer=self.tokenizer, |
|
|
max_seq_length=2048, |
|
|
dataset_text_field="text", |
|
|
packing=True, |
|
|
data_collator=None, |
|
|
) |
|
|
|
|
|
logger.info("SFT trainer created successfully") |
|
|
return self.trainer |
|
|
|
|
|
def train(self) -> Dict[str, Any]: |
|
|
"""Execute training process""" |
|
|
logger.info("Starting training process") |
|
|
|
|
|
if self.trainer is None: |
|
|
raise ValueError("Trainer not initialized. Call create_trainer() first.") |
|
|
|
|
|
|
|
|
training_result = self.trainer.train() |
|
|
|
|
|
|
|
|
self.save_model() |
|
|
|
|
|
|
|
|
if self.agent_manager: |
|
|
self.agent_manager.save_agent_tokens(self.config.output_dir) |
|
|
|
|
|
|
|
|
report = self.generate_training_report(training_result) |
|
|
|
|
|
logger.info("Training completed successfully") |
|
|
return report |
|
|
|
|
|
def save_model(self): |
|
|
"""Save trained model and tokenizer""" |
|
|
logger.info(f"Saving model to {self.config.output_dir}") |
|
|
|
|
|
os.makedirs(self.config.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self.trainer.model.save_pretrained(self.config.output_dir) |
|
|
|
|
|
|
|
|
self.tokenizer.save_pretrained(self.config.output_dir) |
|
|
|
|
|
|
|
|
config_file = os.path.join(self.config.output_dir, "training_config.json") |
|
|
with open(config_file, 'w') as f: |
|
|
json.dump(self.config.__dict__, f, indent=2, default=str) |
|
|
|
|
|
logger.info("Model saved successfully") |
|
|
|
|
|
def generate_training_report(self, training_result: Any) -> Dict[str, Any]: |
|
|
"""Generate comprehensive training report""" |
|
|
report = { |
|
|
"training_config": self.config.__dict__, |
|
|
"dataset_stats": self.dataset_stats, |
|
|
"agents": self.agents, |
|
|
"agent_tokens": self.agent_manager.get_agent_statistics() if self.agent_manager else {}, |
|
|
"training_metrics": { |
|
|
"train_loss": getattr(training_result, 'train_loss', None), |
|
|
"train_runtime": getattr(training_result, 'train_runtime', None), |
|
|
"train_samples_per_second": getattr(training_result, 'train_samples_per_second', None), |
|
|
"train_steps_per_second": getattr(training_result, 'train_steps_per_second', None), |
|
|
}, |
|
|
"model_info": { |
|
|
"base_model": self.config.base_model, |
|
|
"num_parameters": self.model.num_parameters() if self.model else None, |
|
|
"vocab_size": len(self.tokenizer) if self.tokenizer else None, |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
report_file = os.path.join(self.config.output_dir, "training_report.json") |
|
|
with open(report_file, 'w') as f: |
|
|
json.dump(report, f, indent=2, default=str) |
|
|
|
|
|
logger.info(f"Training report saved to {report_file}") |
|
|
return report |
|
|
|
|
|
def push_to_hub(self, repo_id: Optional[str] = None, commit_message: str = "Multi-agent LoRA adapter"): |
|
|
"""Push trained model to Hugging Face Hub""" |
|
|
if not self.config.push_to_hub: |
|
|
logger.info("Push to hub disabled") |
|
|
return |
|
|
|
|
|
repo_id = repo_id or self.config.hub_repo_id |
|
|
if not repo_id: |
|
|
raise ValueError("Repository ID not specified") |
|
|
|
|
|
if not self.config.hub_token: |
|
|
raise ValueError("Hub token not provided") |
|
|
|
|
|
logger.info(f"Pushing model to Hub: {repo_id}") |
|
|
|
|
|
|
|
|
create_repo(repo_id, repo_type="model", exist_ok=True, token=self.config.hub_token) |
|
|
|
|
|
|
|
|
api = HfApi(token=self.config.hub_token) |
|
|
api.upload_folder( |
|
|
folder_path=self.config.output_dir, |
|
|
repo_id=repo_id, |
|
|
repo_type="model", |
|
|
commit_message=commit_message, |
|
|
allow_patterns=["*.json", "*.md", "*.bin", "*.yaml", "*.txt"] |
|
|
) |
|
|
|
|
|
logger.info(f"Model pushed to https://huggingface.co/{repo_id}") |
|
|
|
|
|
def create_readme(self) -> str: |
|
|
"""Create README for the trained model""" |
|
|
readme_content = f"""# Multi-Agent LoRA Adapter for {self.config.base_model} |
|
|
|
|
|
## Overview |
|
|
This is a LoRA (Low-Rank Adaptation) adapter trained for multi-agent scenarios using {self.config.base_model}. |
|
|
|
|
|
## Agent Conditioning Tokens |
|
|
This adapter expects agent-specific tokens to condition the model behavior: |
|
|
|
|
|
""" |
|
|
|
|
|
if self.agents: |
|
|
for agent in self.agents: |
|
|
token = f"{self.config.agent_prefix}{agent}{self.config.agent_suffix}" |
|
|
readme_content += f"- `{token}` - {agent} agent\n" |
|
|
|
|
|
readme_content += f""" |
|
|
## Usage Example |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from peft import PeftModel |
|
|
|
|
|
# Load base model and tokenizer |
|
|
tokenizer = AutoTokenizer.from_pretrained("{self.config.base_model}") |
|
|
model = AutoModelForCausalLM.from_pretrained("{self.config.base_model}") |
|
|
|
|
|
# Load LoRA adapter |
|
|
model = PeftModel.from_pretrained(model, "{self.config.hub_repo_id}") |
|
|
|
|
|
# Example usage |
|
|
prompt = "How do I implement a binary search algorithm?" |
|
|
agent_token = "{self.config.agent_prefix}SWE{self.config.agent_suffix}\\n" |
|
|
full_prompt = agent_token + prompt |
|
|
|
|
|
inputs = tokenizer(full_prompt, return_tensors="pt") |
|
|
outputs = model.generate(**inputs, max_new_tokens=100) |
|
|
response = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
print(response) |
|
|
``` |
|
|
|
|
|
## Training Configuration |
|
|
- **Base Model**: {self.config.base_model} |
|
|
- **LoRA Rank**: {self.config.lora_r} |
|
|
- **LoRA Alpha**: {self.config.lora_alpha} |
|
|
- **Learning Rate**: {self.config.learning_rate} |
|
|
- **Max Steps**: {self.config.max_steps} |
|
|
- **Batch Size**: {self.config.per_device_train_batch_size} |
|
|
|
|
|
## Dataset Statistics |
|
|
- **Total Samples**: {self.dataset_stats.get('total_samples', 'N/A')} |
|
|
- **Agents**: {', '.join(self.agents) if self.agents else 'N/A'} |
|
|
|
|
|
## License |
|
|
This model is released under the same license as the base model. |
|
|
""" |
|
|
else: |
|
|
readme_content += "No specific agents were configured for this adapter.\n" |
|
|
|
|
|
|
|
|
readme_file = os.path.join(self.config.output_dir, "README.md") |
|
|
with open(readme_file, 'w') as f: |
|
|
f.write(readme_content) |
|
|
|
|
|
logger.info(f"README created: {readme_file}") |
|
|
return readme_file |
|
|
|
|
|
class MultiAgentTrainingPipeline: |
|
|
""" |
|
|
Complete pipeline for multi-agent training |
|
|
""" |
|
|
|
|
|
def __init__(self, config: MultiAgentTrainingConfig): |
|
|
self.config = config |
|
|
self.trainer = CPUOptimizedMultiAgentTrainer(config) |
|
|
|
|
|
def run_training(self, dataset_path: str) -> Dict[str, Any]: |
|
|
"""Run complete training pipeline""" |
|
|
logger.info("Starting multi-agent training pipeline") |
|
|
|
|
|
try: |
|
|
|
|
|
self.trainer.load_model_and_tokenizer() |
|
|
|
|
|
|
|
|
dataset, agents, stats = self.trainer.load_dataset(dataset_path) |
|
|
|
|
|
|
|
|
self.trainer.setup_agent_tokens(agents) |
|
|
|
|
|
|
|
|
lora_config = self.trainer.create_lora_config() |
|
|
|
|
|
|
|
|
training_args = self.trainer.create_training_arguments() |
|
|
|
|
|
|
|
|
self.trainer.create_trainer(dataset, lora_config, training_args) |
|
|
|
|
|
|
|
|
training_result = self.trainer.train() |
|
|
|
|
|
|
|
|
self.trainer.create_readme() |
|
|
|
|
|
|
|
|
if self.config.push_to_hub: |
|
|
self.trainer.push_to_hub() |
|
|
|
|
|
logger.info("Training pipeline completed successfully") |
|
|
return training_result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Training pipeline failed: {e}") |
|
|
raise |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
|
|
|
config = MultiAgentTrainingConfig( |
|
|
base_model="microsoft/Phi-3.5-MoE-instruct", |
|
|
output_dir="./outputs/multi_agent_test", |
|
|
max_steps=10, |
|
|
hub_repo_id="test/multi-agent-adapter", |
|
|
push_to_hub=False |
|
|
) |
|
|
|
|
|
|
|
|
pipeline = MultiAgentTrainingPipeline(config) |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
print("Multi-agent training pipeline ready") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
|