import argparse import json import os import time from pathlib import Path from typing import Any, Dict, Optional import torch from PIL import Image from unsloth import FastVisionModel os.environ["CUDA_VISIBLE_DEVICES"] = "7" DEFAULT_ADAPTER_PATH = "outputs/mimic_qwen3vl_lora_8bit_5/checkpoint-17454" DEFAULT_OUTPUT_PATH = "outputs/mimic_qwen3vl_lora_8bit_5_merged" DEFAULT_INSTRUCTION = "Analyze this chest X-ray image and generate the corresponding radiology report." DTYPE_MAP = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description=( "Merge a Qwen3-VL LoRA adapter into base weights using Unsloth. " "This intentionally avoids loading 8-bit modules during merge to preserve accuracy." ) ) parser.add_argument("--adapter_path", type=str, default=DEFAULT_ADAPTER_PATH) parser.add_argument("--output_dir", type=str, default=DEFAULT_OUTPUT_PATH) parser.add_argument( "--save_method", type=str, default="merged_16bit", choices=["merged_16bit", "merged_4bit", "merged_4bit_forced"], help=( "Unsloth save mode. For best fidelity use merged_16bit. " "Use 4bit variants only if you explicitly need compact merged weights." ), ) parser.add_argument( "--dtype", type=str, default="auto", choices=["auto", "float16", "bfloat16", "float32"], help="Compute dtype used while loading model for merge.", ) parser.add_argument("--device_map", type=str, default="auto") parser.add_argument("--cuda_visible_devices", type=str, default="") parser.add_argument("--max_shard_size", type=str, default="10GB") parser.add_argument("--verify_reload", action="store_true", help="Reload merged model after save.") parser.add_argument( "--verify_load_in_8bit", action="store_true", help="When verifying, reload merged model in 8-bit inference mode.", ) parser.add_argument( "--verify_max_new_tokens", type=int, default=32, help="Max new tokens for optional verification generation.", ) parser.add_argument( "--verify_instruction", type=str, default=DEFAULT_INSTRUCTION, help="Prompt text used for optional verification generation.", ) parser.add_argument( "--skip_generate_check", action="store_true", help="If set, verification only checks loadability and skips generation.", ) parser.add_argument( "--safe_serialization", action="store_true", default=True, help="Save in safetensors format when supported.", ) parser.add_argument( "--no_safe_serialization", action="store_true", help="Disable safetensors serialization.", ) return parser.parse_args() def resolve_dtype(dtype_arg: str) -> torch.dtype: if dtype_arg != "auto": return DTYPE_MAP[dtype_arg] if torch.cuda.is_available(): if torch.cuda.is_bf16_supported(): return torch.bfloat16 return torch.float16 return torch.float32 def read_adapter_base_model(adapter_path: Path) -> Optional[str]: adapter_config_path = adapter_path / "adapter_config.json" if not adapter_config_path.exists(): return None try: data = json.loads(adapter_config_path.read_text(encoding="utf-8")) except (OSError, ValueError, json.JSONDecodeError): return None base_model = data.get("base_model_name_or_path") return str(base_model) if base_model else None def print_runtime_info(args: argparse.Namespace, merge_dtype: torch.dtype, adapter_path: Path, output_dir: Path) -> None: print("=" * 88) print("Merge configuration") print("=" * 88) print(f"adapter_path : {adapter_path}") print(f"output_dir : {output_dir}") print(f"save_method : {args.save_method}") print(f"dtype : {merge_dtype}") print(f"device_map : {args.device_map}") print(f"safe_serialization : {args.safe_serialization and not args.no_safe_serialization}") print(f"verify_reload : {args.verify_reload}") print(f"verify_load_in_8bit : {args.verify_load_in_8bit}") print("=" * 88) def save_merge_metadata(output_dir: Path, metadata: Dict[str, Any]) -> None: output_dir.mkdir(parents=True, exist_ok=True) metadata_path = output_dir / "merge_metadata.json" metadata_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8") print(f"Wrote merge metadata: {metadata_path}") def merge_adapter(args: argparse.Namespace) -> Path: adapter_path = Path(args.adapter_path).expanduser().resolve() output_dir = Path(args.output_dir).expanduser().resolve() if not adapter_path.exists() or not adapter_path.is_dir(): raise FileNotFoundError(f"Adapter path does not exist or is not a directory: {adapter_path}") if args.cuda_visible_devices: os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices merge_dtype = resolve_dtype(args.dtype) print_runtime_info(args, merge_dtype, adapter_path, output_dir) inferred_base_model = read_adapter_base_model(adapter_path) if inferred_base_model: print(f"Adapter base model from adapter_config.json: {inferred_base_model}") print("Loading adapter (with base) in non-quantized mode for accurate merge...") model, tokenizer = FastVisionModel.from_pretrained( model_name=str(adapter_path), load_in_4bit=False, load_in_8bit=False, dtype=merge_dtype, device_map=args.device_map, ) # Ensure inference graph after load to avoid training-mode side effects. FastVisionModel.for_inference(model) output_dir.mkdir(parents=True, exist_ok=True) safe_serialization = args.safe_serialization and not args.no_safe_serialization merge_start = time.time() print("Saving merged model...") try: model.save_pretrained_merged( str(output_dir), tokenizer, save_method=args.save_method, safe_serialization=safe_serialization, max_shard_size=args.max_shard_size, ) except TypeError: # Older Unsloth versions may not support all save kwargs. model.save_pretrained_merged( str(output_dir), tokenizer, save_method=args.save_method, ) merge_seconds = time.time() - merge_start print(f"Merge complete in {merge_seconds:.1f}s") save_merge_metadata( output_dir, { "adapter_path": str(adapter_path), "inferred_base_model": inferred_base_model, "output_dir": str(output_dir), "save_method": args.save_method, "merge_dtype": str(merge_dtype), "device_map": args.device_map, "safe_serialization": safe_serialization, "merged_at_unix": int(time.time()), }, ) del model if torch.cuda.is_available(): torch.cuda.empty_cache() return output_dir def verify_merged_model(args: argparse.Namespace, merged_dir: Path) -> None: if not args.verify_reload: return print("Reloading merged model for verification...") verify_load_in_8bit = bool(args.verify_load_in_8bit) if verify_load_in_8bit and not torch.cuda.is_available(): print("CUDA unavailable. Falling back to non-8bit verification load.") verify_load_in_8bit = False model, tokenizer = FastVisionModel.from_pretrained( model_name=str(merged_dir), load_in_4bit=False, load_in_8bit=verify_load_in_8bit, device_map=args.device_map, ) FastVisionModel.for_inference(model) if args.skip_generate_check: print("Verification load successful (generation check skipped).") return # Minimal generation sanity check with one synthetic image. test_image = Image.new("RGB", (224, 224), color=(0, 0, 0)) messages = [ { "role": "user", "content": [ {"type": "image", "image": test_image}, {"type": "text", "text": args.verify_instruction}, ], } ] prompt_text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) inputs = None tokenization_errors = [] for image_argument in ( {"images": [test_image]}, {"images": [[test_image]]}, {"image": [test_image]}, {"image": [[test_image]]}, ): try: inputs = tokenizer( text=[prompt_text], padding=True, return_tensors="pt", **image_argument, ) break except Exception as error: # pragma: no cover tokenization_errors.append(str(error)) if inputs is None: raise RuntimeError("Tokenization failed during verify: " + " | ".join(tokenization_errors)) model_device = next(model.parameters()).device inputs = {k: (v.to(model_device) if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()} with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=max(1, args.verify_max_new_tokens), do_sample=False, ) input_token_count = inputs["input_ids"].shape[-1] if "input_ids" in inputs else 0 generated_ids = outputs[:, input_token_count:] generated_text = tokenizer.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False, )[0] print("Verification generation successful.") print("--- Verification Output (truncated to 500 chars) ---") print(generated_text[:500]) def main() -> None: args = parse_args() if args.verify_max_new_tokens <= 0: raise ValueError("--verify_max_new_tokens must be > 0") merged_dir = merge_adapter(args) verify_merged_model(args, merged_dir) print(f"Merged model saved to: {merged_dir}") print("For production inference speed/memory, load merged model with load_in_8bit=True.") if __name__ == "__main__": main()