Spaces:
Running
Running
File size: 1,919 Bytes
4d939fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
"""Merge LoRA adapters into base models."""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Iterable, Optional
from peft import PeftModel
from transformers import AutoModel, AutoTokenizer
def merge_lora_adapter(base_model: str, adapter_path: Path, output_dir: Path, safe_serialization: bool = True) -> None:
output_dir.mkdir(parents=True, exist_ok=True)
model = AutoModel.from_pretrained(base_model, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(base_model)
peft_model = PeftModel.from_pretrained(model, str(adapter_path))
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained(output_dir, safe_serialization=safe_serialization)
tokenizer.save_pretrained(output_dir)
def build_argument_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Merge a LoRA adapter into its base Hugging Face model.")
parser.add_argument("--base-model", type=str, required=True, help="Base model name or path.")
parser.add_argument("--adapter-path", type=Path, required=True, help="Directory containing the LoRA adapter weights.")
parser.add_argument("--output-dir", type=Path, required=True, help="Directory to store the merged model.")
parser.add_argument(
"--no-safe-serialization",
action="store_true",
help="Disable safetensors when saving the merged model.",
)
return parser
def main(argv: Optional[Iterable[str]] = None) -> None:
parser = build_argument_parser()
args = parser.parse_args(argv)
merge_lora_adapter(
args.base_model,
args.adapter_path,
args.output_dir,
safe_serialization=not args.no_safe_serialization,
)
if __name__ == "__main__":
main()
__all__ = ["build_argument_parser", "merge_lora_adapter", "main"]
|