Spaces:
Sleeping
Sleeping
| """ | |
| Merge the trained LoRA adapter weights into the base Whisper model. | |
| This script loads the base model and the LoRA adapter from outputs/checkpoints/best_model | |
| (or the provided adapter directory). It then calls `merge_and_unload()` on the the PEFT | |
| model to fold the adapter weights into the base linear layers. | |
| The resulting standalone Hugging Face model is saved to `outputs/checkpoints/merged_model` | |
| and can be used for inference directly without needing the `peft` library. | |
| Usage: | |
| python scripts/merge_lora.py | |
| python scripts/merge_lora.py --adapter outputs/checkpoints/best_model --output outputs/checkpoints/merged_model | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import logging | |
| import sys | |
| from pathlib import Path | |
| import torch | |
| import yaml | |
| from peft import PeftModel | |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s %(levelname)-8s %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def main(adapter_path: str, output_path: str, config_path: str) -> None: | |
| root = Path(__file__).parent.parent | |
| config_file = root / config_path | |
| if not config_file.exists(): | |
| logger.error("Configuration file not found at %s", config_file) | |
| sys.exit(1) | |
| with config_file.open("r", encoding="utf-8") as fh: | |
| cfg = yaml.safe_load(fh) | |
| base_model_name = cfg["model"]["base_model"] | |
| adapter_dir = root / adapter_path | |
| output_dir = root / output_path | |
| if not adapter_dir.exists(): | |
| logger.error("LoRA adapter directory not found at %s", adapter_dir) | |
| sys.exit(1) | |
| logger.info("Loading base standard model: %s", base_model_name) | |
| # Load model. We load on CPU or GPU based on availability | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Using float16 if cuda is available, else float32 | |
| torch_dtype = torch.float16 if device == "cuda" else torch.float32 | |
| logger.info("Loading base model in %s", torch_dtype) | |
| base_model = WhisperForConditionalGeneration.from_pretrained( | |
| base_model_name, torch_dtype=torch_dtype, low_cpu_mem_usage=True | |
| ) | |
| logger.info("Loading processor from %s", adapter_dir) | |
| processor = WhisperProcessor.from_pretrained(str(adapter_dir)) | |
| logger.info("Loading LoRA adapter from %s", adapter_dir) | |
| peft_model = PeftModel.from_pretrained(base_model, str(adapter_dir)) | |
| logger.info("Merging LoRA weights into the base model. This may take a moment...") | |
| merged_model = peft_model.merge_and_unload() | |
| logger.info("Merge complete.") | |
| logger.info("Saving standalone merged model to %s", output_dir) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| merged_model.save_pretrained(str(output_dir)) | |
| processor.save_pretrained(str(output_dir)) | |
| logger.info("Model and processor successfully saved.") | |
| logger.info("You can now transcribe using:") | |
| logger.info("python scripts/transcribe.py --model %s your_audio.wav", output_path) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Merge Whisper LoRA adapter into base model") | |
| parser.add_argument( | |
| "--adapter", | |
| default="outputs/checkpoints/best_model", | |
| help="Path to the trained LoRA adapter directory", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| default="outputs/checkpoints/merged_model", | |
| help="Directory to save the merged standalone model", | |
| ) | |
| parser.add_argument( | |
| "--config", | |
| default="config/training_config.yaml", | |
| help="Path to training config used (for base model lookup)", | |
| ) | |
| args = parser.parse_args() | |
| main(args.adapter, args.output, args.config) | |