mmrech commited on
Commit
2e3edd5
·
verified ·
1 Parent(s): dc287cc

Upload merge_stage4_adapter.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. merge_stage4_adapter.py +207 -0
merge_stage4_adapter.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # requires-python = ">=3.10"
4
+ # dependencies = [
5
+ # "torch",
6
+ # "transformers>=4.40.0",
7
+ # "peft>=0.10.0",
8
+ # "accelerate",
9
+ # "bitsandbytes",
10
+ # "huggingface_hub>=0.21.0",
11
+ # ]
12
+ # ///
13
+ """
14
+ Merge Stage 4 (Unified) adapter into base model.
15
+
16
+ Stage 4 is trained on ALL tasks, so it can handle:
17
+ - Point localization
18
+ - Bounding box detection
19
+ - Classification
20
+ - Free-form queries
21
+
22
+ Run with: hf jobs uv run --flavor a10g-large --secrets HF_TOKEN merge_stage4_adapter.py
23
+ """
24
+
25
+ import os
26
+ import torch
27
+ from pathlib import Path
28
+
29
+ # ============================================================
30
+ # Config
31
+ # ============================================================
32
+
33
+ UNIFIED_MODEL = "mmrech/pitvqa-qwen2vl-unified-v2"
34
+ BASE_MODEL = "Qwen/Qwen2-VL-2B-Instruct"
35
+ OUTPUT_REPO = "mmrech/pitvqa-qwen2vl-merged"
36
+
37
+ # ============================================================
38
+ # Setup
39
+ # ============================================================
40
+
41
+ from huggingface_hub import login, HfApi
42
+
43
+ hf_token = os.environ.get("HF_TOKEN")
44
+ if hf_token:
45
+ login(token=hf_token)
46
+ print("✓ Logged in to HuggingFace")
47
+
48
+ api = HfApi()
49
+
50
+ # ============================================================
51
+ # Load and Merge
52
+ # ============================================================
53
+
54
+ print("\n🤖 Loading base model...")
55
+
56
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
57
+ from peft import PeftModel
58
+
59
+ # Load base model (full precision for merging)
60
+ base = Qwen2VLForConditionalGeneration.from_pretrained(
61
+ BASE_MODEL,
62
+ torch_dtype=torch.bfloat16,
63
+ device_map="auto",
64
+ trust_remote_code=True
65
+ )
66
+ print(f"✓ Base model loaded")
67
+
68
+ # Load processor
69
+ processor = AutoProcessor.from_pretrained(BASE_MODEL, trust_remote_code=True)
70
+ print(f"✓ Processor loaded")
71
+
72
+ # Load Stage 4 adapter (Unified - handles all tasks)
73
+ print("\n📦 Loading Stage 4 (Unified) adapter...")
74
+ model = PeftModel.from_pretrained(
75
+ base,
76
+ UNIFIED_MODEL,
77
+ adapter_name="stage4",
78
+ subfolder="stage4"
79
+ )
80
+ print(f"✓ Stage 4 adapter loaded")
81
+
82
+ # Merge adapter into base model
83
+ print("\n🔗 Merging adapter...")
84
+ merged_model = model.merge_and_unload()
85
+ print(f"✓ Adapter merged")
86
+
87
+ # ============================================================
88
+ # Save and Upload
89
+ # ============================================================
90
+
91
+ print("\n💾 Saving merged model...")
92
+ output_dir = Path("./pitvqa-merged")
93
+ output_dir.mkdir(exist_ok=True)
94
+
95
+ merged_model.save_pretrained(output_dir)
96
+ processor.save_pretrained(output_dir)
97
+ print(f"✓ Saved to {output_dir}")
98
+
99
+ # Create model card
100
+ model_card = """---
101
+ license: apache-2.0
102
+ base_model: Qwen/Qwen2-VL-2B-Instruct
103
+ tags:
104
+ - medical
105
+ - vision-language
106
+ - surgical-ai
107
+ - pituitary-surgery
108
+ - qwen2-vl
109
+ - merged-adapter
110
+ ---
111
+
112
+ # PitVQA Merged Model
113
+
114
+ A **merged** version of the PitVQA unified model for pituitary surgery understanding.
115
+
116
+ ## Model Description
117
+
118
+ This model merges the Stage 4 (Unified) LoRA adapter into the Qwen2-VL-2B base model.
119
+ It can handle ALL tasks without adapter switching:
120
+
121
+ - **Point Localization**: `<point x='45.2' y='68.3'>suction device</point>`
122
+ - **Bounding Box**: `<box x1='20' y1='30' x2='60' y2='70'>tumor region</box>`
123
+ - **Classification**: Surgical phase identification
124
+ - **Free-form queries**: Any question about the surgical scene
125
+
126
+ ## Usage
127
+
128
+ ```python
129
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
130
+ import torch
131
+
132
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
133
+ "mmrech/pitvqa-qwen2vl-merged",
134
+ torch_dtype=torch.bfloat16,
135
+ device_map="auto"
136
+ )
137
+ processor = AutoProcessor.from_pretrained("mmrech/pitvqa-qwen2vl-merged")
138
+
139
+ # No adapter switching needed - just inference
140
+ messages = [{"role": "user", "content": [
141
+ {"type": "image", "image": your_image},
142
+ {"type": "text", "text": "Point to the suction device"}
143
+ ]}]
144
+
145
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
146
+ inputs = processor(text=[text], images=[your_image], return_tensors="pt").to(model.device)
147
+ output = model.generate(**inputs, max_new_tokens=128)
148
+ print(processor.decode(output[0], skip_special_tokens=True))
149
+ ```
150
+
151
+ ## Source
152
+
153
+ - Base: `Qwen/Qwen2-VL-2B-Instruct`
154
+ - Adapter source: `mmrech/pitvqa-qwen2vl-unified-v2` (Stage 4)
155
+ - Training dataset: `mmrech/pitvqa-comprehensive-spatial`
156
+ """
157
+
158
+ with open(output_dir / "README.md", "w") as f:
159
+ f.write(model_card)
160
+ print("✓ Created README.md")
161
+
162
+ # Upload to HuggingFace
163
+ print(f"\n📤 Uploading to {OUTPUT_REPO}...")
164
+
165
+ try:
166
+ # Create repo if needed
167
+ api.create_repo(OUTPUT_REPO, exist_ok=True)
168
+
169
+ # Upload all files
170
+ api.upload_folder(
171
+ folder_path=str(output_dir),
172
+ repo_id=OUTPUT_REPO,
173
+ repo_type="model"
174
+ )
175
+ print(f"✓ Uploaded to https://huggingface.co/{OUTPUT_REPO}")
176
+ except Exception as e:
177
+ print(f"⚠ Upload error: {e}")
178
+
179
+ # ============================================================
180
+ # Verify
181
+ # ============================================================
182
+
183
+ print("\n🧪 Verifying merged model...")
184
+
185
+ # Quick test
186
+ from PIL import Image
187
+ import numpy as np
188
+
189
+ # Create test image
190
+ test_image = Image.fromarray(np.random.randint(0, 255, (224, 224, 3), dtype=np.uint8))
191
+
192
+ messages = [{"role": "user", "content": [
193
+ {"type": "image", "image": test_image},
194
+ {"type": "text", "text": "What do you see in this image?"}
195
+ ]}]
196
+
197
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
198
+ inputs = processor(text=[text], images=[test_image], return_tensors="pt").to(merged_model.device)
199
+
200
+ with torch.no_grad():
201
+ output = merged_model.generate(**inputs, max_new_tokens=50, do_sample=False)
202
+
203
+ response = processor.decode(output[0], skip_special_tokens=True)
204
+ print(f"Test response: {response[:200]}...")
205
+
206
+ print("\n✅ DONE! Merged model available at:")
207
+ print(f" https://huggingface.co/{OUTPUT_REPO}")