File size: 4,221 Bytes
0db822c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Fine-tune Whisper on the prepared dataset.

Usage:
    python scripts/train.py
    python scripts/train.py --config config/training_config.yaml
    python scripts/train.py --resume outputs/checkpoints/checkpoint-500
    python scripts/train.py --skip-smoke-test   # bypass pre-flight check

Run prepare_data.py first to build the dataset.

Pre-flight smoke test
---------------------
By default, a short sanity check runs BEFORE the full training loop.  It
exercises 2 optimizer steps + 1 evaluation pass on a tiny subset (8 train,
4 eval samples) to verify:
  - Audio preprocessing and feature extraction work end-to-end.
  - The fp16/fp32 dtype alignment in the data collator is correct.
  - A forward + backward pass succeeds without OOM or dtype errors.
  - Evaluation generation (predict_with_generate) and metric computation work.

If the smoke test fails, the error is logged and the script exits before
wasting GPU time on a run that would crash hours later.
Pass --skip-smoke-test to bypass this check (e.g. when resuming a known-good run).
"""

from __future__ import annotations

import argparse
import logging
import sys
from pathlib import Path

import yaml

sys.path.insert(0, str(Path(__file__).parent.parent))

from src.data_preparation.build_dataset import load_saved_dataset
from src.training.trainer import WhisperFinetuner

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s  %(levelname)-8s  %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)


def main(config_path: str, resume_from: str | None = None, skip_smoke_test: bool = False) -> None:
    root = Path(__file__).parent.parent

    with (root / config_path).open() as fh:
        cfg = yaml.safe_load(fh)

    processed_dir = root / cfg["data"]["processed_dir"]
    dataset_path = processed_dir / "hf_dataset"

    if not dataset_path.exists():
        logger.error(
            "Dataset not found at %s\nRun  python scripts/prepare_data.py  first.",
            dataset_path,
        )
        sys.exit(1)

    logger.info("Loading dataset from %s", dataset_path)
    dataset = load_saved_dataset(processed_dir)
    logger.info(
        "Train examples: %d  |  Eval examples: %d",
        len(dataset["train"]),
        len(dataset["eval"]),
    )

    if resume_from:
        cfg["training"]["output_dir"] = resume_from

    finetuner = WhisperFinetuner(cfg, dataset=dataset)
    finetuner.load_model_and_processor()

    # ------------------------------------------------------------------
    # Pre-flight smoke test: run a micro training loop to catch any
    # dtype / OOM / config errors BEFORE the multi-hour full training run.
    # ------------------------------------------------------------------
    if skip_smoke_test:
        logger.warning("Smoke test SKIPPED — proceeding directly to full training (--skip-smoke-test)")
    else:
        passed = finetuner.run_smoke_test(dataset)
        if not passed:
            logger.error(
                "Aborting — fix the error reported above and re-run.\n"
                "If the smoke test environment differs from full training (e.g. different\n"
                "batch size for OOM), you can bypass it with --skip-smoke-test."
            )
            sys.exit(1)

    # ------------------------------------------------------------------
    # Full training run
    # ------------------------------------------------------------------
    finetuner.train()

    logger.info("Training complete. Best model: %s/best_model", cfg["training"]["output_dir"])


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Fine-tune Whisper on Arabic speech")
    parser.add_argument("--config", default="config/training_config.yaml")
    parser.add_argument(
        "--resume",
        default=None,
        help="Path to a checkpoint directory to resume training from",
    )
    parser.add_argument(
        "--skip-smoke-test",
        action="store_true",
        default=False,
        help="Skip the pre-flight smoke test and go directly to full training "
             "(useful when resuming a known-good run)",
    )
    args = parser.parse_args()
    main(args.config, args.resume, args.skip_smoke_test)