Spaces:
Sleeping
Sleeping
| """ | |
| Fine-tune Whisper on the prepared dataset. | |
| Usage: | |
| python scripts/train.py | |
| python scripts/train.py --config config/training_config.yaml | |
| python scripts/train.py --resume outputs/checkpoints/checkpoint-500 | |
| python scripts/train.py --skip-smoke-test # bypass pre-flight check | |
| Run prepare_data.py first to build the dataset. | |
| Pre-flight smoke test | |
| --------------------- | |
| By default, a short sanity check runs BEFORE the full training loop. It | |
| exercises 2 optimizer steps + 1 evaluation pass on a tiny subset (8 train, | |
| 4 eval samples) to verify: | |
| - Audio preprocessing and feature extraction work end-to-end. | |
| - The fp16/fp32 dtype alignment in the data collator is correct. | |
| - A forward + backward pass succeeds without OOM or dtype errors. | |
| - Evaluation generation (predict_with_generate) and metric computation work. | |
| If the smoke test fails, the error is logged and the script exits before | |
| wasting GPU time on a run that would crash hours later. | |
| Pass --skip-smoke-test to bypass this check (e.g. when resuming a known-good run). | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| import yaml | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from src.data_preparation.build_dataset import load_saved_dataset | |
| from src.training.trainer import WhisperFinetuner | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-8s %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def main(config_path: str, resume_from: str | None = None, skip_smoke_test: bool = False) -> None: | |
| root = Path(__file__).parent.parent | |
| with (root / config_path).open() as fh: | |
| cfg = yaml.safe_load(fh) | |
| processed_dir = root / cfg["data"]["processed_dir"] | |
| dataset_path = processed_dir / "hf_dataset" | |
| if not dataset_path.exists(): | |
| logger.error( | |
| "Dataset not found at %s\nRun python scripts/prepare_data.py first.", | |
| dataset_path, | |
| ) | |
| sys.exit(1) | |
| logger.info("Loading dataset from %s", dataset_path) | |
| dataset = load_saved_dataset(processed_dir) | |
| logger.info( | |
| "Train examples: %d | Eval examples: %d", | |
| len(dataset["train"]), | |
| len(dataset["eval"]), | |
| ) | |
| if resume_from: | |
| cfg["training"]["output_dir"] = resume_from | |
| finetuner = WhisperFinetuner(cfg, dataset=dataset) | |
| finetuner.load_model_and_processor() | |
| # ------------------------------------------------------------------ | |
| # Pre-flight smoke test: run a micro training loop to catch any | |
| # dtype / OOM / config errors BEFORE the multi-hour full training run. | |
| # ------------------------------------------------------------------ | |
| if skip_smoke_test: | |
| logger.warning("Smoke test SKIPPED — proceeding directly to full training (--skip-smoke-test)") | |
| else: | |
| passed = finetuner.run_smoke_test(dataset) | |
| if not passed: | |
| logger.error( | |
| "Aborting — fix the error reported above and re-run.\n" | |
| "If the smoke test environment differs from full training (e.g. different\n" | |
| "batch size for OOM), you can bypass it with --skip-smoke-test." | |
| ) | |
| sys.exit(1) | |
| # ------------------------------------------------------------------ | |
| # Full training run | |
| # ------------------------------------------------------------------ | |
| finetuner.train() | |
| logger.info("Training complete. Best model: %s/best_model", cfg["training"]["output_dir"]) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Fine-tune Whisper on Arabic speech") | |
| parser.add_argument("--config", default="config/training_config.yaml") | |
| parser.add_argument( | |
| "--resume", | |
| default=None, | |
| help="Path to a checkpoint directory to resume training from", | |
| ) | |
| parser.add_argument( | |
| "--skip-smoke-test", | |
| action="store_true", | |
| default=False, | |
| help="Skip the pre-flight smoke test and go directly to full training " | |
| "(useful when resuming a known-good run)", | |
| ) | |
| args = parser.parse_args() | |
| main(args.config, args.resume, args.skip_smoke_test) | |