Oculus / oculus_inference.py
kobiakor15's picture
Upload oculus_inference.py with huggingface_hub
52ac305 verified
raw
history blame
3.7 kB
import torch
import requests
from PIL import Image
from io import BytesIO
from pathlib import Path
from typing import Union, List, Dict, Any
import sys
# Ensure Oculus root is in path
OCULUS_ROOT = Path(__file__).parent
sys.path.insert(0, str(OCULUS_ROOT))
try:
from oculus_unified_model import OculusForConditionalGeneration
except ImportError:
# Attempt absolute import if relative fails
from Oculus.oculus_unified_model import OculusForConditionalGeneration
class OculusPredictor:
"""
Easy-to-use interface for the Oculus Unified Model.
Supports Object Detection, VQA, and Captioning.
"""
def __init__(self, model_path: str = None, device: str = "cpu"):
self.device = device
# Auto-discover latest model if not provided
if model_path is None:
base_dir = OCULUS_ROOT / "checkpoints" / "oculus_detection_v2"
if (base_dir / "final").exists():
model_path = str(base_dir / "final")
else:
# Fallback to V1
model_path = str(OCULUS_ROOT / "checkpoints" / "oculus_detection" / "final")
print(f"Loading Oculus model from: {model_path}")
self.model = OculusForConditionalGeneration.from_pretrained(model_path)
# Load detection heads
heads_path = Path(model_path) / "heads.pth"
if heads_path.exists():
heads = torch.load(heads_path, map_location=device)
self.model.detection_head.load_state_dict(heads['detection'])
print("✓ Detection heads loaded")
# Load instruction-tuned VQA model if available
instruct_path = OCULUS_ROOT / "checkpoints" / "oculus_instruct_v1" / "vqa_model"
if instruct_path.exists():
from transformers import BlipForQuestionAnswering
self.model.lm_vqa_model = BlipForQuestionAnswering.from_pretrained(instruct_path)
print("✓ Instruction-tuned VQA model loaded")
print("✓ Model loaded successfully")
def load_image(self, image_source: Union[str, Image.Image]) -> Image.Image:
"""Load image from path, URL, or PIL object."""
if isinstance(image_source, Image.Image):
return image_source.convert("RGB")
if image_source.startswith("http"):
response = requests.get(image_source, headers={'User-Agent': 'Mozilla/5.0'})
return Image.open(BytesIO(response.content)).convert("RGB")
return Image.open(image_source).convert("RGB")
def detect(self, image_source: Union[str, Image.Image], prompt: str = "Detect objects", threshold: float = 0.2) -> Dict[str, Any]:
"""
Run object detection.
Returns: {'boxes': [[x1,y1,x2,y2], ...], 'labels': [...], 'confidences': [...]}
"""
image = self.load_image(image_source)
output = self.model.generate(image, mode="box", prompt=prompt, threshold=threshold)
# Convert to python friendly format
return {
'boxes': output.boxes, # Normalized [0-1]
'labels': output.labels,
'confidences': output.confidences,
'image_size': image.size
}
def ask(self, image_source: Union[str, Image.Image], question: str) -> str:
"""Ask a question about the image (VQA)."""
image = self.load_image(image_source)
output = self.model.generate(image, mode="text", prompt=question)
return output.text
def caption(self, image_source: Union[str, Image.Image]) -> str:
"""Generate a caption for the image."""
return self.ask(image_source, "A photo of")