#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "torch",
# "transformers>=4.40.0",
# "peft>=0.10.0",
# "accelerate",
# "bitsandbytes",
# "huggingface_hub>=0.21.0",
# ]
# ///
"""
Merge Stage 4 (Unified) adapter into base model.
Stage 4 is trained on ALL tasks, so it can handle:
- Point localization
- Bounding box detection
- Classification
- Free-form queries
Run with: hf jobs uv run --flavor a10g-large --secrets HF_TOKEN merge_stage4_adapter.py
"""
import os
import torch
from pathlib import Path
# ============================================================
# Config
# ============================================================
UNIFIED_MODEL = "mmrech/pitvqa-qwen2vl-unified-v2"
BASE_MODEL = "Qwen/Qwen2-VL-2B-Instruct"
OUTPUT_REPO = "mmrech/pitvqa-qwen2vl-merged"
# ============================================================
# Setup
# ============================================================
from huggingface_hub import login, HfApi
hf_token = os.environ.get("HF_TOKEN")
if hf_token:
login(token=hf_token)
print("โ Logged in to HuggingFace")
api = HfApi()
# ============================================================
# Load and Merge
# ============================================================
print("\n๐ค Loading base model...")
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from peft import PeftModel
# Load base model (full precision for merging)
base = Qwen2VLForConditionalGeneration.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
print(f"โ Base model loaded")
# Load processor
processor = AutoProcessor.from_pretrained(BASE_MODEL, trust_remote_code=True)
print(f"โ Processor loaded")
# Load Stage 4 adapter (Unified - handles all tasks)
print("\n๐ฆ Loading Stage 4 (Unified) adapter...")
model = PeftModel.from_pretrained(
base,
UNIFIED_MODEL,
adapter_name="stage4",
subfolder="stage4"
)
print(f"โ Stage 4 adapter loaded")
# Merge adapter into base model
print("\n๐ Merging adapter...")
merged_model = model.merge_and_unload()
print(f"โ Adapter merged")
# ============================================================
# Save and Upload
# ============================================================
print("\n๐พ Saving merged model...")
output_dir = Path("./pitvqa-merged")
output_dir.mkdir(exist_ok=True)
merged_model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)
print(f"โ Saved to {output_dir}")
# Create model card
model_card = """---
license: apache-2.0
base_model: Qwen/Qwen2-VL-2B-Instruct
tags:
- medical
- vision-language
- surgical-ai
- pituitary-surgery
- qwen2-vl
- merged-adapter
---
# PitVQA Merged Model
A **merged** version of the PitVQA unified model for pituitary surgery understanding.
## Model Description
This model merges the Stage 4 (Unified) LoRA adapter into the Qwen2-VL-2B base model.
It can handle ALL tasks without adapter switching:
- **Point Localization**: `suction device`
- **Bounding Box**: `tumor region`
- **Classification**: Surgical phase identification
- **Free-form queries**: Any question about the surgical scene
## Usage
```python
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
import torch
model = Qwen2VLForConditionalGeneration.from_pretrained(
"mmrech/pitvqa-qwen2vl-merged",
torch_dtype=torch.bfloat16,
device_map="auto"
)
processor = AutoProcessor.from_pretrained("mmrech/pitvqa-qwen2vl-merged")
# No adapter switching needed - just inference
messages = [{"role": "user", "content": [
{"type": "image", "image": your_image},
{"type": "text", "text": "Point to the suction device"}
]}]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=[your_image], return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=128)
print(processor.decode(output[0], skip_special_tokens=True))
```
## Source
- Base: `Qwen/Qwen2-VL-2B-Instruct`
- Adapter source: `mmrech/pitvqa-qwen2vl-unified-v2` (Stage 4)
- Training dataset: `mmrech/pitvqa-comprehensive-spatial`
"""
with open(output_dir / "README.md", "w") as f:
f.write(model_card)
print("โ Created README.md")
# Upload to HuggingFace
print(f"\n๐ค Uploading to {OUTPUT_REPO}...")
try:
# Create repo if needed
api.create_repo(OUTPUT_REPO, exist_ok=True)
# Upload all files
api.upload_folder(
folder_path=str(output_dir),
repo_id=OUTPUT_REPO,
repo_type="model"
)
print(f"โ Uploaded to https://huggingface.co/{OUTPUT_REPO}")
except Exception as e:
print(f"โ Upload error: {e}")
# ============================================================
# Verify
# ============================================================
print("\n๐งช Verifying merged model...")
# Quick test
from PIL import Image
import numpy as np
# Create test image
test_image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
messages = [{"role": "user", "content": [
{"type": "image", "image": test_image},
{"type": "text", "text": "What do you see in this image?"}
]}]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=[text], images=[test_image], return_tensors="pt").to(merged_model.device)
with torch.no_grad():
output = merged_model.generate(**inputs, max_new_tokens=50, do_sample=False)
response = processor.decode(output[0], skip_special_tokens=True)
print(f"Test response: {response[:200]}...")
print("\nโ
DONE! Merged model available at:")
print(f" https://huggingface.co/{OUTPUT_REPO}")