|
|
""" |
|
|
LoRA Knowledge Distillation Trainer for MangoMAS Local |
|
|
|
|
|
This module implements the main training loop for knowledge distillation |
|
|
with LoRA fine-tuning optimized for Mac Mini hardware constraints. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from typing import Dict, List |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import yaml |
|
|
from peft import LoraConfig, TaskType, get_peft_model |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
from tqdm import tqdm |
|
|
from transformers import (AutoModelForCausalLM, AutoTokenizer, |
|
|
get_linear_schedule_with_warmup) |
|
|
|
|
|
|
|
|
try: |
|
|
from context7 import Context7 |
|
|
|
|
|
CONTEXT7_AVAILABLE = True |
|
|
except ImportError: |
|
|
CONTEXT7_AVAILABLE = False |
|
|
Context7 = None |
|
|
|
|
|
|
|
|
try: |
|
|
import mlflow |
|
|
|
|
|
MLFLOW_AVAILABLE = True |
|
|
except ImportError: |
|
|
MLFLOW_AVAILABLE = False |
|
|
mlflow = None |
|
|
|
|
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
|
|
|
|
|
try: |
|
|
from distillation_loss import AdaptiveDistillationLoss, DistillationLoss |
|
|
except ImportError: |
|
|
try: |
|
|
from training.distillation_loss import (AdaptiveDistillationLoss, |
|
|
DistillationLoss) |
|
|
except ImportError: |
|
|
|
|
|
class DistillationLoss: |
|
|
def __init__(self, alpha=0.5, temperature=2.0): |
|
|
self.alpha = alpha |
|
|
self.temperature = temperature |
|
|
self.task_loss = nn.CrossEntropyLoss() |
|
|
|
|
|
def compute_loss( |
|
|
self, student_logits, teacher_logits, labels, attention_mask=None |
|
|
): |
|
|
|
|
|
shift_logits = student_logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
task_loss = self.task_loss( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) |
|
|
) |
|
|
|
|
|
|
|
|
if teacher_logits is not None: |
|
|
student_probs = nn.functional.log_softmax( |
|
|
student_logits / self.temperature, dim=-1 |
|
|
) |
|
|
teacher_probs = nn.functional.softmax( |
|
|
teacher_logits / self.temperature, dim=-1 |
|
|
) |
|
|
distill_loss = nn.functional.kl_div( |
|
|
student_probs, teacher_probs, reduction="batchmean" |
|
|
) |
|
|
distill_loss *= self.temperature**2 |
|
|
else: |
|
|
distill_loss = torch.tensor(0.0) |
|
|
|
|
|
|
|
|
total_loss = (1 - self.alpha) * task_loss + self.alpha * distill_loss |
|
|
|
|
|
return total_loss, { |
|
|
"total_loss": total_loss.item(), |
|
|
"task_loss": task_loss.item(), |
|
|
"distillation_loss": ( |
|
|
distill_loss.item() |
|
|
if isinstance(distill_loss, torch.Tensor) |
|
|
else 0.0 |
|
|
), |
|
|
} |
|
|
|
|
|
AdaptiveDistillationLoss = DistillationLoss |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ConversationDataset: |
|
|
"""Dataset class for conversation-based training data.""" |
|
|
|
|
|
def __init__(self, data_path: str, tokenizer, max_length: int = 512): |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.data = self._load_data(data_path) |
|
|
|
|
|
def _load_data(self, data_path: str) -> List[Dict]: |
|
|
"""Load conversation data from JSONL file.""" |
|
|
data = [] |
|
|
with open(data_path, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
data.append(json.loads(line.strip())) |
|
|
return data |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
|
"""Get tokenized conversation item.""" |
|
|
item = self.data[idx] |
|
|
|
|
|
|
|
|
if "messages" in item: |
|
|
|
|
|
conversation_text = "" |
|
|
for message in item["messages"]: |
|
|
role = message["role"] |
|
|
content = message["content"] |
|
|
conversation_text += f"<{role}>\n{content}\n</{role}>\n\n" |
|
|
elif "instruction" in item and "response" in item: |
|
|
|
|
|
instruction = item["instruction"] |
|
|
response = item["response"] |
|
|
conversation_text = f"<user>\n{instruction}\n</user>\n\n<assistant>\n{response}\n</assistant>\n\n" |
|
|
elif "prompt" in item and "completion" in item: |
|
|
|
|
|
prompt = item["prompt"] |
|
|
completion = item["completion"] |
|
|
conversation_text = f"<user>\n{prompt}\n</user>\n\n<assistant>\n{completion}\n</assistant>\n\n" |
|
|
else: |
|
|
|
|
|
conversation_text = str(item) |
|
|
|
|
|
|
|
|
encoding = self.tokenizer( |
|
|
conversation_text, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
max_length=self.max_length, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
return { |
|
|
"input_ids": encoding["input_ids"].squeeze(), |
|
|
"attention_mask": encoding["attention_mask"].squeeze(), |
|
|
"labels": encoding["input_ids"].squeeze().clone(), |
|
|
"agent_type": item.get("agent_type", "unknown"), |
|
|
} |
|
|
|
|
|
|
|
|
class LoRADistillationTrainer: |
|
|
"""Main trainer class for LoRA knowledge distillation.""" |
|
|
|
|
|
def __init__(self, config_path: str): |
|
|
"""Initialize trainer with configuration.""" |
|
|
with open(config_path, "r") as f: |
|
|
self.config = yaml.safe_load(f) |
|
|
|
|
|
self.setup_logging() |
|
|
self.setup_device() |
|
|
self.setup_monitoring() |
|
|
|
|
|
logger.info("Initialized LoRA Distillation Trainer") |
|
|
logger.info(f"Device: {self.device}") |
|
|
logger.info(f"Config: {config_path}") |
|
|
|
|
|
def setup_logging(self) -> None: |
|
|
"""Set up logging configuration.""" |
|
|
log_dir = Path("logs") |
|
|
log_dir.mkdir(exist_ok=True) |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", |
|
|
handlers=[ |
|
|
logging.FileHandler(log_dir / "training.log"), |
|
|
logging.StreamHandler(), |
|
|
], |
|
|
) |
|
|
|
|
|
def setup_device(self) -> None: |
|
|
"""Set up compute device (MPS for Mac Mini).""" |
|
|
device_config = self.config["hardware"]["device"] |
|
|
|
|
|
if device_config == "mps" and torch.backends.mps.is_available(): |
|
|
self.device = torch.device("mps") |
|
|
logger.info("Using Apple Metal Performance Shaders (MPS)") |
|
|
elif device_config == "cuda" and torch.cuda.is_available(): |
|
|
self.device = torch.device("cuda") |
|
|
logger.info(f"Using CUDA: {torch.cuda.get_device_name()}") |
|
|
else: |
|
|
self.device = torch.device("cpu") |
|
|
logger.warning("Using CPU - training will be slow") |
|
|
|
|
|
def setup_monitoring(self) -> None: |
|
|
"""Set up experiment tracking and monitoring.""" |
|
|
self.use_tensorboard = self.config["monitoring"]["use_tensorboard"] |
|
|
self.use_mlflow = self.config["monitoring"]["use_mlflow"] |
|
|
|
|
|
if self.use_tensorboard: |
|
|
log_dir = self.config["monitoring"]["log_dir"] |
|
|
Path(log_dir).mkdir(parents=True, exist_ok=True) |
|
|
self.tb_writer = SummaryWriter(log_dir) |
|
|
logger.info(f"TensorBoard logging to: {log_dir}") |
|
|
|
|
|
if self.use_mlflow: |
|
|
try: |
|
|
import mlflow |
|
|
|
|
|
experiment_name = self.config["monitoring"]["experiment_name"] |
|
|
mlflow.set_experiment(experiment_name) |
|
|
logger.info(f"MLflow experiment: {experiment_name}") |
|
|
except (ImportError, AttributeError) as e: |
|
|
logger.warning( |
|
|
f"MLflow not available or not properly initialized, disabling: {e}" |
|
|
) |
|
|
self.use_mlflow = False |
|
|
|
|
|
def load_models(self) -> None: |
|
|
"""Load teacher and student models.""" |
|
|
|
|
|
model_name = self.config["models"]["student"]["base_model"] |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
|
|
|
|
|
dtype = ( |
|
|
torch.float16 |
|
|
if self.config["optimization"]["use_fp16"] and self.device.type == "cuda" |
|
|
else torch.float32 |
|
|
) |
|
|
|
|
|
self.student_model = AutoModelForCausalLM.from_pretrained( |
|
|
model_name, |
|
|
dtype=dtype, |
|
|
device_map="auto" if self.device.type == "cuda" else None, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
|
|
|
target_modules = self.config["lora"]["target_modules"] |
|
|
|
|
|
if target_modules == ["q_proj", "v_proj", "k_proj", "o_proj"]: |
|
|
target_modules = ["c_attn", "c_proj", "c_fc"] |
|
|
logger.info("Adjusted LoRA target modules for DialoGPT architecture") |
|
|
|
|
|
lora_config = LoraConfig( |
|
|
r=self.config["lora"]["r"], |
|
|
lora_alpha=self.config["lora"]["lora_alpha"], |
|
|
target_modules=target_modules, |
|
|
lora_dropout=self.config["lora"]["lora_dropout"], |
|
|
bias=self.config["lora"]["bias"], |
|
|
task_type=TaskType.CAUSAL_LM, |
|
|
) |
|
|
|
|
|
self.student_model = get_peft_model(self.student_model, lora_config) |
|
|
self.student_model.to(self.device) |
|
|
|
|
|
|
|
|
self.teacher_manager = TeacherModelManager( |
|
|
self.config["models"]["teacher"], self.tokenizer |
|
|
) |
|
|
|
|
|
logger.info("Loaded student model with LoRA") |
|
|
logger.info( |
|
|
f"Trainable parameters: {self.student_model.num_parameters(only_trainable=True):,}" |
|
|
) |
|
|
logger.info("Loaded teacher model") |
|
|
|
|
|
def load_datasets(self, agent_type: str) -> tuple: |
|
|
"""Load training and validation datasets for specific agent.""" |
|
|
data_dir = Path("data/processed") |
|
|
|
|
|
train_path = data_dir / f"{agent_type}_train.jsonl" |
|
|
val_path = data_dir / f"{agent_type}_validation.jsonl" |
|
|
|
|
|
if not train_path.exists(): |
|
|
raise FileNotFoundError(f"Training data not found: {train_path}") |
|
|
if not val_path.exists(): |
|
|
raise FileNotFoundError(f"Validation data not found: {val_path}") |
|
|
|
|
|
max_length = self.config["data"]["max_sequence_length"] |
|
|
|
|
|
train_dataset = ConversationDataset(train_path, self.tokenizer, max_length) |
|
|
val_dataset = ConversationDataset(val_path, self.tokenizer, max_length) |
|
|
|
|
|
logger.info( |
|
|
f"Loaded datasets: {len(train_dataset)} train, {len(val_dataset)} val" |
|
|
) |
|
|
|
|
|
return train_dataset, val_dataset |
|
|
|
|
|
def create_data_loaders(self, train_dataset, val_dataset) -> tuple: |
|
|
"""Create data loaders for training and validation.""" |
|
|
batch_size = self.config["training"]["batch_size"] |
|
|
num_workers = self.config["optimization"]["dataloader_num_workers"] |
|
|
pin_memory = self.config["optimization"]["pin_memory"] |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
num_workers=num_workers, |
|
|
pin_memory=pin_memory, |
|
|
drop_last=True, |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=num_workers, |
|
|
pin_memory=pin_memory, |
|
|
drop_last=False, |
|
|
) |
|
|
|
|
|
return train_loader, val_loader |
|
|
|
|
|
def setup_training(self, train_dataset_size: int) -> None: |
|
|
"""Set up optimizer, scheduler, and loss function.""" |
|
|
|
|
|
batch_size = self.config["training"]["batch_size"] |
|
|
gradient_accumulation_steps = self.config["training"][ |
|
|
"gradient_accumulation_steps" |
|
|
] |
|
|
num_epochs = self.config["training"]["num_epochs"] |
|
|
|
|
|
steps_per_epoch = train_dataset_size // ( |
|
|
batch_size * gradient_accumulation_steps |
|
|
) |
|
|
self.total_steps = steps_per_epoch * num_epochs |
|
|
|
|
|
|
|
|
self.optimizer = torch.optim.AdamW( |
|
|
self.student_model.parameters(), |
|
|
lr=self.config["training"]["learning_rate"], |
|
|
weight_decay=0.01, |
|
|
) |
|
|
|
|
|
|
|
|
self.scheduler = get_linear_schedule_with_warmup( |
|
|
self.optimizer, |
|
|
num_warmup_steps=self.config["training"]["warmup_steps"], |
|
|
num_training_steps=self.total_steps, |
|
|
) |
|
|
|
|
|
|
|
|
self.distill_loss = DistillationLoss( |
|
|
alpha=self.config["distillation"]["alpha"], |
|
|
temperature=self.config["distillation"]["temperature"], |
|
|
) |
|
|
|
|
|
logger.info(f"Setup training: {self.total_steps} total steps") |
|
|
|
|
|
def train_epoch(self, train_loader: DataLoader, epoch: int) -> Dict[str, float]: |
|
|
"""Train for one epoch.""" |
|
|
self.student_model.train() |
|
|
|
|
|
total_loss = 0.0 |
|
|
total_task_loss = 0.0 |
|
|
total_distill_loss = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}", disable=False) |
|
|
|
|
|
for batch_idx, batch in enumerate(progress_bar): |
|
|
|
|
|
input_ids = batch["input_ids"].to(self.device) |
|
|
attention_mask = batch["attention_mask"].to(self.device) |
|
|
labels = batch["labels"].to(self.device) |
|
|
|
|
|
|
|
|
student_outputs = self.student_model( |
|
|
input_ids=input_ids, attention_mask=attention_mask |
|
|
) |
|
|
student_logits = student_outputs.logits |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
teacher_logits = self.teacher_manager.get_logits( |
|
|
input_ids, attention_mask |
|
|
) |
|
|
|
|
|
|
|
|
loss, loss_dict = self.distill_loss.compute_loss( |
|
|
student_logits, teacher_logits, labels, attention_mask |
|
|
) |
|
|
|
|
|
|
|
|
loss = loss / self.config["training"]["gradient_accumulation_steps"] |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
if (batch_idx + 1) % self.config["training"][ |
|
|
"gradient_accumulation_steps" |
|
|
] == 0: |
|
|
torch.nn.utils.clip_grad_norm_( |
|
|
self.student_model.parameters(), |
|
|
self.config["training"]["max_grad_norm"], |
|
|
) |
|
|
self.optimizer.step() |
|
|
self.scheduler.step() |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
|
|
|
total_loss += loss_dict["total_loss"] |
|
|
total_task_loss += loss_dict["task_loss"] |
|
|
total_distill_loss += loss_dict["distillation_loss"] |
|
|
num_batches += 1 |
|
|
|
|
|
|
|
|
progress_bar.set_postfix( |
|
|
{ |
|
|
"loss": f"{loss_dict['total_loss']:.4f}", |
|
|
"task": f"{loss_dict['task_loss']:.4f}", |
|
|
"distill": f"{loss_dict['distillation_loss']:.4f}", |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
if ( |
|
|
self.use_tensorboard |
|
|
and batch_idx % self.config["training"]["logging_steps"] == 0 |
|
|
): |
|
|
step = epoch * len(train_loader) + batch_idx |
|
|
self.tb_writer.add_scalar( |
|
|
"train/total_loss", loss_dict["total_loss"], step |
|
|
) |
|
|
self.tb_writer.add_scalar( |
|
|
"train/task_loss", loss_dict["task_loss"], step |
|
|
) |
|
|
self.tb_writer.add_scalar( |
|
|
"train/distillation_loss", loss_dict["distillation_loss"], step |
|
|
) |
|
|
|
|
|
|
|
|
epoch_metrics = { |
|
|
"avg_loss": total_loss / num_batches, |
|
|
"avg_task_loss": total_task_loss / num_batches, |
|
|
"avg_distill_loss": total_distill_loss / num_batches, |
|
|
} |
|
|
|
|
|
return epoch_metrics |
|
|
|
|
|
def evaluate(self, val_loader: DataLoader) -> Dict[str, float]: |
|
|
"""Evaluate model on validation set.""" |
|
|
self.student_model.eval() |
|
|
|
|
|
total_loss = 0.0 |
|
|
total_task_loss = 0.0 |
|
|
total_distill_loss = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in tqdm(val_loader, desc="Evaluating"): |
|
|
|
|
|
input_ids = batch["input_ids"].to(self.device) |
|
|
attention_mask = batch["attention_mask"].to(self.device) |
|
|
labels = batch["labels"].to(self.device) |
|
|
|
|
|
|
|
|
student_outputs = self.student_model( |
|
|
input_ids=input_ids, attention_mask=attention_mask |
|
|
) |
|
|
student_logits = student_outputs.logits |
|
|
|
|
|
|
|
|
teacher_logits = self.teacher_manager.get_logits( |
|
|
input_ids, attention_mask |
|
|
) |
|
|
|
|
|
|
|
|
loss, loss_dict = self.distill_loss.compute_loss( |
|
|
student_logits, teacher_logits, labels, attention_mask |
|
|
) |
|
|
|
|
|
total_loss += loss_dict["total_loss"] |
|
|
total_task_loss += loss_dict["task_loss"] |
|
|
total_distill_loss += loss_dict["distillation_loss"] |
|
|
num_batches += 1 |
|
|
|
|
|
val_metrics = { |
|
|
"val_loss": total_loss / num_batches, |
|
|
"val_task_loss": total_task_loss / num_batches, |
|
|
"val_distill_loss": total_distill_loss / num_batches, |
|
|
} |
|
|
|
|
|
return val_metrics |
|
|
|
|
|
def save_model(self, output_dir: str, agent_type: str, epoch: int) -> None: |
|
|
"""Save model checkpoint.""" |
|
|
output_path = Path(output_dir) / agent_type / f"epoch_{epoch}" |
|
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
self.student_model.save_pretrained(output_path) |
|
|
|
|
|
|
|
|
self.tokenizer.save_pretrained(output_path) |
|
|
|
|
|
|
|
|
config_path = output_path / "training_config.yaml" |
|
|
with open(config_path, "w") as f: |
|
|
yaml.dump(self.config, f) |
|
|
|
|
|
logger.info(f"Saved model to: {output_path}") |
|
|
|
|
|
def train_agent(self, agent_type: str) -> None: |
|
|
"""Train a specific agent with knowledge distillation.""" |
|
|
logger.info(f"Starting training for {agent_type} agent") |
|
|
|
|
|
|
|
|
if not hasattr(self, "student_model"): |
|
|
self.load_models() |
|
|
|
|
|
|
|
|
train_dataset, val_dataset = self.load_datasets(agent_type) |
|
|
train_loader, val_loader = self.create_data_loaders(train_dataset, val_dataset) |
|
|
|
|
|
|
|
|
self.setup_training(len(train_dataset)) |
|
|
|
|
|
|
|
|
if self.use_mlflow: |
|
|
mlflow.start_run( |
|
|
run_name=f"{agent_type}_{datetime.now().strftime('%Y%m%d_%H%M%S')}" |
|
|
) |
|
|
mlflow.log_params( |
|
|
{ |
|
|
"agent_type": agent_type, |
|
|
"model_name": self.config["models"]["student"]["base_model"], |
|
|
"lora_r": self.config["lora"]["r"], |
|
|
"lora_alpha": self.config["lora"]["lora_alpha"], |
|
|
"batch_size": self.config["training"]["batch_size"], |
|
|
"learning_rate": self.config["training"]["learning_rate"], |
|
|
"distillation_alpha": self.config["distillation"]["alpha"], |
|
|
"temperature": self.config["distillation"]["temperature"], |
|
|
} |
|
|
) |
|
|
|
|
|
try: |
|
|
|
|
|
best_val_loss = float("inf") |
|
|
num_epochs = self.config["training"]["num_epochs"] |
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
logger.info(f"Epoch {epoch+1}/{num_epochs}") |
|
|
|
|
|
|
|
|
train_metrics = self.train_epoch(train_loader, epoch) |
|
|
logger.info( |
|
|
f"Train - Loss: {train_metrics['avg_loss']:.4f}, " |
|
|
f"Task: {train_metrics['avg_task_loss']:.4f}, " |
|
|
f"Distill: {train_metrics['avg_distill_loss']:.4f}" |
|
|
) |
|
|
|
|
|
|
|
|
val_metrics = self.evaluate(val_loader) |
|
|
logger.info( |
|
|
f"Val - Loss: {val_metrics['val_loss']:.4f}, " |
|
|
f"Task: {val_metrics['val_task_loss']:.4f}, " |
|
|
f"Distill: {val_metrics['val_distill_loss']:.4f}" |
|
|
) |
|
|
|
|
|
|
|
|
if self.use_mlflow: |
|
|
mlflow.log_metrics({**train_metrics, **val_metrics}, step=epoch) |
|
|
|
|
|
|
|
|
if self.use_tensorboard: |
|
|
for key, value in train_metrics.items(): |
|
|
self.tb_writer.add_scalar(f"epoch/{key}", value, epoch) |
|
|
for key, value in val_metrics.items(): |
|
|
self.tb_writer.add_scalar(f"epoch/{key}", value, epoch) |
|
|
|
|
|
|
|
|
if val_metrics["val_loss"] < best_val_loss: |
|
|
best_val_loss = val_metrics["val_loss"] |
|
|
self.save_model( |
|
|
self.config["output"]["base_dir"], agent_type, epoch |
|
|
) |
|
|
logger.info(f"New best model saved (val_loss: {best_val_loss:.4f})") |
|
|
|
|
|
finally: |
|
|
if self.use_mlflow: |
|
|
mlflow.end_run() |
|
|
|
|
|
logger.info(f"Training completed for {agent_type} agent") |
|
|
|
|
|
|
|
|
class TeacherModelManager: |
|
|
"""Manages teacher model interactions (API or local).""" |
|
|
|
|
|
def __init__(self, teacher_config: Dict, tokenizer): |
|
|
self.config = teacher_config |
|
|
self.tokenizer = tokenizer |
|
|
|
|
|
if teacher_config["type"] == "api": |
|
|
self.setup_api_teacher() |
|
|
else: |
|
|
self.setup_local_teacher() |
|
|
|
|
|
def setup_api_teacher(self) -> None: |
|
|
"""Set up API-based teacher model.""" |
|
|
self.model_name = self.config["model_name"] |
|
|
logger.info(f"Using API teacher model: {self.model_name}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def setup_local_teacher(self) -> None: |
|
|
"""Set up local teacher model.""" |
|
|
model_path = self.config.get("local_model_path", "microsoft/DialoGPT-large") |
|
|
|
|
|
self.teacher_model = AutoModelForCausalLM.from_pretrained( |
|
|
model_path, torch_dtype=torch.float16, device_map="auto" |
|
|
) |
|
|
logger.info(f"Loaded local teacher model: {model_path}") |
|
|
|
|
|
def get_logits( |
|
|
self, input_ids: torch.Tensor, attention_mask: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
"""Get teacher model logits.""" |
|
|
if self.config["type"] == "api": |
|
|
|
|
|
|
|
|
batch_size, seq_len = input_ids.shape |
|
|
vocab_size = self.tokenizer.vocab_size |
|
|
return torch.randn(batch_size, seq_len, vocab_size).to(input_ids.device) |
|
|
else: |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.teacher_model( |
|
|
input_ids=input_ids, attention_mask=attention_mask |
|
|
) |
|
|
return outputs.logits |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Train MangoMAS agents with LoRA and knowledge distillation" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--config", |
|
|
type=str, |
|
|
default="config/training/distillation.yaml", |
|
|
help="Path to training configuration file", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--agent", |
|
|
type=str, |
|
|
choices=["infrastructure", "devsecops", "risk_assessment", "all"], |
|
|
default="all", |
|
|
help="Which agent to train", |
|
|
) |
|
|
parser.add_argument("--data", type=str, help="Path to training data file") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
trainer = LoRADistillationTrainer(args.config) |
|
|
|
|
|
|
|
|
if args.data: |
|
|
trainer.custom_data_path = args.data |
|
|
|
|
|
|
|
|
if args.agent == "all": |
|
|
agents = ["infrastructure", "devsecops", "risk_assessment"] |
|
|
else: |
|
|
agents = [args.agent] |
|
|
|
|
|
for agent_type in agents: |
|
|
trainer.train_agent(agent_type) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|