""" Fine-tune Whisper on the prepared dataset. Usage: python scripts/train.py python scripts/train.py --config config/training_config.yaml python scripts/train.py --resume outputs/checkpoints/checkpoint-500 python scripts/train.py --skip-smoke-test # bypass pre-flight check Run prepare_data.py first to build the dataset. Pre-flight smoke test --------------------- By default, a short sanity check runs BEFORE the full training loop. It exercises 2 optimizer steps + 1 evaluation pass on a tiny subset (8 train, 4 eval samples) to verify: - Audio preprocessing and feature extraction work end-to-end. - The fp16/fp32 dtype alignment in the data collator is correct. - A forward + backward pass succeeds without OOM or dtype errors. - Evaluation generation (predict_with_generate) and metric computation work. If the smoke test fails, the error is logged and the script exits before wasting GPU time on a run that would crash hours later. Pass --skip-smoke-test to bypass this check (e.g. when resuming a known-good run). """ from __future__ import annotations import argparse import logging import sys from pathlib import Path import yaml sys.path.insert(0, str(Path(__file__).parent.parent)) from src.data_preparation.build_dataset import load_saved_dataset from src.training.trainer import WhisperFinetuner logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%H:%M:%S", ) logger = logging.getLogger(__name__) def main(config_path: str, resume_from: str | None = None, skip_smoke_test: bool = False) -> None: root = Path(__file__).parent.parent with (root / config_path).open() as fh: cfg = yaml.safe_load(fh) processed_dir = root / cfg["data"]["processed_dir"] dataset_path = processed_dir / "hf_dataset" if not dataset_path.exists(): logger.error( "Dataset not found at %s\nRun python scripts/prepare_data.py first.", dataset_path, ) sys.exit(1) logger.info("Loading dataset from %s", dataset_path) dataset = load_saved_dataset(processed_dir) logger.info( "Train examples: %d | Eval examples: %d", len(dataset["train"]), len(dataset["eval"]), ) if resume_from: cfg["training"]["output_dir"] = resume_from finetuner = WhisperFinetuner(cfg, dataset=dataset) finetuner.load_model_and_processor() # ------------------------------------------------------------------ # Pre-flight smoke test: run a micro training loop to catch any # dtype / OOM / config errors BEFORE the multi-hour full training run. # ------------------------------------------------------------------ if skip_smoke_test: logger.warning("Smoke test SKIPPED — proceeding directly to full training (--skip-smoke-test)") else: passed = finetuner.run_smoke_test(dataset) if not passed: logger.error( "Aborting — fix the error reported above and re-run.\n" "If the smoke test environment differs from full training (e.g. different\n" "batch size for OOM), you can bypass it with --skip-smoke-test." ) sys.exit(1) # ------------------------------------------------------------------ # Full training run # ------------------------------------------------------------------ finetuner.train() logger.info("Training complete. Best model: %s/best_model", cfg["training"]["output_dir"]) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Fine-tune Whisper on Arabic speech") parser.add_argument("--config", default="config/training_config.yaml") parser.add_argument( "--resume", default=None, help="Path to a checkpoint directory to resume training from", ) parser.add_argument( "--skip-smoke-test", action="store_true", default=False, help="Skip the pre-flight smoke test and go directly to full training " "(useful when resuming a known-good run)", ) args = parser.parse_args() main(args.config, args.resume, args.skip_smoke_test)