File size: 1,919 Bytes
4d939fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""Merge LoRA adapters into base models."""

from __future__ import annotations

import argparse
from pathlib import Path
from typing import Iterable, Optional

from peft import PeftModel
from transformers import AutoModel, AutoTokenizer


def merge_lora_adapter(base_model: str, adapter_path: Path, output_dir: Path, safe_serialization: bool = True) -> None:
    output_dir.mkdir(parents=True, exist_ok=True)

    model = AutoModel.from_pretrained(base_model, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(base_model)

    peft_model = PeftModel.from_pretrained(model, str(adapter_path))
    merged_model = peft_model.merge_and_unload()
    merged_model.save_pretrained(output_dir, safe_serialization=safe_serialization)
    tokenizer.save_pretrained(output_dir)


def build_argument_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Merge a LoRA adapter into its base Hugging Face model.")
    parser.add_argument("--base-model", type=str, required=True, help="Base model name or path.")
    parser.add_argument("--adapter-path", type=Path, required=True, help="Directory containing the LoRA adapter weights.")
    parser.add_argument("--output-dir", type=Path, required=True, help="Directory to store the merged model.")
    parser.add_argument(
        "--no-safe-serialization",
        action="store_true",
        help="Disable safetensors when saving the merged model.",
    )
    return parser


def main(argv: Optional[Iterable[str]] = None) -> None:
    parser = build_argument_parser()
    args = parser.parse_args(argv)
    merge_lora_adapter(
        args.base_model,
        args.adapter_path,
        args.output_dir,
        safe_serialization=not args.no_safe_serialization,
    )


if __name__ == "__main__":
    main()

__all__ = ["build_argument_parser", "merge_lora_adapter", "main"]