mmrech commited on
Commit
dc287cc
·
verified ·
1 Parent(s): 529cdab

Upload test_all_frames_for_curation.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. test_all_frames_for_curation.py +340 -0
test_all_frames_for_curation.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # requires-python = ">=3.10"
4
+ # dependencies = [
5
+ # "torch",
6
+ # "torchvision",
7
+ # "transformers>=4.40.0",
8
+ # "peft>=0.10.0",
9
+ # "datasets>=2.18.0",
10
+ # "accelerate",
11
+ # "bitsandbytes",
12
+ # "qwen-vl-utils",
13
+ # "pillow",
14
+ # "opencv-python-headless",
15
+ # "huggingface_hub>=0.21.0",
16
+ # "av",
17
+ # ]
18
+ # ///
19
+ """
20
+ Test ALL frames for manual curation.
21
+
22
+ Saves all results with images for human review.
23
+ Does NOT auto-select - human curator will pick best examples.
24
+
25
+ Run with: hf jobs uv run --flavor a10g-large --secrets HF_TOKEN test_all_frames_for_curation.py
26
+ """
27
+
28
+ import os
29
+ import cv2
30
+ import re
31
+ import json
32
+ import torch
33
+ import base64
34
+ from io import BytesIO
35
+ from PIL import Image, ImageDraw, ImageFont
36
+ from pathlib import Path
37
+ from typing import Optional, List, Tuple
38
+
39
+ # ============================================================
40
+ # Config
41
+ # ============================================================
42
+
43
+ UNIFIED_MODEL = "mmrech/pitvqa-qwen2vl-unified-v2"
44
+ VIDEO_DATASET = "UCL-WEISS/PitVis-2023"
45
+
46
+ VIDEO_CACHE = Path("/tmp/videos")
47
+ VIDEO_CACHE.mkdir(exist_ok=True)
48
+
49
+ OUTPUT_DIR = Path("./curation_review")
50
+ OUTPUT_DIR.mkdir(exist_ok=True)
51
+
52
+ # Test configurations - EXTENSIVE
53
+ # Sample frames from each video at regular intervals
54
+ VIDEOS_TO_TEST = ["video_01", "video_02", "video_03", "video_05", "video_06", "video_10", "video_15", "video_20"]
55
+ FRAMES_PER_VIDEO = [200, 500, 800, 1200, 1800] # Sample at these frame indices
56
+
57
+ # Targets to test
58
+ POINT_TARGETS = ["suction device", "surgical instruments"] # Focus on main targets
59
+ BBOX_TARGETS = ["suction device", "surgical instruments"]
60
+
61
+ # ============================================================
62
+ # Setup
63
+ # ============================================================
64
+
65
+ from huggingface_hub import login, HfApi, hf_hub_download
66
+
67
+ hf_token = os.environ.get("HF_TOKEN")
68
+ if hf_token:
69
+ login(token=hf_token)
70
+ print("✓ Logged in to HuggingFace")
71
+
72
+ api = HfApi()
73
+
74
+ # ============================================================
75
+ # Load Model
76
+ # ============================================================
77
+
78
+ print("\n🤖 Loading model...")
79
+
80
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig
81
+ from peft import PeftModel
82
+
83
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True)
84
+
85
+ bnb_config = BitsAndBytesConfig(
86
+ load_in_4bit=True,
87
+ bnb_4bit_quant_type="nf4",
88
+ bnb_4bit_compute_dtype=torch.bfloat16,
89
+ bnb_4bit_use_double_quant=True,
90
+ )
91
+
92
+ base = Qwen2VLForConditionalGeneration.from_pretrained(
93
+ "Qwen/Qwen2-VL-2B-Instruct",
94
+ quantization_config=bnb_config,
95
+ device_map="auto",
96
+ trust_remote_code=True
97
+ )
98
+
99
+ model = PeftModel.from_pretrained(base, UNIFIED_MODEL, adapter_name="stage1", subfolder="stage1")
100
+ model.load_adapter(UNIFIED_MODEL, adapter_name="stage2", subfolder="stage2")
101
+
102
+ print(f"✓ Model loaded")
103
+
104
+ # ============================================================
105
+ # Helpers
106
+ # ============================================================
107
+
108
+ def download_video(video_id: str) -> Optional[Path]:
109
+ video_path = VIDEO_CACHE / f"{video_id}.mp4"
110
+ if not video_path.exists():
111
+ try:
112
+ downloaded = hf_hub_download(
113
+ repo_id=VIDEO_DATASET,
114
+ filename=f"videos/{video_id}.mp4",
115
+ repo_type="dataset"
116
+ )
117
+ import shutil
118
+ shutil.copy(downloaded, video_path)
119
+ except Exception as e:
120
+ print(f" ⚠ Could not download {video_id}: {e}")
121
+ return None
122
+ return video_path
123
+
124
+ def extract_frame(video_id: str, frame_idx: int) -> Optional[Image.Image]:
125
+ video_path = download_video(video_id)
126
+ if video_path is None:
127
+ return None
128
+ cap = cv2.VideoCapture(str(video_path))
129
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
130
+ ret, frame = cap.read()
131
+ cap.release()
132
+ if ret:
133
+ return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
134
+ return None
135
+
136
+ def run_inference(image, prompt, adapter="stage1"):
137
+ model.set_adapter(adapter)
138
+ content = [{"type": "image", "image": image}, {"type": "text", "text": prompt}]
139
+ messages = [{"role": "user", "content": content}]
140
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
141
+ inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt").to(model.device)
142
+ with torch.no_grad():
143
+ output = model.generate(**inputs, max_new_tokens=256, do_sample=False)
144
+ response = processor.decode(output[0], skip_special_tokens=True)
145
+ if "assistant" in response.lower():
146
+ response = response.split("assistant")[-1].strip()
147
+ return response
148
+
149
+ def extract_point(text) -> Tuple[Optional[float], Optional[float]]:
150
+ match = re.search(r"<point x='([\d.]+)' y='([\d.]+)'>", text)
151
+ if match:
152
+ return float(match.group(1)), float(match.group(2))
153
+ return None, None
154
+
155
+ def extract_bbox(text) -> Optional[List[float]]:
156
+ match = re.search(r"<box x1='([\d.]+)' y1='([\d.]+)' x2='([\d.]+)' y2='([\d.]+)'>", text)
157
+ if match:
158
+ return [float(match.group(i)) for i in range(1, 5)]
159
+ return None
160
+
161
+ def draw_point_on_image(image: Image.Image, x: float, y: float, label: str) -> Image.Image:
162
+ """Draw point marker on image for visualization."""
163
+ img = image.copy()
164
+ draw = ImageDraw.Draw(img)
165
+ w, h = img.size
166
+ px, py = int(x * w / 100), int(y * h / 100)
167
+
168
+ # Draw crosshair
169
+ draw.ellipse([px-8, py-8, px+8, py+8], fill="red", outline="white", width=2)
170
+ draw.line([px-20, py, px+20, py], fill="white", width=2)
171
+ draw.line([px, py-20, px, py+20], fill="white", width=2)
172
+
173
+ # Draw label
174
+ draw.text((10, 10), f"{label}: ({x:.1f}, {y:.1f})", fill="white")
175
+
176
+ return img
177
+
178
+ def draw_bbox_on_image(image: Image.Image, bbox: List[float], label: str) -> Image.Image:
179
+ """Draw bounding box on image for visualization."""
180
+ img = image.copy()
181
+ draw = ImageDraw.Draw(img)
182
+ w, h = img.size
183
+ x1, y1, x2, y2 = [int(c * w / 100) if i % 2 == 0 else int(c * h / 100) for i, c in enumerate(bbox)]
184
+
185
+ draw.rectangle([x1, y1, x2, y2], outline="lime", width=3)
186
+ draw.text((10, 10), f"{label}: [{bbox[0]:.0f},{bbox[1]:.0f}]-[{bbox[2]:.0f},{bbox[3]:.0f}]", fill="white")
187
+
188
+ return img
189
+
190
+ # ============================================================
191
+ # Test All Frames
192
+ # ============================================================
193
+
194
+ print("\n" + "=" * 60)
195
+ print("🧪 TESTING ALL FRAMES FOR CURATION")
196
+ print("=" * 60)
197
+
198
+ all_results = []
199
+
200
+ for video_id in VIDEOS_TO_TEST:
201
+ print(f"\n📹 Processing {video_id}...")
202
+
203
+ for frame_idx in FRAMES_PER_VIDEO:
204
+ frame = extract_frame(video_id, frame_idx)
205
+ if frame is None:
206
+ print(f" ⚠ Frame {frame_idx} failed")
207
+ continue
208
+
209
+ print(f" Frame {frame_idx}:")
210
+
211
+ # Test pointing
212
+ for target in POINT_TARGETS:
213
+ prompt = f"Point to the {target} in this surgical image."
214
+ response = run_inference(frame, prompt, adapter="stage1")
215
+ x, y = extract_point(response)
216
+ success = x is not None and 0 <= x <= 100 and 0 <= y <= 100
217
+
218
+ result = {
219
+ "id": f"{video_id}_{frame_idx}_point_{target.replace(' ', '_')}",
220
+ "video_id": video_id,
221
+ "frame_idx": frame_idx,
222
+ "task": "point",
223
+ "target": target,
224
+ "response": response,
225
+ "x": x,
226
+ "y": y,
227
+ "success": success,
228
+ }
229
+ all_results.append(result)
230
+
231
+ # Save visualization
232
+ if success:
233
+ viz = draw_point_on_image(frame, x, y, target)
234
+ viz_path = OUTPUT_DIR / f"{video_id}_{frame_idx}_point_{target.replace(' ', '_')}.jpg"
235
+ viz.save(viz_path, quality=90)
236
+
237
+ status = "✅" if success else "❌"
238
+ coords = f"({x:.1f}, {y:.1f})" if success else "FAILED"
239
+ print(f" {status} Point {target}: {coords}")
240
+
241
+ # Test bbox
242
+ for target in BBOX_TARGETS:
243
+ prompt = f"Draw a bounding box around the {target}."
244
+ response = run_inference(frame, prompt, adapter="stage2")
245
+ bbox = extract_bbox(response)
246
+ success = bbox is not None and all(0 <= c <= 100 for c in bbox)
247
+
248
+ result = {
249
+ "id": f"{video_id}_{frame_idx}_bbox_{target.replace(' ', '_')}",
250
+ "video_id": video_id,
251
+ "frame_idx": frame_idx,
252
+ "task": "bbox",
253
+ "target": target,
254
+ "response": response,
255
+ "bbox": bbox,
256
+ "success": success,
257
+ }
258
+ all_results.append(result)
259
+
260
+ # Save visualization
261
+ if success:
262
+ viz = draw_bbox_on_image(frame, bbox, target)
263
+ viz_path = OUTPUT_DIR / f"{video_id}_{frame_idx}_bbox_{target.replace(' ', '_')}.jpg"
264
+ viz.save(viz_path, quality=90)
265
+
266
+ status = "✅" if success else "❌"
267
+ coords = f"[{bbox[0]:.0f}-{bbox[2]:.0f}]x[{bbox[1]:.0f}-{bbox[3]:.0f}]" if success else "FAILED"
268
+ print(f" {status} BBox {target}: {coords}")
269
+
270
+ # Also save raw frame for reference
271
+ raw_path = OUTPUT_DIR / f"{video_id}_{frame_idx}_raw.jpg"
272
+ frame.save(raw_path, quality=90)
273
+
274
+ # ============================================================
275
+ # Save Results
276
+ # ============================================================
277
+
278
+ print("\n" + "=" * 60)
279
+ print("💾 SAVING FOR CURATION")
280
+ print("=" * 60)
281
+
282
+ # Save all results as JSON
283
+ with open(OUTPUT_DIR / "all_results.json", "w") as f:
284
+ json.dump(all_results, f, indent=2)
285
+
286
+ # Summary
287
+ successful = [r for r in all_results if r["success"]]
288
+ print(f"Total tests: {len(all_results)}")
289
+ print(f"Successful: {len(successful)} ({100*len(successful)/len(all_results):.1f}%)")
290
+
291
+ # Create curation index
292
+ index_html = """<!DOCTYPE html>
293
+ <html>
294
+ <head><title>PitVQA Curation Review</title>
295
+ <style>
296
+ body { font-family: sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; }
297
+ .result { display: inline-block; margin: 10px; text-align: center; }
298
+ .result img { max-width: 300px; border: 2px solid #ccc; }
299
+ .success { border-color: green !important; }
300
+ .fail { border-color: red !important; }
301
+ </style>
302
+ </head>
303
+ <body>
304
+ <h1>PitVQA Curation Review</h1>
305
+ <p>Review these results and note which ones are good examples.</p>
306
+ """
307
+
308
+ for r in successful:
309
+ img_name = f"{r['id']}.jpg"
310
+ index_html += f"""
311
+ <div class="result">
312
+ <img src="{img_name}" class="success">
313
+ <br><small>{r['video_id']} f{r['frame_idx']}<br>{r['task']}: {r['target']}</small>
314
+ </div>
315
+ """
316
+
317
+ index_html += "</body></html>"
318
+
319
+ with open(OUTPUT_DIR / "index.html", "w") as f:
320
+ f.write(index_html)
321
+
322
+ # Upload to HuggingFace as dataset for review
323
+ print("\n📤 Uploading for review...")
324
+
325
+ try:
326
+ # Create/upload to a review dataset
327
+ REVIEW_REPO = "mmrech/pitvqa-curation-review"
328
+ api.create_repo(REVIEW_REPO, repo_type="dataset", exist_ok=True)
329
+ api.upload_folder(
330
+ folder_path=str(OUTPUT_DIR),
331
+ repo_id=REVIEW_REPO,
332
+ repo_type="dataset"
333
+ )
334
+ print(f"✓ Uploaded to https://huggingface.co/datasets/{REVIEW_REPO}")
335
+ except Exception as e:
336
+ print(f"⚠ Upload error: {e}")
337
+
338
+ print("\n✅ DONE!")
339
+ print(f"Review the results at: https://huggingface.co/datasets/mmrech/pitvqa-curation-review")
340
+ print("Then tell me which examples to use for the showcase.")