File size: 11,171 Bytes
ed1b365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
#!/usr/bin/env python3
"""

Codette LoRA Adapter Merger

==============================



Merge one or more LoRA adapters into the base model to produce

a standalone fine-tuned model. Adapters are applied and merged

sequentially in the order specified.



Usage:

    python -m training.merge_adapters \

        --base-model meta-llama/Llama-3.1-8B-Instruct \

        --adapters adapters/newton/final adapters/davinci/final \

        --output merged_model



    python -m training.merge_adapters \

        --base-model meta-llama/Llama-3.1-8B-Instruct \

        --adapters adapters/rcxi/final \

        --output merged_model \

        --dtype bfloat16

"""

import argparse
import json
import logging
import os
import sys
import time
from datetime import datetime
from pathlib import Path

import torch


def setup_logging(output_dir: str) -> logging.Logger:
    """Configure logging for the merge process.



    Args:

        output_dir: Directory for log output.



    Returns:

        Configured logger instance.

    """
    log_dir = Path(output_dir)
    log_dir.mkdir(parents=True, exist_ok=True)

    logger = logging.getLogger("codette.merge")
    logger.setLevel(logging.DEBUG)
    logger.handlers.clear()

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    fh = logging.FileHandler(
        str(log_dir / f"merge_{timestamp}.log"), encoding="utf-8"
    )
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(logging.Formatter(
        "%(asctime)s | %(levelname)-8s | %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    ))
    logger.addHandler(fh)

    ch = logging.StreamHandler(sys.stdout)
    ch.setLevel(logging.INFO)
    ch.setFormatter(logging.Formatter(
        "%(asctime)s | %(levelname)-8s | %(message)s",
        datefmt="%H:%M:%S",
    ))
    logger.addHandler(ch)

    return logger


def resolve_dtype(dtype_str: str) -> torch.dtype:
    """Convert a string dtype to a torch dtype.



    Args:

        dtype_str: One of 'float32', 'float16', 'bfloat16'.



    Returns:

        Corresponding torch.dtype.



    Raises:

        ValueError: If the string is not a recognized dtype.

    """
    dtype_map = {
        "float32": torch.float32,
        "fp32": torch.float32,
        "float16": torch.float16,
        "fp16": torch.float16,
        "bfloat16": torch.bfloat16,
        "bf16": torch.bfloat16,
    }
    if dtype_str not in dtype_map:
        raise ValueError(
            f"Unknown dtype: {dtype_str}. "
            f"Choose from: {list(dtype_map.keys())}"
        )
    return dtype_map[dtype_str]


def validate_adapter_paths(adapter_paths: list[str], logger: logging.Logger) -> None:
    """Validate that all adapter paths exist and contain expected files.



    Args:

        adapter_paths: List of adapter directory paths.

        logger: Logger instance.



    Raises:

        FileNotFoundError: If any adapter path is invalid.

    """
    for adapter_path in adapter_paths:
        path = Path(adapter_path)
        if not path.exists():
            raise FileNotFoundError(f"Adapter directory not found: {adapter_path}")

        # Check for adapter_config.json (PEFT marker)
        config_file = path / "adapter_config.json"
        if not config_file.exists():
            raise FileNotFoundError(
                f"No adapter_config.json found in {adapter_path}. "
                f"Is this a valid PEFT adapter directory?"
            )

        logger.info(f"Validated adapter: {adapter_path}")


def load_base_model(

    model_name: str,

    dtype: torch.dtype,

    device_map: str,

    logger: logging.Logger,

):
    """Load the base model for merging.



    Args:

        model_name: HuggingFace model identifier.

        dtype: Torch dtype for model weights.

        device_map: Device map strategy.

        logger: Logger instance.



    Returns:

        Tuple of (model, tokenizer).

    """
    from transformers import AutoModelForCausalLM, AutoTokenizer

    logger.info(f"Loading base model: {model_name}")
    logger.info(f"  dtype: {dtype}, device_map: {device_map}")

    tokenizer = AutoTokenizer.from_pretrained(
        model_name, trust_remote_code=True
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=dtype,
        device_map=device_map,
        trust_remote_code=True,
    )

    param_count = sum(p.numel() for p in model.parameters())
    logger.info(f"Base model loaded: {param_count:,} parameters")

    return model, tokenizer


def apply_and_merge_adapter(

    model,

    adapter_path: str,

    adapter_index: int,

    total_adapters: int,

    logger: logging.Logger,

):
    """Apply a single LoRA adapter and merge it into the base weights.



    Uses PEFT's load_adapter, set_adapter, and merge_and_unload

    to apply LoRA weights directly into the base model.



    Args:

        model: The current model (base or previously merged).

        adapter_path: Path to the PEFT adapter directory.

        adapter_index: Index of this adapter (for logging).

        total_adapters: Total number of adapters to merge.

        logger: Logger instance.



    Returns:

        Model with the adapter merged in.

    """
    from peft import PeftModel

    adapter_name = Path(adapter_path).parent.name
    logger.info(
        f"[{adapter_index}/{total_adapters}] "
        f"Applying adapter: {adapter_name} ({adapter_path})"
    )

    # Load adapter config to log details
    config_path = Path(adapter_path) / "adapter_config.json"
    with open(config_path, "r", encoding="utf-8") as f:
        adapter_config = json.load(f)

    lora_rank = adapter_config.get("r", "unknown")
    lora_alpha = adapter_config.get("lora_alpha", "unknown")
    target_modules = adapter_config.get("target_modules", [])

    logger.info(
        f"  LoRA config: rank={lora_rank}, alpha={lora_alpha}, "
        f"modules={target_modules}"
    )

    # Load and merge
    if adapter_index == 1:
        # First adapter: wrap model with PeftModel
        model = PeftModel.from_pretrained(
            model,
            adapter_path,
            is_trainable=False,
        )
    else:
        # Subsequent adapters: load as named adapter
        adapter_id = f"adapter_{adapter_index}"
        model.load_adapter(adapter_path, adapter_name=adapter_id)
        model.set_adapter(adapter_id)

    # Merge adapter weights into base model
    logger.info(f"  Merging adapter weights into base model...")
    model = model.merge_and_unload()

    param_count = sum(p.numel() for p in model.parameters())
    logger.info(f"  Merged successfully. Model params: {param_count:,}")

    return model


def save_merged_model(

    model,

    tokenizer,

    output_dir: str,

    logger: logging.Logger,

) -> None:
    """Save the fully merged model and tokenizer.



    Args:

        model: The merged model.

        tokenizer: The tokenizer.

        output_dir: Directory to save the model.

        logger: Logger instance.

    """
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    logger.info(f"Saving merged model to: {output_dir}")

    model.save_pretrained(output_dir, safe_serialization=True)
    tokenizer.save_pretrained(output_dir)

    # Calculate total size
    total_size = 0
    for f in output_path.glob("*.safetensors"):
        total_size += f.stat().st_size
    for f in output_path.glob("*.bin"):
        total_size += f.stat().st_size

    size_gb = total_size / (1024 ** 3)
    logger.info(f"Model saved: {size_gb:.2f} GB")


def parse_args() -> argparse.Namespace:
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(
        description="Merge LoRA adapters into the base model",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--base-model",
        type=str,
        default="meta-llama/Llama-3.1-8B-Instruct",
        help="Base model to merge adapters into",
    )
    parser.add_argument(
        "--adapters",
        nargs="+",
        required=True,
        help="Paths to PEFT adapter directories (applied in order)",
    )
    parser.add_argument(
        "--output",
        type=str,
        required=True,
        help="Output directory for merged model",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="bfloat16",
        choices=["float32", "fp32", "float16", "fp16", "bfloat16", "bf16"],
        help="Model dtype for merging",
    )
    parser.add_argument(
        "--device-map",
        type=str,
        default="auto",
        help="Device map strategy (auto, cpu, cuda:0, etc.)",
    )
    return parser.parse_args()


def main():
    """Main entry point for adapter merging."""
    args = parse_args()

    logger = setup_logging(args.output)
    logger.info("=== Codette LoRA Adapter Merger ===")
    logger.info(f"Base model: {args.base_model}")
    logger.info(f"Adapters to merge ({len(args.adapters)}): {args.adapters}")
    logger.info(f"Output: {args.output}")
    logger.info(f"dtype: {args.dtype}")

    dtype = resolve_dtype(args.dtype)

    # Validate adapters
    try:
        validate_adapter_paths(args.adapters, logger)
    except FileNotFoundError as e:
        logger.error(str(e))
        sys.exit(1)

    start_time = time.time()

    try:
        # Load base model
        model, tokenizer = load_base_model(
            args.base_model, dtype, args.device_map, logger
        )

        # Apply and merge each adapter sequentially
        for i, adapter_path in enumerate(args.adapters, 1):
            model = apply_and_merge_adapter(
                model=model,
                adapter_path=adapter_path,
                adapter_index=i,
                total_adapters=len(args.adapters),
                logger=logger,
            )

        # Save merged model
        save_merged_model(model, tokenizer, args.output, logger)

        elapsed = time.time() - start_time

        # Save merge metadata
        metadata = {
            "base_model": args.base_model,
            "adapters_merged": args.adapters,
            "adapter_count": len(args.adapters),
            "dtype": args.dtype,
            "merge_time_seconds": elapsed,
            "timestamp": datetime.now().isoformat(),
        }
        metadata_path = Path(args.output) / "merge_metadata.json"
        with open(metadata_path, "w", encoding="utf-8") as f:
            json.dump(metadata, f, indent=2)

        logger.info(f"=== Merge complete in {elapsed:.1f}s ===")
        logger.info(f"Merged model saved to: {args.output}")

    except Exception as e:
        logger.error(f"Merge failed: {e}", exc_info=True)
        sys.exit(1)


if __name__ == "__main__":
    main()