File size: 3,756 Bytes
0db822c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
Merge the trained LoRA adapter weights into the base Whisper model.

This script loads the base model and the LoRA adapter from outputs/checkpoints/best_model
(or the provided adapter directory). It then calls `merge_and_unload()` on the the PEFT
model to fold the adapter weights into the base linear layers.

The resulting standalone Hugging Face model is saved to `outputs/checkpoints/merged_model`
and can be used for inference directly without needing the `peft` library.

Usage:
    python scripts/merge_lora.py
    python scripts/merge_lora.py --adapter outputs/checkpoints/best_model --output outputs/checkpoints/merged_model
"""

from __future__ import annotations

import argparse
import logging
import sys
from pathlib import Path

import torch
import yaml
from peft import PeftModel
from transformers import WhisperForConditionalGeneration, WhisperProcessor

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s  %(levelname)-8s  %(message)s",
    datefmt="%H:%M:%S",
)
logger = logging.getLogger(__name__)


def main(adapter_path: str, output_path: str, config_path: str) -> None:
    root = Path(__file__).parent.parent

    config_file = root / config_path
    if not config_file.exists():
        logger.error("Configuration file not found at %s", config_file)
        sys.exit(1)

    with config_file.open("r", encoding="utf-8") as fh:
        cfg = yaml.safe_load(fh)

    base_model_name = cfg["model"]["base_model"]
    
    adapter_dir = root / adapter_path
    output_dir = root / output_path

    if not adapter_dir.exists():
        logger.error("LoRA adapter directory not found at %s", adapter_dir)
        sys.exit(1)

    logger.info("Loading base standard model: %s", base_model_name)
    # Load model. We load on CPU or GPU based on availability
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # Using float16 if cuda is available, else float32
    torch_dtype = torch.float16 if device == "cuda" else torch.float32

    logger.info("Loading base model in %s", torch_dtype)
    base_model = WhisperForConditionalGeneration.from_pretrained(
        base_model_name, torch_dtype=torch_dtype, low_cpu_mem_usage=True
    )
    
    logger.info("Loading processor from %s", adapter_dir)
    processor = WhisperProcessor.from_pretrained(str(adapter_dir))

    logger.info("Loading LoRA adapter from %s", adapter_dir)
    peft_model = PeftModel.from_pretrained(base_model, str(adapter_dir))

    logger.info("Merging LoRA weights into the base model. This may take a moment...")
    merged_model = peft_model.merge_and_unload()
    logger.info("Merge complete.")

    logger.info("Saving standalone merged model to %s", output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    merged_model.save_pretrained(str(output_dir))
    processor.save_pretrained(str(output_dir))
    logger.info("Model and processor successfully saved.")
    
    logger.info("You can now transcribe using:")
    logger.info("python scripts/transcribe.py --model %s your_audio.wav", output_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Merge Whisper LoRA adapter into base model")
    parser.add_argument(
        "--adapter",
        default="outputs/checkpoints/best_model",
        help="Path to the trained LoRA adapter directory",
    )
    parser.add_argument(
        "--output",
        default="outputs/checkpoints/merged_model",
        help="Directory to save the merged standalone model",
    )
    parser.add_argument(
        "--config",
        default="config/training_config.yaml",
        help="Path to training config used (for base model lookup)",
    )
    args = parser.parse_args()
    main(args.adapter, args.output, args.config)