ai_exec / src /training /merge_adapter.py
Chaitanya-aitf's picture
Upload 38 files
45ee481 verified
"""
Merge Adapter Module
Merge LoRA adapter weights into base model for deployment.
Optionally push the merged model to Hugging Face Hub.
Example usage:
from src.training.merge_adapter import merge_adapter
merge_adapter(
base_model="Qwen/Qwen3-4B-Instruct",
adapter_path="./outputs/final_adapter",
output_path="./outputs/merged_model",
push_to_hub=True,
hub_model_id="username/ceo-voice-model-merged",
)
"""
import os
from pathlib import Path
from typing import Optional
from loguru import logger
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
PEFT_AVAILABLE = True
except ImportError:
PEFT_AVAILABLE = False
logger.warning("PEFT not available")
def merge_adapter(
base_model: str,
adapter_path: str | Path,
output_path: str | Path,
torch_dtype: str = "bfloat16",
push_to_hub: bool = False,
hub_model_id: Optional[str] = None,
hub_token: Optional[str] = None,
private: bool = True,
) -> str:
"""
Merge LoRA adapter into base model.
Args:
base_model: Base model name or path
adapter_path: Path to LoRA adapter
output_path: Path to save merged model
torch_dtype: Torch dtype for loading
push_to_hub: Whether to push to HF Hub
hub_model_id: Hub repository ID
hub_token: HF token
private: Whether hub repo should be private
Returns:
Path to merged model
"""
if not PEFT_AVAILABLE:
raise ImportError("PEFT required. Run: pip install peft")
adapter_path = Path(adapter_path)
output_path = Path(output_path)
# Get torch dtype
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
dtype = dtype_map.get(torch_dtype, torch.bfloat16)
logger.info(f"Loading base model: {base_model}")
base = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True,
)
logger.info(f"Loading adapter from: {adapter_path}")
model = PeftModel.from_pretrained(base, str(adapter_path))
logger.info("Merging adapter weights...")
model = model.merge_and_unload()
# Save merged model
output_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving merged model to: {output_path}")
model.save_pretrained(str(output_path), safe_serialization=True)
# Save tokenizer
logger.info("Saving tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
str(adapter_path),
trust_remote_code=True,
)
tokenizer.save_pretrained(str(output_path))
# Push to hub if requested
if push_to_hub and hub_model_id:
token = hub_token or os.environ.get("HF_TOKEN")
logger.info(f"Pushing to Hub: {hub_model_id}")
model.push_to_hub(
hub_model_id,
token=token,
private=private,
safe_serialization=True,
)
tokenizer.push_to_hub(
hub_model_id,
token=token,
private=private,
)
logger.info(f"Model pushed to: https://huggingface.co/{hub_model_id}")
logger.info("Merge complete!")
return str(output_path)
def merge_adapter_for_inference(
base_model: str,
adapter_path: str | Path,
torch_dtype: str = "bfloat16",
load_in_4bit: bool = False,
load_in_8bit: bool = False,
) -> tuple:
"""
Load and merge adapter for inference (keeps in memory).
Args:
base_model: Base model name or path
adapter_path: Path to LoRA adapter
torch_dtype: Torch dtype
load_in_4bit: Use 4-bit quantization
load_in_8bit: Use 8-bit quantization
Returns:
Tuple of (model, tokenizer)
"""
if not PEFT_AVAILABLE:
raise ImportError("PEFT required. Run: pip install peft")
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
dtype = dtype_map.get(torch_dtype, torch.bfloat16)
# Quantization config
quantization_config = None
if load_in_4bit or load_in_8bit:
from transformers import BitsAndBytesConfig
if load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=dtype,
)
else:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
logger.info(f"Loading base model: {base_model}")
base = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype=dtype,
device_map="auto",
trust_remote_code=True,
quantization_config=quantization_config,
)
logger.info(f"Loading adapter from: {adapter_path}")
model = PeftModel.from_pretrained(base, str(adapter_path))
logger.info("Merging adapter weights...")
model = model.merge_and_unload()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
str(adapter_path),
trust_remote_code=True,
)
return model, tokenizer
def push_adapter_to_hub(
adapter_path: str | Path,
hub_model_id: str,
hub_token: Optional[str] = None,
private: bool = True,
) -> str:
"""
Push LoRA adapter (without merging) to Hugging Face Hub.
Args:
adapter_path: Path to adapter
hub_model_id: Hub repository ID
hub_token: HF token
private: Whether repo should be private
Returns:
Hub URL
"""
from huggingface_hub import HfApi
adapter_path = Path(adapter_path)
token = hub_token or os.environ.get("HF_TOKEN")
api = HfApi(token=token)
logger.info(f"Creating/updating repo: {hub_model_id}")
api.create_repo(
repo_id=hub_model_id,
private=private,
exist_ok=True,
)
logger.info(f"Uploading adapter from: {adapter_path}")
api.upload_folder(
folder_path=str(adapter_path),
repo_id=hub_model_id,
token=token,
)
hub_url = f"https://huggingface.co/{hub_model_id}"
logger.info(f"Adapter pushed to: {hub_url}")
return hub_url
def main():
"""CLI entry point for adapter merging."""
import argparse
parser = argparse.ArgumentParser(
description="Merge LoRA adapter into base model",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Merge and save locally
python merge_adapter.py \\
--base-model Qwen/Qwen3-4B-Instruct \\
--adapter ./outputs/final_adapter \\
--output ./outputs/merged_model
# Merge and push to Hub
python merge_adapter.py \\
--base-model Qwen/Qwen3-4B-Instruct \\
--adapter ./outputs/final_adapter \\
--output ./outputs/merged_model \\
--push-to-hub \\
--hub-model-id username/ceo-voice-model-merged
# Push adapter only (no merge)
python merge_adapter.py \\
--adapter ./outputs/final_adapter \\
--push-adapter-only \\
--hub-model-id username/ceo-voice-adapter
""",
)
parser.add_argument(
"--base-model",
default="Qwen/Qwen3-4B-Instruct",
help="Base model name",
)
parser.add_argument(
"--adapter",
required=True,
help="Path to LoRA adapter",
)
parser.add_argument(
"--output",
help="Output path for merged model",
)
parser.add_argument(
"--dtype",
choices=["float16", "bfloat16", "float32"],
default="bfloat16",
help="Torch dtype",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Push merged model to Hub",
)
parser.add_argument(
"--push-adapter-only",
action="store_true",
help="Push adapter only (no merge)",
)
parser.add_argument(
"--hub-model-id",
help="Hub model ID",
)
parser.add_argument(
"--private",
action="store_true",
default=True,
help="Make Hub repo private",
)
args = parser.parse_args()
if args.push_adapter_only:
# Just push adapter
if not args.hub_model_id:
parser.error("--hub-model-id required with --push-adapter-only")
url = push_adapter_to_hub(
adapter_path=args.adapter,
hub_model_id=args.hub_model_id,
private=args.private,
)
print(f"\nAdapter pushed to: {url}")
else:
# Merge and optionally push
if not args.output and not args.push_to_hub:
parser.error("Either --output or --push-to-hub required")
output_path = args.output or "./merged_model"
merge_adapter(
base_model=args.base_model,
adapter_path=args.adapter,
output_path=output_path,
torch_dtype=args.dtype,
push_to_hub=args.push_to_hub,
hub_model_id=args.hub_model_id,
private=args.private,
)
print(f"\nMerged model saved to: {output_path}")
if args.push_to_hub:
print(f"Also pushed to: https://huggingface.co/{args.hub_model_id}")
if __name__ == "__main__":
main()