File size: 3,697 Bytes
52ac305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93

import torch
import requests
from PIL import Image
from io import BytesIO
from pathlib import Path
from typing import Union, List, Dict, Any
import sys

# Ensure Oculus root is in path
OCULUS_ROOT = Path(__file__).parent
sys.path.insert(0, str(OCULUS_ROOT))

try:
    from oculus_unified_model import OculusForConditionalGeneration
except ImportError:
    # Attempt absolute import if relative fails
    from Oculus.oculus_unified_model import OculusForConditionalGeneration

class OculusPredictor:
    """
    Easy-to-use interface for the Oculus Unified Model.
    Supports Object Detection, VQA, and Captioning.
    """
    
    def __init__(self, model_path: str = None, device: str = "cpu"):
        self.device = device
        
        # Auto-discover latest model if not provided
        if model_path is None:
            base_dir = OCULUS_ROOT / "checkpoints" / "oculus_detection_v2"
            if (base_dir / "final").exists():
                model_path = str(base_dir / "final")
            else:
                # Fallback to V1
                model_path = str(OCULUS_ROOT / "checkpoints" / "oculus_detection" / "final")
                
        print(f"Loading Oculus model from: {model_path}")
        self.model = OculusForConditionalGeneration.from_pretrained(model_path)
        
        # Load detection heads
        heads_path = Path(model_path) / "heads.pth"
        if heads_path.exists():
            heads = torch.load(heads_path, map_location=device)
            self.model.detection_head.load_state_dict(heads['detection'])
            print("✓ Detection heads loaded")
        
        # Load instruction-tuned VQA model if available
        instruct_path = OCULUS_ROOT / "checkpoints" / "oculus_instruct_v1" / "vqa_model"
        if instruct_path.exists():
            from transformers import BlipForQuestionAnswering
            self.model.lm_vqa_model = BlipForQuestionAnswering.from_pretrained(instruct_path)
            print("✓ Instruction-tuned VQA model loaded")
        
        print("✓ Model loaded successfully")

    def load_image(self, image_source: Union[str, Image.Image]) -> Image.Image:
        """Load image from path, URL, or PIL object."""
        if isinstance(image_source, Image.Image):
            return image_source.convert("RGB")
        
        if image_source.startswith("http"):
            response = requests.get(image_source, headers={'User-Agent': 'Mozilla/5.0'})
            return Image.open(BytesIO(response.content)).convert("RGB")
        
        return Image.open(image_source).convert("RGB")

    def detect(self, image_source: Union[str, Image.Image], prompt: str = "Detect objects", threshold: float = 0.2) -> Dict[str, Any]:
        """
        Run object detection.
        Returns: {'boxes': [[x1,y1,x2,y2], ...], 'labels': [...], 'confidences': [...]}
        """
        image = self.load_image(image_source)
        output = self.model.generate(image, mode="box", prompt=prompt, threshold=threshold)
        
        # Convert to python friendly format
        return {
            'boxes': output.boxes, # Normalized [0-1]
            'labels': output.labels,
            'confidences': output.confidences,
            'image_size': image.size
        }

    def ask(self, image_source: Union[str, Image.Image], question: str) -> str:
        """Ask a question about the image (VQA)."""
        image = self.load_image(image_source)
        output = self.model.generate(image, mode="text", prompt=question)
        return output.text

    def caption(self, image_source: Union[str, Image.Image]) -> str:
        """Generate a caption for the image."""
        return self.ask(image_source, "A photo of")