import torch import requests from PIL import Image from io import BytesIO from pathlib import Path from typing import Union, List, Dict, Any import sys # Ensure Oculus root is in path OCULUS_ROOT = Path(__file__).parent sys.path.insert(0, str(OCULUS_ROOT)) try: from oculus_unified_model import OculusForConditionalGeneration except ImportError: # Attempt absolute import if relative fails from Oculus.oculus_unified_model import OculusForConditionalGeneration class OculusPredictor: """ Easy-to-use interface for the Oculus Unified Model. Supports Object Detection, VQA, and Captioning. """ def __init__(self, model_path: str = None, device: str = "cpu"): self.device = device # Auto-discover latest model if not provided if model_path is None: base_dir = OCULUS_ROOT / "checkpoints" / "oculus_detection_v2" if (base_dir / "final").exists(): model_path = str(base_dir / "final") else: # Fallback to V1 model_path = str(OCULUS_ROOT / "checkpoints" / "oculus_detection" / "final") print(f"Loading Oculus model from: {model_path}") self.model = OculusForConditionalGeneration.from_pretrained(model_path) # Load detection heads heads_path = Path(model_path) / "heads.pth" if heads_path.exists(): heads = torch.load(heads_path, map_location=device) self.model.detection_head.load_state_dict(heads['detection']) print("✓ Detection heads loaded") # Load instruction-tuned VQA model if available instruct_path = OCULUS_ROOT / "checkpoints" / "oculus_instruct_v1" / "vqa_model" if instruct_path.exists(): from transformers import BlipForQuestionAnswering self.model.lm_vqa_model = BlipForQuestionAnswering.from_pretrained(instruct_path) print("✓ Instruction-tuned VQA model loaded") print("✓ Model loaded successfully") def load_image(self, image_source: Union[str, Image.Image]) -> Image.Image: """Load image from path, URL, or PIL object.""" if isinstance(image_source, Image.Image): return image_source.convert("RGB") if image_source.startswith("http"): response = requests.get(image_source, headers={'User-Agent': 'Mozilla/5.0'}) return Image.open(BytesIO(response.content)).convert("RGB") return Image.open(image_source).convert("RGB") def detect(self, image_source: Union[str, Image.Image], prompt: str = "Detect objects", threshold: float = 0.2) -> Dict[str, Any]: """ Run object detection. Returns: {'boxes': [[x1,y1,x2,y2], ...], 'labels': [...], 'confidences': [...]} """ image = self.load_image(image_source) output = self.model.generate(image, mode="box", prompt=prompt, threshold=threshold) # Convert to python friendly format return { 'boxes': output.boxes, # Normalized [0-1] 'labels': output.labels, 'confidences': output.confidences, 'image_size': image.size } def ask(self, image_source: Union[str, Image.Image], question: str) -> str: """Ask a question about the image (VQA).""" image = self.load_image(image_source) output = self.model.generate(image, mode="text", prompt=question) return output.text def caption(self, image_source: Union[str, Image.Image]) -> str: """Generate a caption for the image.""" return self.ask(image_source, "A photo of")