Speach-To-Text / scripts /merge_lora.py
MIP-Tech's picture
Deploy to HF Spaces
0db822c
"""
Merge the trained LoRA adapter weights into the base Whisper model.
This script loads the base model and the LoRA adapter from outputs/checkpoints/best_model
(or the provided adapter directory). It then calls `merge_and_unload()` on the the PEFT
model to fold the adapter weights into the base linear layers.
The resulting standalone Hugging Face model is saved to `outputs/checkpoints/merged_model`
and can be used for inference directly without needing the `peft` library.
Usage:
python scripts/merge_lora.py
python scripts/merge_lora.py --adapter outputs/checkpoints/best_model --output outputs/checkpoints/merged_model
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
import torch
import yaml
from peft import PeftModel
from transformers import WhisperForConditionalGeneration, WhisperProcessor
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)
def main(adapter_path: str, output_path: str, config_path: str) -> None:
root = Path(__file__).parent.parent
config_file = root / config_path
if not config_file.exists():
logger.error("Configuration file not found at %s", config_file)
sys.exit(1)
with config_file.open("r", encoding="utf-8") as fh:
cfg = yaml.safe_load(fh)
base_model_name = cfg["model"]["base_model"]
adapter_dir = root / adapter_path
output_dir = root / output_path
if not adapter_dir.exists():
logger.error("LoRA adapter directory not found at %s", adapter_dir)
sys.exit(1)
logger.info("Loading base standard model: %s", base_model_name)
# Load model. We load on CPU or GPU based on availability
device = "cuda" if torch.cuda.is_available() else "cpu"
# Using float16 if cuda is available, else float32
torch_dtype = torch.float16 if device == "cuda" else torch.float32
logger.info("Loading base model in %s", torch_dtype)
base_model = WhisperForConditionalGeneration.from_pretrained(
base_model_name, torch_dtype=torch_dtype, low_cpu_mem_usage=True
)
logger.info("Loading processor from %s", adapter_dir)
processor = WhisperProcessor.from_pretrained(str(adapter_dir))
logger.info("Loading LoRA adapter from %s", adapter_dir)
peft_model = PeftModel.from_pretrained(base_model, str(adapter_dir))
logger.info("Merging LoRA weights into the base model. This may take a moment...")
merged_model = peft_model.merge_and_unload()
logger.info("Merge complete.")
logger.info("Saving standalone merged model to %s", output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
merged_model.save_pretrained(str(output_dir))
processor.save_pretrained(str(output_dir))
logger.info("Model and processor successfully saved.")
logger.info("You can now transcribe using:")
logger.info("python scripts/transcribe.py --model %s your_audio.wav", output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Merge Whisper LoRA adapter into base model")
parser.add_argument(
"--adapter",
default="outputs/checkpoints/best_model",
help="Path to the trained LoRA adapter directory",
)
parser.add_argument(
"--output",
default="outputs/checkpoints/merged_model",
help="Directory to save the merged standalone model",
)
parser.add_argument(
"--config",
default="config/training_config.yaml",
help="Path to training config used (for base model lookup)",
)
args = parser.parse_args()
main(args.adapter, args.output, args.config)