""" Merge the trained LoRA adapter weights into the base Whisper model. This script loads the base model and the LoRA adapter from outputs/checkpoints/best_model (or the provided adapter directory). It then calls `merge_and_unload()` on the the PEFT model to fold the adapter weights into the base linear layers. The resulting standalone Hugging Face model is saved to `outputs/checkpoints/merged_model` and can be used for inference directly without needing the `peft` library. Usage: python scripts/merge_lora.py python scripts/merge_lora.py --adapter outputs/checkpoints/best_model --output outputs/checkpoints/merged_model """ from __future__ import annotations import argparse import logging import sys from pathlib import Path import torch import yaml from peft import PeftModel from transformers import WhisperForConditionalGeneration, WhisperProcessor # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%H:%M:%S", ) logger = logging.getLogger(__name__) def main(adapter_path: str, output_path: str, config_path: str) -> None: root = Path(__file__).parent.parent config_file = root / config_path if not config_file.exists(): logger.error("Configuration file not found at %s", config_file) sys.exit(1) with config_file.open("r", encoding="utf-8") as fh: cfg = yaml.safe_load(fh) base_model_name = cfg["model"]["base_model"] adapter_dir = root / adapter_path output_dir = root / output_path if not adapter_dir.exists(): logger.error("LoRA adapter directory not found at %s", adapter_dir) sys.exit(1) logger.info("Loading base standard model: %s", base_model_name) # Load model. We load on CPU or GPU based on availability device = "cuda" if torch.cuda.is_available() else "cpu" # Using float16 if cuda is available, else float32 torch_dtype = torch.float16 if device == "cuda" else torch.float32 logger.info("Loading base model in %s", torch_dtype) base_model = WhisperForConditionalGeneration.from_pretrained( base_model_name, torch_dtype=torch_dtype, low_cpu_mem_usage=True ) logger.info("Loading processor from %s", adapter_dir) processor = WhisperProcessor.from_pretrained(str(adapter_dir)) logger.info("Loading LoRA adapter from %s", adapter_dir) peft_model = PeftModel.from_pretrained(base_model, str(adapter_dir)) logger.info("Merging LoRA weights into the base model. This may take a moment...") merged_model = peft_model.merge_and_unload() logger.info("Merge complete.") logger.info("Saving standalone merged model to %s", output_dir) output_dir.mkdir(parents=True, exist_ok=True) merged_model.save_pretrained(str(output_dir)) processor.save_pretrained(str(output_dir)) logger.info("Model and processor successfully saved.") logger.info("You can now transcribe using:") logger.info("python scripts/transcribe.py --model %s your_audio.wav", output_path) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Merge Whisper LoRA adapter into base model") parser.add_argument( "--adapter", default="outputs/checkpoints/best_model", help="Path to the trained LoRA adapter directory", ) parser.add_argument( "--output", default="outputs/checkpoints/merged_model", help="Directory to save the merged standalone model", ) parser.add_argument( "--config", default="config/training_config.yaml", help="Path to training config used (for base model lookup)", ) args = parser.parse_args() main(args.adapter, args.output, args.config)