|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Test ALL frames for manual curation. |
|
|
|
|
|
Saves all results with images for human review. |
|
|
Does NOT auto-select - human curator will pick best examples. |
|
|
|
|
|
Run with: hf jobs uv run --flavor a10g-large --secrets HF_TOKEN test_all_frames_for_curation.py |
|
|
""" |
|
|
|
|
|
import os |
|
|
import cv2 |
|
|
import re |
|
|
import json |
|
|
import torch |
|
|
import base64 |
|
|
from io import BytesIO |
|
|
from PIL import Image, ImageDraw, ImageFont |
|
|
from pathlib import Path |
|
|
from typing import Optional, List, Tuple |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UNIFIED_MODEL = "mmrech/pitvqa-qwen2vl-unified-v2" |
|
|
VIDEO_DATASET = "UCL-WEISS/PitVis-2023" |
|
|
|
|
|
VIDEO_CACHE = Path("/tmp/videos") |
|
|
VIDEO_CACHE.mkdir(exist_ok=True) |
|
|
|
|
|
OUTPUT_DIR = Path("./curation_review") |
|
|
OUTPUT_DIR.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
VIDEOS_TO_TEST = ["video_01", "video_02", "video_03", "video_05", "video_06", "video_10", "video_15", "video_20"] |
|
|
FRAMES_PER_VIDEO = [200, 500, 800, 1200, 1800] |
|
|
|
|
|
|
|
|
POINT_TARGETS = ["suction device", "surgical instruments"] |
|
|
BBOX_TARGETS = ["suction device", "surgical instruments"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from huggingface_hub import login, HfApi, hf_hub_download |
|
|
|
|
|
hf_token = os.environ.get("HF_TOKEN") |
|
|
if hf_token: |
|
|
login(token=hf_token) |
|
|
print("✓ Logged in to HuggingFace") |
|
|
|
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n🤖 Loading model...") |
|
|
|
|
|
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, BitsAndBytesConfig |
|
|
from peft import PeftModel |
|
|
|
|
|
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True) |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_quant_type="nf4", |
|
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
) |
|
|
|
|
|
base = Qwen2VLForConditionalGeneration.from_pretrained( |
|
|
"Qwen/Qwen2-VL-2B-Instruct", |
|
|
quantization_config=bnb_config, |
|
|
device_map="auto", |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
model = PeftModel.from_pretrained(base, UNIFIED_MODEL, adapter_name="stage1", subfolder="stage1") |
|
|
model.load_adapter(UNIFIED_MODEL, adapter_name="stage2", subfolder="stage2") |
|
|
|
|
|
print(f"✓ Model loaded") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def download_video(video_id: str) -> Optional[Path]: |
|
|
video_path = VIDEO_CACHE / f"{video_id}.mp4" |
|
|
if not video_path.exists(): |
|
|
try: |
|
|
downloaded = hf_hub_download( |
|
|
repo_id=VIDEO_DATASET, |
|
|
filename=f"videos/{video_id}.mp4", |
|
|
repo_type="dataset" |
|
|
) |
|
|
import shutil |
|
|
shutil.copy(downloaded, video_path) |
|
|
except Exception as e: |
|
|
print(f" ⚠ Could not download {video_id}: {e}") |
|
|
return None |
|
|
return video_path |
|
|
|
|
|
def extract_frame(video_id: str, frame_idx: int) -> Optional[Image.Image]: |
|
|
video_path = download_video(video_id) |
|
|
if video_path is None: |
|
|
return None |
|
|
cap = cv2.VideoCapture(str(video_path)) |
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) |
|
|
ret, frame = cap.read() |
|
|
cap.release() |
|
|
if ret: |
|
|
return Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) |
|
|
return None |
|
|
|
|
|
def run_inference(image, prompt, adapter="stage1"): |
|
|
model.set_adapter(adapter) |
|
|
content = [{"type": "image", "image": image}, {"type": "text", "text": prompt}] |
|
|
messages = [{"role": "user", "content": content}] |
|
|
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
|
inputs = processor(text=[text], images=[image], padding=True, return_tensors="pt").to(model.device) |
|
|
with torch.no_grad(): |
|
|
output = model.generate(**inputs, max_new_tokens=256, do_sample=False) |
|
|
response = processor.decode(output[0], skip_special_tokens=True) |
|
|
if "assistant" in response.lower(): |
|
|
response = response.split("assistant")[-1].strip() |
|
|
return response |
|
|
|
|
|
def extract_point(text) -> Tuple[Optional[float], Optional[float]]: |
|
|
match = re.search(r"<point x='([\d.]+)' y='([\d.]+)'>", text) |
|
|
if match: |
|
|
return float(match.group(1)), float(match.group(2)) |
|
|
return None, None |
|
|
|
|
|
def extract_bbox(text) -> Optional[List[float]]: |
|
|
match = re.search(r"<box x1='([\d.]+)' y1='([\d.]+)' x2='([\d.]+)' y2='([\d.]+)'>", text) |
|
|
if match: |
|
|
return [float(match.group(i)) for i in range(1, 5)] |
|
|
return None |
|
|
|
|
|
def draw_point_on_image(image: Image.Image, x: float, y: float, label: str) -> Image.Image: |
|
|
"""Draw point marker on image for visualization.""" |
|
|
img = image.copy() |
|
|
draw = ImageDraw.Draw(img) |
|
|
w, h = img.size |
|
|
px, py = int(x * w / 100), int(y * h / 100) |
|
|
|
|
|
|
|
|
draw.ellipse([px-8, py-8, px+8, py+8], fill="red", outline="white", width=2) |
|
|
draw.line([px-20, py, px+20, py], fill="white", width=2) |
|
|
draw.line([px, py-20, px, py+20], fill="white", width=2) |
|
|
|
|
|
|
|
|
draw.text((10, 10), f"{label}: ({x:.1f}, {y:.1f})", fill="white") |
|
|
|
|
|
return img |
|
|
|
|
|
def draw_bbox_on_image(image: Image.Image, bbox: List[float], label: str) -> Image.Image: |
|
|
"""Draw bounding box on image for visualization.""" |
|
|
img = image.copy() |
|
|
draw = ImageDraw.Draw(img) |
|
|
w, h = img.size |
|
|
x1, y1, x2, y2 = [int(c * w / 100) if i % 2 == 0 else int(c * h / 100) for i, c in enumerate(bbox)] |
|
|
|
|
|
draw.rectangle([x1, y1, x2, y2], outline="lime", width=3) |
|
|
draw.text((10, 10), f"{label}: [{bbox[0]:.0f},{bbox[1]:.0f}]-[{bbox[2]:.0f},{bbox[3]:.0f}]", fill="white") |
|
|
|
|
|
return img |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("🧪 TESTING ALL FRAMES FOR CURATION") |
|
|
print("=" * 60) |
|
|
|
|
|
all_results = [] |
|
|
|
|
|
for video_id in VIDEOS_TO_TEST: |
|
|
print(f"\n📹 Processing {video_id}...") |
|
|
|
|
|
for frame_idx in FRAMES_PER_VIDEO: |
|
|
frame = extract_frame(video_id, frame_idx) |
|
|
if frame is None: |
|
|
print(f" ⚠ Frame {frame_idx} failed") |
|
|
continue |
|
|
|
|
|
print(f" Frame {frame_idx}:") |
|
|
|
|
|
|
|
|
for target in POINT_TARGETS: |
|
|
prompt = f"Point to the {target} in this surgical image." |
|
|
response = run_inference(frame, prompt, adapter="stage1") |
|
|
x, y = extract_point(response) |
|
|
success = x is not None and 0 <= x <= 100 and 0 <= y <= 100 |
|
|
|
|
|
result = { |
|
|
"id": f"{video_id}_{frame_idx}_point_{target.replace(' ', '_')}", |
|
|
"video_id": video_id, |
|
|
"frame_idx": frame_idx, |
|
|
"task": "point", |
|
|
"target": target, |
|
|
"response": response, |
|
|
"x": x, |
|
|
"y": y, |
|
|
"success": success, |
|
|
} |
|
|
all_results.append(result) |
|
|
|
|
|
|
|
|
if success: |
|
|
viz = draw_point_on_image(frame, x, y, target) |
|
|
viz_path = OUTPUT_DIR / f"{video_id}_{frame_idx}_point_{target.replace(' ', '_')}.jpg" |
|
|
viz.save(viz_path, quality=90) |
|
|
|
|
|
status = "✅" if success else "❌" |
|
|
coords = f"({x:.1f}, {y:.1f})" if success else "FAILED" |
|
|
print(f" {status} Point {target}: {coords}") |
|
|
|
|
|
|
|
|
for target in BBOX_TARGETS: |
|
|
prompt = f"Draw a bounding box around the {target}." |
|
|
response = run_inference(frame, prompt, adapter="stage2") |
|
|
bbox = extract_bbox(response) |
|
|
success = bbox is not None and all(0 <= c <= 100 for c in bbox) |
|
|
|
|
|
result = { |
|
|
"id": f"{video_id}_{frame_idx}_bbox_{target.replace(' ', '_')}", |
|
|
"video_id": video_id, |
|
|
"frame_idx": frame_idx, |
|
|
"task": "bbox", |
|
|
"target": target, |
|
|
"response": response, |
|
|
"bbox": bbox, |
|
|
"success": success, |
|
|
} |
|
|
all_results.append(result) |
|
|
|
|
|
|
|
|
if success: |
|
|
viz = draw_bbox_on_image(frame, bbox, target) |
|
|
viz_path = OUTPUT_DIR / f"{video_id}_{frame_idx}_bbox_{target.replace(' ', '_')}.jpg" |
|
|
viz.save(viz_path, quality=90) |
|
|
|
|
|
status = "✅" if success else "❌" |
|
|
coords = f"[{bbox[0]:.0f}-{bbox[2]:.0f}]x[{bbox[1]:.0f}-{bbox[3]:.0f}]" if success else "FAILED" |
|
|
print(f" {status} BBox {target}: {coords}") |
|
|
|
|
|
|
|
|
raw_path = OUTPUT_DIR / f"{video_id}_{frame_idx}_raw.jpg" |
|
|
frame.save(raw_path, quality=90) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("💾 SAVING FOR CURATION") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
with open(OUTPUT_DIR / "all_results.json", "w") as f: |
|
|
json.dump(all_results, f, indent=2) |
|
|
|
|
|
|
|
|
successful = [r for r in all_results if r["success"]] |
|
|
print(f"Total tests: {len(all_results)}") |
|
|
print(f"Successful: {len(successful)} ({100*len(successful)/len(all_results):.1f}%)") |
|
|
|
|
|
|
|
|
index_html = """<!DOCTYPE html> |
|
|
<html> |
|
|
<head><title>PitVQA Curation Review</title> |
|
|
<style> |
|
|
body { font-family: sans-serif; max-width: 1200px; margin: 0 auto; padding: 20px; } |
|
|
.result { display: inline-block; margin: 10px; text-align: center; } |
|
|
.result img { max-width: 300px; border: 2px solid #ccc; } |
|
|
.success { border-color: green !important; } |
|
|
.fail { border-color: red !important; } |
|
|
</style> |
|
|
</head> |
|
|
<body> |
|
|
<h1>PitVQA Curation Review</h1> |
|
|
<p>Review these results and note which ones are good examples.</p> |
|
|
""" |
|
|
|
|
|
for r in successful: |
|
|
img_name = f"{r['id']}.jpg" |
|
|
index_html += f""" |
|
|
<div class="result"> |
|
|
<img src="{img_name}" class="success"> |
|
|
<br><small>{r['video_id']} f{r['frame_idx']}<br>{r['task']}: {r['target']}</small> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
index_html += "</body></html>" |
|
|
|
|
|
with open(OUTPUT_DIR / "index.html", "w") as f: |
|
|
f.write(index_html) |
|
|
|
|
|
|
|
|
print("\n📤 Uploading for review...") |
|
|
|
|
|
try: |
|
|
|
|
|
REVIEW_REPO = "mmrech/pitvqa-curation-review" |
|
|
api.create_repo(REVIEW_REPO, repo_type="dataset", exist_ok=True) |
|
|
api.upload_folder( |
|
|
folder_path=str(OUTPUT_DIR), |
|
|
repo_id=REVIEW_REPO, |
|
|
repo_type="dataset" |
|
|
) |
|
|
print(f"✓ Uploaded to https://huggingface.co/datasets/{REVIEW_REPO}") |
|
|
except Exception as e: |
|
|
print(f"⚠ Upload error: {e}") |
|
|
|
|
|
print("\n✅ DONE!") |
|
|
print(f"Review the results at: https://huggingface.co/datasets/mmrech/pitvqa-curation-review") |
|
|
print("Then tell me which examples to use for the showcase.") |
|
|
|