Spaces:
Paused
Paused
| """ | |
| Merge Adapter Module | |
| Merge LoRA adapter weights into base model for deployment. | |
| Optionally push the merged model to Hugging Face Hub. | |
| Example usage: | |
| from src.training.merge_adapter import merge_adapter | |
| merge_adapter( | |
| base_model="Qwen/Qwen3-4B-Instruct", | |
| adapter_path="./outputs/final_adapter", | |
| output_path="./outputs/merged_model", | |
| push_to_hub=True, | |
| hub_model_id="username/ceo-voice-model-merged", | |
| ) | |
| """ | |
| import os | |
| from pathlib import Path | |
| from typing import Optional | |
| from loguru import logger | |
| try: | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| PEFT_AVAILABLE = True | |
| except ImportError: | |
| PEFT_AVAILABLE = False | |
| logger.warning("PEFT not available") | |
| def merge_adapter( | |
| base_model: str, | |
| adapter_path: str | Path, | |
| output_path: str | Path, | |
| torch_dtype: str = "bfloat16", | |
| push_to_hub: bool = False, | |
| hub_model_id: Optional[str] = None, | |
| hub_token: Optional[str] = None, | |
| private: bool = True, | |
| ) -> str: | |
| """ | |
| Merge LoRA adapter into base model. | |
| Args: | |
| base_model: Base model name or path | |
| adapter_path: Path to LoRA adapter | |
| output_path: Path to save merged model | |
| torch_dtype: Torch dtype for loading | |
| push_to_hub: Whether to push to HF Hub | |
| hub_model_id: Hub repository ID | |
| hub_token: HF token | |
| private: Whether hub repo should be private | |
| Returns: | |
| Path to merged model | |
| """ | |
| if not PEFT_AVAILABLE: | |
| raise ImportError("PEFT required. Run: pip install peft") | |
| adapter_path = Path(adapter_path) | |
| output_path = Path(output_path) | |
| # Get torch dtype | |
| dtype_map = { | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "float32": torch.float32, | |
| } | |
| dtype = dtype_map.get(torch_dtype, torch.bfloat16) | |
| logger.info(f"Loading base model: {base_model}") | |
| base = AutoModelForCausalLM.from_pretrained( | |
| base_model, | |
| torch_dtype=dtype, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| logger.info(f"Loading adapter from: {adapter_path}") | |
| model = PeftModel.from_pretrained(base, str(adapter_path)) | |
| logger.info("Merging adapter weights...") | |
| model = model.merge_and_unload() | |
| # Save merged model | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| logger.info(f"Saving merged model to: {output_path}") | |
| model.save_pretrained(str(output_path), safe_serialization=True) | |
| # Save tokenizer | |
| logger.info("Saving tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| str(adapter_path), | |
| trust_remote_code=True, | |
| ) | |
| tokenizer.save_pretrained(str(output_path)) | |
| # Push to hub if requested | |
| if push_to_hub and hub_model_id: | |
| token = hub_token or os.environ.get("HF_TOKEN") | |
| logger.info(f"Pushing to Hub: {hub_model_id}") | |
| model.push_to_hub( | |
| hub_model_id, | |
| token=token, | |
| private=private, | |
| safe_serialization=True, | |
| ) | |
| tokenizer.push_to_hub( | |
| hub_model_id, | |
| token=token, | |
| private=private, | |
| ) | |
| logger.info(f"Model pushed to: https://huggingface.co/{hub_model_id}") | |
| logger.info("Merge complete!") | |
| return str(output_path) | |
| def merge_adapter_for_inference( | |
| base_model: str, | |
| adapter_path: str | Path, | |
| torch_dtype: str = "bfloat16", | |
| load_in_4bit: bool = False, | |
| load_in_8bit: bool = False, | |
| ) -> tuple: | |
| """ | |
| Load and merge adapter for inference (keeps in memory). | |
| Args: | |
| base_model: Base model name or path | |
| adapter_path: Path to LoRA adapter | |
| torch_dtype: Torch dtype | |
| load_in_4bit: Use 4-bit quantization | |
| load_in_8bit: Use 8-bit quantization | |
| Returns: | |
| Tuple of (model, tokenizer) | |
| """ | |
| if not PEFT_AVAILABLE: | |
| raise ImportError("PEFT required. Run: pip install peft") | |
| dtype_map = { | |
| "float16": torch.float16, | |
| "bfloat16": torch.bfloat16, | |
| "float32": torch.float32, | |
| } | |
| dtype = dtype_map.get(torch_dtype, torch.bfloat16) | |
| # Quantization config | |
| quantization_config = None | |
| if load_in_4bit or load_in_8bit: | |
| from transformers import BitsAndBytesConfig | |
| if load_in_4bit: | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=dtype, | |
| ) | |
| else: | |
| quantization_config = BitsAndBytesConfig(load_in_8bit=True) | |
| logger.info(f"Loading base model: {base_model}") | |
| base = AutoModelForCausalLM.from_pretrained( | |
| base_model, | |
| torch_dtype=dtype, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| quantization_config=quantization_config, | |
| ) | |
| logger.info(f"Loading adapter from: {adapter_path}") | |
| model = PeftModel.from_pretrained(base, str(adapter_path)) | |
| logger.info("Merging adapter weights...") | |
| model = model.merge_and_unload() | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| str(adapter_path), | |
| trust_remote_code=True, | |
| ) | |
| return model, tokenizer | |
| def push_adapter_to_hub( | |
| adapter_path: str | Path, | |
| hub_model_id: str, | |
| hub_token: Optional[str] = None, | |
| private: bool = True, | |
| ) -> str: | |
| """ | |
| Push LoRA adapter (without merging) to Hugging Face Hub. | |
| Args: | |
| adapter_path: Path to adapter | |
| hub_model_id: Hub repository ID | |
| hub_token: HF token | |
| private: Whether repo should be private | |
| Returns: | |
| Hub URL | |
| """ | |
| from huggingface_hub import HfApi | |
| adapter_path = Path(adapter_path) | |
| token = hub_token or os.environ.get("HF_TOKEN") | |
| api = HfApi(token=token) | |
| logger.info(f"Creating/updating repo: {hub_model_id}") | |
| api.create_repo( | |
| repo_id=hub_model_id, | |
| private=private, | |
| exist_ok=True, | |
| ) | |
| logger.info(f"Uploading adapter from: {adapter_path}") | |
| api.upload_folder( | |
| folder_path=str(adapter_path), | |
| repo_id=hub_model_id, | |
| token=token, | |
| ) | |
| hub_url = f"https://huggingface.co/{hub_model_id}" | |
| logger.info(f"Adapter pushed to: {hub_url}") | |
| return hub_url | |
| def main(): | |
| """CLI entry point for adapter merging.""" | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| description="Merge LoRA adapter into base model", | |
| formatter_class=argparse.RawDescriptionHelpFormatter, | |
| epilog=""" | |
| Examples: | |
| # Merge and save locally | |
| python merge_adapter.py \\ | |
| --base-model Qwen/Qwen3-4B-Instruct \\ | |
| --adapter ./outputs/final_adapter \\ | |
| --output ./outputs/merged_model | |
| # Merge and push to Hub | |
| python merge_adapter.py \\ | |
| --base-model Qwen/Qwen3-4B-Instruct \\ | |
| --adapter ./outputs/final_adapter \\ | |
| --output ./outputs/merged_model \\ | |
| --push-to-hub \\ | |
| --hub-model-id username/ceo-voice-model-merged | |
| # Push adapter only (no merge) | |
| python merge_adapter.py \\ | |
| --adapter ./outputs/final_adapter \\ | |
| --push-adapter-only \\ | |
| --hub-model-id username/ceo-voice-adapter | |
| """, | |
| ) | |
| parser.add_argument( | |
| "--base-model", | |
| default="Qwen/Qwen3-4B-Instruct", | |
| help="Base model name", | |
| ) | |
| parser.add_argument( | |
| "--adapter", | |
| required=True, | |
| help="Path to LoRA adapter", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| help="Output path for merged model", | |
| ) | |
| parser.add_argument( | |
| "--dtype", | |
| choices=["float16", "bfloat16", "float32"], | |
| default="bfloat16", | |
| help="Torch dtype", | |
| ) | |
| parser.add_argument( | |
| "--push-to-hub", | |
| action="store_true", | |
| help="Push merged model to Hub", | |
| ) | |
| parser.add_argument( | |
| "--push-adapter-only", | |
| action="store_true", | |
| help="Push adapter only (no merge)", | |
| ) | |
| parser.add_argument( | |
| "--hub-model-id", | |
| help="Hub model ID", | |
| ) | |
| parser.add_argument( | |
| "--private", | |
| action="store_true", | |
| default=True, | |
| help="Make Hub repo private", | |
| ) | |
| args = parser.parse_args() | |
| if args.push_adapter_only: | |
| # Just push adapter | |
| if not args.hub_model_id: | |
| parser.error("--hub-model-id required with --push-adapter-only") | |
| url = push_adapter_to_hub( | |
| adapter_path=args.adapter, | |
| hub_model_id=args.hub_model_id, | |
| private=args.private, | |
| ) | |
| print(f"\nAdapter pushed to: {url}") | |
| else: | |
| # Merge and optionally push | |
| if not args.output and not args.push_to_hub: | |
| parser.error("Either --output or --push-to-hub required") | |
| output_path = args.output or "./merged_model" | |
| merge_adapter( | |
| base_model=args.base_model, | |
| adapter_path=args.adapter, | |
| output_path=output_path, | |
| torch_dtype=args.dtype, | |
| push_to_hub=args.push_to_hub, | |
| hub_model_id=args.hub_model_id, | |
| private=args.private, | |
| ) | |
| print(f"\nMerged model saved to: {output_path}") | |
| if args.push_to_hub: | |
| print(f"Also pushed to: https://huggingface.co/{args.hub_model_id}") | |
| if __name__ == "__main__": | |
| main() | |