kobiakor15 commited on
Commit
52ac305
·
verified ·
1 Parent(s): ad39c92

Upload oculus_inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. oculus_inference.py +92 -0
oculus_inference.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import requests
4
+ from PIL import Image
5
+ from io import BytesIO
6
+ from pathlib import Path
7
+ from typing import Union, List, Dict, Any
8
+ import sys
9
+
10
+ # Ensure Oculus root is in path
11
+ OCULUS_ROOT = Path(__file__).parent
12
+ sys.path.insert(0, str(OCULUS_ROOT))
13
+
14
+ try:
15
+ from oculus_unified_model import OculusForConditionalGeneration
16
+ except ImportError:
17
+ # Attempt absolute import if relative fails
18
+ from Oculus.oculus_unified_model import OculusForConditionalGeneration
19
+
20
+ class OculusPredictor:
21
+ """
22
+ Easy-to-use interface for the Oculus Unified Model.
23
+ Supports Object Detection, VQA, and Captioning.
24
+ """
25
+
26
+ def __init__(self, model_path: str = None, device: str = "cpu"):
27
+ self.device = device
28
+
29
+ # Auto-discover latest model if not provided
30
+ if model_path is None:
31
+ base_dir = OCULUS_ROOT / "checkpoints" / "oculus_detection_v2"
32
+ if (base_dir / "final").exists():
33
+ model_path = str(base_dir / "final")
34
+ else:
35
+ # Fallback to V1
36
+ model_path = str(OCULUS_ROOT / "checkpoints" / "oculus_detection" / "final")
37
+
38
+ print(f"Loading Oculus model from: {model_path}")
39
+ self.model = OculusForConditionalGeneration.from_pretrained(model_path)
40
+
41
+ # Load detection heads
42
+ heads_path = Path(model_path) / "heads.pth"
43
+ if heads_path.exists():
44
+ heads = torch.load(heads_path, map_location=device)
45
+ self.model.detection_head.load_state_dict(heads['detection'])
46
+ print("✓ Detection heads loaded")
47
+
48
+ # Load instruction-tuned VQA model if available
49
+ instruct_path = OCULUS_ROOT / "checkpoints" / "oculus_instruct_v1" / "vqa_model"
50
+ if instruct_path.exists():
51
+ from transformers import BlipForQuestionAnswering
52
+ self.model.lm_vqa_model = BlipForQuestionAnswering.from_pretrained(instruct_path)
53
+ print("✓ Instruction-tuned VQA model loaded")
54
+
55
+ print("✓ Model loaded successfully")
56
+
57
+ def load_image(self, image_source: Union[str, Image.Image]) -> Image.Image:
58
+ """Load image from path, URL, or PIL object."""
59
+ if isinstance(image_source, Image.Image):
60
+ return image_source.convert("RGB")
61
+
62
+ if image_source.startswith("http"):
63
+ response = requests.get(image_source, headers={'User-Agent': 'Mozilla/5.0'})
64
+ return Image.open(BytesIO(response.content)).convert("RGB")
65
+
66
+ return Image.open(image_source).convert("RGB")
67
+
68
+ def detect(self, image_source: Union[str, Image.Image], prompt: str = "Detect objects", threshold: float = 0.2) -> Dict[str, Any]:
69
+ """
70
+ Run object detection.
71
+ Returns: {'boxes': [[x1,y1,x2,y2], ...], 'labels': [...], 'confidences': [...]}
72
+ """
73
+ image = self.load_image(image_source)
74
+ output = self.model.generate(image, mode="box", prompt=prompt, threshold=threshold)
75
+
76
+ # Convert to python friendly format
77
+ return {
78
+ 'boxes': output.boxes, # Normalized [0-1]
79
+ 'labels': output.labels,
80
+ 'confidences': output.confidences,
81
+ 'image_size': image.size
82
+ }
83
+
84
+ def ask(self, image_source: Union[str, Image.Image], question: str) -> str:
85
+ """Ask a question about the image (VQA)."""
86
+ image = self.load_image(image_source)
87
+ output = self.model.generate(image, mode="text", prompt=question)
88
+ return output.text
89
+
90
+ def caption(self, image_source: Union[str, Image.Image]) -> str:
91
+ """Generate a caption for the image."""
92
+ return self.ask(image_source, "A photo of")