|
|
""" |
|
|
Refactored LoRA Knowledge Distillation Trainer using modular architecture. |
|
|
|
|
|
This module implements a clean, testable trainer that follows the interface contracts |
|
|
and provides better separation of concerns. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from ..core.base_components import BaseTrainer |
|
|
from ..core.exceptions import TrainingError |
|
|
from ..core.interfaces import TrainingConfig |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class ModularLoRATrainer(BaseTrainer): |
|
|
"""Modular LoRA trainer with clean separation of concerns.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: nn.Module, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
loss_function, |
|
|
device: str = "cpu", |
|
|
teacher_model: Optional[nn.Module] = None, |
|
|
): |
|
|
""" |
|
|
Initialize the modular LoRA trainer. |
|
|
|
|
|
Args: |
|
|
model: Student model to train |
|
|
optimizer: Optimizer for training |
|
|
loss_function: Loss function implementing ILossFunction |
|
|
device: Training device |
|
|
teacher_model: Optional teacher model for distillation |
|
|
""" |
|
|
super().__init__(model, optimizer, device) |
|
|
self.loss_function = loss_function |
|
|
self.teacher_model = teacher_model |
|
|
if self.teacher_model: |
|
|
self.teacher_model.to(self.device) |
|
|
self.teacher_model.eval() |
|
|
|
|
|
self.custom_loss_fn = None |
|
|
|
|
|
def set_custom_loss_fn(self, loss_fn): |
|
|
"""Set custom loss function for specialized training.""" |
|
|
self.custom_loss_fn = loss_fn |
|
|
|
|
|
def compute_distillation_loss(self, student_outputs, teacher_outputs, batch): |
|
|
"""Compute standard distillation loss.""" |
|
|
return self.loss_function.compute( |
|
|
student_outputs.logits, |
|
|
( |
|
|
teacher_outputs.logits |
|
|
if hasattr(teacher_outputs, "logits") |
|
|
else teacher_outputs |
|
|
), |
|
|
labels=batch.get("labels"), |
|
|
) |
|
|
|
|
|
def train(self, dataloader: DataLoader, config: TrainingConfig) -> Dict[str, Any]: |
|
|
""" |
|
|
Train the model with the given configuration. |
|
|
|
|
|
Args: |
|
|
dataloader: Training data loader |
|
|
config: Training configuration |
|
|
|
|
|
Returns: |
|
|
Training results and metrics |
|
|
""" |
|
|
try: |
|
|
self.model.train() |
|
|
total_loss = 0.0 |
|
|
num_batches = 0 |
|
|
training_metrics = {} |
|
|
|
|
|
for epoch in range(config.num_epochs): |
|
|
epoch_loss = 0.0 |
|
|
epoch_batches = 0 |
|
|
|
|
|
for batch_idx, batch in enumerate(dataloader): |
|
|
|
|
|
batch = self._move_batch_to_device(batch) |
|
|
|
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
|
|
|
student_outputs = self.model(**batch) |
|
|
|
|
|
|
|
|
teacher_outputs = None |
|
|
if self.teacher_model: |
|
|
with torch.no_grad(): |
|
|
teacher_outputs = self.teacher_model(**batch) |
|
|
|
|
|
|
|
|
if self.custom_loss_fn: |
|
|
loss = self.custom_loss_fn( |
|
|
student_outputs, teacher_outputs, batch |
|
|
) |
|
|
else: |
|
|
loss = self.loss_function.compute( |
|
|
( |
|
|
student_outputs.logits |
|
|
if hasattr(student_outputs, "logits") |
|
|
else student_outputs |
|
|
), |
|
|
batch.get("labels", batch.get("input_ids")), |
|
|
) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
|
epoch_loss += loss.item() |
|
|
epoch_batches += 1 |
|
|
|
|
|
|
|
|
if batch_idx % config.save_steps == 0: |
|
|
step_metrics = self.loss_function.get_metrics() |
|
|
self._log_training_step( |
|
|
epoch, batch_idx, loss.item(), step_metrics |
|
|
) |
|
|
|
|
|
logger.info( |
|
|
f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}" |
|
|
) |
|
|
|
|
|
|
|
|
avg_epoch_loss = ( |
|
|
epoch_loss / epoch_batches if epoch_batches > 0 else 0.0 |
|
|
) |
|
|
total_loss += epoch_loss |
|
|
num_batches += epoch_batches |
|
|
|
|
|
training_metrics[f"epoch_{epoch}_loss"] = avg_epoch_loss |
|
|
|
|
|
|
|
|
if epoch % config.save_steps == 0: |
|
|
checkpoint_path = ( |
|
|
Path(config.output_dir) / f"checkpoint_epoch_{epoch}.pt" |
|
|
) |
|
|
self.save_checkpoint(checkpoint_path, epoch) |
|
|
|
|
|
|
|
|
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 |
|
|
|
|
|
results = { |
|
|
"average_loss": avg_loss, |
|
|
"total_epochs": config.num_epochs, |
|
|
"total_batches": num_batches, |
|
|
"training_metrics": training_metrics, |
|
|
"loss_function_metrics": self.loss_function.get_metrics(), |
|
|
} |
|
|
|
|
|
logger.info(f"Training completed. Average loss: {avg_loss:.4f}") |
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
raise TrainingError( |
|
|
f"Training failed: {str(e)}", |
|
|
"TRAINING_FAILED", |
|
|
{"epoch": getattr(self, "current_epoch", 0)}, |
|
|
) |
|
|
|
|
|
def evaluate(self, dataloader: DataLoader) -> Dict[str, float]: |
|
|
""" |
|
|
Evaluate the model on the given dataset. |
|
|
|
|
|
Args: |
|
|
dataloader: Evaluation data loader |
|
|
|
|
|
Returns: |
|
|
Evaluation metrics |
|
|
""" |
|
|
try: |
|
|
self.model.eval() |
|
|
total_loss = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in dataloader: |
|
|
batch = self._move_batch_to_device(batch) |
|
|
|
|
|
|
|
|
outputs = self.model(**batch) |
|
|
|
|
|
|
|
|
loss = self.loss_function.compute( |
|
|
outputs.logits if hasattr(outputs, "logits") else outputs, |
|
|
batch.get("labels", batch.get("input_ids")), |
|
|
) |
|
|
|
|
|
total_loss += loss.item() |
|
|
num_batches += 1 |
|
|
|
|
|
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 |
|
|
|
|
|
results = {"eval_loss": avg_loss, "eval_batches": num_batches} |
|
|
results.update(self.loss_function.get_metrics()) |
|
|
|
|
|
logger.info(f"Evaluation completed. Average loss: {avg_loss:.4f}") |
|
|
return results |
|
|
|
|
|
except Exception as e: |
|
|
raise TrainingError(f"Evaluation failed: {str(e)}", "EVALUATION_FAILED") |
|
|
|
|
|
def _move_batch_to_device( |
|
|
self, batch: Dict[str, torch.Tensor] |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Move batch tensors to the training device.""" |
|
|
device_batch = {} |
|
|
for key, value in batch.items(): |
|
|
if isinstance(value, torch.Tensor): |
|
|
device_batch[key] = value.to(self.device) |
|
|
else: |
|
|
device_batch[key] = value |
|
|
return device_batch |
|
|
|