|
|
""" |
|
|
MiniMind Training Utilities |
|
|
Standard training loop with mixed precision and gradient accumulation. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import math |
|
|
import time |
|
|
from typing import Optional, Dict, Any |
|
|
from pathlib import Path |
|
|
from dataclasses import dataclass |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader |
|
|
from torch.cuda.amp import GradScaler, autocast |
|
|
|
|
|
import sys |
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
from configs.model_config import Mind2Config |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainingConfig: |
|
|
"""Training configuration.""" |
|
|
|
|
|
learning_rate: float = 3e-4 |
|
|
min_learning_rate: float = 3e-5 |
|
|
weight_decay: float = 0.1 |
|
|
beta1: float = 0.9 |
|
|
beta2: float = 0.95 |
|
|
grad_clip: float = 1.0 |
|
|
warmup_steps: int = 1000 |
|
|
|
|
|
|
|
|
num_epochs: int = 3 |
|
|
batch_size: int = 8 |
|
|
gradient_accumulation_steps: int = 4 |
|
|
max_steps: Optional[int] = None |
|
|
|
|
|
|
|
|
use_amp: bool = True |
|
|
amp_dtype: str = "float16" |
|
|
|
|
|
|
|
|
save_steps: int = 1000 |
|
|
eval_steps: int = 500 |
|
|
output_dir: str = "./outputs" |
|
|
resume_from: Optional[str] = None |
|
|
|
|
|
|
|
|
log_steps: int = 10 |
|
|
wandb_project: Optional[str] = None |
|
|
|
|
|
|
|
|
class Mind2Trainer: |
|
|
"""Trainer for MiniMind models.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model: nn.Module, |
|
|
train_dataloader: DataLoader, |
|
|
eval_dataloader: Optional[DataLoader] = None, |
|
|
config: Optional[TrainingConfig] = None, |
|
|
): |
|
|
self.model = model |
|
|
self.train_dataloader = train_dataloader |
|
|
self.eval_dataloader = eval_dataloader |
|
|
self.config = config or TrainingConfig() |
|
|
|
|
|
self.device = next(model.parameters()).device |
|
|
self.global_step = 0 |
|
|
self.epoch = 0 |
|
|
|
|
|
|
|
|
self.optimizer = self._create_optimizer() |
|
|
self.scheduler = self._create_scheduler() |
|
|
|
|
|
|
|
|
self.scaler = GradScaler() if self.config.use_amp else None |
|
|
self.amp_dtype = torch.float16 if self.config.amp_dtype == "float16" else torch.bfloat16 |
|
|
|
|
|
|
|
|
Path(self.config.output_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def _create_optimizer(self) -> torch.optim.Optimizer: |
|
|
"""Create AdamW optimizer with weight decay.""" |
|
|
decay_params = [] |
|
|
no_decay_params = [] |
|
|
|
|
|
for name, param in self.model.named_parameters(): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
if "bias" in name or "norm" in name or "layernorm" in name: |
|
|
no_decay_params.append(param) |
|
|
else: |
|
|
decay_params.append(param) |
|
|
|
|
|
optimizer_groups = [ |
|
|
{"params": decay_params, "weight_decay": self.config.weight_decay}, |
|
|
{"params": no_decay_params, "weight_decay": 0.0}, |
|
|
] |
|
|
|
|
|
return torch.optim.AdamW( |
|
|
optimizer_groups, |
|
|
lr=self.config.learning_rate, |
|
|
betas=(self.config.beta1, self.config.beta2), |
|
|
) |
|
|
|
|
|
def _create_scheduler(self): |
|
|
"""Create cosine annealing scheduler with warmup.""" |
|
|
total_steps = self._get_total_steps() |
|
|
|
|
|
def lr_lambda(step): |
|
|
if step < self.config.warmup_steps: |
|
|
return step / max(1, self.config.warmup_steps) |
|
|
progress = (step - self.config.warmup_steps) / max(1, total_steps - self.config.warmup_steps) |
|
|
return max( |
|
|
self.config.min_learning_rate / self.config.learning_rate, |
|
|
0.5 * (1.0 + math.cos(math.pi * progress)) |
|
|
) |
|
|
|
|
|
return torch.optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda) |
|
|
|
|
|
def _get_total_steps(self) -> int: |
|
|
if self.config.max_steps: |
|
|
return self.config.max_steps |
|
|
steps_per_epoch = len(self.train_dataloader) // self.config.gradient_accumulation_steps |
|
|
return steps_per_epoch * self.config.num_epochs |
|
|
|
|
|
def train(self) -> Dict[str, float]: |
|
|
"""Main training loop.""" |
|
|
self.model.train() |
|
|
total_steps = self._get_total_steps() |
|
|
|
|
|
print(f"Starting training for {total_steps} steps") |
|
|
print(f" Batch size: {self.config.batch_size}") |
|
|
print(f" Gradient accumulation: {self.config.gradient_accumulation_steps}") |
|
|
print(f" Effective batch size: {self.config.batch_size * self.config.gradient_accumulation_steps}") |
|
|
|
|
|
running_loss = 0.0 |
|
|
start_time = time.time() |
|
|
|
|
|
for epoch in range(self.config.num_epochs): |
|
|
self.epoch = epoch |
|
|
|
|
|
for step, batch in enumerate(self.train_dataloader): |
|
|
loss = self._training_step(batch) |
|
|
running_loss += loss |
|
|
|
|
|
if (step + 1) % self.config.gradient_accumulation_steps == 0: |
|
|
self._optimizer_step() |
|
|
self.global_step += 1 |
|
|
|
|
|
|
|
|
if self.global_step % self.config.log_steps == 0: |
|
|
avg_loss = running_loss / self.config.log_steps |
|
|
elapsed = time.time() - start_time |
|
|
tokens_per_sec = ( |
|
|
self.config.batch_size * self.config.gradient_accumulation_steps * |
|
|
batch["input_ids"].shape[1] * self.config.log_steps / elapsed |
|
|
) |
|
|
print( |
|
|
f"Step {self.global_step}/{total_steps} | " |
|
|
f"Loss: {avg_loss:.4f} | " |
|
|
f"LR: {self.scheduler.get_last_lr()[0]:.2e} | " |
|
|
f"Tokens/s: {tokens_per_sec:.0f}" |
|
|
) |
|
|
running_loss = 0.0 |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
if self.eval_dataloader and self.global_step % self.config.eval_steps == 0: |
|
|
eval_loss = self.evaluate() |
|
|
print(f"Eval Loss: {eval_loss:.4f}") |
|
|
self.model.train() |
|
|
|
|
|
|
|
|
if self.global_step % self.config.save_steps == 0: |
|
|
self.save_checkpoint() |
|
|
|
|
|
if self.config.max_steps and self.global_step >= self.config.max_steps: |
|
|
break |
|
|
|
|
|
if self.config.max_steps and self.global_step >= self.config.max_steps: |
|
|
break |
|
|
|
|
|
self.save_checkpoint(final=True) |
|
|
return {"final_loss": running_loss} |
|
|
|
|
|
def _training_step(self, batch: Dict[str, torch.Tensor]) -> float: |
|
|
"""Single training step.""" |
|
|
input_ids = batch["input_ids"].to(self.device) |
|
|
attention_mask = batch.get("attention_mask", None) |
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.to(self.device) |
|
|
labels = batch["labels"].to(self.device) |
|
|
|
|
|
if self.config.use_amp: |
|
|
with autocast(dtype=self.amp_dtype): |
|
|
loss, _, _, _ = self.model(input_ids, attention_mask, labels) |
|
|
loss = loss / self.config.gradient_accumulation_steps |
|
|
self.scaler.scale(loss).backward() |
|
|
else: |
|
|
loss, _, _, _ = self.model(input_ids, attention_mask, labels) |
|
|
loss = loss / self.config.gradient_accumulation_steps |
|
|
loss.backward() |
|
|
|
|
|
return loss.item() * self.config.gradient_accumulation_steps |
|
|
|
|
|
def _optimizer_step(self): |
|
|
"""Optimizer step with gradient clipping.""" |
|
|
if self.config.use_amp: |
|
|
self.scaler.unscale_(self.optimizer) |
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip) |
|
|
|
|
|
if self.config.use_amp: |
|
|
self.scaler.step(self.optimizer) |
|
|
self.scaler.update() |
|
|
else: |
|
|
self.optimizer.step() |
|
|
|
|
|
self.scheduler.step() |
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
@torch.no_grad() |
|
|
def evaluate(self) -> float: |
|
|
"""Evaluate model on eval dataset.""" |
|
|
self.model.eval() |
|
|
total_loss = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
for batch in self.eval_dataloader: |
|
|
input_ids = batch["input_ids"].to(self.device) |
|
|
attention_mask = batch.get("attention_mask") |
|
|
if attention_mask is not None: |
|
|
attention_mask = attention_mask.to(self.device) |
|
|
labels = batch["labels"].to(self.device) |
|
|
|
|
|
loss, _, _, _ = self.model(input_ids, attention_mask, labels) |
|
|
total_loss += loss.item() |
|
|
num_batches += 1 |
|
|
|
|
|
return total_loss / max(1, num_batches) |
|
|
|
|
|
def save_checkpoint(self, final: bool = False): |
|
|
"""Save model checkpoint.""" |
|
|
checkpoint_name = "final" if final else f"step_{self.global_step}" |
|
|
checkpoint_path = Path(self.config.output_dir) / checkpoint_name |
|
|
|
|
|
checkpoint_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
torch.save(self.model.state_dict(), checkpoint_path / "model.pt") |
|
|
torch.save(self.optimizer.state_dict(), checkpoint_path / "optimizer.pt") |
|
|
torch.save({ |
|
|
"global_step": self.global_step, |
|
|
"epoch": self.epoch, |
|
|
"config": self.config, |
|
|
}, checkpoint_path / "trainer_state.pt") |
|
|
|
|
|
print(f"Checkpoint saved to {checkpoint_path}") |
|
|
|
|
|
def load_checkpoint(self, checkpoint_path: str): |
|
|
"""Load model checkpoint.""" |
|
|
path = Path(checkpoint_path) |
|
|
self.model.load_state_dict(torch.load(path / "model.pt", map_location=self.device)) |
|
|
self.optimizer.load_state_dict(torch.load(path / "optimizer.pt", map_location=self.device)) |
|
|
|
|
|
state = torch.load(path / "trainer_state.pt", map_location=self.device) |
|
|
self.global_step = state["global_step"] |
|
|
self.epoch = state["epoch"] |
|
|
|
|
|
print(f"Checkpoint loaded from {checkpoint_path}") |
|
|
|