File size: 4,948 Bytes
3df89a1 1239566 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
"""
Training module for KerdosAI.
"""
from typing import Dict, Any, Optional
import torch
from torch.utils.data import DataLoader
from transformers import (
Trainer as HFTrainer,
TrainingArguments,
DataCollatorForLanguageModeling
)
from datasets import Dataset
from tqdm import tqdm
import wandb
class Trainer:
"""
Handles the training process for the LLM.
"""
def __init__(
self,
model: Any,
tokenizer: Any,
device: str,
use_wandb: bool = True
):
"""
Initialize the trainer.
Args:
model: The model to train
tokenizer: The tokenizer for the model
device: Device to run training on
use_wandb: Whether to use Weights & Biases for logging
"""
self.model = model
self.tokenizer = tokenizer
self.device = device
self.use_wandb = use_wandb
if use_wandb:
wandb.init(project="kerdosai")
def train(
self,
dataset: Dataset,
epochs: int = 3,
batch_size: int = 4,
learning_rate: float = 2e-5,
gradient_accumulation_steps: int = 1,
warmup_steps: int = 100,
weight_decay: float = 0.01,
logging_steps: int = 10,
save_steps: int = 100,
output_dir: str = "output",
**kwargs
) -> Dict[str, Any]:
"""
Train the model on the provided dataset.
Args:
dataset: Training dataset
epochs: Number of training epochs
batch_size: Training batch size
learning_rate: Learning rate
gradient_accumulation_steps: Number of steps for gradient accumulation
warmup_steps: Number of warmup steps
weight_decay: Weight decay for optimizer
logging_steps: Number of steps between logging
save_steps: Number of steps between model saves
output_dir: Directory to save checkpoints
**kwargs: Additional training arguments
Returns:
Dictionary containing training metrics
"""
# Prepare training arguments
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
learning_rate=learning_rate,
gradient_accumulation_steps=gradient_accumulation_steps,
warmup_steps=warmup_steps,
weight_decay=weight_decay,
logging_steps=logging_steps,
save_steps=save_steps,
fp16=self.device == "cuda",
report_to="wandb" if self.use_wandb else "none",
**kwargs
)
# Initialize data collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=self.tokenizer,
mlm=False
)
# Initialize HuggingFace trainer
trainer = HFTrainer(
model=self.model,
args=training_args,
train_dataset=dataset,
data_collator=data_collator
)
# Train the model
train_result = trainer.train()
# Save the final model
trainer.save_model(output_dir)
# Log final metrics
metrics = train_result.metrics
if self.use_wandb:
wandb.log(metrics)
return metrics
def evaluate(
self,
dataset: Dataset,
batch_size: int = 4,
**kwargs
) -> Dict[str, float]:
"""
Evaluate the model on the provided dataset.
Args:
dataset: Evaluation dataset
batch_size: Evaluation batch size
**kwargs: Additional evaluation arguments
Returns:
Dictionary containing evaluation metrics
"""
self.model.eval()
total_loss = 0
num_batches = 0
dataloader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False
)
with torch.no_grad():
for batch in tqdm(dataloader, desc="Evaluating"):
# Move batch to device
batch = {k: v.to(self.device) for k, v in batch.items()}
# Forward pass
outputs = self.model(**batch)
loss = outputs.loss
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches
metrics = {"eval_loss": avg_loss}
if self.use_wandb:
wandb.log(metrics)
return metrics |