Oculus / demo_oculus.py
kobiakor15's picture
Upload demo_oculus.py with huggingface_hub
8d515d0 verified
#!/usr/bin/env python3
"""
Oculus Car Part Detection Demo
Demonstrates detection on car images using the extended training model.
"""
import sys
import requests
from io import BytesIO
from PIL import Image, ImageDraw, ImageFont
import torch
import numpy as np
# Add parent to path
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from oculus_unified_model import OculusForConditionalGeneration
def visualize_results(image, output, filename="output_car_parts.png"):
"""Draw bounding boxes and labels on image."""
draw = ImageDraw.Draw(image)
# Try to load a font
try:
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 16)
except:
font = ImageFont.load_default()
width, height = image.size
# COCO Classes
COCO_CLASSES = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench',
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
'toothbrush'
]
# Draw boxes
for box, label, conf in zip(output.boxes, output.labels, output.confidences):
# Box is [x1, y1, x2, y2] normalized
x1, y1, x2, y2 = box
# Clamp normalized coords
x1 = max(0.0, min(1.0, x1))
y1 = max(0.0, min(1.0, y1))
x2 = max(0.0, min(1.0, x2))
y2 = max(0.0, min(1.0, y2))
# Ensure valid box
if x2 <= x1 or y2 <= y1:
continue
x1 *= width
y1 *= height
x2 *= width
y2 *= height
# Color based on confidence
color = "red" if conf < 0.5 else "green"
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
# Label
try:
class_name = COCO_CLASSES[int(label)]
except:
class_name = str(label)
label_text = f"{class_name} ({conf:.2f})"
# Draw text background
text_bbox = draw.textbbox((x1, y1), label_text, font=font)
draw.rectangle(text_bbox, fill=color)
draw.text((x1, y1), label_text, fill="white", font=font)
image.save(filename)
print(f"Saved visualization to {filename}")
def main():
import argparse
parser = argparse.ArgumentParser(description="Oculus General Object Detection Demo")
parser.add_argument("--image", type=str, help="Path to image file to test")
parser.add_argument("--prompt", type=str, default="Detect objects", help="Text prompt for the model")
parser.add_argument("--mode", type=str, default="box", choices=["box", "vqa", "caption"], help="Inference mode")
parser.add_argument("--threshold", type=float, default=0.2, help="Detection threshold")
parser.add_argument("--output", type=str, default="detection_result.png", help="Output filename")
args = parser.parse_args()
# ... (Checkpoint loading logic remains the same) ...
# Find latest checkpoint
checkpoint_dir = Path("checkpoints/oculus_detection_v2")
model_path = None
if checkpoint_dir.exists():
# Get all step folders
steps = []
for d in checkpoint_dir.iterdir():
if d.is_dir() and d.name.startswith("step_"):
try:
step = int(d.name.split("_")[1])
steps.append((step, d))
except:
pass
# Sort and pick latest
if steps:
steps.sort(key=lambda x: x[0], reverse=True)
model_path = str(steps[0][1])
print(f"✨ Found latest checkpoint: {model_path}")
if model_path is None:
model_path = str(checkpoint_dir / "final")
# Fallback to initial detection checkpoint if extended one isn't ready
if not Path(model_path).exists():
model_path = "checkpoints/oculus_detection/final"
print(f"⚠️ Extended V2 model not found, falling back to V1: {model_path}")
print(f"Loading model from {model_path}...")
try:
model = OculusForConditionalGeneration.from_pretrained(model_path)
# Load heads
heads_path = Path(model_path) / "heads.pth"
if heads_path.exists():
heads = torch.load(heads_path, map_location="cpu")
model.detection_head.load_state_dict(heads['detection'])
print("✓ Loaded detection heads")
except Exception as e:
print(f"Error loading model: {e}")
return
# Image logic
if args.image:
image_path = args.image
print(f"\nProcessing Custom Image: {image_path}...")
else:
# Use a generic COCO sample (dining table/people) instead of car if possible
# defaulting to the car one is fine, but let's see if we have others
image_path = "data/coco/images/000000071345.jpg"
print(f"\nProcessing Default Image: {image_path}...")
try:
if Path(image_path).exists():
image = Image.open(image_path).convert('RGB')
else:
# Fallback to online image
# Let's use a more crowded scene for generic detection
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/8/8d/President_Barack_Obama.jpg/800px-President_Barack_Obama.jpg"
print(f"Image not found, downloading sample {url}...")
response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'})
image = Image.open(BytesIO(response.content)).convert('RGB')
# Mode selection
if args.mode == "box":
print(f"Running detection with prompt: '{args.prompt}'...")
output = model.generate(
image,
mode="box",
prompt=args.prompt,
threshold=args.threshold
)
print(f"Found {len(output.boxes)} objects")
visualize_results(image, output, args.output)
elif args.mode == "caption":
print("Generating caption...")
output = model.generate(image, mode="text", prompt="A photo of")
print(f"\n📝 Caption: {output.text}\n")
elif args.mode == "vqa":
question = args.prompt if args.prompt != "Detect objects" else "What is in this image?"
print(f"Thinking about question: '{question}'...")
output = model.generate(image, mode="text", prompt=question)
print(f"\n🤔 Answer: {output.text}\n")
except Exception as e:
print(f"Error processing image: {e}")
if __name__ == "__main__":
main()