|
|
"""
|
|
|
Training module for KerdosAI.
|
|
|
"""
|
|
|
|
|
|
from typing import Dict, Any, Optional
|
|
|
import torch
|
|
|
from torch.utils.data import DataLoader
|
|
|
from transformers import (
|
|
|
Trainer as HFTrainer,
|
|
|
TrainingArguments,
|
|
|
DataCollatorForLanguageModeling
|
|
|
)
|
|
|
from datasets import Dataset
|
|
|
from tqdm import tqdm
|
|
|
import wandb
|
|
|
|
|
|
class Trainer:
|
|
|
"""
|
|
|
Handles the training process for the LLM.
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
model: Any,
|
|
|
tokenizer: Any,
|
|
|
device: str,
|
|
|
use_wandb: bool = True
|
|
|
):
|
|
|
"""
|
|
|
Initialize the trainer.
|
|
|
|
|
|
Args:
|
|
|
model: The model to train
|
|
|
tokenizer: The tokenizer for the model
|
|
|
device: Device to run training on
|
|
|
use_wandb: Whether to use Weights & Biases for logging
|
|
|
"""
|
|
|
self.model = model
|
|
|
self.tokenizer = tokenizer
|
|
|
self.device = device
|
|
|
self.use_wandb = use_wandb
|
|
|
|
|
|
if use_wandb:
|
|
|
wandb.init(project="kerdosai")
|
|
|
|
|
|
def train(
|
|
|
self,
|
|
|
dataset: Dataset,
|
|
|
epochs: int = 3,
|
|
|
batch_size: int = 4,
|
|
|
learning_rate: float = 2e-5,
|
|
|
gradient_accumulation_steps: int = 1,
|
|
|
warmup_steps: int = 100,
|
|
|
weight_decay: float = 0.01,
|
|
|
logging_steps: int = 10,
|
|
|
save_steps: int = 100,
|
|
|
output_dir: str = "output",
|
|
|
**kwargs
|
|
|
) -> Dict[str, Any]:
|
|
|
"""
|
|
|
Train the model on the provided dataset.
|
|
|
|
|
|
Args:
|
|
|
dataset: Training dataset
|
|
|
epochs: Number of training epochs
|
|
|
batch_size: Training batch size
|
|
|
learning_rate: Learning rate
|
|
|
gradient_accumulation_steps: Number of steps for gradient accumulation
|
|
|
warmup_steps: Number of warmup steps
|
|
|
weight_decay: Weight decay for optimizer
|
|
|
logging_steps: Number of steps between logging
|
|
|
save_steps: Number of steps between model saves
|
|
|
output_dir: Directory to save checkpoints
|
|
|
**kwargs: Additional training arguments
|
|
|
|
|
|
Returns:
|
|
|
Dictionary containing training metrics
|
|
|
"""
|
|
|
|
|
|
training_args = TrainingArguments(
|
|
|
output_dir=output_dir,
|
|
|
num_train_epochs=epochs,
|
|
|
per_device_train_batch_size=batch_size,
|
|
|
learning_rate=learning_rate,
|
|
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
|
|
warmup_steps=warmup_steps,
|
|
|
weight_decay=weight_decay,
|
|
|
logging_steps=logging_steps,
|
|
|
save_steps=save_steps,
|
|
|
fp16=self.device == "cuda",
|
|
|
report_to="wandb" if self.use_wandb else "none",
|
|
|
**kwargs
|
|
|
)
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling(
|
|
|
tokenizer=self.tokenizer,
|
|
|
mlm=False
|
|
|
)
|
|
|
|
|
|
|
|
|
trainer = HFTrainer(
|
|
|
model=self.model,
|
|
|
args=training_args,
|
|
|
train_dataset=dataset,
|
|
|
data_collator=data_collator
|
|
|
)
|
|
|
|
|
|
|
|
|
train_result = trainer.train()
|
|
|
|
|
|
|
|
|
trainer.save_model(output_dir)
|
|
|
|
|
|
|
|
|
metrics = train_result.metrics
|
|
|
if self.use_wandb:
|
|
|
wandb.log(metrics)
|
|
|
|
|
|
return metrics
|
|
|
|
|
|
def evaluate(
|
|
|
self,
|
|
|
dataset: Dataset,
|
|
|
batch_size: int = 4,
|
|
|
**kwargs
|
|
|
) -> Dict[str, float]:
|
|
|
"""
|
|
|
Evaluate the model on the provided dataset.
|
|
|
|
|
|
Args:
|
|
|
dataset: Evaluation dataset
|
|
|
batch_size: Evaluation batch size
|
|
|
**kwargs: Additional evaluation arguments
|
|
|
|
|
|
Returns:
|
|
|
Dictionary containing evaluation metrics
|
|
|
"""
|
|
|
self.model.eval()
|
|
|
total_loss = 0
|
|
|
num_batches = 0
|
|
|
|
|
|
dataloader = DataLoader(
|
|
|
dataset,
|
|
|
batch_size=batch_size,
|
|
|
shuffle=False
|
|
|
)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for batch in tqdm(dataloader, desc="Evaluating"):
|
|
|
|
|
|
batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
|
|
|
|
|
|
|
outputs = self.model(**batch)
|
|
|
loss = outputs.loss
|
|
|
|
|
|
total_loss += loss.item()
|
|
|
num_batches += 1
|
|
|
|
|
|
avg_loss = total_loss / num_batches
|
|
|
metrics = {"eval_loss": avg_loss}
|
|
|
|
|
|
if self.use_wandb:
|
|
|
wandb.log(metrics)
|
|
|
|
|
|
return metrics |