secondme-api / lpm_kernel /L2 /merge_lora_weights.py
Gemini
feat: add detailed logging
01d5a5d
"""Utility for merging LoRA weights into base language models.
This module provides functions to merge trained LoRA adapter weights with a base model,
producing a standalone model that incorporates the adaptations without needing the
LoRA architecture during inference.
"""
import argparse
import os
import gc
import sys
import logging
import traceback
import torch
import datetime
from typing import Optional, Dict, Any
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from lpm_kernel.L2.memory_manager import get_memory_manager
# Configure logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def merge_lora_weights(base_model_path, lora_adapter_path, output_model_path):
"""Merge LoRA weights into a base model and save the result.
This function loads a base model and a LoRA adapter, merges them together,
and saves the resulting model to the specified output path. It leverages
PyTorch's built-in memory management features.
Args:
base_model_path: Path to the base model directory.
lora_adapter_path: Path to the LoRA adapter directory.
output_model_path: Path where the merged model will be saved.
"""
# Get memory manager
memory_manager = get_memory_manager()
try:
# Log initial memory state
memory_info = memory_manager.get_memory_info()
logger.info(f"Initial memory state: RAM used: {memory_info['ram_used_gb']:.2f}GB, "
f"available: {memory_info['ram_available_gb']:.2f}GB")
# Determine if CUDA is available and should be used
use_cuda = memory_manager.cuda_available
device = "cuda" if use_cuda else "cpu"
if use_cuda:
logger.info(f"CUDA is available. VRAM used: {memory_info.get('vram_used_gb', 0):.2f}GB")
else:
logger.warning("CUDA not available or not enabled. Using CPU for model operations.")
# Clean up memory before starting
memory_manager.cleanup_memory(force=True)
# Explicitly set device configuration based on available hardware
device_map = "auto" if use_cuda else None
dtype = torch.float16 if use_cuda else torch.float32
logger.info(f"Loading base model from {base_model_path} with device_map={device_map}, dtype={dtype}")
# Use explicit configuration for GPU utilization
base_model = AutoModelForCausalLM.from_pretrained(
base_model_path,
torch_dtype=dtype,
device_map=device_map
)
# Load tokenizer - this doesn't consume much memory
tokenizer = AutoTokenizer.from_pretrained(base_model_path)
# Load the LoRA adapter and apply it to the base model
logger.info(f"Loading LoRA adapter from {lora_adapter_path}")
lora_model = PeftModel.from_pretrained(base_model, lora_adapter_path)
# Merge weights - this is done automatically by PyTorch on appropriate devices
logger.info(f"Merging LoRA weights into base model on {device}")
merged_model = lora_model.merge_and_unload()
# Clean up before saving
memory_manager.cleanup_memory()
# Add inference optimization config to the merged model for faster startup
if use_cuda:
# Set inference-specific configuration in model config
if hasattr(merged_model.config, "torch_dtype"):
merged_model.config.torch_dtype = "float16" # Prefer float16 for inference
if not hasattr(merged_model.config, "pretraining_tp"):
merged_model.config.pretraining_tp = 1 # For tensor parallelism during inference
# Set default inference device
if not hasattr(merged_model.config, "_default_inference_device"):
merged_model.config._default_inference_device = "cuda:0"
logger.info("Added GPU optimization settings to model configuration")
# Save merged model with shard size to prevent OOM errors during save
logger.info(f"Saving merged model to {output_model_path}")
merged_model.save_pretrained(
output_model_path,
safe_serialization=True,
max_shard_size="2GB" # Sharded saving to avoid memory spikes
)
tokenizer.save_pretrained(output_model_path)
# Save a special marker file to indicate this model should use GPU for inference
if use_cuda:
with open(os.path.join(output_model_path, "gpu_optimized.json"), "w") as f:
import json
json.dump({"gpu_optimized": True, "optimized_on": datetime.datetime.now().isoformat()}, f)
logger.info("Added GPU optimization marker file for faster service startup")
logger.info("Model successfully merged and saved!")
except Exception as e:
logger.error(f"Error during model merge: {str(e)}")
logger.error(traceback.format_exc())
# Force cleanup
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
raise
def merge_model_weights(
base_model_path="resources/L2/base_models",
lora_adapter_path="resources/model/output/personal_model",
output_model_path="resources/model/output/merged_model",
):
"""Merge LoRA weights into base model with default paths.
This is a convenience function that calls merge_lora_weights with default
paths that match the expected directory structure of the project.
Args:
base_model_path: Path to the base model. Defaults to "resources/L2/base_models".
lora_adapter_path: Path to the LoRA adapter. Defaults to "resources/model/output/personal_model".
output_model_path: Path to save the merged model. Defaults to "resources/model/output/merged_model".
"""
merge_lora_weights(base_model_path, lora_adapter_path, output_model_path)
def parse_arguments():
"""Parse command line arguments for the script.
Returns:
argparse.Namespace: The parsed command line arguments.
"""
parser = argparse.ArgumentParser(
description="Merge LoRA weights into a base model."
)
parser.add_argument(
"--base_model_path", type=str, required=True, help="Path to the base model."
)
parser.add_argument(
"--lora_adapter_path", type=str, required=True, help="Path to the LoRA adapter."
)
parser.add_argument(
"--output_model_path",
type=str,
required=True,
help="Path to save the merged model.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
merge_lora_weights(
args.base_model_path, args.lora_adapter_path, args.output_model_path
)