Speach-To-Text / scripts /train.py
MIP-Tech's picture
Deploy to HF Spaces
0db822c
"""
Fine-tune Whisper on the prepared dataset.
Usage:
python scripts/train.py
python scripts/train.py --config config/training_config.yaml
python scripts/train.py --resume outputs/checkpoints/checkpoint-500
python scripts/train.py --skip-smoke-test # bypass pre-flight check
Run prepare_data.py first to build the dataset.
Pre-flight smoke test
---------------------
By default, a short sanity check runs BEFORE the full training loop. It
exercises 2 optimizer steps + 1 evaluation pass on a tiny subset (8 train,
4 eval samples) to verify:
- Audio preprocessing and feature extraction work end-to-end.
- The fp16/fp32 dtype alignment in the data collator is correct.
- A forward + backward pass succeeds without OOM or dtype errors.
- Evaluation generation (predict_with_generate) and metric computation work.
If the smoke test fails, the error is logged and the script exits before
wasting GPU time on a run that would crash hours later.
Pass --skip-smoke-test to bypass this check (e.g. when resuming a known-good run).
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
import yaml
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.data_preparation.build_dataset import load_saved_dataset
from src.training.trainer import WhisperFinetuner
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
def main(config_path: str, resume_from: str | None = None, skip_smoke_test: bool = False) -> None:
root = Path(__file__).parent.parent
with (root / config_path).open() as fh:
cfg = yaml.safe_load(fh)
processed_dir = root / cfg["data"]["processed_dir"]
dataset_path = processed_dir / "hf_dataset"
if not dataset_path.exists():
logger.error(
"Dataset not found at %s\nRun python scripts/prepare_data.py first.",
dataset_path,
)
sys.exit(1)
logger.info("Loading dataset from %s", dataset_path)
dataset = load_saved_dataset(processed_dir)
logger.info(
"Train examples: %d | Eval examples: %d",
len(dataset["train"]),
len(dataset["eval"]),
)
if resume_from:
cfg["training"]["output_dir"] = resume_from
finetuner = WhisperFinetuner(cfg, dataset=dataset)
finetuner.load_model_and_processor()
# ------------------------------------------------------------------
# Pre-flight smoke test: run a micro training loop to catch any
# dtype / OOM / config errors BEFORE the multi-hour full training run.
# ------------------------------------------------------------------
if skip_smoke_test:
logger.warning("Smoke test SKIPPED — proceeding directly to full training (--skip-smoke-test)")
else:
passed = finetuner.run_smoke_test(dataset)
if not passed:
logger.error(
"Aborting — fix the error reported above and re-run.\n"
"If the smoke test environment differs from full training (e.g. different\n"
"batch size for OOM), you can bypass it with --skip-smoke-test."
)
sys.exit(1)
# ------------------------------------------------------------------
# Full training run
# ------------------------------------------------------------------
finetuner.train()
logger.info("Training complete. Best model: %s/best_model", cfg["training"]["output_dir"])
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Fine-tune Whisper on Arabic speech")
parser.add_argument("--config", default="config/training_config.yaml")
parser.add_argument(
"--resume",
default=None,
help="Path to a checkpoint directory to resume training from",
)
parser.add_argument(
"--skip-smoke-test",
action="store_true",
default=False,
help="Skip the pre-flight smoke test and go directly to full training "
"(useful when resuming a known-good run)",
)
args = parser.parse_args()
main(args.config, args.resume, args.skip_smoke_test)