File size: 15,920 Bytes
3df5819 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 | """
Full training entry point.
Run: python scripts/train.py --config configs/training_config.yaml
"""
import click
import yaml
import torch
import os
import gc
from transformers import TrainingArguments, Seq2SeqTrainingArguments
from loguru import logger
try:
import wandb
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
from src.model.base_model import load_model_and_tokenizer
from src.model.style_conditioner import StyleConditioner
from src.training.dataset import WritingCorrectionDataset
from src.training.loss_functions import CombinedCorrectionLoss, CombinedCorrectionLossV2
from src.training.trainer import CorrectionTrainer
from src.training.callbacks import StyleMetricsCallback, EarlyStoppingOnStyleDrift
from src.style.fingerprinter import StyleFingerprinter
from src.evaluation.gleu_scorer import GLEUScorer
# ββ Hybrid GPU Management βββββββββββββββββββββββββββββββββββββββββββββββββββ
def _setup_device():
"""Detect GPU and configure hybrid VRAM management.
Returns (device, gpu_info) where gpu_info is a dict with:
- available: bool
- name: str
- vram_total_mb: int
- vram_free_mb: int
- compute_cap: tuple
"""
gpu_info = {"available": False, "name": "CPU", "vram_total_mb": 0,
"vram_free_mb": 0, "compute_cap": (0, 0)}
if not torch.cuda.is_available():
logger.info("No GPU detected β training on CPU")
return "cpu", gpu_info
gpu_info["available"] = True
gpu_info["name"] = torch.cuda.get_device_name(0)
gpu_info["compute_cap"] = torch.cuda.get_device_capability(0)
# Query actual free VRAM
vram_total = torch.cuda.get_device_properties(0).total_memory // (1024 * 1024)
vram_reserved = torch.cuda.memory_reserved(0) // (1024 * 1024)
vram_allocated = torch.cuda.memory_allocated(0) // (1024 * 1024)
vram_free = vram_total - vram_allocated
gpu_info["vram_total_mb"] = vram_total
gpu_info["vram_free_mb"] = vram_free
logger.info(
f"GPU: {gpu_info['name']} | "
f"VRAM: {vram_allocated}MB used / {vram_total}MB total ({vram_free}MB free) | "
f"Compute: {gpu_info['compute_cap']}"
)
# Leave headroom for the system β reserve at most 85% of free VRAM
# This prevents the desktop/compositor from starving
usable_vram_mb = int(vram_free * 0.85)
if usable_vram_mb > 0:
# Set PyTorch memory limit to avoid hogging all VRAM
fraction = min(usable_vram_mb / vram_total, 0.90)
torch.cuda.set_per_process_memory_fraction(fraction, 0)
logger.info(
f"Hybrid GPU mode: capped PyTorch VRAM to {fraction:.0%} "
f"(~{int(vram_total * fraction)}MB), leaving room for system"
)
return "cuda", gpu_info
def _auto_batch_size(model_key: str, device: str, gpu_info: dict,
config_batch: int) -> int:
"""Pick optimal batch size based on model size and available resources."""
if device == "cpu":
# CPU: T5-Small can handle batch=8 with 32GB RAM, larger models less
if "small" in model_key:
return min(config_batch, 8)
return min(config_batch, 2)
# GPU: estimate based on free VRAM
free_mb = gpu_info["vram_free_mb"]
# Rough VRAM per sample estimates (bf16, seq_len=128):
# T5-Small: ~120MB model + ~50MB/sample
# T5-Base: ~350MB model + ~90MB/sample
# T5-Large: ~900MB model + ~150MB/sample
model_vram_estimates = {
"flan-t5-small": {"model_mb": 160, "per_sample_mb": 60},
"flan-t5-base": {"model_mb": 400, "per_sample_mb": 100},
"flan-t5-large": {"model_mb": 1000, "per_sample_mb": 160},
"flan-t5-xl": {"model_mb": 3000, "per_sample_mb": 300},
}
est = model_vram_estimates.get(model_key, {"model_mb": 500, "per_sample_mb": 120})
# Available for batches = free VRAM - model footprint - 300MB safety buffer
available_for_batches = free_mb - est["model_mb"] - 300
if available_for_batches <= 0:
logger.warning("Very tight VRAM β using batch_size=1")
return 1
max_batch = max(1, available_for_batches // est["per_sample_mb"])
optimal = min(config_batch, max_batch)
logger.info(
f"Auto batch size: {optimal} "
f"(model ~{est['model_mb']}MB + {optimal}Γ{est['per_sample_mb']}MB "
f"= ~{est['model_mb'] + optimal * est['per_sample_mb']}MB / {free_mb}MB free)"
)
return max(1, optimal)
@click.command()
@click.option("--config", default="configs/training_config.yaml")
@click.option("--use-v2-loss", is_flag=True, help="Use V2 loss with human pattern term")
def train(config: str, use_v2_loss: bool):
"""Launch the full training pipeline."""
# Step 1: Load config
logger.info("Step 1: Loading config...")
with open(config) as f:
cfg = yaml.safe_load(f)
model_cfg = cfg.get("model", {})
lora_cfg = cfg.get("lora", {})
data_cfg = cfg.get("data", {})
train_cfg = cfg.get("training", {})
loss_cfg = cfg.get("loss", {})
gen_cfg = cfg.get("generation", {})
# Step 2: Initialise W&B (optional)
logger.info("Step 2: Initialising experiment tracking...")
if HAS_WANDB and os.environ.get("WANDB_API_KEY"):
wandb.init(
project="dyslexia-rewriter",
name=f"train-{model_cfg.get('key', 'flan-t5')}",
config=cfg,
)
else:
logger.info("W&B not configured, logging to TensorBoard only")
os.environ["WANDB_DISABLED"] = "true"
# Step 3: Detect GPU and configure hybrid VRAM management
logger.info("Step 3: Setting up device (hybrid GPU mode)...")
device, gpu_info = _setup_device()
# Step 4: Load model + tokenizer
logger.info("Step 4: Loading model and tokenizer...")
model_key = model_cfg.get("key", "flan-t5-small")
model, tokenizer, is_seq2seq = load_model_and_tokenizer(
model_key=model_key,
quantize=model_cfg.get("quantize", False),
use_lora=model_cfg.get("use_lora", True),
lora_config_dict=lora_cfg,
)
# Required for PEFT + gradient checkpointing compatibility
if hasattr(model, 'enable_input_require_grads'):
model.enable_input_require_grads()
# ββ torch.compile for fused kernels (PyTorch 2.x) βββββββββββββββββββββββ
if hasattr(torch, "compile") and device == "cuda":
try:
# "default" mode: fuses kernels via Triton without CUDA graphs.
# "reduce-overhead" uses CUDA graphs which break with LoRA/PEFT
# (tensor outputs get overwritten between graph replays).
logger.info("Applying torch.compile(mode='default')...")
model = torch.compile(model, mode="default")
logger.info("β torch.compile applied β first few steps will be slower (compiling)")
except Exception as e:
logger.warning(f"torch.compile failed (non-fatal): {e}")
# Step 5: Create fingerprinter
logger.info("Step 5: Creating style fingerprinter...")
fingerprinter = StyleFingerprinter(
spacy_model="en_core_web_sm", # Use small model for training speed
awl_path="data/awl/coxhead_awl.txt",
)
# Step 6: Create datasets
logger.info("Step 6: Loading datasets...")
train_dataset = WritingCorrectionDataset(
data_path=data_cfg.get("train_path", "data/processed/train.jsonl"),
tokenizer=tokenizer,
fingerprinter=fingerprinter,
max_input_length=data_cfg.get("max_input_length", 512),
max_target_length=data_cfg.get("max_target_length", 512),
augment_with_synthetic=data_cfg.get("augment_synthetic", True),
synthetic_ratio=data_cfg.get("synthetic_ratio", 0.3),
)
val_dataset = WritingCorrectionDataset(
data_path=data_cfg.get("val_path", "data/processed/val.jsonl"),
tokenizer=tokenizer,
fingerprinter=fingerprinter,
max_input_length=data_cfg.get("max_input_length", 512),
max_target_length=data_cfg.get("max_target_length", 512),
augment_with_synthetic=False,
)
logger.info(f"Train: {len(train_dataset)} | Val: {len(val_dataset)}")
# Free memory after dataset loading
gc.collect()
if device == "cuda":
torch.cuda.empty_cache()
# Use simple CE-only loss for training β aux models (sentence-transformer,
# GPT-2, HP classifier) are NOT loaded since they provide no gradient signal
# (they decode via argmax under no_grad). This saves ~1GB+ memory.
from torch import nn
class CEOnlyLoss(nn.Module):
"""Cross-entropy only loss β the only loss that provides gradient signal."""
def __init__(self):
super().__init__()
self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100)
def forward(self, logits, labels, **kwargs):
if logits.dim() == 3:
ce_logits = logits.view(-1, logits.size(-1))
ce_labels = labels.view(-1)
else:
ce_logits = logits
ce_labels = labels
l_ce = self.ce_loss(ce_logits, ce_labels)
return {"total_loss": l_ce, "ce_loss": l_ce}
loss_fn = CEOnlyLoss()
logger.info("Using CE-only loss (aux models skipped to save memory)")
# Step 8: Create training arguments
logger.info("Step 8: Creating training arguments...")
# Auto-detect precision support
use_bf16 = False
use_fp16 = False
if device == "cuda":
if gpu_info["compute_cap"][0] >= 8:
use_bf16 = True
logger.info("Using BF16 (Ampere+ GPU)")
else:
use_fp16 = True
logger.info("Using FP16 (pre-Ampere GPU)")
elif device == "cpu":
# Zen 3+ CPUs (Ryzen 5000+) support BF16 in PyTorch 2.x
try:
test = torch.tensor([1.0], dtype=torch.bfloat16)
_ = test + test # Test BF16 compute works
use_bf16 = True
logger.info("Using BF16 on CPU (Zen 3+ detected)")
except Exception:
logger.info("BF16 not supported on this CPU, using FP32")
# Smart batch size based on model + available resources
config_batch = train_cfg.get("per_device_train_batch_size", 4)
batch_size = _auto_batch_size(model_key, device, gpu_info, config_batch)
# Smart gradient checkpointing:
# - ENABLE for large models (saves VRAM at cost of compute)
# - DISABLE for small models (they fit in VRAM, checkpointing is pure overhead)
# - ALWAYS DISABLE on CPU (plenty of RAM, checkpointing wastes CPU cycles)
large_models = {"flan-t5-large", "flan-t5-xl", "llama-3.1-8b"}
use_grad_ckpt = model_key in large_models and device == "cuda"
if use_grad_ckpt:
logger.info("Gradient checkpointing: ON (large model, saving VRAM)")
else:
logger.info(f"Gradient checkpointing: OFF ({'small model fits in VRAM' if device == 'cuda' else 'CPU has plenty of RAM'})")
# Dataloader workers: Python 3.14 changed default start method to "forkserver"
# on Linux, which hits "too many fds" with num_workers > 0.
# Use 0 (main-process loading) β dataset is pre-tokenized so overhead is minimal.
num_workers = train_cfg.get("dataloader_num_workers", 0)
# Filter report_to to only available tools
report_to = []
if HAS_WANDB and os.environ.get("WANDB_API_KEY"):
report_to.append("wandb")
report_to.append("tensorboard")
training_args = TrainingArguments(
output_dir=train_cfg.get("output_dir", "checkpoints/"),
num_train_epochs=train_cfg.get("num_train_epochs", 5),
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=train_cfg.get("per_device_eval_batch_size", 8) if device == "cuda" else 2,
gradient_accumulation_steps=train_cfg.get("gradient_accumulation_steps", 8),
learning_rate=train_cfg.get("learning_rate", 3e-4),
lr_scheduler_type=train_cfg.get("lr_scheduler_type", "cosine"),
warmup_ratio=train_cfg.get("warmup_ratio", 0.05),
weight_decay=train_cfg.get("weight_decay", 0.01),
fp16=use_fp16,
bf16=use_bf16,
eval_strategy=train_cfg.get("evaluation_strategy", "steps"),
eval_steps=train_cfg.get("eval_steps", 100),
save_strategy=train_cfg.get("save_strategy", "steps"),
save_steps=train_cfg.get("save_steps", 100),
save_total_limit=train_cfg.get("save_total_limit", 3),
load_best_model_at_end=False, # Handled manually below (PEFT adapters break Trainer's loader)
metric_for_best_model=train_cfg.get("metric_for_best_model", "eval_loss"),
greater_is_better=train_cfg.get("greater_is_better", False),
logging_dir=train_cfg.get("logging_dir", "logs/"),
logging_steps=train_cfg.get("logging_steps", 25),
report_to=report_to,
dataloader_num_workers=num_workers,
seed=train_cfg.get("seed", 42),
remove_unused_columns=False, # We have custom columns (style_vector, etc.)
gradient_checkpointing=use_grad_ckpt,
)
# Step 9: Create trainer
logger.info("Step 9: Creating trainer...")
trainer = CorrectionTrainer(
loss_fn=loss_fn,
fingerprinter=fingerprinter,
tokenizer=tokenizer,
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
callbacks=[
StyleMetricsCallback(),
EarlyStoppingOnStyleDrift(min_style_similarity=0.75),
],
)
# Step 10: Train
logger.info("Step 10: Starting training...")
logger.info(
f"Config summary: model={model_key} | batch={batch_size} | "
f"accum={training_args.gradient_accumulation_steps} | "
f"effective_batch={batch_size * training_args.gradient_accumulation_steps} | "
f"epochs={training_args.num_train_epochs} | "
f"precision={'bf16' if use_bf16 else 'fp16' if use_fp16 else 'fp32'} | "
f"grad_ckpt={use_grad_ckpt} | device={device}"
)
trainer.train()
# Step 11: Save best model (manual PEFT-aware loading)
logger.info("Step 11: Saving best model...")
output_dir = train_cfg.get("output_dir", "checkpoints/")
save_path = os.path.join(output_dir, "best_model")
# Find best checkpoint from trainer state
best_ckpt = None
state_path = os.path.join(output_dir, "trainer_state.json")
# Check each checkpoint for trainer_state.json
import glob
for ckpt_dir in sorted(glob.glob(os.path.join(output_dir, "checkpoint-*"))):
ts = os.path.join(ckpt_dir, "trainer_state.json")
if os.path.exists(ts):
import json as json_mod
with open(ts) as f:
state = json_mod.load(f)
best_path = state.get("best_model_checkpoint")
if best_path:
best_ckpt = best_path
if best_ckpt and os.path.isdir(best_ckpt):
logger.info(f"Loading best checkpoint from {best_ckpt}")
from peft import PeftModel
# Reload the best adapter weights
best_adapter = os.path.join(best_ckpt, "adapter_model.safetensors")
if os.path.exists(best_adapter):
model.load_adapter(best_ckpt, adapter_name="default")
logger.info(f"Loaded best adapter from {best_ckpt}")
else:
logger.warning(f"No adapter found at {best_ckpt}, saving current model")
else:
logger.info("No best checkpoint found, saving final model state")
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)
logger.info(f"Model saved to {save_path}")
if HAS_WANDB and wandb.run is not None:
wandb.finish()
logger.info("β Training complete!")
if __name__ == "__main__":
train()
|