Spaces:
Sleeping
Sleeping
| """ | |
| Fine-tune openai/whisper-large-v3 on Arabic (Egyptian) speech using LoRA. | |
| Key design decisions: | |
| - LoRA targets q/k/v/out_proj + fc1/fc2 in both encoder and decoder for | |
| maximum dialect adaptation with minimal VRAM overhead. | |
| - Training prepare_fn applies speed perturbation + Gaussian noise; the eval | |
| prepare_fn runs the same pipeline without augmentation so metrics are clean. | |
| - SpecAugment is applied inside the DataCollator on every training step | |
| (checked via model.training) so it is freshly random each batch rather than | |
| being cached to disk like map()-applied augmentations. | |
| - Evaluation reports both CER (primary, more reliable for Arabic morphology) | |
| and WER (secondary, for comparison with published baselines). | |
| - forced_decoder_ids lock the decoder to Arabic transcription at every step. | |
| """ | |
| from __future__ import annotations | |
| import io | |
| import logging | |
| import platform | |
| import shutil | |
| import sys | |
| import tempfile | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional | |
| import evaluate | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from datasets import Audio, DatasetDict | |
| from peft import LoraConfig, get_peft_model | |
| from transformers import ( | |
| EarlyStoppingCallback, | |
| Seq2SeqTrainer, | |
| Seq2SeqTrainingArguments, | |
| WhisperForConditionalGeneration, | |
| WhisperProcessor, | |
| ) | |
| from src.data_preparation.augmentation import ( | |
| apply_spec_augment, | |
| maybe_apply_noise, | |
| maybe_apply_speed, | |
| ) | |
| from src.data_preparation.parse_transcripts import normalize_arabic | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Data collator | |
| # --------------------------------------------------------------------------- | |
| class DataCollatorSpeechSeq2SeqWithPadding: | |
| """ | |
| Pad a batch of (input_features, labels) pairs. | |
| Handles two important correctness issues: | |
| 1. dtype alignment: feature_extractor always returns float32, but the model | |
| may be loaded in float16 (GPU). During eval the AMP autocast context is | |
| NOT active for generate(), so we must cast input_features to the model | |
| dtype here — otherwise conv1 gets float32 inputs with float16 bias and | |
| raises "Input type (float) and bias type (Half) should be the same". | |
| 2. SpecAugment: applied only during training (model.training == True) so | |
| eval metrics are computed on clean, un-augmented features. | |
| """ | |
| processor: Any | |
| decoder_start_token_id: int | |
| model: Any = field(default=None, repr=False) | |
| spec_augment_config: Optional[Dict] = None | |
| def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: | |
| # --- Pad mel-spectrogram input features --- | |
| input_features = [{"input_features": f["input_features"]} for f in features] | |
| batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") | |
| # FIX: Cast input_features to the model's dtype. | |
| # The feature extractor always produces float32. The model may be float16 | |
| # (fp16=True on GPU). During training, PyTorch AMP autocast handles the | |
| # conversion transparently, but during evaluation generate() is called | |
| # OUTSIDE the autocast context — so the float32 tensor hits the float16 | |
| # conv1 bias and raises a RuntimeError. Casting here fixes both paths. | |
| if self.model is not None: | |
| model_dtype = next(self.model.parameters()).dtype | |
| if batch["input_features"].dtype != model_dtype: | |
| batch["input_features"] = batch["input_features"].to(dtype=model_dtype) | |
| # FIX: Provide an explicit encoder attention mask (all-ones). | |
| # Whisper's feature extractor always pads mel spectrograms to exactly | |
| # 3000 frames, so every frame is valid — the mask is always all-ones. | |
| # Without this, generate() tries to infer the mask from pad_token_id, | |
| # but pad_token_id == eos_token_id in Whisper so it can't tell which | |
| # frames are padding and emits: "attention mask is not set and cannot | |
| # be inferred ... pad token is same as eos token". | |
| batch["attention_mask"] = torch.ones( | |
| batch["input_features"].shape[0], # batch size | |
| batch["input_features"].shape[2], # time frames (always 3000) | |
| dtype=torch.long, | |
| ) | |
| # --- Pad label token sequences; mask padding with -100 (ignored in loss) --- | |
| label_features = [{"input_ids": f["labels"]} for f in features] | |
| labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") | |
| labels = labels_batch["input_ids"].masked_fill( | |
| labels_batch.attention_mask.ne(1), -100 | |
| ) | |
| # Remove the leading BOS token that the tokenizer inserts. | |
| # Seq2SeqTrainer shifts labels internally; keeping BOS causes an off-by-one error. | |
| if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item(): | |
| labels = labels[:, 1:] | |
| batch["labels"] = labels | |
| # --- SpecAugment (training only, applied on every step for fresh randomness) --- | |
| if ( | |
| self.spec_augment_config | |
| and self.spec_augment_config.get("enabled", False) | |
| and self.model is not None | |
| and self.model.training | |
| ): | |
| batch["input_features"] = apply_spec_augment( | |
| batch["input_features"], | |
| time_mask_param=self.spec_augment_config.get("time_mask_param", 80), | |
| freq_mask_param=self.spec_augment_config.get("freq_mask_param", 27), | |
| num_time_masks=self.spec_augment_config.get("num_time_masks", 2), | |
| num_freq_masks=self.spec_augment_config.get("num_freq_masks", 2), | |
| ) | |
| return batch | |
| # --------------------------------------------------------------------------- | |
| # Feature extraction (with optional audio augmentation) | |
| # --------------------------------------------------------------------------- | |
| def make_prepare_fn( | |
| processor: WhisperProcessor, | |
| augment_config: Optional[Dict] = None, | |
| ): | |
| """ | |
| Return a map-function that converts raw audio + text into model inputs. | |
| When `augment_config` is provided and has enabled=True, speed perturbation | |
| and Gaussian noise are applied to the audio array before mel extraction. | |
| This is used for the training split only; eval/test use augment_config=None. | |
| """ | |
| aug_enabled = augment_config is not None and augment_config.get("enabled", False) | |
| speed_cfg = (augment_config or {}).get("speed_perturbation", {}) | |
| noise_cfg = (augment_config or {}).get("noise", {}) | |
| def prepare_dataset(batch): | |
| audio_data = batch["audio"] | |
| # Decode audio manually with soundfile (avoids torchcodec dependency) | |
| if audio_data.get("bytes"): | |
| array, sampling_rate = sf.read(io.BytesIO(audio_data["bytes"])) | |
| else: | |
| array, sampling_rate = sf.read(audio_data["path"]) | |
| # Convert stereo / multi-channel to mono | |
| if array.ndim > 1: | |
| array = array.mean(axis=1) | |
| array = array.astype(np.float32) | |
| # Audio-level augmentation (training split only) | |
| if aug_enabled: | |
| array = maybe_apply_speed(array, sampling_rate, speed_cfg) | |
| array = maybe_apply_noise(array, noise_cfg) | |
| batch["input_features"] = processor.feature_extractor( | |
| array, | |
| sampling_rate=sampling_rate, | |
| ).input_features[0] | |
| batch["labels"] = processor.tokenizer(batch["sentence"]).input_ids | |
| return batch | |
| return prepare_dataset | |
| # --------------------------------------------------------------------------- | |
| # Evaluation metrics: CER (primary) + WER (secondary) | |
| # --------------------------------------------------------------------------- | |
| def make_compute_metrics_fn(processor: WhisperProcessor): | |
| """ | |
| Return a compute_metrics function that reports both CER and WER. | |
| Both predictions and references are normalized with normalize_arabic() | |
| before scoring so that metric values reflect real transcription quality | |
| rather than superficial differences in diacritics or punctuation. | |
| CER is the primary metric for Arabic because Arabic morphology causes | |
| word-boundary tokenization to be unreliable for WER comparisons. | |
| """ | |
| wer_metric = evaluate.load("wer") | |
| cer_metric = evaluate.load("cer") | |
| def compute_metrics(pred): | |
| pred_ids = pred.predictions | |
| label_ids = pred.label_ids | |
| # Restore pad tokens so the tokenizer can decode normally | |
| label_ids[label_ids == -100] = processor.tokenizer.pad_token_id | |
| pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True) | |
| label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True) | |
| # Apply same normalization used on training labels | |
| pred_str = [normalize_arabic(s) for s in pred_str] | |
| label_str = [normalize_arabic(s) for s in label_str] | |
| wer = 100 * wer_metric.compute(predictions=pred_str, references=label_str) | |
| cer = 100 * cer_metric.compute(predictions=pred_str, references=label_str) | |
| return {"cer": cer, "wer": wer} | |
| return compute_metrics | |
| # --------------------------------------------------------------------------- | |
| # Main trainer class | |
| # --------------------------------------------------------------------------- | |
| class WhisperFinetuner: | |
| def __init__(self, config: dict, dataset: Optional[DatasetDict] = None): | |
| self.cfg = config | |
| self.dataset = dataset | |
| self.model_name = config["model"]["base_model"] | |
| self.language = config["model"]["language"] | |
| self.task = config["model"]["task"] | |
| self.output_dir = Path(config["training"]["output_dir"]) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| self.processor: Optional[WhisperProcessor] = None | |
| self.model: Optional[WhisperForConditionalGeneration] = None | |
| # ------------------------------------------------------------------ | |
| # Setup | |
| # ------------------------------------------------------------------ | |
| def _apply_lora(self) -> None: | |
| """Wrap the model with LoRA adapters based on config.""" | |
| lora_cfg = self.cfg.get("lora", {}) | |
| if not lora_cfg.get("enabled", True): | |
| total_params = sum(p.numel() for p in self.model.parameters()) | |
| logger.info("LoRA disabled — all %d parameters (%.1f M) will be trained", | |
| total_params, total_params / 1e6) | |
| return | |
| r = lora_cfg.get("r", 32) | |
| lora_alpha = lora_cfg.get("lora_alpha", 64) | |
| lora_dropout = lora_cfg.get("lora_dropout", 0.05) | |
| bias = lora_cfg.get("bias", "none") | |
| target_modules = lora_cfg.get( | |
| "target_modules", | |
| ["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2"], | |
| ) | |
| logger.info("Applying LoRA adapters:") | |
| logger.info(" rank (r) : %d", r) | |
| logger.info(" lora_alpha : %d (effective scale = alpha/r = %.1f)", lora_alpha, lora_alpha / r) | |
| logger.info(" dropout : %.2f", lora_dropout) | |
| logger.info(" bias : %s", bias) | |
| logger.info(" target modules : %s", target_modules) | |
| # Do NOT set task_type=SEQ_2_SEQ_LM — Whisper uses input_features (not | |
| # input_ids) for its encoder; PeftModelForSeq2SeqLM injects a duplicate | |
| # input_ids kwarg and crashes. task_type omitted keeps the base PeftModel | |
| # wrapper which passes all kwargs through unchanged. | |
| lora_config = LoraConfig( | |
| r=r, | |
| lora_alpha=lora_alpha, | |
| target_modules=target_modules, | |
| lora_dropout=lora_dropout, | |
| bias=bias, | |
| ) | |
| self.model = get_peft_model(self.model, lora_config) | |
| trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad) | |
| total = sum(p.numel() for p in self.model.parameters()) | |
| logger.info("LoRA applied successfully:") | |
| logger.info(" trainable parameters : %d (%.2f%%)", trainable, 100 * trainable / total) | |
| logger.info(" frozen parameters : %d (%.2f%%)", total - trainable, 100 * (total - trainable) / total) | |
| logger.info(" total parameters : %d (%.1f M)", total, total / 1e6) | |
| def load_model_and_processor(self) -> None: | |
| logger.info("=" * 60) | |
| logger.info("STEP 1/3 — LOADING PROCESSOR") | |
| logger.info(" model : %s", self.model_name) | |
| logger.info(" language: %s", self.language) | |
| logger.info(" task : %s", self.task) | |
| self.processor = WhisperProcessor.from_pretrained( | |
| self.model_name, | |
| language=self.language, | |
| task=self.task, | |
| ) | |
| vocab_size = self.processor.tokenizer.vocab_size | |
| logger.info("Processor ready — vocabulary size: %d tokens", vocab_size) | |
| # Decide dtype based on hardware | |
| use_cuda = torch.cuda.is_available() | |
| if use_cuda: | |
| gpu_name = torch.cuda.get_device_name(0) | |
| vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 | |
| logger.info("GPU detected: %s (%.1f GB VRAM total)", gpu_name, vram_gb) | |
| else: | |
| logger.warning( | |
| "No CUDA GPU detected — training will run on CPU. " | |
| "whisper-large-v3 on CPU is extremely slow (hours per epoch). " | |
| "Consider running on a machine with a CUDA-capable GPU." | |
| ) | |
| # Load in fp16 on GPU (halves VRAM vs fp32); keep fp32 on CPU where fp16 | |
| # training support is limited. The Seq2SeqTrainer's fp16 flag and grad | |
| # scaler handle mixed-precision automatically on GPU. | |
| load_dtype = torch.float16 if use_cuda else torch.float32 | |
| logger.info("Loading model weights in %s ...", "float16 (GPU — halves VRAM)" if use_cuda else "float32 (CPU)") | |
| self.model = WhisperForConditionalGeneration.from_pretrained( | |
| self.model_name, torch_dtype=load_dtype | |
| ) | |
| total_params = sum(p.numel() for p in self.model.parameters()) | |
| logger.info("Model loaded: %d parameters (%.1f M)", total_params, total_params / 1e6) | |
| # ── Decoder configuration ────────────────────────────────────────── | |
| # 1. Language / task: set on generation_config so every generate() call | |
| # decodes in Arabic transcription mode without any extra kwarg. | |
| self.model.generation_config.language = self.language | |
| self.model.generation_config.task = self.task | |
| logger.info("Generation config: language='%s', task='%s'", self.language, self.task) | |
| # 2. forced_decoder_ids = None: older Whisper checkpoints ship with a | |
| # pre-built forced_decoder_ids list. When this list is present, | |
| # transformers runs a language-detection forward pass (fp32) BEFORE | |
| # the main generate() call. On fp16 models that crashes with | |
| # "Input type (float) and bias type (Half) should be the same". | |
| # Clearing it forces the model to use generation_config.language/task | |
| # instead, which avoids the detection pass entirely. | |
| self.model.generation_config.forced_decoder_ids = None | |
| self.model.config.forced_decoder_ids = None | |
| logger.info("forced_decoder_ids cleared — language detection pass disabled " | |
| "(prevents fp16/fp32 dtype crash during generate())") | |
| # 3. suppress_tokens / begin_suppress_tokens = None on *both* config and | |
| # generation_config: Whisper's own generate() override (generation_whisper.py) | |
| # builds SuppressTokensLogitsProcessor and SuppressTokensAtBeginLogitsProcessor | |
| # internally from its own logic and passes them to super().generate(). | |
| # super().generate() then reads these same fields from generation_config and | |
| # creates *duplicate* processors, triggering the warning: | |
| # "A custom logits processor ... was also created in .generate()" | |
| # Setting these to None stops super().generate() from creating its own copies; | |
| # Whisper's override still builds the correct processors independently. | |
| self.model.config.suppress_tokens = None | |
| self.model.generation_config.suppress_tokens = None | |
| self.model.generation_config.begin_suppress_tokens = None | |
| logger.info("suppress_tokens / begin_suppress_tokens cleared from generation_config " | |
| "(Whisper's generate() builds these processors internally — " | |
| "clearing prevents duplicate-processor warnings)") | |
| # gradient_checkpointing must be enabled before LoRA wrapping. | |
| # enable_input_require_grads() ensures LoRA leaf tensors receive | |
| # gradients through the frozen backbone. | |
| # NOTE: gradient checkpointing is NOT gated on CUDA — it trades | |
| # recomputation for memory and is equally valid (and critical) on CPU. | |
| if self.cfg["training"].get("gradient_checkpointing", False): | |
| self.model.config.use_cache = False | |
| self.model.gradient_checkpointing_enable() | |
| self.model.enable_input_require_grads() | |
| logger.info("Gradient checkpointing enabled — activations recomputed on backward pass to save memory") | |
| else: | |
| logger.info("Gradient checkpointing disabled — all activations kept in memory") | |
| self._apply_lora() | |
| def prepare_datasets(self) -> DatasetDict: | |
| """ | |
| Tokenize audio features and text labels for each split. | |
| Training split uses augment_config so speed perturbation and noise are | |
| applied stochastically during the map() call. | |
| Eval and test splits use no augmentation so metrics are deterministic. | |
| """ | |
| assert self.processor is not None, "Call load_model_and_processor() first" | |
| assert self.dataset is not None, "No dataset provided" | |
| logger.info("=" * 60) | |
| logger.info("STEP 2/3 — PREPARING DATASETS") | |
| logger.info(" train split : %d samples", len(self.dataset["train"])) | |
| logger.info(" eval split : %d samples", len(self.dataset["eval"])) | |
| if "test" in self.dataset: | |
| logger.info(" test split : %d samples", len(self.dataset["test"])) | |
| else: | |
| logger.info(" test split : not present") | |
| aug_config = self.cfg.get("augmentation", None) | |
| aug_enabled = aug_config is not None and aug_config.get("enabled", False) | |
| if aug_enabled: | |
| speed_cfg = aug_config.get("speed_perturbation", {}) | |
| noise_cfg = aug_config.get("noise", {}) | |
| spec_cfg = aug_config.get("spec_augment", {}) | |
| logger.info("Training augmentation: ENABLED") | |
| if speed_cfg.get("enabled", False): | |
| logger.info(" speed perturbation : factors=%s, probability=%.0f%%", | |
| speed_cfg.get("factors", [0.9, 0.95, 1.05, 1.1]), | |
| 100 * speed_cfg.get("probability", 0.3)) | |
| else: | |
| logger.info(" speed perturbation : disabled") | |
| if noise_cfg.get("enabled", False): | |
| logger.info(" noise injection : SNR=[%.0f–%.0f] dB, probability=%.0f%%", | |
| noise_cfg.get("min_snr_db", 15.0), | |
| noise_cfg.get("max_snr_db", 30.0), | |
| 100 * noise_cfg.get("probability", 0.3)) | |
| else: | |
| logger.info(" noise injection : disabled") | |
| if spec_cfg.get("enabled", False): | |
| logger.info(" SpecAugment : time_mask=%d, freq_mask=%d, " | |
| "num_time=%d, num_freq=%d (applied per step, not cached)", | |
| spec_cfg.get("time_mask_param", 80), | |
| spec_cfg.get("freq_mask_param", 27), | |
| spec_cfg.get("num_time_masks", 2), | |
| spec_cfg.get("num_freq_masks", 2)) | |
| else: | |
| logger.info(" SpecAugment : disabled") | |
| else: | |
| logger.info("Training augmentation: DISABLED — raw audio used as-is") | |
| logger.info("Eval/test augmentation: always DISABLED — clean audio for accurate metrics") | |
| train_prepare_fn = make_prepare_fn(self.processor, augment_config=aug_config if aug_enabled else None) | |
| eval_prepare_fn = make_prepare_fn(self.processor, augment_config=None) | |
| # Disable torchcodec-based decoding; audio is decoded manually in prepare_fn | |
| logger.info("Disabling HuggingFace auto-decode on audio column (manual decode via soundfile)") | |
| dataset = self.dataset.cast_column("audio", Audio(decode=False)) | |
| remove_cols = ["audio", "sentence", "duration", "source_audio"] | |
| logger.info("Columns to remove after feature extraction: %s", remove_cols) | |
| logger.info("Processing training split — extracting mel features + tokenizing labels ...") | |
| train_processed = dataset["train"].map( | |
| train_prepare_fn, | |
| remove_columns=remove_cols, | |
| num_proc=1, | |
| ) | |
| logger.info("Training split done: %d examples → columns: %s", | |
| len(train_processed), train_processed.column_names) | |
| logger.info("Processing eval split ...") | |
| eval_processed = dataset["eval"].map( | |
| eval_prepare_fn, | |
| remove_columns=remove_cols, | |
| num_proc=1, | |
| ) | |
| logger.info("Eval split done: %d examples → columns: %s", | |
| len(eval_processed), eval_processed.column_names) | |
| processed = DatasetDict({"train": train_processed, "eval": eval_processed}) | |
| # Include the held-out test split if present | |
| if "test" in dataset: | |
| logger.info("Processing test split ...") | |
| processed["test"] = dataset["test"].map( | |
| eval_prepare_fn, | |
| remove_columns=remove_cols, | |
| num_proc=1, | |
| ) | |
| logger.info("Test split done: %d examples", len(processed["test"])) | |
| logger.info("All splits prepared successfully") | |
| return processed | |
| # ------------------------------------------------------------------ | |
| # Private helpers shared by smoke_test and train | |
| # ------------------------------------------------------------------ | |
| def _build_data_collator(self) -> DataCollatorSpeechSeq2SeqWithPadding: | |
| """Instantiate the data collator from current config and model.""" | |
| assert self.processor is not None and self.model is not None | |
| aug_config = self.cfg.get("augmentation", None) | |
| spec_aug_cfg = aug_config.get("spec_augment", None) if aug_config else None | |
| return DataCollatorSpeechSeq2SeqWithPadding( | |
| processor=self.processor, | |
| decoder_start_token_id=self.model.config.decoder_start_token_id, | |
| model=self.model, | |
| spec_augment_config=spec_aug_cfg, | |
| ) | |
| def _prepare_raw_subset(self, raw_dataset: DatasetDict, n_train: int, n_eval: int) -> DatasetDict: | |
| """ | |
| Run feature extraction on a small subset of the raw (un-processed) dataset. | |
| Used by the smoke test to avoid re-processing the full dataset. | |
| """ | |
| assert self.processor is not None, "Call load_model_and_processor() first" | |
| n_train = min(n_train, len(raw_dataset["train"])) | |
| n_eval = min(n_eval, len(raw_dataset["eval"])) | |
| aug_config = self.cfg.get("augmentation", None) | |
| aug_enabled = aug_config is not None and aug_config.get("enabled", False) | |
| train_fn = make_prepare_fn(self.processor, augment_config=aug_config if aug_enabled else None) | |
| eval_fn = make_prepare_fn(self.processor, augment_config=None) | |
| remove_cols = ["audio", "sentence", "duration", "source_audio"] | |
| tiny = raw_dataset.cast_column("audio", Audio(decode=False)) | |
| return DatasetDict({ | |
| "train": tiny["train"].select(range(n_train)).map(train_fn, remove_columns=remove_cols, num_proc=1), | |
| "eval": tiny["eval"].select(range(n_eval)).map(eval_fn, remove_columns=remove_cols, num_proc=1), | |
| }) | |
| # ------------------------------------------------------------------ | |
| # Smoke test — run before full training to catch errors early | |
| # ------------------------------------------------------------------ | |
| def run_smoke_test( | |
| self, | |
| raw_dataset: DatasetDict, | |
| n_train: int = 8, | |
| n_eval: int = 4, | |
| ) -> bool: | |
| """ | |
| Run a micro training loop (2 optimizer steps + 1 evaluation) on a tiny | |
| subset of the raw dataset to verify that the pipeline is fully functional | |
| before committing to a multi-hour full training run. | |
| Checks that: | |
| - Audio preprocessing + feature extraction work end-to-end. | |
| - The data collator dtype cast is correct (fp16/fp32 alignment). | |
| - The model can execute a forward + backward pass without OOM or dtype errors. | |
| - Evaluation generation (predict_with_generate) completes successfully. | |
| - Metric computation (CER/WER) runs without errors. | |
| Args: | |
| raw_dataset: The original (un-processed) DatasetDict with raw audio. | |
| n_train: Number of training samples to include in the smoke test. | |
| n_eval: Number of eval samples to include in the smoke test. | |
| Returns: | |
| True — smoke test passed; safe to start full training. | |
| False — smoke test failed; error details are logged. | |
| """ | |
| assert self.processor is not None and self.model is not None, \ | |
| "Call load_model_and_processor() first" | |
| logger.info("=" * 60) | |
| logger.info("SMOKE TEST — pre-flight check before full training") | |
| logger.info(" train samples : %d (from %d total)", n_train, len(raw_dataset["train"])) | |
| logger.info(" eval samples : %d (from %d total)", n_eval, len(raw_dataset["eval"])) | |
| logger.info(" steps : 2 optimizer steps + 1 evaluation pass") | |
| logger.info(" purpose : verify dtype alignment, forward/backward pass, " | |
| "generate(), and metric computation") | |
| smoke_dir = Path(tempfile.mkdtemp(prefix="whisper_smoke_")) | |
| try: | |
| logger.info("Preparing tiny smoke-test dataset ...") | |
| tiny_processed = self._prepare_raw_subset(raw_dataset, n_train, n_eval) | |
| logger.info("Tiny dataset ready — train=%d, eval=%d", | |
| len(tiny_processed["train"]), len(tiny_processed["eval"])) | |
| use_cuda = torch.cuda.is_available() | |
| fp16_active = self.cfg["training"].get("fp16", False) and use_cuda | |
| gc_enabled = self.cfg["training"].get("gradient_checkpointing", False) | |
| gen_max_len = self.cfg["training"].get("generation_max_length", 225) | |
| smoke_args = Seq2SeqTrainingArguments( | |
| output_dir=str(smoke_dir), | |
| # Run exactly 2 optimizer steps — enough to exercise forward, | |
| # backward, optimizer update, AND one eval loop with generate() | |
| max_steps=2, | |
| per_device_train_batch_size=1, | |
| per_device_eval_batch_size=1, | |
| gradient_accumulation_steps=1, | |
| fp16=fp16_active, | |
| gradient_checkpointing=gc_enabled, | |
| predict_with_generate=True, | |
| generation_max_length=gen_max_len, | |
| eval_strategy="steps", | |
| eval_steps=2, # evaluate after the 2 training steps | |
| save_strategy="no", # no checkpoints for smoke test | |
| logging_steps=1, | |
| report_to="none", # never report smoke-test metrics to wandb | |
| dataloader_num_workers=0, | |
| dataloader_pin_memory=use_cuda, | |
| torch_empty_cache_steps=None, | |
| remove_unused_columns=False, | |
| ) | |
| smoke_trainer = Seq2SeqTrainer( | |
| model=self.model, | |
| args=smoke_args, | |
| train_dataset=tiny_processed["train"], | |
| eval_dataset=tiny_processed["eval"], | |
| data_collator=self._build_data_collator(), | |
| compute_metrics=make_compute_metrics_fn(self.processor), | |
| processing_class=self.processor.feature_extractor, | |
| ) | |
| logger.info("Running 2 training steps ...") | |
| smoke_trainer.train() | |
| logger.info("Running evaluation pass ...") | |
| metrics = smoke_trainer.evaluate() | |
| logger.info("=" * 60) | |
| logger.info("SMOKE TEST PASSED") | |
| logger.info(" eval metrics: %s", {k: f"{v:.4f}" if isinstance(v, float) else v | |
| for k, v in metrics.items()}) | |
| logger.info("Full training run is safe to proceed.") | |
| logger.info("=" * 60) | |
| return True | |
| except Exception as exc: | |
| logger.error("=" * 60) | |
| logger.error("SMOKE TEST FAILED — full training will NOT start") | |
| logger.error("Error type : %s", type(exc).__name__) | |
| logger.error("Error msg : %s", exc) | |
| logger.error("Full traceback:", exc_info=True) | |
| logger.error("Fix the error above, then re-run training.") | |
| logger.error("=" * 60) | |
| return False | |
| finally: | |
| shutil.rmtree(smoke_dir, ignore_errors=True) | |
| logger.debug("Smoke-test temp dir cleaned up: %s", smoke_dir) | |
| # ------------------------------------------------------------------ | |
| # Training | |
| # ------------------------------------------------------------------ | |
| def train(self, dataset: Optional[DatasetDict] = None) -> None: | |
| if dataset is not None: | |
| self.dataset = dataset | |
| if self.model is None or self.processor is None: | |
| self.load_model_and_processor() | |
| processed = self.prepare_datasets() | |
| data_collator = self._build_data_collator() | |
| compute_metrics = make_compute_metrics_fn(self.processor) | |
| use_cuda = torch.cuda.is_available() | |
| t = self.cfg["training"] | |
| num_workers = 0 if platform.system() == "Windows" else t.get("dataloader_num_workers", 4) | |
| # Compute training shape for informational logging | |
| train_samples = len(processed["train"]) | |
| batch_size = t.get("per_device_train_batch_size", 2) | |
| grad_accum = t.get("gradient_accumulation_steps", 8) | |
| effective_batch = batch_size * grad_accum | |
| max_epochs = t.get("num_train_epochs", 5) | |
| steps_per_epoch = max(1, train_samples // effective_batch) | |
| total_steps_estimate = steps_per_epoch * max_epochs | |
| warmup_steps = t.get("warmup_steps", 500) | |
| early_patience = t.get("early_stopping_patience", 0) | |
| fp16_active = t.get("fp16", False) and use_cuda | |
| save_strategy = t.get("save_strategy", "epoch") | |
| eval_strategy = t.get("eval_strategy", "epoch") | |
| logger.info("=" * 60) | |
| logger.info("STEP 3/3 — TRAINING") | |
| logger.info(" Train samples : %d", train_samples) | |
| logger.info(" Eval samples : %d", len(processed["eval"])) | |
| logger.info(" Batch size (per device) : %d", batch_size) | |
| logger.info(" Gradient accumulation : %d steps", grad_accum) | |
| logger.info(" Effective batch size : %d samples per update", effective_batch) | |
| logger.info(" Max epochs : %d", max_epochs) | |
| logger.info(" Steps per epoch (~) : %d", steps_per_epoch) | |
| logger.info(" Total steps (~) : %d", total_steps_estimate) | |
| logger.info(" Learning rate : %g", float(t.get("learning_rate", 1e-5))) | |
| logger.info(" LR warmup steps : %d (%.1f%% of total steps)", | |
| warmup_steps, 100 * warmup_steps / max(1, total_steps_estimate)) | |
| logger.info(" Mixed precision (fp16) : %s", "enabled (GPU)" if fp16_active else "disabled (CPU/fp32)") | |
| logger.info(" Eval strategy : %s", eval_strategy) | |
| logger.info(" Save strategy : %s", save_strategy) | |
| logger.info(" Checkpoint dir : %s", self.output_dir) | |
| logger.info(" Save total limit : %d checkpoints kept", t.get("save_total_limit", 3)) | |
| logger.info(" Best model metric : %s (%s is better)", | |
| t.get("metric_for_best_model", "cer"), | |
| "lower" if not t.get("greater_is_better", False) else "higher") | |
| if early_patience > 0: | |
| logger.info(" Early stopping : patience=%d epochs — training will stop " | |
| "if eval/%s does not improve for %d consecutive epochs", | |
| early_patience, t.get("metric_for_best_model", "cer"), early_patience) | |
| else: | |
| logger.info(" Early stopping : disabled — will run all %d epochs", max_epochs) | |
| if use_cuda: | |
| logger.info(" GPU cache flush : every %d steps", t.get("torch_empty_cache_steps", 50)) | |
| logger.info(" Dataloader workers : %d%s", | |
| num_workers, " (Windows requires 0)" if platform.system() == "Windows" else "") | |
| logger.info("=" * 60) | |
| training_args = Seq2SeqTrainingArguments( | |
| output_dir=str(self.output_dir), | |
| num_train_epochs=max_epochs, | |
| per_device_train_batch_size=batch_size, | |
| per_device_eval_batch_size=t.get("per_device_eval_batch_size", 2), | |
| gradient_accumulation_steps=grad_accum, | |
| learning_rate=float(t.get("learning_rate", 1e-5)), | |
| warmup_steps=warmup_steps, | |
| eval_strategy=eval_strategy, | |
| save_strategy=save_strategy, | |
| save_total_limit=t.get("save_total_limit", 3), | |
| load_best_model_at_end=t.get("load_best_model_at_end", True), | |
| metric_for_best_model=t.get("metric_for_best_model", "cer"), | |
| greater_is_better=t.get("greater_is_better", False), | |
| # fp16 only makes sense on GPU; CPU training stays in fp32 | |
| fp16=fp16_active, | |
| # gradient_checkpointing is NOT gated on CUDA — it's critical on | |
| # CPU too; it recomputes activations during backward pass instead | |
| # of keeping all of them in memory simultaneously. | |
| gradient_checkpointing=t.get("gradient_checkpointing", True), | |
| predict_with_generate=True, | |
| generation_max_length=t.get("generation_max_length", 225), | |
| logging_steps=t.get("logging_steps", 10), | |
| report_to=t.get("report_to", "none"), | |
| dataloader_num_workers=num_workers, | |
| dataloader_pin_memory=use_cuda, | |
| # Flush GPU cache every N steps to avoid memory fragmentation | |
| torch_empty_cache_steps=t.get("torch_empty_cache_steps", 50) if use_cuda else None, | |
| remove_unused_columns=False, # required for PEFT models | |
| ) | |
| callbacks = [] | |
| if early_patience > 0: | |
| callbacks.append(EarlyStoppingCallback(early_stopping_patience=early_patience)) | |
| logger.info("EarlyStoppingCallback registered with patience=%d", early_patience) | |
| trainer = Seq2SeqTrainer( | |
| model=self.model, | |
| args=training_args, | |
| train_dataset=processed["train"], | |
| eval_dataset=processed["eval"], | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| processing_class=self.processor.feature_extractor, | |
| callbacks=callbacks if callbacks else None, | |
| ) | |
| logger.info("Trainer initialised — starting training loop now ...") | |
| logger.info("Each epoch = %d steps. Loss is logged every %d steps.", | |
| steps_per_epoch, t.get("logging_steps", 10)) | |
| logger.info("Evaluation and checkpoint save occur at the end of each epoch.") | |
| trainer.train() | |
| logger.info("Training loop finished") | |
| # Save LoRA adapter weights + processor together | |
| best_dir = self.output_dir / "best_model" | |
| logger.info("Saving best model (LoRA adapter + processor) to %s ...", best_dir) | |
| trainer.save_model(str(best_dir)) | |
| self.processor.save_pretrained(str(best_dir)) | |
| logger.info("Best model saved successfully to %s", best_dir) | |
| logger.info( | |
| "To use at inference time: load with PeftModel.from_pretrained('%s') " | |
| "then call .merge_and_unload() to merge LoRA weights into the base model", | |
| best_dir, | |
| ) | |