File size: 6,935 Bytes
01d5a5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""Utility for merging LoRA weights into base language models.

This module provides functions to merge trained LoRA adapter weights with a base model,
producing a standalone model that incorporates the adaptations without needing the
LoRA architecture during inference.
"""

import argparse
import os
import gc
import sys
import logging
import traceback
import torch
import datetime
from typing import Optional, Dict, Any

from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from lpm_kernel.L2.memory_manager import get_memory_manager

# Configure logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

def merge_lora_weights(base_model_path, lora_adapter_path, output_model_path):
    """Merge LoRA weights into a base model and save the result.
    
    This function loads a base model and a LoRA adapter, merges them together,
    and saves the resulting model to the specified output path. It leverages
    PyTorch's built-in memory management features.
    
    Args:
        base_model_path: Path to the base model directory.
        lora_adapter_path: Path to the LoRA adapter directory.
        output_model_path: Path where the merged model will be saved.
    """
    # Get memory manager
    memory_manager = get_memory_manager()
    
    try:
        # Log initial memory state
        memory_info = memory_manager.get_memory_info()
        logger.info(f"Initial memory state: RAM used: {memory_info['ram_used_gb']:.2f}GB, "
                   f"available: {memory_info['ram_available_gb']:.2f}GB")
        
        # Determine if CUDA is available and should be used
        use_cuda = memory_manager.cuda_available
        device = "cuda" if use_cuda else "cpu"
        
        if use_cuda:
            logger.info(f"CUDA is available. VRAM used: {memory_info.get('vram_used_gb', 0):.2f}GB")
        else:
            logger.warning("CUDA not available or not enabled. Using CPU for model operations.")
        
        # Clean up memory before starting
        memory_manager.cleanup_memory(force=True)
        
        # Explicitly set device configuration based on available hardware
        device_map = "auto" if use_cuda else None
        dtype = torch.float16 if use_cuda else torch.float32
        
        logger.info(f"Loading base model from {base_model_path} with device_map={device_map}, dtype={dtype}")
        
        # Use explicit configuration for GPU utilization
        base_model = AutoModelForCausalLM.from_pretrained(
            base_model_path,
            torch_dtype=dtype,
            device_map=device_map
        )
        
        # Load tokenizer - this doesn't consume much memory
        tokenizer = AutoTokenizer.from_pretrained(base_model_path)
        
        # Load the LoRA adapter and apply it to the base model
        logger.info(f"Loading LoRA adapter from {lora_adapter_path}")
        lora_model = PeftModel.from_pretrained(base_model, lora_adapter_path)
        
        # Merge weights - this is done automatically by PyTorch on appropriate devices
        logger.info(f"Merging LoRA weights into base model on {device}")
        merged_model = lora_model.merge_and_unload()
        
        # Clean up before saving
        memory_manager.cleanup_memory()
        
        # Add inference optimization config to the merged model for faster startup
        if use_cuda:
            # Set inference-specific configuration in model config
            if hasattr(merged_model.config, "torch_dtype"):
                merged_model.config.torch_dtype = "float16"  # Prefer float16 for inference
            if not hasattr(merged_model.config, "pretraining_tp"):
                merged_model.config.pretraining_tp = 1  # For tensor parallelism during inference
            
            # Set default inference device
            if not hasattr(merged_model.config, "_default_inference_device"):
                merged_model.config._default_inference_device = "cuda:0"
            
            logger.info("Added GPU optimization settings to model configuration")
        
        # Save merged model with shard size to prevent OOM errors during save
        logger.info(f"Saving merged model to {output_model_path}")
        merged_model.save_pretrained(
            output_model_path,
            safe_serialization=True,
            max_shard_size="2GB"  # Sharded saving to avoid memory spikes
        )
        tokenizer.save_pretrained(output_model_path)
        
        # Save a special marker file to indicate this model should use GPU for inference
        if use_cuda:
            with open(os.path.join(output_model_path, "gpu_optimized.json"), "w") as f:
                import json
                json.dump({"gpu_optimized": True, "optimized_on": datetime.datetime.now().isoformat()}, f)
                logger.info("Added GPU optimization marker file for faster service startup")
        
        logger.info("Model successfully merged and saved!")
        
    except Exception as e:
        logger.error(f"Error during model merge: {str(e)}")
        logger.error(traceback.format_exc())
        # Force cleanup
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        raise


def merge_model_weights(
    base_model_path="resources/L2/base_models",
    lora_adapter_path="resources/model/output/personal_model",
    output_model_path="resources/model/output/merged_model",
):
    """Merge LoRA weights into base model with default paths.
    
    This is a convenience function that calls merge_lora_weights with default 
    paths that match the expected directory structure of the project.

    Args:
        base_model_path: Path to the base model. Defaults to "resources/L2/base_models".
        lora_adapter_path: Path to the LoRA adapter. Defaults to "resources/model/output/personal_model".
        output_model_path: Path to save the merged model. Defaults to "resources/model/output/merged_model".
    """
    merge_lora_weights(base_model_path, lora_adapter_path, output_model_path)


def parse_arguments():
    """Parse command line arguments for the script.
    
    Returns:
        argparse.Namespace: The parsed command line arguments.
    """
    parser = argparse.ArgumentParser(
        description="Merge LoRA weights into a base model."
    )
    parser.add_argument(
        "--base_model_path", type=str, required=True, help="Path to the base model."
    )
    parser.add_argument(
        "--lora_adapter_path", type=str, required=True, help="Path to the LoRA adapter."
    )
    parser.add_argument(
        "--output_model_path",
        type=str,
        required=True,
        help="Path to save the merged model.",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = parse_arguments()
    merge_lora_weights(
        args.base_model_path, args.lora_adapter_path, args.output_model_path
    )