Oculus / demo_oculus_unified.py
kobiakor15's picture
Upload demo_oculus_unified.py with huggingface_hub
61cc71c verified
#!/usr/bin/env python3
"""
Oculus 0.2 Unified Demo
Demonstrates all features of the unified Oculus model:
- Text mode (captioning, VQA)
- Point mode (counting objects)
- Box mode (detection with bounding boxes)
- Polygon mode (segmentation)
- Optional reasoning with thinking traces
- Focus system for fine-grained perception
"""
import os
import sys
import requests
from pathlib import Path
from io import BytesIO
from PIL import Image
import torch
# Add parent to path
sys.path.insert(0, str(Path(__file__).parent))
from oculus_unified_model import OculusForConditionalGeneration, OculusConfig
def download_image(url: str) -> Image.Image:
"""Download image from URL."""
headers = {'User-Agent': 'Mozilla/5.0'}
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
return Image.open(BytesIO(response.content)).convert('RGB')
def print_header(title: str):
print("\n" + "=" * 70)
print(f"๐Ÿ”ฎ {title}")
print("=" * 70)
def print_section(title: str):
print(f"\n{'โ”€' * 70}")
print(f" {title}")
print(f"{'โ”€' * 70}")
def demo():
print_header("OCULUS 0.2 UNIFIED MODEL DEMO")
# ================================================================
# Load Model
# ================================================================
print("\n[1] Loading Oculus Model...")
# Check if we have trained weights
weights_path = Path(__file__).parent / "checkpoints" / "oculus_coco" / "final"
if weights_path.exists():
print(f" Found trained weights at: {weights_path}")
model = OculusForConditionalGeneration.from_pretrained(weights_path)
else:
print(" Using default configuration")
config = OculusConfig(
reasoning_enabled=True,
enable_focus=True,
)
model = OculusForConditionalGeneration(config)
print(" โœ“ Model loaded!")
# ================================================================
# Test Images
# ================================================================
test_images = [
{
"name": "Cat on Couch",
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg"
},
{
"name": "Golden Gate Bridge",
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/0/0c/GoldenGateBridge-001.jpg/1200px-GoldenGateBridge-001.jpg"
},
]
for test in test_images:
print_header(f"Testing: {test['name']}")
try:
print("\n[Downloading image...]")
image = download_image(test["url"])
print(f" Image size: {image.size}")
# ========================================================
# Mode 1: TEXT (Captioning)
# ========================================================
print_section("๐Ÿ“ TEXT MODE - Captioning")
output = model.generate(
image=image,
prompt="Describe this image in detail",
mode="text",
think=False
)
print(f" Caption: \"{output.text}\"")
# ========================================================
# Mode 2: TEXT with Reasoning
# ========================================================
print_section("๐Ÿง  TEXT MODE - With Reasoning")
output = model.generate(
image=image,
prompt="What is the main subject of this image?",
mode="text",
think=True # Enable thinking traces
)
if output.thinking_trace:
print(f" ๐Ÿ’ญ Thinking: {output.thinking_trace[:200]}...")
print(f" Answer: \"{output.text}\"")
# ========================================================
# Mode 3: TEXT (VQA)
# ========================================================
print_section("โ“ TEXT MODE - VQA")
questions = [
"What colors are visible in this image?",
"Is this indoors or outdoors?",
]
for q in questions:
output = model.generate(
image=image,
prompt=q,
mode="text"
)
print(f" Q: {q}")
print(f" A: {output.text}")
# ========================================================
# Mode 4: POINT (Counting)
# ========================================================
print_section("๐Ÿ“ POINT MODE - Object Counting")
output = model.generate(
image=image,
prompt="Find objects",
mode="point"
)
print(f" Detected {len(output.points)} points")
for i, (pt, label, conf) in enumerate(zip(
output.points[:5],
output.labels[:5],
output.confidences[:5]
)):
print(f" Point {i+1}: {pt} (class={label}, conf={conf:.2f})")
# ========================================================
# Mode 5: BOX (Detection)
# ========================================================
print_section("๐Ÿ“ฆ BOX MODE - Object Detection")
output = model.generate(
image=image,
prompt="Detect all objects",
mode="box"
)
print(f" Detected {len(output.boxes)} boxes")
for i, (box, label, conf) in enumerate(zip(
output.boxes[:5],
output.labels[:5],
output.confidences[:5]
)):
print(f" Box {i+1}: {[f'{b:.2f}' for b in box]} (class={label}, conf={conf:.2f})")
# ========================================================
# Mode 6: POLYGON (Segmentation)
# ========================================================
print_section("๐Ÿ”ท POLYGON MODE - Segmentation")
output = model.generate(
image=image,
prompt="Segment the scene",
mode="polygon"
)
print(f" Segmentation mask shape: {output.mask.shape if output.mask is not None else 'N/A'}")
print(f" Detected {len(output.polygons)} regions")
for i, (poly, label) in enumerate(zip(
output.polygons[:3],
output.labels[:3]
)):
print(f" Region {i+1}: class={label}, vertices={len(poly)}")
print("\n โœ… All modes successful!")
except Exception as e:
print(f"\n โŒ Error: {e}")
import traceback
traceback.print_exc()
# ================================================================
# Summary
# ================================================================
print_header("DEMO COMPLETE")
print("""
Oculus 0.2 supports:
๐Ÿ“ TEXT MODE
- Image captioning
- Visual question answering
- With optional reasoning traces
๐Ÿ“ POINT MODE
- Object counting
- Point localization
๐Ÿ“ฆ BOX MODE
- Object detection
- Bounding box prediction
๐Ÿ”ท POLYGON MODE
- Semantic segmentation
- Instance segmentation
๐Ÿง  REASONING
- Optional thinking traces
- Multi-step reasoning
๐Ÿ” FOCUS SYSTEM
- Zoom & crop for fine-grained perception
- Automatic region detection
Usage:
```python
from oculus_unified_model import OculusForConditionalGeneration
model = OculusForConditionalGeneration.from_pretrained("./checkpoints/oculus_coco/final")
# Caption
output = model.generate(image, mode="text", prompt="Describe this")
# VQA with reasoning
output = model.generate(image, mode="text", prompt="What color is it?", think=True)
# Detection
output = model.generate(image, mode="box", prompt="Find cars")
# Segmentation
output = model.generate(image, mode="polygon")
```
""")
if __name__ == "__main__":
demo()