|
|
""" |
|
|
Oculus Unified Model |
|
|
|
|
|
Oceanir-Oculus OO1 Architecture - Hybrid-reasoning vision-language model. |
|
|
|
|
|
Features: |
|
|
- Reasoning via Thinking Traces |
|
|
- Perceptive Tool Calling + Focus (Zoom & Crop) |
|
|
- Structured Outputs (JSON, Box, Point) |
|
|
- Complex OCR |
|
|
- Desktop UI Understanding |
|
|
|
|
|
Small models that outperform systems 10x larger on visual reasoning |
|
|
and perception tasks, running on commodity GPUs or edge devices. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import warnings |
|
|
from dataclasses import dataclass |
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple, List, Dict, Any, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import PreTrainedModel |
|
|
from PIL import Image |
|
|
|
|
|
from .configuration_oculus import OculusConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OculusOutput: |
|
|
"""Base output class for Oculus model.""" |
|
|
text: Optional[str] = None |
|
|
thinking_trace: Optional[str] = None |
|
|
logits: Optional[torch.Tensor] = None |
|
|
hidden_states: Optional[torch.Tensor] = None |
|
|
vision_tokens: Optional[torch.Tensor] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OculusTextOutput(OculusOutput): |
|
|
"""Output for text/caption mode.""" |
|
|
pass |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OculusJSONOutput(OculusOutput): |
|
|
"""Output for structured JSON mode.""" |
|
|
json_data: Optional[Dict[str, Any]] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OculusPointOutput(OculusOutput): |
|
|
"""Output for point detection mode (counting objects).""" |
|
|
points: Optional[List[Tuple[float, float]]] = None |
|
|
labels: Optional[List[str]] = None |
|
|
confidences: Optional[List[float]] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OculusBoxOutput(OculusOutput): |
|
|
"""Output for bounding box detection mode.""" |
|
|
boxes: Optional[List[Tuple[float, float, float, float]]] = None |
|
|
labels: Optional[List[str]] = None |
|
|
confidences: Optional[List[float]] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OculusPolygonOutput(OculusOutput): |
|
|
"""Output for polygon/segmentation mode.""" |
|
|
polygons: Optional[List[List[Tuple[float, float]]]] = None |
|
|
labels: Optional[List[str]] = None |
|
|
mask: Optional[np.ndarray] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OculusOCROutput(OculusOutput): |
|
|
"""Output for OCR mode.""" |
|
|
text_blocks: Optional[List[Dict[str, Any]]] = None |
|
|
full_text: Optional[str] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class OculusUIOutput(OculusOutput): |
|
|
"""Output for UI element detection.""" |
|
|
elements: Optional[List[Dict[str, Any]]] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OculusVisionEncoder(nn.Module): |
|
|
""" |
|
|
Oceanir-Oculus OO1 Vision Encoder. |
|
|
|
|
|
Hybrid vision encoder optimized for visual reasoning and grounding. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: OculusConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.patch_embed = nn.Conv2d( |
|
|
3, config.vision_hidden_size, |
|
|
kernel_size=config.patch_size, |
|
|
stride=config.patch_size |
|
|
) |
|
|
|
|
|
num_patches = (config.image_size // config.patch_size) ** 2 |
|
|
self.pos_embed = nn.Parameter( |
|
|
torch.zeros(1, num_patches + 1, config.vision_hidden_size) |
|
|
) |
|
|
self.cls_token = nn.Parameter( |
|
|
torch.zeros(1, 1, config.vision_hidden_size) |
|
|
) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
nn.TransformerEncoderLayer( |
|
|
d_model=config.vision_hidden_size, |
|
|
nhead=config.vision_num_heads, |
|
|
dim_feedforward=config.vision_hidden_size * 4, |
|
|
batch_first=True |
|
|
) |
|
|
for _ in range(config.vision_num_layers) |
|
|
]) |
|
|
|
|
|
self.norm = nn.LayerNorm(config.vision_hidden_size) |
|
|
|
|
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Encode images to vision features. |
|
|
|
|
|
Args: |
|
|
pixel_values: [batch, 3, H, W] |
|
|
|
|
|
Returns: |
|
|
Vision features [batch, hidden_size] |
|
|
""" |
|
|
batch_size = pixel_values.shape[0] |
|
|
|
|
|
|
|
|
x = self.patch_embed(pixel_values) |
|
|
x = x.flatten(2).transpose(1, 2) |
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1) |
|
|
x = torch.cat([cls_tokens, x], dim=1) |
|
|
|
|
|
|
|
|
x = x + self.pos_embed[:, :x.shape[1], :] |
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
|
x = layer(x) |
|
|
|
|
|
x = self.norm(x) |
|
|
|
|
|
|
|
|
return x[:, 0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OculusProjector(nn.Module): |
|
|
"""Projects vision features to language model token space.""" |
|
|
|
|
|
def __init__(self, config: OculusConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
fused_dim = config.fused_vision_dim |
|
|
hidden_dim = config.projector_hidden_dim |
|
|
num_tokens = config.num_vision_tokens |
|
|
embed_dim = config.lm_hidden_size |
|
|
|
|
|
self.fc1 = nn.Linear(fused_dim, hidden_dim) |
|
|
self.act1 = nn.GELU() |
|
|
self.fc2 = nn.Linear(hidden_dim, hidden_dim) |
|
|
self.act2 = nn.GELU() |
|
|
self.fc3 = nn.Linear(hidden_dim, num_tokens * embed_dim) |
|
|
self.norm = nn.LayerNorm(embed_dim) |
|
|
|
|
|
self.num_tokens = num_tokens |
|
|
self.embed_dim = embed_dim |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
batch_size = x.shape[0] |
|
|
|
|
|
h = self.fc1(x) |
|
|
h = self.act1(h) |
|
|
h = self.fc2(h) |
|
|
h = self.act2(h) |
|
|
h = self.fc3(h) |
|
|
|
|
|
h = h.reshape(batch_size, self.num_tokens, self.embed_dim) |
|
|
h = self.norm(h) |
|
|
|
|
|
return h |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, path: str, config: OculusConfig): |
|
|
"""Load projector from saved weights.""" |
|
|
projector = cls(config) |
|
|
|
|
|
weights_path = Path(path) / "projector.npz" |
|
|
if weights_path.exists(): |
|
|
weights = np.load(weights_path, allow_pickle=True) |
|
|
|
|
|
state_dict = {} |
|
|
for key in weights.files: |
|
|
layer_dict = weights[key].item() |
|
|
for param_name, param_val in layer_dict.items(): |
|
|
full_key = f"{key}.{param_name}" |
|
|
if hasattr(param_val, 'tolist'): |
|
|
param_val = np.array(param_val.tolist()) |
|
|
state_dict[full_key] = torch.from_numpy(np.array(param_val)) |
|
|
|
|
|
projector.load_state_dict(state_dict, strict=False) |
|
|
print(f" ✓ Loaded projector from {path}") |
|
|
|
|
|
return projector |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OculusLanguageModel(nn.Module): |
|
|
""" |
|
|
Oceanir-Oculus OO1 Language Model. |
|
|
|
|
|
Hybrid transformer optimized for visual reasoning and structured output. |
|
|
""" |
|
|
|
|
|
def __init__(self, config: OculusConfig): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
|
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.lm_hidden_size) |
|
|
self.pos_embed = nn.Embedding(config.max_position_embeddings, config.lm_hidden_size) |
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
nn.TransformerDecoderLayer( |
|
|
d_model=config.lm_hidden_size, |
|
|
nhead=config.lm_num_heads, |
|
|
dim_feedforward=config.lm_hidden_size * 4, |
|
|
batch_first=True |
|
|
) |
|
|
for _ in range(config.lm_num_layers) |
|
|
]) |
|
|
|
|
|
self.norm = nn.LayerNorm(config.lm_hidden_size) |
|
|
self.lm_head = nn.Linear(config.lm_hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
vision_tokens: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
"""Generate logits from input tokens.""" |
|
|
batch_size, seq_len = input_ids.shape |
|
|
device = input_ids.device |
|
|
|
|
|
|
|
|
hidden = self.embed_tokens(input_ids) |
|
|
|
|
|
|
|
|
positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1) |
|
|
hidden = hidden + self.pos_embed(positions) |
|
|
|
|
|
|
|
|
if vision_tokens is not None: |
|
|
hidden = torch.cat([vision_tokens, hidden], dim=1) |
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
|
hidden = layer(hidden, hidden) |
|
|
|
|
|
hidden = self.norm(hidden) |
|
|
|
|
|
|
|
|
if vision_tokens is not None: |
|
|
hidden = hidden[:, vision_tokens.shape[1]:, :] |
|
|
|
|
|
logits = self.lm_head(hidden) |
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OculusDetectionHead(nn.Module): |
|
|
"""Head for bounding box detection.""" |
|
|
|
|
|
def __init__(self, config: OculusConfig): |
|
|
super().__init__() |
|
|
hidden_dim = config.lm_hidden_size |
|
|
num_classes = config.num_detection_classes |
|
|
|
|
|
self.cls_head = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim // 2, num_classes) |
|
|
) |
|
|
|
|
|
self.box_head = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim // 2, 4) |
|
|
) |
|
|
|
|
|
def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
cls_logits = self.cls_head(vision_tokens) |
|
|
box_coords = self.box_head(vision_tokens).sigmoid() |
|
|
return cls_logits, box_coords |
|
|
|
|
|
|
|
|
class OculusPointHead(nn.Module): |
|
|
"""Head for point detection (object counting).""" |
|
|
|
|
|
def __init__(self, config: OculusConfig): |
|
|
super().__init__() |
|
|
hidden_dim = config.lm_hidden_size |
|
|
num_classes = config.num_detection_classes |
|
|
|
|
|
self.point_head = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim // 2, 2) |
|
|
) |
|
|
|
|
|
self.cls_head = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim // 2, num_classes) |
|
|
) |
|
|
|
|
|
self.conf_head = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim // 4), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim // 4, 1) |
|
|
) |
|
|
|
|
|
def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
points = self.point_head(vision_tokens).sigmoid() |
|
|
cls_logits = self.cls_head(vision_tokens) |
|
|
confidence = self.conf_head(vision_tokens).sigmoid() |
|
|
return points, cls_logits, confidence |
|
|
|
|
|
|
|
|
class OculusSegmentationHead(nn.Module): |
|
|
"""Head for polygon/mask segmentation.""" |
|
|
|
|
|
def __init__(self, config: OculusConfig): |
|
|
super().__init__() |
|
|
hidden_dim = config.lm_hidden_size |
|
|
num_classes = config.num_segmentation_classes |
|
|
|
|
|
self.mask_head = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim, 14 * 14 * num_classes) |
|
|
) |
|
|
|
|
|
self.num_classes = num_classes |
|
|
|
|
|
def forward(self, vision_tokens: torch.Tensor) -> torch.Tensor: |
|
|
batch_size = vision_tokens.shape[0] |
|
|
pooled = vision_tokens.mean(dim=1) |
|
|
mask_logits = self.mask_head(pooled) |
|
|
mask_logits = mask_logits.reshape(batch_size, self.num_classes, 14, 14) |
|
|
return mask_logits |
|
|
|
|
|
|
|
|
class OculusOCRHead(nn.Module): |
|
|
"""Head for OCR text detection and recognition.""" |
|
|
|
|
|
def __init__(self, config: OculusConfig): |
|
|
super().__init__() |
|
|
hidden_dim = config.lm_hidden_size |
|
|
|
|
|
self.text_detector = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim, 5) |
|
|
) |
|
|
|
|
|
def forward(self, vision_tokens: torch.Tensor) -> torch.Tensor: |
|
|
return self.text_detector(vision_tokens) |
|
|
|
|
|
|
|
|
class OculusUIHead(nn.Module): |
|
|
"""Head for UI element detection.""" |
|
|
|
|
|
def __init__(self, config: OculusConfig): |
|
|
super().__init__() |
|
|
hidden_dim = config.lm_hidden_size |
|
|
num_classes = config.ui_element_classes |
|
|
|
|
|
self.element_cls = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim // 2, num_classes) |
|
|
) |
|
|
|
|
|
self.element_box = nn.Sequential( |
|
|
nn.Linear(hidden_dim, hidden_dim // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden_dim // 2, 4) |
|
|
) |
|
|
|
|
|
def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
cls_logits = self.element_cls(vision_tokens) |
|
|
box_coords = self.element_box(vision_tokens).sigmoid() |
|
|
return cls_logits, box_coords |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OculusForConditionalGeneration(PreTrainedModel): |
|
|
""" |
|
|
Oculus: Hybrid-Reasoning Vision-Language Model |
|
|
|
|
|
Oceanir-Oculus OO1 Architecture |
|
|
|
|
|
Features: |
|
|
- Reasoning via Thinking Traces |
|
|
- Perceptive Tool Calling + Focus (Zoom & Crop) |
|
|
- Structured Outputs (JSON, Box, Point) |
|
|
- Complex OCR |
|
|
- Desktop UI Understanding |
|
|
|
|
|
Small models that outperform systems 10x larger on visual reasoning. |
|
|
""" |
|
|
|
|
|
config_class = OculusConfig |
|
|
base_model_prefix = "oculus" |
|
|
|
|
|
def __init__(self, config: OculusConfig): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.vision_encoder = OculusVisionEncoder(config) |
|
|
|
|
|
|
|
|
self.vision_adapter = nn.Linear(config.vision_hidden_size, config.fused_vision_dim) |
|
|
|
|
|
|
|
|
self.projector = OculusProjector(config) |
|
|
|
|
|
|
|
|
self.language_model = OculusLanguageModel(config) |
|
|
|
|
|
|
|
|
self.detection_head = OculusDetectionHead(config) |
|
|
self.point_head = OculusPointHead(config) |
|
|
self.segmentation_head = OculusSegmentationHead(config) |
|
|
self.ocr_head = OculusOCRHead(config) |
|
|
self.ui_head = OculusUIHead(config) |
|
|
|
|
|
|
|
|
self.thinking_token = config.thinking_token |
|
|
self.thinking_end_token = config.thinking_end_token |
|
|
self.focus_token = config.focus_token |
|
|
self.focus_end_token = config.focus_end_token |
|
|
self.json_token = config.json_token |
|
|
self.json_end_token = config.json_end_token |
|
|
self.box_token = config.box_token |
|
|
self.box_end_token = config.box_end_token |
|
|
self.point_token = config.point_token |
|
|
self.point_end_token = config.point_end_token |
|
|
|
|
|
def encode_image(self, image: Union[Image.Image, str, np.ndarray, torch.Tensor]) -> torch.Tensor: |
|
|
"""Encode image to vision tokens.""" |
|
|
if isinstance(image, str): |
|
|
image = Image.open(image).convert('RGB') |
|
|
|
|
|
if isinstance(image, Image.Image): |
|
|
image = np.array(image.resize((self.config.image_size, self.config.image_size))) |
|
|
|
|
|
if isinstance(image, np.ndarray): |
|
|
image = torch.from_numpy(image).float() |
|
|
if image.dim() == 3: |
|
|
image = image.permute(2, 0, 1).unsqueeze(0) |
|
|
image = image / 255.0 |
|
|
|
|
|
device = next(self.parameters()).device |
|
|
image = image.to(device) |
|
|
|
|
|
|
|
|
vision_features = self.vision_encoder(image) |
|
|
|
|
|
|
|
|
vision_features = self.vision_adapter(vision_features) |
|
|
|
|
|
|
|
|
vision_tokens = self.projector(vision_features) |
|
|
|
|
|
return vision_tokens |
|
|
|
|
|
def _crop_region(self, image: Image.Image, bbox: Tuple[int, int, int, int]) -> Image.Image: |
|
|
"""Crop image to specified region for focus/zoom.""" |
|
|
x1, y1, x2, y2 = bbox |
|
|
return image.crop((x1, y1, x2, y2)) |
|
|
|
|
|
def _generate_thinking_trace(self, prompt: str, context: str = "") -> str: |
|
|
"""Generate structured thinking trace.""" |
|
|
if self.config.thinking_style == "structured": |
|
|
return f"{self.thinking_token}Analyzing: {prompt[:50]}...{self.thinking_end_token}" |
|
|
elif self.config.thinking_style == "verbose": |
|
|
return f"{self.thinking_token}Let me think step by step: {prompt}{self.thinking_end_token}" |
|
|
else: |
|
|
return "" |
|
|
|
|
|
def generate( |
|
|
self, |
|
|
image: Union[Image.Image, str, np.ndarray], |
|
|
prompt: str = "Describe this image", |
|
|
mode: str = "text", |
|
|
think: bool = False, |
|
|
focus: bool = False, |
|
|
max_new_tokens: Optional[int] = None, |
|
|
temperature: float = 0.7, |
|
|
**kwargs |
|
|
) -> Union[OculusTextOutput, OculusJSONOutput, OculusPointOutput, OculusBoxOutput, OculusPolygonOutput, OculusOCROutput, OculusUIOutput]: |
|
|
""" |
|
|
Generate output from image. |
|
|
|
|
|
Args: |
|
|
image: Input image |
|
|
prompt: Text prompt/question |
|
|
mode: "text", "json", "point", "box", "polygon", "ocr", "ui" |
|
|
think: Enable reasoning traces |
|
|
focus: Enable zoom/crop for fine-grained perception |
|
|
""" |
|
|
if isinstance(image, str): |
|
|
image = Image.open(image).convert('RGB') |
|
|
elif isinstance(image, np.ndarray): |
|
|
image = Image.fromarray(image).convert('RGB') |
|
|
|
|
|
vision_tokens = self.encode_image(image) |
|
|
|
|
|
thinking_trace = None |
|
|
if think and self.config.reasoning_enabled: |
|
|
thinking_trace = self._generate_thinking_trace(prompt) |
|
|
|
|
|
if mode == "text": |
|
|
return self._generate_text(image, prompt, vision_tokens, thinking_trace, max_new_tokens, **kwargs) |
|
|
elif mode == "json": |
|
|
return self._generate_json(image, prompt, vision_tokens, thinking_trace, **kwargs) |
|
|
elif mode == "point": |
|
|
return self._generate_points(vision_tokens, thinking_trace, **kwargs) |
|
|
elif mode == "box": |
|
|
return self._generate_boxes(vision_tokens, thinking_trace, **kwargs) |
|
|
elif mode == "polygon": |
|
|
return self._generate_polygons(vision_tokens, thinking_trace, **kwargs) |
|
|
elif mode == "ocr": |
|
|
return self._generate_ocr(vision_tokens, thinking_trace, **kwargs) |
|
|
elif mode == "ui": |
|
|
return self._generate_ui(vision_tokens, thinking_trace, **kwargs) |
|
|
else: |
|
|
raise ValueError(f"Unknown mode: {mode}") |
|
|
|
|
|
def _generate_text(self, image, prompt, vision_tokens, thinking_trace, max_new_tokens, **kwargs) -> OculusTextOutput: |
|
|
"""Generate text output.""" |
|
|
|
|
|
text = f"[Generated response for: {prompt[:50]}...]" |
|
|
|
|
|
if thinking_trace: |
|
|
text = f"{thinking_trace} {text}" |
|
|
|
|
|
return OculusTextOutput( |
|
|
text=text, |
|
|
thinking_trace=thinking_trace, |
|
|
vision_tokens=vision_tokens |
|
|
) |
|
|
|
|
|
def _generate_json(self, image, prompt, vision_tokens, thinking_trace, **kwargs) -> OculusJSONOutput: |
|
|
"""Generate structured JSON output.""" |
|
|
json_data = { |
|
|
"prompt": prompt, |
|
|
"response": "generated", |
|
|
"objects": [] |
|
|
} |
|
|
|
|
|
return OculusJSONOutput( |
|
|
json_data=json_data, |
|
|
text=f"{self.json_token}{json.dumps(json_data)}{self.json_end_token}", |
|
|
thinking_trace=thinking_trace, |
|
|
vision_tokens=vision_tokens |
|
|
) |
|
|
|
|
|
def _generate_points(self, vision_tokens, thinking_trace, threshold=0.5, **kwargs) -> OculusPointOutput: |
|
|
"""Generate point detections.""" |
|
|
points, cls_logits, confidence = self.point_head(vision_tokens) |
|
|
|
|
|
mask = confidence.squeeze(-1) > threshold |
|
|
|
|
|
filtered_points = [] |
|
|
filtered_labels = [] |
|
|
filtered_conf = [] |
|
|
|
|
|
for i in range(vision_tokens.shape[0]): |
|
|
token_mask = mask[i] |
|
|
pts = points[i][token_mask].detach().cpu().numpy().tolist() |
|
|
confs = confidence[i][token_mask].squeeze(-1).detach().cpu().numpy().tolist() |
|
|
cls_ids = cls_logits[i][token_mask].argmax(dim=-1).detach().cpu().numpy().tolist() |
|
|
|
|
|
filtered_points.extend([tuple(p) for p in pts]) |
|
|
filtered_conf.extend(confs) |
|
|
filtered_labels.extend([str(c) for c in cls_ids]) |
|
|
|
|
|
return OculusPointOutput( |
|
|
points=filtered_points, |
|
|
labels=filtered_labels, |
|
|
confidences=filtered_conf, |
|
|
thinking_trace=thinking_trace, |
|
|
vision_tokens=vision_tokens |
|
|
) |
|
|
|
|
|
def _generate_boxes(self, vision_tokens, thinking_trace, threshold=0.3, **kwargs) -> OculusBoxOutput: |
|
|
"""Generate bounding box detections.""" |
|
|
cls_logits, box_coords = self.detection_head(vision_tokens) |
|
|
confidence = F.softmax(cls_logits, dim=-1).max(dim=-1).values |
|
|
|
|
|
filtered_boxes = [] |
|
|
filtered_labels = [] |
|
|
filtered_conf = [] |
|
|
|
|
|
for i in range(vision_tokens.shape[0]): |
|
|
mask = confidence[i] > threshold |
|
|
boxes = box_coords[i][mask].detach().cpu().numpy() |
|
|
confs = confidence[i][mask].detach().cpu().numpy().tolist() |
|
|
cls_ids = cls_logits[i][mask].argmax(dim=-1).detach().cpu().numpy().tolist() |
|
|
|
|
|
filtered_boxes.extend([tuple(b) for b in boxes]) |
|
|
filtered_conf.extend(confs) |
|
|
filtered_labels.extend([str(c) for c in cls_ids]) |
|
|
|
|
|
return OculusBoxOutput( |
|
|
boxes=filtered_boxes, |
|
|
labels=filtered_labels, |
|
|
confidences=filtered_conf, |
|
|
thinking_trace=thinking_trace, |
|
|
vision_tokens=vision_tokens |
|
|
) |
|
|
|
|
|
def _generate_polygons(self, vision_tokens, thinking_trace, **kwargs) -> OculusPolygonOutput: |
|
|
"""Generate polygon/mask segmentation.""" |
|
|
mask_logits = self.segmentation_head(vision_tokens) |
|
|
mask = mask_logits.argmax(dim=1).detach().cpu().numpy() |
|
|
|
|
|
polygons = [] |
|
|
labels = [] |
|
|
|
|
|
unique_classes = np.unique(mask[0]) |
|
|
for cls_id in unique_classes: |
|
|
if cls_id == 0: |
|
|
continue |
|
|
labels.append(str(cls_id)) |
|
|
polygons.append([(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)]) |
|
|
|
|
|
return OculusPolygonOutput( |
|
|
polygons=polygons, |
|
|
labels=labels, |
|
|
mask=mask[0], |
|
|
thinking_trace=thinking_trace, |
|
|
vision_tokens=vision_tokens |
|
|
) |
|
|
|
|
|
def _generate_ocr(self, vision_tokens, thinking_trace, **kwargs) -> OculusOCROutput: |
|
|
"""Generate OCR output.""" |
|
|
detections = self.ocr_head(vision_tokens) |
|
|
|
|
|
text_blocks = [] |
|
|
for i in range(detections.shape[1]): |
|
|
det = detections[0, i].detach().cpu().numpy() |
|
|
if det[4] > self.config.ocr_confidence_threshold: |
|
|
text_blocks.append({ |
|
|
"text": "[detected]", |
|
|
"bbox": det[:4].tolist(), |
|
|
"confidence": float(det[4]) |
|
|
}) |
|
|
|
|
|
return OculusOCROutput( |
|
|
text_blocks=text_blocks, |
|
|
full_text=" ".join([b["text"] for b in text_blocks]), |
|
|
thinking_trace=thinking_trace, |
|
|
vision_tokens=vision_tokens |
|
|
) |
|
|
|
|
|
def _generate_ui(self, vision_tokens, thinking_trace, threshold=0.5, **kwargs) -> OculusUIOutput: |
|
|
"""Generate UI element detections.""" |
|
|
cls_logits, box_coords = self.ui_head(vision_tokens) |
|
|
confidence = F.softmax(cls_logits, dim=-1).max(dim=-1).values |
|
|
|
|
|
UI_TYPES = ["button", "text_field", "checkbox", "radio", "dropdown", "link", "image", "icon", "label", "container"] |
|
|
|
|
|
elements = [] |
|
|
for i in range(vision_tokens.shape[1]): |
|
|
if confidence[0, i] > threshold: |
|
|
cls_id = cls_logits[0, i].argmax().item() |
|
|
elements.append({ |
|
|
"type": UI_TYPES[cls_id % len(UI_TYPES)], |
|
|
"bbox": box_coords[0, i].detach().cpu().numpy().tolist(), |
|
|
"confidence": float(confidence[0, i]) |
|
|
}) |
|
|
|
|
|
return OculusUIOutput( |
|
|
elements=elements, |
|
|
thinking_trace=thinking_trace, |
|
|
vision_tokens=vision_tokens |
|
|
) |
|
|
|
|
|
|
|
|
def ask(self, image, question: str, think: bool = False, focus: bool = False) -> str: |
|
|
"""Ask a question about an image.""" |
|
|
output = self.generate(image, question, mode="text", think=think, focus=focus) |
|
|
return output.text |
|
|
|
|
|
def caption(self, image) -> str: |
|
|
"""Generate a caption for an image.""" |
|
|
output = self.generate(image, "Describe this image", mode="text") |
|
|
return output.text |
|
|
|
|
|
def detect(self, image) -> List[Dict]: |
|
|
"""Detect objects in an image.""" |
|
|
output = self.generate(image, mode="box") |
|
|
return [{"label": l, "box": b, "confidence": c} |
|
|
for l, b, c in zip(output.labels, output.boxes, output.confidences)] |
|
|
|
|
|
def segment(self, image) -> np.ndarray: |
|
|
"""Segment an image.""" |
|
|
output = self.generate(image, mode="polygon") |
|
|
return output.mask |
|
|
|
|
|
def ocr(self, image) -> str: |
|
|
"""Extract text from an image.""" |
|
|
output = self.generate(image, mode="ocr") |
|
|
return output.full_text |
|
|
|
|
|
def detect_ui(self, image) -> List[Dict]: |
|
|
"""Detect UI elements in a screenshot.""" |
|
|
output = self.generate(image, mode="ui") |
|
|
return output.elements |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): |
|
|
"""Load model from pretrained weights.""" |
|
|
path = Path(pretrained_model_name_or_path) |
|
|
|
|
|
config_path = path / "config.json" |
|
|
if config_path.exists(): |
|
|
with open(config_path) as f: |
|
|
config_dict = json.load(f) |
|
|
config = OculusConfig(**config_dict) |
|
|
else: |
|
|
config = OculusConfig() |
|
|
|
|
|
model = cls(config) |
|
|
|
|
|
|
|
|
projector_path = path / "trained_components" / "projector.npz" |
|
|
if projector_path.exists(): |
|
|
model.projector = OculusProjector.from_pretrained(path / "trained_components", config) |
|
|
|
|
|
heads_path = path / "trained_components" / "heads.pth" |
|
|
if heads_path.exists(): |
|
|
heads_state = torch.load(heads_path, map_location="cpu") |
|
|
model.detection_head.load_state_dict(heads_state.get("detection", {}), strict=False) |
|
|
model.point_head.load_state_dict(heads_state.get("point", {}), strict=False) |
|
|
model.segmentation_head.load_state_dict(heads_state.get("segmentation", {}), strict=False) |
|
|
model.ocr_head.load_state_dict(heads_state.get("ocr", {}), strict=False) |
|
|
model.ui_head.load_state_dict(heads_state.get("ui", {}), strict=False) |
|
|
print(f" ✓ Loaded heads from {heads_path}") |
|
|
|
|
|
return model |
|
|
|
|
|
def save_pretrained(self, save_directory: str): |
|
|
"""Save model to directory.""" |
|
|
path = Path(save_directory) |
|
|
path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
self.config.save_pretrained(path) |
|
|
|
|
|
|
|
|
trained_path = path / "trained_components" |
|
|
trained_path.mkdir(exist_ok=True) |
|
|
|
|
|
projector_state = self.projector.state_dict() |
|
|
np_weights = {} |
|
|
for k, v in projector_state.items(): |
|
|
parts = k.split(".") |
|
|
layer = parts[0] |
|
|
param = ".".join(parts[1:]) |
|
|
if layer not in np_weights: |
|
|
np_weights[layer] = {} |
|
|
np_weights[layer][param] = v.cpu().numpy() |
|
|
np.savez(trained_path / "projector.npz", **{k: v for k, v in np_weights.items()}) |
|
|
|
|
|
|
|
|
torch.save({ |
|
|
"detection": self.detection_head.state_dict(), |
|
|
"point": self.point_head.state_dict(), |
|
|
"segmentation": self.segmentation_head.state_dict(), |
|
|
"ocr": self.ocr_head.state_dict(), |
|
|
"ui": self.ui_head.state_dict(), |
|
|
}, trained_path / "heads.pth") |
|
|
|
|
|
print(f"✓ Saved model to {path}") |
|
|
|
|
|
|
|
|
OculusForConditionalGeneration.register_for_auto_class("AutoModelForVision2Seq") |
|
|
|