| import argparse |
| import json |
| import os |
| import time |
| from pathlib import Path |
| from typing import Any, Dict, Optional |
|
|
| import torch |
| from PIL import Image |
| from unsloth import FastVisionModel |
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = "7" |
|
|
|
|
| DEFAULT_ADAPTER_PATH = "outputs/mimic_qwen3vl_lora_8bit_5/checkpoint-17454" |
| DEFAULT_OUTPUT_PATH = "outputs/mimic_qwen3vl_lora_8bit_5_merged" |
| DEFAULT_INSTRUCTION = "Analyze this chest X-ray image and generate the corresponding radiology report." |
|
|
|
|
| DTYPE_MAP = { |
| "float16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| "float32": torch.float32, |
| } |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description=( |
| "Merge a Qwen3-VL LoRA adapter into base weights using Unsloth. " |
| "This intentionally avoids loading 8-bit modules during merge to preserve accuracy." |
| ) |
| ) |
|
|
| parser.add_argument("--adapter_path", type=str, default=DEFAULT_ADAPTER_PATH) |
| parser.add_argument("--output_dir", type=str, default=DEFAULT_OUTPUT_PATH) |
| parser.add_argument( |
| "--save_method", |
| type=str, |
| default="merged_16bit", |
| choices=["merged_16bit", "merged_4bit", "merged_4bit_forced"], |
| help=( |
| "Unsloth save mode. For best fidelity use merged_16bit. " |
| "Use 4bit variants only if you explicitly need compact merged weights." |
| ), |
| ) |
|
|
| parser.add_argument( |
| "--dtype", |
| type=str, |
| default="auto", |
| choices=["auto", "float16", "bfloat16", "float32"], |
| help="Compute dtype used while loading model for merge.", |
| ) |
| parser.add_argument("--device_map", type=str, default="auto") |
| parser.add_argument("--cuda_visible_devices", type=str, default="") |
| parser.add_argument("--max_shard_size", type=str, default="10GB") |
|
|
| parser.add_argument("--verify_reload", action="store_true", help="Reload merged model after save.") |
| parser.add_argument( |
| "--verify_load_in_8bit", |
| action="store_true", |
| help="When verifying, reload merged model in 8-bit inference mode.", |
| ) |
| parser.add_argument( |
| "--verify_max_new_tokens", |
| type=int, |
| default=32, |
| help="Max new tokens for optional verification generation.", |
| ) |
| parser.add_argument( |
| "--verify_instruction", |
| type=str, |
| default=DEFAULT_INSTRUCTION, |
| help="Prompt text used for optional verification generation.", |
| ) |
| parser.add_argument( |
| "--skip_generate_check", |
| action="store_true", |
| help="If set, verification only checks loadability and skips generation.", |
| ) |
|
|
| parser.add_argument( |
| "--safe_serialization", |
| action="store_true", |
| default=True, |
| help="Save in safetensors format when supported.", |
| ) |
| parser.add_argument( |
| "--no_safe_serialization", |
| action="store_true", |
| help="Disable safetensors serialization.", |
| ) |
|
|
| return parser.parse_args() |
|
|
|
|
| def resolve_dtype(dtype_arg: str) -> torch.dtype: |
| if dtype_arg != "auto": |
| return DTYPE_MAP[dtype_arg] |
|
|
| if torch.cuda.is_available(): |
| if torch.cuda.is_bf16_supported(): |
| return torch.bfloat16 |
| return torch.float16 |
| return torch.float32 |
|
|
|
|
| def read_adapter_base_model(adapter_path: Path) -> Optional[str]: |
| adapter_config_path = adapter_path / "adapter_config.json" |
| if not adapter_config_path.exists(): |
| return None |
|
|
| try: |
| data = json.loads(adapter_config_path.read_text(encoding="utf-8")) |
| except (OSError, ValueError, json.JSONDecodeError): |
| return None |
|
|
| base_model = data.get("base_model_name_or_path") |
| return str(base_model) if base_model else None |
|
|
|
|
| def print_runtime_info(args: argparse.Namespace, merge_dtype: torch.dtype, adapter_path: Path, output_dir: Path) -> None: |
| print("=" * 88) |
| print("Merge configuration") |
| print("=" * 88) |
| print(f"adapter_path : {adapter_path}") |
| print(f"output_dir : {output_dir}") |
| print(f"save_method : {args.save_method}") |
| print(f"dtype : {merge_dtype}") |
| print(f"device_map : {args.device_map}") |
| print(f"safe_serialization : {args.safe_serialization and not args.no_safe_serialization}") |
| print(f"verify_reload : {args.verify_reload}") |
| print(f"verify_load_in_8bit : {args.verify_load_in_8bit}") |
| print("=" * 88) |
|
|
|
|
| def save_merge_metadata(output_dir: Path, metadata: Dict[str, Any]) -> None: |
| output_dir.mkdir(parents=True, exist_ok=True) |
| metadata_path = output_dir / "merge_metadata.json" |
| metadata_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8") |
| print(f"Wrote merge metadata: {metadata_path}") |
|
|
|
|
| def merge_adapter(args: argparse.Namespace) -> Path: |
| adapter_path = Path(args.adapter_path).expanduser().resolve() |
| output_dir = Path(args.output_dir).expanduser().resolve() |
|
|
| if not adapter_path.exists() or not adapter_path.is_dir(): |
| raise FileNotFoundError(f"Adapter path does not exist or is not a directory: {adapter_path}") |
|
|
| if args.cuda_visible_devices: |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices |
|
|
| merge_dtype = resolve_dtype(args.dtype) |
| print_runtime_info(args, merge_dtype, adapter_path, output_dir) |
|
|
| inferred_base_model = read_adapter_base_model(adapter_path) |
| if inferred_base_model: |
| print(f"Adapter base model from adapter_config.json: {inferred_base_model}") |
|
|
| print("Loading adapter (with base) in non-quantized mode for accurate merge...") |
| model, tokenizer = FastVisionModel.from_pretrained( |
| model_name=str(adapter_path), |
| load_in_4bit=False, |
| load_in_8bit=False, |
| dtype=merge_dtype, |
| device_map=args.device_map, |
| ) |
|
|
| |
| FastVisionModel.for_inference(model) |
|
|
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| safe_serialization = args.safe_serialization and not args.no_safe_serialization |
| merge_start = time.time() |
|
|
| print("Saving merged model...") |
| try: |
| model.save_pretrained_merged( |
| str(output_dir), |
| tokenizer, |
| save_method=args.save_method, |
| safe_serialization=safe_serialization, |
| max_shard_size=args.max_shard_size, |
| ) |
| except TypeError: |
| |
| model.save_pretrained_merged( |
| str(output_dir), |
| tokenizer, |
| save_method=args.save_method, |
| ) |
|
|
| merge_seconds = time.time() - merge_start |
| print(f"Merge complete in {merge_seconds:.1f}s") |
|
|
| save_merge_metadata( |
| output_dir, |
| { |
| "adapter_path": str(adapter_path), |
| "inferred_base_model": inferred_base_model, |
| "output_dir": str(output_dir), |
| "save_method": args.save_method, |
| "merge_dtype": str(merge_dtype), |
| "device_map": args.device_map, |
| "safe_serialization": safe_serialization, |
| "merged_at_unix": int(time.time()), |
| }, |
| ) |
|
|
| del model |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| return output_dir |
|
|
|
|
| def verify_merged_model(args: argparse.Namespace, merged_dir: Path) -> None: |
| if not args.verify_reload: |
| return |
|
|
| print("Reloading merged model for verification...") |
|
|
| verify_load_in_8bit = bool(args.verify_load_in_8bit) |
| if verify_load_in_8bit and not torch.cuda.is_available(): |
| print("CUDA unavailable. Falling back to non-8bit verification load.") |
| verify_load_in_8bit = False |
|
|
| model, tokenizer = FastVisionModel.from_pretrained( |
| model_name=str(merged_dir), |
| load_in_4bit=False, |
| load_in_8bit=verify_load_in_8bit, |
| device_map=args.device_map, |
| ) |
| FastVisionModel.for_inference(model) |
|
|
| if args.skip_generate_check: |
| print("Verification load successful (generation check skipped).") |
| return |
|
|
| |
| test_image = Image.new("RGB", (224, 224), color=(0, 0, 0)) |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "image", "image": test_image}, |
| {"type": "text", "text": args.verify_instruction}, |
| ], |
| } |
| ] |
|
|
| prompt_text = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
|
|
| inputs = None |
| tokenization_errors = [] |
| for image_argument in ( |
| {"images": [test_image]}, |
| {"images": [[test_image]]}, |
| {"image": [test_image]}, |
| {"image": [[test_image]]}, |
| ): |
| try: |
| inputs = tokenizer( |
| text=[prompt_text], |
| padding=True, |
| return_tensors="pt", |
| **image_argument, |
| ) |
| break |
| except Exception as error: |
| tokenization_errors.append(str(error)) |
|
|
| if inputs is None: |
| raise RuntimeError("Tokenization failed during verify: " + " | ".join(tokenization_errors)) |
|
|
| model_device = next(model.parameters()).device |
| inputs = {k: (v.to(model_device) if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()} |
|
|
| with torch.inference_mode(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max(1, args.verify_max_new_tokens), |
| do_sample=False, |
| ) |
|
|
| input_token_count = inputs["input_ids"].shape[-1] if "input_ids" in inputs else 0 |
| generated_ids = outputs[:, input_token_count:] |
| generated_text = tokenizer.batch_decode( |
| generated_ids, |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False, |
| )[0] |
|
|
| print("Verification generation successful.") |
| print("--- Verification Output (truncated to 500 chars) ---") |
| print(generated_text[:500]) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
|
|
| if args.verify_max_new_tokens <= 0: |
| raise ValueError("--verify_max_new_tokens must be > 0") |
|
|
| merged_dir = merge_adapter(args) |
| verify_merged_model(args, merged_dir) |
|
|
| print(f"Merged model saved to: {merged_dir}") |
| print("For production inference speed/memory, load merged model with load_in_8bit=True.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|