Spaces:
Sleeping
Sleeping
| """Utility for merging LoRA weights into base language models. | |
| This module provides functions to merge trained LoRA adapter weights with a base model, | |
| producing a standalone model that incorporates the adaptations without needing the | |
| LoRA architecture during inference. | |
| """ | |
| import argparse | |
| import os | |
| import gc | |
| import sys | |
| import logging | |
| import traceback | |
| import torch | |
| import datetime | |
| from typing import Optional, Dict, Any | |
| from peft import PeftModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from lpm_kernel.L2.memory_manager import get_memory_manager | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(level=logging.INFO) | |
| def merge_lora_weights(base_model_path, lora_adapter_path, output_model_path): | |
| """Merge LoRA weights into a base model and save the result. | |
| This function loads a base model and a LoRA adapter, merges them together, | |
| and saves the resulting model to the specified output path. It leverages | |
| PyTorch's built-in memory management features. | |
| Args: | |
| base_model_path: Path to the base model directory. | |
| lora_adapter_path: Path to the LoRA adapter directory. | |
| output_model_path: Path where the merged model will be saved. | |
| """ | |
| # Get memory manager | |
| memory_manager = get_memory_manager() | |
| try: | |
| # Log initial memory state | |
| memory_info = memory_manager.get_memory_info() | |
| logger.info(f"Initial memory state: RAM used: {memory_info['ram_used_gb']:.2f}GB, " | |
| f"available: {memory_info['ram_available_gb']:.2f}GB") | |
| # Determine if CUDA is available and should be used | |
| use_cuda = memory_manager.cuda_available | |
| device = "cuda" if use_cuda else "cpu" | |
| if use_cuda: | |
| logger.info(f"CUDA is available. VRAM used: {memory_info.get('vram_used_gb', 0):.2f}GB") | |
| else: | |
| logger.warning("CUDA not available or not enabled. Using CPU for model operations.") | |
| # Clean up memory before starting | |
| memory_manager.cleanup_memory(force=True) | |
| # Explicitly set device configuration based on available hardware | |
| device_map = "auto" if use_cuda else None | |
| dtype = torch.float16 if use_cuda else torch.float32 | |
| logger.info(f"Loading base model from {base_model_path} with device_map={device_map}, dtype={dtype}") | |
| # Use explicit configuration for GPU utilization | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base_model_path, | |
| torch_dtype=dtype, | |
| device_map=device_map | |
| ) | |
| # Load tokenizer - this doesn't consume much memory | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_path) | |
| # Load the LoRA adapter and apply it to the base model | |
| logger.info(f"Loading LoRA adapter from {lora_adapter_path}") | |
| lora_model = PeftModel.from_pretrained(base_model, lora_adapter_path) | |
| # Merge weights - this is done automatically by PyTorch on appropriate devices | |
| logger.info(f"Merging LoRA weights into base model on {device}") | |
| merged_model = lora_model.merge_and_unload() | |
| # Clean up before saving | |
| memory_manager.cleanup_memory() | |
| # Add inference optimization config to the merged model for faster startup | |
| if use_cuda: | |
| # Set inference-specific configuration in model config | |
| if hasattr(merged_model.config, "torch_dtype"): | |
| merged_model.config.torch_dtype = "float16" # Prefer float16 for inference | |
| if not hasattr(merged_model.config, "pretraining_tp"): | |
| merged_model.config.pretraining_tp = 1 # For tensor parallelism during inference | |
| # Set default inference device | |
| if not hasattr(merged_model.config, "_default_inference_device"): | |
| merged_model.config._default_inference_device = "cuda:0" | |
| logger.info("Added GPU optimization settings to model configuration") | |
| # Save merged model with shard size to prevent OOM errors during save | |
| logger.info(f"Saving merged model to {output_model_path}") | |
| merged_model.save_pretrained( | |
| output_model_path, | |
| safe_serialization=True, | |
| max_shard_size="2GB" # Sharded saving to avoid memory spikes | |
| ) | |
| tokenizer.save_pretrained(output_model_path) | |
| # Save a special marker file to indicate this model should use GPU for inference | |
| if use_cuda: | |
| with open(os.path.join(output_model_path, "gpu_optimized.json"), "w") as f: | |
| import json | |
| json.dump({"gpu_optimized": True, "optimized_on": datetime.datetime.now().isoformat()}, f) | |
| logger.info("Added GPU optimization marker file for faster service startup") | |
| logger.info("Model successfully merged and saved!") | |
| except Exception as e: | |
| logger.error(f"Error during model merge: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| # Force cleanup | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| raise | |
| def merge_model_weights( | |
| base_model_path="resources/L2/base_models", | |
| lora_adapter_path="resources/model/output/personal_model", | |
| output_model_path="resources/model/output/merged_model", | |
| ): | |
| """Merge LoRA weights into base model with default paths. | |
| This is a convenience function that calls merge_lora_weights with default | |
| paths that match the expected directory structure of the project. | |
| Args: | |
| base_model_path: Path to the base model. Defaults to "resources/L2/base_models". | |
| lora_adapter_path: Path to the LoRA adapter. Defaults to "resources/model/output/personal_model". | |
| output_model_path: Path to save the merged model. Defaults to "resources/model/output/merged_model". | |
| """ | |
| merge_lora_weights(base_model_path, lora_adapter_path, output_model_path) | |
| def parse_arguments(): | |
| """Parse command line arguments for the script. | |
| Returns: | |
| argparse.Namespace: The parsed command line arguments. | |
| """ | |
| parser = argparse.ArgumentParser( | |
| description="Merge LoRA weights into a base model." | |
| ) | |
| parser.add_argument( | |
| "--base_model_path", type=str, required=True, help="Path to the base model." | |
| ) | |
| parser.add_argument( | |
| "--lora_adapter_path", type=str, required=True, help="Path to the LoRA adapter." | |
| ) | |
| parser.add_argument( | |
| "--output_model_path", | |
| type=str, | |
| required=True, | |
| help="Path to save the merged model.", | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_arguments() | |
| merge_lora_weights( | |
| args.base_model_path, args.lora_adapter_path, args.output_model_path | |
| ) | |