rewrite / scripts /train.py
morpheuslord's picture
Add files using upload-large-folder tool
3df5819 verified
"""
Full training entry point.
Run: python scripts/train.py --config configs/training_config.yaml
"""
import click
import yaml
import torch
import os
import gc
from transformers import TrainingArguments, Seq2SeqTrainingArguments
from loguru import logger
try:
import wandb
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
from src.model.base_model import load_model_and_tokenizer
from src.model.style_conditioner import StyleConditioner
from src.training.dataset import WritingCorrectionDataset
from src.training.loss_functions import CombinedCorrectionLoss, CombinedCorrectionLossV2
from src.training.trainer import CorrectionTrainer
from src.training.callbacks import StyleMetricsCallback, EarlyStoppingOnStyleDrift
from src.style.fingerprinter import StyleFingerprinter
from src.evaluation.gleu_scorer import GLEUScorer
# ── Hybrid GPU Management ───────────────────────────────────────────────────
def _setup_device():
"""Detect GPU and configure hybrid VRAM management.
Returns (device, gpu_info) where gpu_info is a dict with:
- available: bool
- name: str
- vram_total_mb: int
- vram_free_mb: int
- compute_cap: tuple
"""
gpu_info = {"available": False, "name": "CPU", "vram_total_mb": 0,
"vram_free_mb": 0, "compute_cap": (0, 0)}
if not torch.cuda.is_available():
logger.info("No GPU detected — training on CPU")
return "cpu", gpu_info
gpu_info["available"] = True
gpu_info["name"] = torch.cuda.get_device_name(0)
gpu_info["compute_cap"] = torch.cuda.get_device_capability(0)
# Query actual free VRAM
vram_total = torch.cuda.get_device_properties(0).total_memory // (1024 * 1024)
vram_reserved = torch.cuda.memory_reserved(0) // (1024 * 1024)
vram_allocated = torch.cuda.memory_allocated(0) // (1024 * 1024)
vram_free = vram_total - vram_allocated
gpu_info["vram_total_mb"] = vram_total
gpu_info["vram_free_mb"] = vram_free
logger.info(
f"GPU: {gpu_info['name']} | "
f"VRAM: {vram_allocated}MB used / {vram_total}MB total ({vram_free}MB free) | "
f"Compute: {gpu_info['compute_cap']}"
)
# Leave headroom for the system — reserve at most 85% of free VRAM
# This prevents the desktop/compositor from starving
usable_vram_mb = int(vram_free * 0.85)
if usable_vram_mb > 0:
# Set PyTorch memory limit to avoid hogging all VRAM
fraction = min(usable_vram_mb / vram_total, 0.90)
torch.cuda.set_per_process_memory_fraction(fraction, 0)
logger.info(
f"Hybrid GPU mode: capped PyTorch VRAM to {fraction:.0%} "
f"(~{int(vram_total * fraction)}MB), leaving room for system"
)
return "cuda", gpu_info
def _auto_batch_size(model_key: str, device: str, gpu_info: dict,
config_batch: int) -> int:
"""Pick optimal batch size based on model size and available resources."""
if device == "cpu":
# CPU: T5-Small can handle batch=8 with 32GB RAM, larger models less
if "small" in model_key:
return min(config_batch, 8)
return min(config_batch, 2)
# GPU: estimate based on free VRAM
free_mb = gpu_info["vram_free_mb"]
# Rough VRAM per sample estimates (bf16, seq_len=128):
# T5-Small: ~120MB model + ~50MB/sample
# T5-Base: ~350MB model + ~90MB/sample
# T5-Large: ~900MB model + ~150MB/sample
model_vram_estimates = {
"flan-t5-small": {"model_mb": 160, "per_sample_mb": 60},
"flan-t5-base": {"model_mb": 400, "per_sample_mb": 100},
"flan-t5-large": {"model_mb": 1000, "per_sample_mb": 160},
"flan-t5-xl": {"model_mb": 3000, "per_sample_mb": 300},
}
est = model_vram_estimates.get(model_key, {"model_mb": 500, "per_sample_mb": 120})
# Available for batches = free VRAM - model footprint - 300MB safety buffer
available_for_batches = free_mb - est["model_mb"] - 300
if available_for_batches <= 0:
logger.warning("Very tight VRAM — using batch_size=1")
return 1
max_batch = max(1, available_for_batches // est["per_sample_mb"])
optimal = min(config_batch, max_batch)
logger.info(
f"Auto batch size: {optimal} "
f"(model ~{est['model_mb']}MB + {optimal}×{est['per_sample_mb']}MB "
f"= ~{est['model_mb'] + optimal * est['per_sample_mb']}MB / {free_mb}MB free)"
)
return max(1, optimal)
@click.command()
@click.option("--config", default="configs/training_config.yaml")
@click.option("--use-v2-loss", is_flag=True, help="Use V2 loss with human pattern term")
def train(config: str, use_v2_loss: bool):
"""Launch the full training pipeline."""
# Step 1: Load config
logger.info("Step 1: Loading config...")
with open(config) as f:
cfg = yaml.safe_load(f)
model_cfg = cfg.get("model", {})
lora_cfg = cfg.get("lora", {})
data_cfg = cfg.get("data", {})
train_cfg = cfg.get("training", {})
loss_cfg = cfg.get("loss", {})
gen_cfg = cfg.get("generation", {})
# Step 2: Initialise W&B (optional)
logger.info("Step 2: Initialising experiment tracking...")
if HAS_WANDB and os.environ.get("WANDB_API_KEY"):
wandb.init(
project="dyslexia-rewriter",
name=f"train-{model_cfg.get('key', 'flan-t5')}",
config=cfg,
)
else:
logger.info("W&B not configured, logging to TensorBoard only")
os.environ["WANDB_DISABLED"] = "true"
# Step 3: Detect GPU and configure hybrid VRAM management
logger.info("Step 3: Setting up device (hybrid GPU mode)...")
device, gpu_info = _setup_device()
# Step 4: Load model + tokenizer
logger.info("Step 4: Loading model and tokenizer...")
model_key = model_cfg.get("key", "flan-t5-small")
model, tokenizer, is_seq2seq = load_model_and_tokenizer(
model_key=model_key,
quantize=model_cfg.get("quantize", False),
use_lora=model_cfg.get("use_lora", True),
lora_config_dict=lora_cfg,
)
# Required for PEFT + gradient checkpointing compatibility
if hasattr(model, 'enable_input_require_grads'):
model.enable_input_require_grads()
# ── torch.compile for fused kernels (PyTorch 2.x) ───────────────────────
if hasattr(torch, "compile") and device == "cuda":
try:
# "default" mode: fuses kernels via Triton without CUDA graphs.
# "reduce-overhead" uses CUDA graphs which break with LoRA/PEFT
# (tensor outputs get overwritten between graph replays).
logger.info("Applying torch.compile(mode='default')...")
model = torch.compile(model, mode="default")
logger.info("✓ torch.compile applied — first few steps will be slower (compiling)")
except Exception as e:
logger.warning(f"torch.compile failed (non-fatal): {e}")
# Step 5: Create fingerprinter
logger.info("Step 5: Creating style fingerprinter...")
fingerprinter = StyleFingerprinter(
spacy_model="en_core_web_sm", # Use small model for training speed
awl_path="data/awl/coxhead_awl.txt",
)
# Step 6: Create datasets
logger.info("Step 6: Loading datasets...")
train_dataset = WritingCorrectionDataset(
data_path=data_cfg.get("train_path", "data/processed/train.jsonl"),
tokenizer=tokenizer,
fingerprinter=fingerprinter,
max_input_length=data_cfg.get("max_input_length", 512),
max_target_length=data_cfg.get("max_target_length", 512),
augment_with_synthetic=data_cfg.get("augment_synthetic", True),
synthetic_ratio=data_cfg.get("synthetic_ratio", 0.3),
)
val_dataset = WritingCorrectionDataset(
data_path=data_cfg.get("val_path", "data/processed/val.jsonl"),
tokenizer=tokenizer,
fingerprinter=fingerprinter,
max_input_length=data_cfg.get("max_input_length", 512),
max_target_length=data_cfg.get("max_target_length", 512),
augment_with_synthetic=False,
)
logger.info(f"Train: {len(train_dataset)} | Val: {len(val_dataset)}")
# Free memory after dataset loading
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
# Use simple CE-only loss for training — aux models (sentence-transformer,
# GPT-2, HP classifier) are NOT loaded since they provide no gradient signal
# (they decode via argmax under no_grad). This saves ~1GB+ memory.
from torch import nn
class CEOnlyLoss(nn.Module):
"""Cross-entropy only loss — the only loss that provides gradient signal."""
def __init__(self):
super().__init__()
self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100)
def forward(self, logits, labels, **kwargs):
if logits.dim() == 3:
ce_logits = logits.view(-1, logits.size(-1))
ce_labels = labels.view(-1)
else:
ce_logits = logits
ce_labels = labels
l_ce = self.ce_loss(ce_logits, ce_labels)
return {"total_loss": l_ce, "ce_loss": l_ce}
loss_fn = CEOnlyLoss()
logger.info("Using CE-only loss (aux models skipped to save memory)")
# Step 8: Create training arguments
logger.info("Step 8: Creating training arguments...")
# Auto-detect precision support
use_bf16 = False
use_fp16 = False
if device == "cuda":
if gpu_info["compute_cap"][0] >= 8:
use_bf16 = True
logger.info("Using BF16 (Ampere+ GPU)")
else:
use_fp16 = True
logger.info("Using FP16 (pre-Ampere GPU)")
elif device == "cpu":
# Zen 3+ CPUs (Ryzen 5000+) support BF16 in PyTorch 2.x
try:
test = torch.tensor([1.0], dtype=torch.bfloat16)
_ = test + test # Test BF16 compute works
use_bf16 = True
logger.info("Using BF16 on CPU (Zen 3+ detected)")
except Exception:
logger.info("BF16 not supported on this CPU, using FP32")
# Smart batch size based on model + available resources
config_batch = train_cfg.get("per_device_train_batch_size", 4)
batch_size = _auto_batch_size(model_key, device, gpu_info, config_batch)
# Smart gradient checkpointing:
# - ENABLE for large models (saves VRAM at cost of compute)
# - DISABLE for small models (they fit in VRAM, checkpointing is pure overhead)
# - ALWAYS DISABLE on CPU (plenty of RAM, checkpointing wastes CPU cycles)
large_models = {"flan-t5-large", "flan-t5-xl", "llama-3.1-8b"}
use_grad_ckpt = model_key in large_models and device == "cuda"
if use_grad_ckpt:
logger.info("Gradient checkpointing: ON (large model, saving VRAM)")
else:
logger.info(f"Gradient checkpointing: OFF ({'small model fits in VRAM' if device == 'cuda' else 'CPU has plenty of RAM'})")
# Dataloader workers: Python 3.14 changed default start method to "forkserver"
# on Linux, which hits "too many fds" with num_workers > 0.
# Use 0 (main-process loading) — dataset is pre-tokenized so overhead is minimal.
num_workers = train_cfg.get("dataloader_num_workers", 0)
# Filter report_to to only available tools
report_to = []
if HAS_WANDB and os.environ.get("WANDB_API_KEY"):
report_to.append("wandb")
report_to.append("tensorboard")
training_args = TrainingArguments(
output_dir=train_cfg.get("output_dir", "checkpoints/"),
num_train_epochs=train_cfg.get("num_train_epochs", 5),
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=train_cfg.get("per_device_eval_batch_size", 8) if device == "cuda" else 2,
gradient_accumulation_steps=train_cfg.get("gradient_accumulation_steps", 8),
learning_rate=train_cfg.get("learning_rate", 3e-4),
lr_scheduler_type=train_cfg.get("lr_scheduler_type", "cosine"),
warmup_ratio=train_cfg.get("warmup_ratio", 0.05),
weight_decay=train_cfg.get("weight_decay", 0.01),
fp16=use_fp16,
bf16=use_bf16,
eval_strategy=train_cfg.get("evaluation_strategy", "steps"),
eval_steps=train_cfg.get("eval_steps", 100),
save_strategy=train_cfg.get("save_strategy", "steps"),
save_steps=train_cfg.get("save_steps", 100),
save_total_limit=train_cfg.get("save_total_limit", 3),
load_best_model_at_end=False, # Handled manually below (PEFT adapters break Trainer's loader)
metric_for_best_model=train_cfg.get("metric_for_best_model", "eval_loss"),
greater_is_better=train_cfg.get("greater_is_better", False),
logging_dir=train_cfg.get("logging_dir", "logs/"),
logging_steps=train_cfg.get("logging_steps", 25),
report_to=report_to,
dataloader_num_workers=num_workers,
seed=train_cfg.get("seed", 42),
remove_unused_columns=False, # We have custom columns (style_vector, etc.)
gradient_checkpointing=use_grad_ckpt,
)
# Step 9: Create trainer
logger.info("Step 9: Creating trainer...")
trainer = CorrectionTrainer(
loss_fn=loss_fn,
fingerprinter=fingerprinter,
tokenizer=tokenizer,
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
callbacks=[
StyleMetricsCallback(),
EarlyStoppingOnStyleDrift(min_style_similarity=0.75),
],
)
# Step 10: Train
logger.info("Step 10: Starting training...")
logger.info(
f"Config summary: model={model_key} | batch={batch_size} | "
f"accum={training_args.gradient_accumulation_steps} | "
f"effective_batch={batch_size * training_args.gradient_accumulation_steps} | "
f"epochs={training_args.num_train_epochs} | "
f"precision={'bf16' if use_bf16 else 'fp16' if use_fp16 else 'fp32'} | "
f"grad_ckpt={use_grad_ckpt} | device={device}"
)
trainer.train()
# Step 11: Save best model (manual PEFT-aware loading)
logger.info("Step 11: Saving best model...")
output_dir = train_cfg.get("output_dir", "checkpoints/")
save_path = os.path.join(output_dir, "best_model")
# Find best checkpoint from trainer state
best_ckpt = None
state_path = os.path.join(output_dir, "trainer_state.json")
# Check each checkpoint for trainer_state.json
import glob
for ckpt_dir in sorted(glob.glob(os.path.join(output_dir, "checkpoint-*"))):
ts = os.path.join(ckpt_dir, "trainer_state.json")
if os.path.exists(ts):
import json as json_mod
with open(ts) as f:
state = json_mod.load(f)
best_path = state.get("best_model_checkpoint")
if best_path:
best_ckpt = best_path
if best_ckpt and os.path.isdir(best_ckpt):
logger.info(f"Loading best checkpoint from {best_ckpt}")
from peft import PeftModel
# Reload the best adapter weights
best_adapter = os.path.join(best_ckpt, "adapter_model.safetensors")
if os.path.exists(best_adapter):
model.load_adapter(best_ckpt, adapter_name="default")
logger.info(f"Loaded best adapter from {best_ckpt}")
else:
logger.warning(f"No adapter found at {best_ckpt}, saving current model")
else:
logger.info("No best checkpoint found, saving final model state")
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)
logger.info(f"Model saved to {save_path}")
if HAS_WANDB and wandb.run is not None:
wandb.finish()
logger.info("✓ Training complete!")
if __name__ == "__main__":
train()