File size: 7,413 Bytes
8d515d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#!/usr/bin/env python3
"""
Oculus Car Part Detection Demo

Demonstrates detection on car images using the extended training model.
"""

import sys
import requests
from io import BytesIO
from PIL import Image, ImageDraw, ImageFont
import torch
import numpy as np

# Add parent to path
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))

from oculus_unified_model import OculusForConditionalGeneration

def visualize_results(image, output, filename="output_car_parts.png"):
    """Draw bounding boxes and labels on image."""
    draw = ImageDraw.Draw(image)
    
    # Try to load a font
    try:
        font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 16)
    except:
        font = ImageFont.load_default()
    
    width, height = image.size
    
    # COCO Classes
    COCO_CLASSES = [
        'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
        'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench',
        'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
        'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
        'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
        'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
        'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
        'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
        'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
        'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
        'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
        'toothbrush'
    ]

    # Draw boxes
    for box, label, conf in zip(output.boxes, output.labels, output.confidences):
        # Box is [x1, y1, x2, y2] normalized
        x1, y1, x2, y2 = box
        
        # Clamp normalized coords
        x1 = max(0.0, min(1.0, x1))
        y1 = max(0.0, min(1.0, y1))
        x2 = max(0.0, min(1.0, x2))
        y2 = max(0.0, min(1.0, y2))
        
        # Ensure valid box
        if x2 <= x1 or y2 <= y1:
            continue
            
        x1 *= width
        y1 *= height
        x2 *= width
        y2 *= height
        
        # Color based on confidence
        color = "red" if conf < 0.5 else "green"
        
        draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
        
        # Label
        try:
            class_name = COCO_CLASSES[int(label)]
        except:
            class_name = str(label)
            
        label_text = f"{class_name} ({conf:.2f})"
        
        # Draw text background
        text_bbox = draw.textbbox((x1, y1), label_text, font=font)
        draw.rectangle(text_bbox, fill=color)
        draw.text((x1, y1), label_text, fill="white", font=font)
        
    image.save(filename)
    print(f"Saved visualization to {filename}")

def main():
    import argparse
    parser = argparse.ArgumentParser(description="Oculus General Object Detection Demo")
    parser.add_argument("--image", type=str, help="Path to image file to test")
    parser.add_argument("--prompt", type=str, default="Detect objects", help="Text prompt for the model")
    parser.add_argument("--mode", type=str, default="box", choices=["box", "vqa", "caption"], help="Inference mode")
    parser.add_argument("--threshold", type=float, default=0.2, help="Detection threshold")
    parser.add_argument("--output", type=str, default="detection_result.png", help="Output filename")
    args = parser.parse_args()
    
    # ... (Checkpoint loading logic remains the same) ...
    # Find latest checkpoint
    checkpoint_dir = Path("checkpoints/oculus_detection_v2")
    model_path = None
    
    if checkpoint_dir.exists():
        # Get all step folders
        steps = []
        for d in checkpoint_dir.iterdir():
            if d.is_dir() and d.name.startswith("step_"):
                try:
                    step = int(d.name.split("_")[1])
                    steps.append((step, d))
                except:
                    pass
        
        # Sort and pick latest
        if steps:
            steps.sort(key=lambda x: x[0], reverse=True)
            model_path = str(steps[0][1])
            print(f"✨ Found latest checkpoint: {model_path}")
            
    if model_path is None:
        model_path = str(checkpoint_dir / "final")
        
    # Fallback to initial detection checkpoint if extended one isn't ready
    if not Path(model_path).exists():
        model_path = "checkpoints/oculus_detection/final"
        print(f"⚠️ Extended V2 model not found, falling back to V1: {model_path}")
    
    print(f"Loading model from {model_path}...")
    try:
        model = OculusForConditionalGeneration.from_pretrained(model_path)
        
        # Load heads
        heads_path = Path(model_path) / "heads.pth"
        if heads_path.exists():
            heads = torch.load(heads_path, map_location="cpu")
            model.detection_head.load_state_dict(heads['detection'])
            print("✓ Loaded detection heads")
    except Exception as e:
        print(f"Error loading model: {e}")
        return

    # Image logic
    if args.image:
        image_path = args.image
        print(f"\nProcessing Custom Image: {image_path}...")
    else:
        # Use a generic COCO sample (dining table/people) instead of car if possible
        # defaulting to the car one is fine, but let's see if we have others
        image_path = "data/coco/images/000000071345.jpg" 
        print(f"\nProcessing Default Image: {image_path}...")
    
    try:
        if Path(image_path).exists():
            image = Image.open(image_path).convert('RGB')
        else:
            # Fallback to online image
             # Let's use a more crowded scene for generic detection
            url = "https://upload.wikimedia.org/wikipedia/commons/thumb/8/8d/President_Barack_Obama.jpg/800px-President_Barack_Obama.jpg"
            print(f"Image not found, downloading sample {url}...")
            response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'})
            image = Image.open(BytesIO(response.content)).convert('RGB')
            
        # Mode selection
        if args.mode == "box":
            print(f"Running detection with prompt: '{args.prompt}'...")
            output = model.generate(
                image, 
                mode="box", 
                prompt=args.prompt,
                threshold=args.threshold
            )
            print(f"Found {len(output.boxes)} objects")
            visualize_results(image, output, args.output)
            
        elif args.mode == "caption":
            print("Generating caption...")
            output = model.generate(image, mode="text", prompt="A photo of")
            print(f"\n📝 Caption: {output.text}\n")
            
        elif args.mode == "vqa":
            question = args.prompt if args.prompt != "Detect objects" else "What is in this image?"
            print(f"Thinking about question: '{question}'...")
            output = model.generate(image, mode="text", prompt=question)
            print(f"\n🤔 Answer: {output.text}\n")
            
    except Exception as e:
        print(f"Error processing image: {e}")

if __name__ == "__main__":
    main()