kobiakor15 commited on
Commit
8d515d0
·
verified ·
1 Parent(s): 52ac305

Upload demo_oculus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. demo_oculus.py +192 -0
demo_oculus.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Oculus Car Part Detection Demo
4
+
5
+ Demonstrates detection on car images using the extended training model.
6
+ """
7
+
8
+ import sys
9
+ import requests
10
+ from io import BytesIO
11
+ from PIL import Image, ImageDraw, ImageFont
12
+ import torch
13
+ import numpy as np
14
+
15
+ # Add parent to path
16
+ from pathlib import Path
17
+ sys.path.insert(0, str(Path(__file__).parent))
18
+
19
+ from oculus_unified_model import OculusForConditionalGeneration
20
+
21
+ def visualize_results(image, output, filename="output_car_parts.png"):
22
+ """Draw bounding boxes and labels on image."""
23
+ draw = ImageDraw.Draw(image)
24
+
25
+ # Try to load a font
26
+ try:
27
+ font = ImageFont.truetype("/System/Library/Fonts/Helvetica.ttc", 16)
28
+ except:
29
+ font = ImageFont.load_default()
30
+
31
+ width, height = image.size
32
+
33
+ # COCO Classes
34
+ COCO_CLASSES = [
35
+ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck',
36
+ 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench',
37
+ 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
38
+ 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
39
+ 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
40
+ 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
41
+ 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
42
+ 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
43
+ 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
44
+ 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
45
+ 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier',
46
+ 'toothbrush'
47
+ ]
48
+
49
+ # Draw boxes
50
+ for box, label, conf in zip(output.boxes, output.labels, output.confidences):
51
+ # Box is [x1, y1, x2, y2] normalized
52
+ x1, y1, x2, y2 = box
53
+
54
+ # Clamp normalized coords
55
+ x1 = max(0.0, min(1.0, x1))
56
+ y1 = max(0.0, min(1.0, y1))
57
+ x2 = max(0.0, min(1.0, x2))
58
+ y2 = max(0.0, min(1.0, y2))
59
+
60
+ # Ensure valid box
61
+ if x2 <= x1 or y2 <= y1:
62
+ continue
63
+
64
+ x1 *= width
65
+ y1 *= height
66
+ x2 *= width
67
+ y2 *= height
68
+
69
+ # Color based on confidence
70
+ color = "red" if conf < 0.5 else "green"
71
+
72
+ draw.rectangle([x1, y1, x2, y2], outline=color, width=3)
73
+
74
+ # Label
75
+ try:
76
+ class_name = COCO_CLASSES[int(label)]
77
+ except:
78
+ class_name = str(label)
79
+
80
+ label_text = f"{class_name} ({conf:.2f})"
81
+
82
+ # Draw text background
83
+ text_bbox = draw.textbbox((x1, y1), label_text, font=font)
84
+ draw.rectangle(text_bbox, fill=color)
85
+ draw.text((x1, y1), label_text, fill="white", font=font)
86
+
87
+ image.save(filename)
88
+ print(f"Saved visualization to {filename}")
89
+
90
+ def main():
91
+ import argparse
92
+ parser = argparse.ArgumentParser(description="Oculus General Object Detection Demo")
93
+ parser.add_argument("--image", type=str, help="Path to image file to test")
94
+ parser.add_argument("--prompt", type=str, default="Detect objects", help="Text prompt for the model")
95
+ parser.add_argument("--mode", type=str, default="box", choices=["box", "vqa", "caption"], help="Inference mode")
96
+ parser.add_argument("--threshold", type=float, default=0.2, help="Detection threshold")
97
+ parser.add_argument("--output", type=str, default="detection_result.png", help="Output filename")
98
+ args = parser.parse_args()
99
+
100
+ # ... (Checkpoint loading logic remains the same) ...
101
+ # Find latest checkpoint
102
+ checkpoint_dir = Path("checkpoints/oculus_detection_v2")
103
+ model_path = None
104
+
105
+ if checkpoint_dir.exists():
106
+ # Get all step folders
107
+ steps = []
108
+ for d in checkpoint_dir.iterdir():
109
+ if d.is_dir() and d.name.startswith("step_"):
110
+ try:
111
+ step = int(d.name.split("_")[1])
112
+ steps.append((step, d))
113
+ except:
114
+ pass
115
+
116
+ # Sort and pick latest
117
+ if steps:
118
+ steps.sort(key=lambda x: x[0], reverse=True)
119
+ model_path = str(steps[0][1])
120
+ print(f"✨ Found latest checkpoint: {model_path}")
121
+
122
+ if model_path is None:
123
+ model_path = str(checkpoint_dir / "final")
124
+
125
+ # Fallback to initial detection checkpoint if extended one isn't ready
126
+ if not Path(model_path).exists():
127
+ model_path = "checkpoints/oculus_detection/final"
128
+ print(f"⚠️ Extended V2 model not found, falling back to V1: {model_path}")
129
+
130
+ print(f"Loading model from {model_path}...")
131
+ try:
132
+ model = OculusForConditionalGeneration.from_pretrained(model_path)
133
+
134
+ # Load heads
135
+ heads_path = Path(model_path) / "heads.pth"
136
+ if heads_path.exists():
137
+ heads = torch.load(heads_path, map_location="cpu")
138
+ model.detection_head.load_state_dict(heads['detection'])
139
+ print("✓ Loaded detection heads")
140
+ except Exception as e:
141
+ print(f"Error loading model: {e}")
142
+ return
143
+
144
+ # Image logic
145
+ if args.image:
146
+ image_path = args.image
147
+ print(f"\nProcessing Custom Image: {image_path}...")
148
+ else:
149
+ # Use a generic COCO sample (dining table/people) instead of car if possible
150
+ # defaulting to the car one is fine, but let's see if we have others
151
+ image_path = "data/coco/images/000000071345.jpg"
152
+ print(f"\nProcessing Default Image: {image_path}...")
153
+
154
+ try:
155
+ if Path(image_path).exists():
156
+ image = Image.open(image_path).convert('RGB')
157
+ else:
158
+ # Fallback to online image
159
+ # Let's use a more crowded scene for generic detection
160
+ url = "https://upload.wikimedia.org/wikipedia/commons/thumb/8/8d/President_Barack_Obama.jpg/800px-President_Barack_Obama.jpg"
161
+ print(f"Image not found, downloading sample {url}...")
162
+ response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'})
163
+ image = Image.open(BytesIO(response.content)).convert('RGB')
164
+
165
+ # Mode selection
166
+ if args.mode == "box":
167
+ print(f"Running detection with prompt: '{args.prompt}'...")
168
+ output = model.generate(
169
+ image,
170
+ mode="box",
171
+ prompt=args.prompt,
172
+ threshold=args.threshold
173
+ )
174
+ print(f"Found {len(output.boxes)} objects")
175
+ visualize_results(image, output, args.output)
176
+
177
+ elif args.mode == "caption":
178
+ print("Generating caption...")
179
+ output = model.generate(image, mode="text", prompt="A photo of")
180
+ print(f"\n📝 Caption: {output.text}\n")
181
+
182
+ elif args.mode == "vqa":
183
+ question = args.prompt if args.prompt != "Detect objects" else "What is in this image?"
184
+ print(f"Thinking about question: '{question}'...")
185
+ output = model.generate(image, mode="text", prompt=question)
186
+ print(f"\n🤔 Answer: {output.text}\n")
187
+
188
+ except Exception as e:
189
+ print(f"Error processing image: {e}")
190
+
191
+ if __name__ == "__main__":
192
+ main()