| """ |
| Full training entry point. |
| Run: python scripts/train.py --config configs/training_config.yaml |
| """ |
|
|
| import click |
| import yaml |
| import torch |
| import os |
| import gc |
| from transformers import TrainingArguments, Seq2SeqTrainingArguments |
| from loguru import logger |
|
|
| try: |
| import wandb |
| HAS_WANDB = True |
| except ImportError: |
| HAS_WANDB = False |
|
|
| from src.model.base_model import load_model_and_tokenizer |
| from src.model.style_conditioner import StyleConditioner |
| from src.training.dataset import WritingCorrectionDataset |
| from src.training.loss_functions import CombinedCorrectionLoss, CombinedCorrectionLossV2 |
| from src.training.trainer import CorrectionTrainer |
| from src.training.callbacks import StyleMetricsCallback, EarlyStoppingOnStyleDrift |
| from src.style.fingerprinter import StyleFingerprinter |
| from src.evaluation.gleu_scorer import GLEUScorer |
|
|
|
|
| |
| def _setup_device(): |
| """Detect GPU and configure hybrid VRAM management. |
| |
| Returns (device, gpu_info) where gpu_info is a dict with: |
| - available: bool |
| - name: str |
| - vram_total_mb: int |
| - vram_free_mb: int |
| - compute_cap: tuple |
| """ |
| gpu_info = {"available": False, "name": "CPU", "vram_total_mb": 0, |
| "vram_free_mb": 0, "compute_cap": (0, 0)} |
|
|
| if not torch.cuda.is_available(): |
| logger.info("No GPU detected — training on CPU") |
| return "cpu", gpu_info |
|
|
| gpu_info["available"] = True |
| gpu_info["name"] = torch.cuda.get_device_name(0) |
| gpu_info["compute_cap"] = torch.cuda.get_device_capability(0) |
|
|
| |
| vram_total = torch.cuda.get_device_properties(0).total_memory // (1024 * 1024) |
| vram_reserved = torch.cuda.memory_reserved(0) // (1024 * 1024) |
| vram_allocated = torch.cuda.memory_allocated(0) // (1024 * 1024) |
| vram_free = vram_total - vram_allocated |
|
|
| gpu_info["vram_total_mb"] = vram_total |
| gpu_info["vram_free_mb"] = vram_free |
|
|
| logger.info( |
| f"GPU: {gpu_info['name']} | " |
| f"VRAM: {vram_allocated}MB used / {vram_total}MB total ({vram_free}MB free) | " |
| f"Compute: {gpu_info['compute_cap']}" |
| ) |
|
|
| |
| |
| usable_vram_mb = int(vram_free * 0.85) |
| if usable_vram_mb > 0: |
| |
| fraction = min(usable_vram_mb / vram_total, 0.90) |
| torch.cuda.set_per_process_memory_fraction(fraction, 0) |
| logger.info( |
| f"Hybrid GPU mode: capped PyTorch VRAM to {fraction:.0%} " |
| f"(~{int(vram_total * fraction)}MB), leaving room for system" |
| ) |
|
|
| return "cuda", gpu_info |
|
|
|
|
| def _auto_batch_size(model_key: str, device: str, gpu_info: dict, |
| config_batch: int) -> int: |
| """Pick optimal batch size based on model size and available resources.""" |
| if device == "cpu": |
| |
| if "small" in model_key: |
| return min(config_batch, 8) |
| return min(config_batch, 2) |
|
|
| |
| free_mb = gpu_info["vram_free_mb"] |
|
|
| |
| |
| |
| |
| model_vram_estimates = { |
| "flan-t5-small": {"model_mb": 160, "per_sample_mb": 60}, |
| "flan-t5-base": {"model_mb": 400, "per_sample_mb": 100}, |
| "flan-t5-large": {"model_mb": 1000, "per_sample_mb": 160}, |
| "flan-t5-xl": {"model_mb": 3000, "per_sample_mb": 300}, |
| } |
| est = model_vram_estimates.get(model_key, {"model_mb": 500, "per_sample_mb": 120}) |
|
|
| |
| available_for_batches = free_mb - est["model_mb"] - 300 |
| if available_for_batches <= 0: |
| logger.warning("Very tight VRAM — using batch_size=1") |
| return 1 |
|
|
| max_batch = max(1, available_for_batches // est["per_sample_mb"]) |
| optimal = min(config_batch, max_batch) |
|
|
| logger.info( |
| f"Auto batch size: {optimal} " |
| f"(model ~{est['model_mb']}MB + {optimal}×{est['per_sample_mb']}MB " |
| f"= ~{est['model_mb'] + optimal * est['per_sample_mb']}MB / {free_mb}MB free)" |
| ) |
| return max(1, optimal) |
|
|
|
|
| @click.command() |
| @click.option("--config", default="configs/training_config.yaml") |
| @click.option("--use-v2-loss", is_flag=True, help="Use V2 loss with human pattern term") |
| def train(config: str, use_v2_loss: bool): |
| """Launch the full training pipeline.""" |
| |
| logger.info("Step 1: Loading config...") |
| with open(config) as f: |
| cfg = yaml.safe_load(f) |
|
|
| model_cfg = cfg.get("model", {}) |
| lora_cfg = cfg.get("lora", {}) |
| data_cfg = cfg.get("data", {}) |
| train_cfg = cfg.get("training", {}) |
| loss_cfg = cfg.get("loss", {}) |
| gen_cfg = cfg.get("generation", {}) |
|
|
| |
| logger.info("Step 2: Initialising experiment tracking...") |
| if HAS_WANDB and os.environ.get("WANDB_API_KEY"): |
| wandb.init( |
| project="dyslexia-rewriter", |
| name=f"train-{model_cfg.get('key', 'flan-t5')}", |
| config=cfg, |
| ) |
| else: |
| logger.info("W&B not configured, logging to TensorBoard only") |
| os.environ["WANDB_DISABLED"] = "true" |
|
|
| |
| logger.info("Step 3: Setting up device (hybrid GPU mode)...") |
| device, gpu_info = _setup_device() |
|
|
| |
| logger.info("Step 4: Loading model and tokenizer...") |
| model_key = model_cfg.get("key", "flan-t5-small") |
| model, tokenizer, is_seq2seq = load_model_and_tokenizer( |
| model_key=model_key, |
| quantize=model_cfg.get("quantize", False), |
| use_lora=model_cfg.get("use_lora", True), |
| lora_config_dict=lora_cfg, |
| ) |
|
|
| |
| if hasattr(model, 'enable_input_require_grads'): |
| model.enable_input_require_grads() |
|
|
| |
| if hasattr(torch, "compile") and device == "cuda": |
| try: |
| |
| |
| |
| logger.info("Applying torch.compile(mode='default')...") |
| model = torch.compile(model, mode="default") |
| logger.info("✓ torch.compile applied — first few steps will be slower (compiling)") |
| except Exception as e: |
| logger.warning(f"torch.compile failed (non-fatal): {e}") |
|
|
| |
| logger.info("Step 5: Creating style fingerprinter...") |
| fingerprinter = StyleFingerprinter( |
| spacy_model="en_core_web_sm", |
| awl_path="data/awl/coxhead_awl.txt", |
| ) |
|
|
| |
| logger.info("Step 6: Loading datasets...") |
| train_dataset = WritingCorrectionDataset( |
| data_path=data_cfg.get("train_path", "data/processed/train.jsonl"), |
| tokenizer=tokenizer, |
| fingerprinter=fingerprinter, |
| max_input_length=data_cfg.get("max_input_length", 512), |
| max_target_length=data_cfg.get("max_target_length", 512), |
| augment_with_synthetic=data_cfg.get("augment_synthetic", True), |
| synthetic_ratio=data_cfg.get("synthetic_ratio", 0.3), |
| ) |
|
|
| val_dataset = WritingCorrectionDataset( |
| data_path=data_cfg.get("val_path", "data/processed/val.jsonl"), |
| tokenizer=tokenizer, |
| fingerprinter=fingerprinter, |
| max_input_length=data_cfg.get("max_input_length", 512), |
| max_target_length=data_cfg.get("max_target_length", 512), |
| augment_with_synthetic=False, |
| ) |
|
|
| logger.info(f"Train: {len(train_dataset)} | Val: {len(val_dataset)}") |
|
|
| |
| gc.collect() |
| if device == "cuda": |
| torch.cuda.empty_cache() |
|
|
| |
| |
| |
| from torch import nn |
| class CEOnlyLoss(nn.Module): |
| """Cross-entropy only loss — the only loss that provides gradient signal.""" |
| def __init__(self): |
| super().__init__() |
| self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
| def forward(self, logits, labels, **kwargs): |
| if logits.dim() == 3: |
| ce_logits = logits.view(-1, logits.size(-1)) |
| ce_labels = labels.view(-1) |
| else: |
| ce_logits = logits |
| ce_labels = labels |
| l_ce = self.ce_loss(ce_logits, ce_labels) |
| return {"total_loss": l_ce, "ce_loss": l_ce} |
|
|
| loss_fn = CEOnlyLoss() |
| logger.info("Using CE-only loss (aux models skipped to save memory)") |
|
|
| |
| logger.info("Step 8: Creating training arguments...") |
|
|
| |
| use_bf16 = False |
| use_fp16 = False |
| if device == "cuda": |
| if gpu_info["compute_cap"][0] >= 8: |
| use_bf16 = True |
| logger.info("Using BF16 (Ampere+ GPU)") |
| else: |
| use_fp16 = True |
| logger.info("Using FP16 (pre-Ampere GPU)") |
| elif device == "cpu": |
| |
| try: |
| test = torch.tensor([1.0], dtype=torch.bfloat16) |
| _ = test + test |
| use_bf16 = True |
| logger.info("Using BF16 on CPU (Zen 3+ detected)") |
| except Exception: |
| logger.info("BF16 not supported on this CPU, using FP32") |
|
|
| |
| config_batch = train_cfg.get("per_device_train_batch_size", 4) |
| batch_size = _auto_batch_size(model_key, device, gpu_info, config_batch) |
|
|
| |
| |
| |
| |
| large_models = {"flan-t5-large", "flan-t5-xl", "llama-3.1-8b"} |
| use_grad_ckpt = model_key in large_models and device == "cuda" |
| if use_grad_ckpt: |
| logger.info("Gradient checkpointing: ON (large model, saving VRAM)") |
| else: |
| logger.info(f"Gradient checkpointing: OFF ({'small model fits in VRAM' if device == 'cuda' else 'CPU has plenty of RAM'})") |
|
|
| |
| |
| |
| num_workers = train_cfg.get("dataloader_num_workers", 0) |
|
|
| |
| report_to = [] |
| if HAS_WANDB and os.environ.get("WANDB_API_KEY"): |
| report_to.append("wandb") |
| report_to.append("tensorboard") |
|
|
| training_args = TrainingArguments( |
| output_dir=train_cfg.get("output_dir", "checkpoints/"), |
| num_train_epochs=train_cfg.get("num_train_epochs", 5), |
| per_device_train_batch_size=batch_size, |
| per_device_eval_batch_size=train_cfg.get("per_device_eval_batch_size", 8) if device == "cuda" else 2, |
| gradient_accumulation_steps=train_cfg.get("gradient_accumulation_steps", 8), |
| learning_rate=train_cfg.get("learning_rate", 3e-4), |
| lr_scheduler_type=train_cfg.get("lr_scheduler_type", "cosine"), |
| warmup_ratio=train_cfg.get("warmup_ratio", 0.05), |
| weight_decay=train_cfg.get("weight_decay", 0.01), |
| fp16=use_fp16, |
| bf16=use_bf16, |
| eval_strategy=train_cfg.get("evaluation_strategy", "steps"), |
| eval_steps=train_cfg.get("eval_steps", 100), |
| save_strategy=train_cfg.get("save_strategy", "steps"), |
| save_steps=train_cfg.get("save_steps", 100), |
| save_total_limit=train_cfg.get("save_total_limit", 3), |
| load_best_model_at_end=False, |
| metric_for_best_model=train_cfg.get("metric_for_best_model", "eval_loss"), |
| greater_is_better=train_cfg.get("greater_is_better", False), |
| logging_dir=train_cfg.get("logging_dir", "logs/"), |
| logging_steps=train_cfg.get("logging_steps", 25), |
| report_to=report_to, |
| dataloader_num_workers=num_workers, |
| seed=train_cfg.get("seed", 42), |
| remove_unused_columns=False, |
| gradient_checkpointing=use_grad_ckpt, |
| ) |
|
|
| |
| logger.info("Step 9: Creating trainer...") |
| trainer = CorrectionTrainer( |
| loss_fn=loss_fn, |
| fingerprinter=fingerprinter, |
| tokenizer=tokenizer, |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| eval_dataset=val_dataset, |
| callbacks=[ |
| StyleMetricsCallback(), |
| EarlyStoppingOnStyleDrift(min_style_similarity=0.75), |
| ], |
| ) |
|
|
| |
| logger.info("Step 10: Starting training...") |
| logger.info( |
| f"Config summary: model={model_key} | batch={batch_size} | " |
| f"accum={training_args.gradient_accumulation_steps} | " |
| f"effective_batch={batch_size * training_args.gradient_accumulation_steps} | " |
| f"epochs={training_args.num_train_epochs} | " |
| f"precision={'bf16' if use_bf16 else 'fp16' if use_fp16 else 'fp32'} | " |
| f"grad_ckpt={use_grad_ckpt} | device={device}" |
| ) |
| trainer.train() |
|
|
| |
| logger.info("Step 11: Saving best model...") |
| output_dir = train_cfg.get("output_dir", "checkpoints/") |
| save_path = os.path.join(output_dir, "best_model") |
|
|
| |
| best_ckpt = None |
| state_path = os.path.join(output_dir, "trainer_state.json") |
| |
| import glob |
| for ckpt_dir in sorted(glob.glob(os.path.join(output_dir, "checkpoint-*"))): |
| ts = os.path.join(ckpt_dir, "trainer_state.json") |
| if os.path.exists(ts): |
| import json as json_mod |
| with open(ts) as f: |
| state = json_mod.load(f) |
| best_path = state.get("best_model_checkpoint") |
| if best_path: |
| best_ckpt = best_path |
|
|
| if best_ckpt and os.path.isdir(best_ckpt): |
| logger.info(f"Loading best checkpoint from {best_ckpt}") |
| from peft import PeftModel |
| |
| best_adapter = os.path.join(best_ckpt, "adapter_model.safetensors") |
| if os.path.exists(best_adapter): |
| model.load_adapter(best_ckpt, adapter_name="default") |
| logger.info(f"Loaded best adapter from {best_ckpt}") |
| else: |
| logger.warning(f"No adapter found at {best_ckpt}, saving current model") |
| else: |
| logger.info("No best checkpoint found, saving final model state") |
|
|
| trainer.save_model(save_path) |
| tokenizer.save_pretrained(save_path) |
| logger.info(f"Model saved to {save_path}") |
|
|
| if HAS_WANDB and wandb.run is not None: |
| wandb.finish() |
|
|
| logger.info("✓ Training complete!") |
|
|
|
|
| if __name__ == "__main__": |
| train() |
|
|