File size: 7,413 Bytes
8d515d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
#!/usr/bin/env python3
"""
Oculus Car Part Detection Demo
Demonstrates detection on car images using the extended training model.
"""
import sys
import requests
from io import BytesIO
from PIL import Image, ImageDraw, ImageFont
import torch
import numpy as np
# Add parent to path
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from oculus_unified_model import OculusForConditionalGeneration
def visualize_results(image, output, filename="output_car_parts.png"):
"""Draw bounding boxes and labels on image."""
draw = ImageDraw.Draw(image)
# Try to load a font
try:
font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 16)
except:
font = ImageFont.load_default()
width, height = image.size
# COCO Classes
COCO_CLASSES = [
'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench',
'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
'toothbrush'
]
# Draw boxes
for box, label, conf in zip(output.boxes, output.labels, output.confidences):
# Box is [x1, y1, x2, y2] normalized
x1, y1, x2, y2 = box
# Clamp normalized coords
x1 = max(0.0, min(1.0, x1))
y1 = max(0.0, min(1.0, y1))
x2 = max(0.0, min(1.0, x2))
y2 = max(0.0, min(1.0, y2))
# Ensure valid box
if x2 <= x1 or y2 <= y1:
continue
x1 *= width
y1 *= height
x2 *= width
y2 *= height
# Color based on confidence
color = "red" if conf < 0.5 else "green"
draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
# Label
try:
class_name = COCO_CLASSES[int(label)]
except:
class_name = str(label)
label_text = f"{class_name} ({conf:.2f})"
# Draw text background
text_bbox = draw.textbbox((x1, y1), label_text, font=font)
draw.rectangle(text_bbox, fill=color)
draw.text((x1, y1), label_text, fill="white", font=font)
image.save(filename)
print(f"Saved visualization to {filename}")
def main():
import argparse
parser = argparse.ArgumentParser(description="Oculus General Object Detection Demo")
parser.add_argument("--image", type=str, help="Path to image file to test")
parser.add_argument("--prompt", type=str, default="Detect objects", help="Text prompt for the model")
parser.add_argument("--mode", type=str, default="box", choices=["box", "vqa", "caption"], help="Inference mode")
parser.add_argument("--threshold", type=float, default=0.2, help="Detection threshold")
parser.add_argument("--output", type=str, default="detection_result.png", help="Output filename")
args = parser.parse_args()
# ... (Checkpoint loading logic remains the same) ...
# Find latest checkpoint
checkpoint_dir = Path("checkpoints/oculus_detection_v2")
model_path = None
if checkpoint_dir.exists():
# Get all step folders
steps = []
for d in checkpoint_dir.iterdir():
if d.is_dir() and d.name.startswith("step_"):
try:
step = int(d.name.split("_")[1])
steps.append((step, d))
except:
pass
# Sort and pick latest
if steps:
steps.sort(key=lambda x: x[0], reverse=True)
model_path = str(steps[0][1])
print(f"✨ Found latest checkpoint: {model_path}")
if model_path is None:
model_path = str(checkpoint_dir / "final")
# Fallback to initial detection checkpoint if extended one isn't ready
if not Path(model_path).exists():
model_path = "checkpoints/oculus_detection/final"
print(f"⚠️ Extended V2 model not found, falling back to V1: {model_path}")
print(f"Loading model from {model_path}...")
try:
model = OculusForConditionalGeneration.from_pretrained(model_path)
# Load heads
heads_path = Path(model_path) / "heads.pth"
if heads_path.exists():
heads = torch.load(heads_path, map_location="cpu")
model.detection_head.load_state_dict(heads['detection'])
print("✓ Loaded detection heads")
except Exception as e:
print(f"Error loading model: {e}")
return
# Image logic
if args.image:
image_path = args.image
print(f"\nProcessing Custom Image: {image_path}...")
else:
# Use a generic COCO sample (dining table/people) instead of car if possible
# defaulting to the car one is fine, but let's see if we have others
image_path = "data/coco/images/000000071345.jpg"
print(f"\nProcessing Default Image: {image_path}...")
try:
if Path(image_path).exists():
image = Image.open(image_path).convert('RGB')
else:
# Fallback to online image
# Let's use a more crowded scene for generic detection
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/8/8d/President_Barack_Obama.jpg/800px-President_Barack_Obama.jpg"
print(f"Image not found, downloading sample {url}...")
response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'})
image = Image.open(BytesIO(response.content)).convert('RGB')
# Mode selection
if args.mode == "box":
print(f"Running detection with prompt: '{args.prompt}'...")
output = model.generate(
image,
mode="box",
prompt=args.prompt,
threshold=args.threshold
)
print(f"Found {len(output.boxes)} objects")
visualize_results(image, output, args.output)
elif args.mode == "caption":
print("Generating caption...")
output = model.generate(image, mode="text", prompt="A photo of")
print(f"\n📝 Caption: {output.text}\n")
elif args.mode == "vqa":
question = args.prompt if args.prompt != "Detect objects" else "What is in this image?"
print(f"Thinking about question: '{question}'...")
output = model.generate(image, mode="text", prompt=question)
print(f"\n🤔 Answer: {output.text}\n")
except Exception as e:
print(f"Error processing image: {e}")
if __name__ == "__main__":
main()
|