""" Full training entry point. Run: python scripts/train.py --config configs/training_config.yaml """ import click import yaml import torch import os import gc from transformers import TrainingArguments, Seq2SeqTrainingArguments from loguru import logger try: import wandb HAS_WANDB = True except ImportError: HAS_WANDB = False from src.model.base_model import load_model_and_tokenizer from src.model.style_conditioner import StyleConditioner from src.training.dataset import WritingCorrectionDataset from src.training.loss_functions import CombinedCorrectionLoss, CombinedCorrectionLossV2 from src.training.trainer import CorrectionTrainer from src.training.callbacks import StyleMetricsCallback, EarlyStoppingOnStyleDrift from src.style.fingerprinter import StyleFingerprinter from src.evaluation.gleu_scorer import GLEUScorer # ── Hybrid GPU Management ─────────────────────────────────────────────────── def _setup_device(): """Detect GPU and configure hybrid VRAM management. Returns (device, gpu_info) where gpu_info is a dict with: - available: bool - name: str - vram_total_mb: int - vram_free_mb: int - compute_cap: tuple """ gpu_info = {"available": False, "name": "CPU", "vram_total_mb": 0, "vram_free_mb": 0, "compute_cap": (0, 0)} if not torch.cuda.is_available(): logger.info("No GPU detected — training on CPU") return "cpu", gpu_info gpu_info["available"] = True gpu_info["name"] = torch.cuda.get_device_name(0) gpu_info["compute_cap"] = torch.cuda.get_device_capability(0) # Query actual free VRAM vram_total = torch.cuda.get_device_properties(0).total_memory // (1024 * 1024) vram_reserved = torch.cuda.memory_reserved(0) // (1024 * 1024) vram_allocated = torch.cuda.memory_allocated(0) // (1024 * 1024) vram_free = vram_total - vram_allocated gpu_info["vram_total_mb"] = vram_total gpu_info["vram_free_mb"] = vram_free logger.info( f"GPU: {gpu_info['name']} | " f"VRAM: {vram_allocated}MB used / {vram_total}MB total ({vram_free}MB free) | " f"Compute: {gpu_info['compute_cap']}" ) # Leave headroom for the system — reserve at most 85% of free VRAM # This prevents the desktop/compositor from starving usable_vram_mb = int(vram_free * 0.85) if usable_vram_mb > 0: # Set PyTorch memory limit to avoid hogging all VRAM fraction = min(usable_vram_mb / vram_total, 0.90) torch.cuda.set_per_process_memory_fraction(fraction, 0) logger.info( f"Hybrid GPU mode: capped PyTorch VRAM to {fraction:.0%} " f"(~{int(vram_total * fraction)}MB), leaving room for system" ) return "cuda", gpu_info def _auto_batch_size(model_key: str, device: str, gpu_info: dict, config_batch: int) -> int: """Pick optimal batch size based on model size and available resources.""" if device == "cpu": # CPU: T5-Small can handle batch=8 with 32GB RAM, larger models less if "small" in model_key: return min(config_batch, 8) return min(config_batch, 2) # GPU: estimate based on free VRAM free_mb = gpu_info["vram_free_mb"] # Rough VRAM per sample estimates (bf16, seq_len=128): # T5-Small: ~120MB model + ~50MB/sample # T5-Base: ~350MB model + ~90MB/sample # T5-Large: ~900MB model + ~150MB/sample model_vram_estimates = { "flan-t5-small": {"model_mb": 160, "per_sample_mb": 60}, "flan-t5-base": {"model_mb": 400, "per_sample_mb": 100}, "flan-t5-large": {"model_mb": 1000, "per_sample_mb": 160}, "flan-t5-xl": {"model_mb": 3000, "per_sample_mb": 300}, } est = model_vram_estimates.get(model_key, {"model_mb": 500, "per_sample_mb": 120}) # Available for batches = free VRAM - model footprint - 300MB safety buffer available_for_batches = free_mb - est["model_mb"] - 300 if available_for_batches <= 0: logger.warning("Very tight VRAM — using batch_size=1") return 1 max_batch = max(1, available_for_batches // est["per_sample_mb"]) optimal = min(config_batch, max_batch) logger.info( f"Auto batch size: {optimal} " f"(model ~{est['model_mb']}MB + {optimal}×{est['per_sample_mb']}MB " f"= ~{est['model_mb'] + optimal * est['per_sample_mb']}MB / {free_mb}MB free)" ) return max(1, optimal) @click.command() @click.option("--config", default="configs/training_config.yaml") @click.option("--use-v2-loss", is_flag=True, help="Use V2 loss with human pattern term") def train(config: str, use_v2_loss: bool): """Launch the full training pipeline.""" # Step 1: Load config logger.info("Step 1: Loading config...") with open(config) as f: cfg = yaml.safe_load(f) model_cfg = cfg.get("model", {}) lora_cfg = cfg.get("lora", {}) data_cfg = cfg.get("data", {}) train_cfg = cfg.get("training", {}) loss_cfg = cfg.get("loss", {}) gen_cfg = cfg.get("generation", {}) # Step 2: Initialise W&B (optional) logger.info("Step 2: Initialising experiment tracking...") if HAS_WANDB and os.environ.get("WANDB_API_KEY"): wandb.init( project="dyslexia-rewriter", name=f"train-{model_cfg.get('key', 'flan-t5')}", config=cfg, ) else: logger.info("W&B not configured, logging to TensorBoard only") os.environ["WANDB_DISABLED"] = "true" # Step 3: Detect GPU and configure hybrid VRAM management logger.info("Step 3: Setting up device (hybrid GPU mode)...") device, gpu_info = _setup_device() # Step 4: Load model + tokenizer logger.info("Step 4: Loading model and tokenizer...") model_key = model_cfg.get("key", "flan-t5-small") model, tokenizer, is_seq2seq = load_model_and_tokenizer( model_key=model_key, quantize=model_cfg.get("quantize", False), use_lora=model_cfg.get("use_lora", True), lora_config_dict=lora_cfg, ) # Required for PEFT + gradient checkpointing compatibility if hasattr(model, 'enable_input_require_grads'): model.enable_input_require_grads() # ── torch.compile for fused kernels (PyTorch 2.x) ─────────────────────── if hasattr(torch, "compile") and device == "cuda": try: # "default" mode: fuses kernels via Triton without CUDA graphs. # "reduce-overhead" uses CUDA graphs which break with LoRA/PEFT # (tensor outputs get overwritten between graph replays). logger.info("Applying torch.compile(mode='default')...") model = torch.compile(model, mode="default") logger.info("✓ torch.compile applied — first few steps will be slower (compiling)") except Exception as e: logger.warning(f"torch.compile failed (non-fatal): {e}") # Step 5: Create fingerprinter logger.info("Step 5: Creating style fingerprinter...") fingerprinter = StyleFingerprinter( spacy_model="en_core_web_sm", # Use small model for training speed awl_path="data/awl/coxhead_awl.txt", ) # Step 6: Create datasets logger.info("Step 6: Loading datasets...") train_dataset = WritingCorrectionDataset( data_path=data_cfg.get("train_path", "data/processed/train.jsonl"), tokenizer=tokenizer, fingerprinter=fingerprinter, max_input_length=data_cfg.get("max_input_length", 512), max_target_length=data_cfg.get("max_target_length", 512), augment_with_synthetic=data_cfg.get("augment_synthetic", True), synthetic_ratio=data_cfg.get("synthetic_ratio", 0.3), ) val_dataset = WritingCorrectionDataset( data_path=data_cfg.get("val_path", "data/processed/val.jsonl"), tokenizer=tokenizer, fingerprinter=fingerprinter, max_input_length=data_cfg.get("max_input_length", 512), max_target_length=data_cfg.get("max_target_length", 512), augment_with_synthetic=False, ) logger.info(f"Train: {len(train_dataset)} | Val: {len(val_dataset)}") # Free memory after dataset loading gc.collect() if device == "cuda": torch.cuda.empty_cache() # Use simple CE-only loss for training — aux models (sentence-transformer, # GPT-2, HP classifier) are NOT loaded since they provide no gradient signal # (they decode via argmax under no_grad). This saves ~1GB+ memory. from torch import nn class CEOnlyLoss(nn.Module): """Cross-entropy only loss — the only loss that provides gradient signal.""" def __init__(self): super().__init__() self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100) def forward(self, logits, labels, **kwargs): if logits.dim() == 3: ce_logits = logits.view(-1, logits.size(-1)) ce_labels = labels.view(-1) else: ce_logits = logits ce_labels = labels l_ce = self.ce_loss(ce_logits, ce_labels) return {"total_loss": l_ce, "ce_loss": l_ce} loss_fn = CEOnlyLoss() logger.info("Using CE-only loss (aux models skipped to save memory)") # Step 8: Create training arguments logger.info("Step 8: Creating training arguments...") # Auto-detect precision support use_bf16 = False use_fp16 = False if device == "cuda": if gpu_info["compute_cap"][0] >= 8: use_bf16 = True logger.info("Using BF16 (Ampere+ GPU)") else: use_fp16 = True logger.info("Using FP16 (pre-Ampere GPU)") elif device == "cpu": # Zen 3+ CPUs (Ryzen 5000+) support BF16 in PyTorch 2.x try: test = torch.tensor([1.0], dtype=torch.bfloat16) _ = test + test # Test BF16 compute works use_bf16 = True logger.info("Using BF16 on CPU (Zen 3+ detected)") except Exception: logger.info("BF16 not supported on this CPU, using FP32") # Smart batch size based on model + available resources config_batch = train_cfg.get("per_device_train_batch_size", 4) batch_size = _auto_batch_size(model_key, device, gpu_info, config_batch) # Smart gradient checkpointing: # - ENABLE for large models (saves VRAM at cost of compute) # - DISABLE for small models (they fit in VRAM, checkpointing is pure overhead) # - ALWAYS DISABLE on CPU (plenty of RAM, checkpointing wastes CPU cycles) large_models = {"flan-t5-large", "flan-t5-xl", "llama-3.1-8b"} use_grad_ckpt = model_key in large_models and device == "cuda" if use_grad_ckpt: logger.info("Gradient checkpointing: ON (large model, saving VRAM)") else: logger.info(f"Gradient checkpointing: OFF ({'small model fits in VRAM' if device == 'cuda' else 'CPU has plenty of RAM'})") # Dataloader workers: Python 3.14 changed default start method to "forkserver" # on Linux, which hits "too many fds" with num_workers > 0. # Use 0 (main-process loading) — dataset is pre-tokenized so overhead is minimal. num_workers = train_cfg.get("dataloader_num_workers", 0) # Filter report_to to only available tools report_to = [] if HAS_WANDB and os.environ.get("WANDB_API_KEY"): report_to.append("wandb") report_to.append("tensorboard") training_args = TrainingArguments( output_dir=train_cfg.get("output_dir", "checkpoints/"), num_train_epochs=train_cfg.get("num_train_epochs", 5), per_device_train_batch_size=batch_size, per_device_eval_batch_size=train_cfg.get("per_device_eval_batch_size", 8) if device == "cuda" else 2, gradient_accumulation_steps=train_cfg.get("gradient_accumulation_steps", 8), learning_rate=train_cfg.get("learning_rate", 3e-4), lr_scheduler_type=train_cfg.get("lr_scheduler_type", "cosine"), warmup_ratio=train_cfg.get("warmup_ratio", 0.05), weight_decay=train_cfg.get("weight_decay", 0.01), fp16=use_fp16, bf16=use_bf16, eval_strategy=train_cfg.get("evaluation_strategy", "steps"), eval_steps=train_cfg.get("eval_steps", 100), save_strategy=train_cfg.get("save_strategy", "steps"), save_steps=train_cfg.get("save_steps", 100), save_total_limit=train_cfg.get("save_total_limit", 3), load_best_model_at_end=False, # Handled manually below (PEFT adapters break Trainer's loader) metric_for_best_model=train_cfg.get("metric_for_best_model", "eval_loss"), greater_is_better=train_cfg.get("greater_is_better", False), logging_dir=train_cfg.get("logging_dir", "logs/"), logging_steps=train_cfg.get("logging_steps", 25), report_to=report_to, dataloader_num_workers=num_workers, seed=train_cfg.get("seed", 42), remove_unused_columns=False, # We have custom columns (style_vector, etc.) gradient_checkpointing=use_grad_ckpt, ) # Step 9: Create trainer logger.info("Step 9: Creating trainer...") trainer = CorrectionTrainer( loss_fn=loss_fn, fingerprinter=fingerprinter, tokenizer=tokenizer, model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, callbacks=[ StyleMetricsCallback(), EarlyStoppingOnStyleDrift(min_style_similarity=0.75), ], ) # Step 10: Train logger.info("Step 10: Starting training...") logger.info( f"Config summary: model={model_key} | batch={batch_size} | " f"accum={training_args.gradient_accumulation_steps} | " f"effective_batch={batch_size * training_args.gradient_accumulation_steps} | " f"epochs={training_args.num_train_epochs} | " f"precision={'bf16' if use_bf16 else 'fp16' if use_fp16 else 'fp32'} | " f"grad_ckpt={use_grad_ckpt} | device={device}" ) trainer.train() # Step 11: Save best model (manual PEFT-aware loading) logger.info("Step 11: Saving best model...") output_dir = train_cfg.get("output_dir", "checkpoints/") save_path = os.path.join(output_dir, "best_model") # Find best checkpoint from trainer state best_ckpt = None state_path = os.path.join(output_dir, "trainer_state.json") # Check each checkpoint for trainer_state.json import glob for ckpt_dir in sorted(glob.glob(os.path.join(output_dir, "checkpoint-*"))): ts = os.path.join(ckpt_dir, "trainer_state.json") if os.path.exists(ts): import json as json_mod with open(ts) as f: state = json_mod.load(f) best_path = state.get("best_model_checkpoint") if best_path: best_ckpt = best_path if best_ckpt and os.path.isdir(best_ckpt): logger.info(f"Loading best checkpoint from {best_ckpt}") from peft import PeftModel # Reload the best adapter weights best_adapter = os.path.join(best_ckpt, "adapter_model.safetensors") if os.path.exists(best_adapter): model.load_adapter(best_ckpt, adapter_name="default") logger.info(f"Loaded best adapter from {best_ckpt}") else: logger.warning(f"No adapter found at {best_ckpt}, saving current model") else: logger.info("No best checkpoint found, saving final model state") trainer.save_model(save_path) tokenizer.save_pretrained(save_path) logger.info(f"Model saved to {save_path}") if HAS_WANDB and wandb.run is not None: wandb.finish() logger.info("✓ Training complete!") if __name__ == "__main__": train()