Oculus-0.1 / oculus_unified_model /modeling_oculus.py
kobiakor15's picture
Upload oculus_unified_model/modeling_oculus.py with huggingface_hub
2a37793 verified
"""
Oculus Unified Model
Oceanir-Oculus OO1 Architecture - Hybrid-reasoning vision-language model.
Features:
- Reasoning via Thinking Traces
- Perceptive Tool Calling + Focus (Zoom & Crop)
- Structured Outputs (JSON, Box, Point)
- Complex OCR
- Desktop UI Understanding
Small models that outperform systems 10x larger on visual reasoning
and perception tasks, running on commodity GPUs or edge devices.
"""
import os
import json
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, List, Dict, Any, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from PIL import Image
from .configuration_oculus import OculusConfig
# ============================================================================
# Output Data Classes
# ============================================================================
@dataclass
class OculusOutput:
"""Base output class for Oculus model."""
text: Optional[str] = None
thinking_trace: Optional[str] = None
logits: Optional[torch.Tensor] = None
hidden_states: Optional[torch.Tensor] = None
vision_tokens: Optional[torch.Tensor] = None
@dataclass
class OculusTextOutput(OculusOutput):
"""Output for text/caption mode."""
pass
@dataclass
class OculusJSONOutput(OculusOutput):
"""Output for structured JSON mode."""
json_data: Optional[Dict[str, Any]] = None
@dataclass
class OculusPointOutput(OculusOutput):
"""Output for point detection mode (counting objects)."""
points: Optional[List[Tuple[float, float]]] = None
labels: Optional[List[str]] = None
confidences: Optional[List[float]] = None
@dataclass
class OculusBoxOutput(OculusOutput):
"""Output for bounding box detection mode."""
boxes: Optional[List[Tuple[float, float, float, float]]] = None
labels: Optional[List[str]] = None
confidences: Optional[List[float]] = None
@dataclass
class OculusPolygonOutput(OculusOutput):
"""Output for polygon/segmentation mode."""
polygons: Optional[List[List[Tuple[float, float]]]] = None
labels: Optional[List[str]] = None
mask: Optional[np.ndarray] = None
@dataclass
class OculusOCROutput(OculusOutput):
"""Output for OCR mode."""
text_blocks: Optional[List[Dict[str, Any]]] = None
full_text: Optional[str] = None
@dataclass
class OculusUIOutput(OculusOutput):
"""Output for UI element detection."""
elements: Optional[List[Dict[str, Any]]] = None
# ============================================================================
# Vision Encoder
# ============================================================================
class OculusVisionEncoder(nn.Module):
"""
Oceanir-Oculus OO1 Vision Encoder.
Hybrid vision encoder optimized for visual reasoning and grounding.
"""
def __init__(self, config: OculusConfig):
super().__init__()
self.config = config
# Vision transformer components
self.patch_embed = nn.Conv2d(
3, config.vision_hidden_size,
kernel_size=config.patch_size,
stride=config.patch_size
)
num_patches = (config.image_size // config.patch_size) ** 2
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, config.vision_hidden_size)
)
self.cls_token = nn.Parameter(
torch.zeros(1, 1, config.vision_hidden_size)
)
# Transformer layers
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(
d_model=config.vision_hidden_size,
nhead=config.vision_num_heads,
dim_feedforward=config.vision_hidden_size * 4,
batch_first=True
)
for _ in range(config.vision_num_layers)
])
self.norm = nn.LayerNorm(config.vision_hidden_size)
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
"""
Encode images to vision features.
Args:
pixel_values: [batch, 3, H, W]
Returns:
Vision features [batch, hidden_size]
"""
batch_size = pixel_values.shape[0]
# Patch embedding
x = self.patch_embed(pixel_values)
x = x.flatten(2).transpose(1, 2)
# Add CLS token
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
# Add position embedding
x = x + self.pos_embed[:, :x.shape[1], :]
# Transformer layers
for layer in self.layers:
x = layer(x)
x = self.norm(x)
# Return CLS token
return x[:, 0]
# ============================================================================
# Vision Projector
# ============================================================================
class OculusProjector(nn.Module):
"""Projects vision features to language model token space."""
def __init__(self, config: OculusConfig):
super().__init__()
self.config = config
fused_dim = config.fused_vision_dim
hidden_dim = config.projector_hidden_dim
num_tokens = config.num_vision_tokens
embed_dim = config.lm_hidden_size
self.fc1 = nn.Linear(fused_dim, hidden_dim)
self.act1 = nn.GELU()
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.act2 = nn.GELU()
self.fc3 = nn.Linear(hidden_dim, num_tokens * embed_dim)
self.norm = nn.LayerNorm(embed_dim)
self.num_tokens = num_tokens
self.embed_dim = embed_dim
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size = x.shape[0]
h = self.fc1(x)
h = self.act1(h)
h = self.fc2(h)
h = self.act2(h)
h = self.fc3(h)
h = h.reshape(batch_size, self.num_tokens, self.embed_dim)
h = self.norm(h)
return h
@classmethod
def from_pretrained(cls, path: str, config: OculusConfig):
"""Load projector from saved weights."""
projector = cls(config)
weights_path = Path(path) / "projector.npz"
if weights_path.exists():
weights = np.load(weights_path, allow_pickle=True)
state_dict = {}
for key in weights.files:
layer_dict = weights[key].item()
for param_name, param_val in layer_dict.items():
full_key = f"{key}.{param_name}"
if hasattr(param_val, 'tolist'):
param_val = np.array(param_val.tolist())
state_dict[full_key] = torch.from_numpy(np.array(param_val))
projector.load_state_dict(state_dict, strict=False)
print(f" ✓ Loaded projector from {path}")
return projector
# ============================================================================
# Language Model
# ============================================================================
class OculusLanguageModel(nn.Module):
"""
Oceanir-Oculus OO1 Language Model.
Hybrid transformer optimized for visual reasoning and structured output.
"""
def __init__(self, config: OculusConfig):
super().__init__()
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.lm_hidden_size)
self.pos_embed = nn.Embedding(config.max_position_embeddings, config.lm_hidden_size)
self.layers = nn.ModuleList([
nn.TransformerDecoderLayer(
d_model=config.lm_hidden_size,
nhead=config.lm_num_heads,
dim_feedforward=config.lm_hidden_size * 4,
batch_first=True
)
for _ in range(config.lm_num_layers)
])
self.norm = nn.LayerNorm(config.lm_hidden_size)
self.lm_head = nn.Linear(config.lm_hidden_size, config.vocab_size, bias=False)
def forward(
self,
input_ids: torch.Tensor,
vision_tokens: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Generate logits from input tokens."""
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Token embeddings
hidden = self.embed_tokens(input_ids)
# Position embeddings
positions = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
hidden = hidden + self.pos_embed(positions)
# Prepend vision tokens if provided
if vision_tokens is not None:
hidden = torch.cat([vision_tokens, hidden], dim=1)
# Transformer layers
for layer in self.layers:
hidden = layer(hidden, hidden)
hidden = self.norm(hidden)
# Only return logits for text tokens
if vision_tokens is not None:
hidden = hidden[:, vision_tokens.shape[1]:, :]
logits = self.lm_head(hidden)
return logits
# ============================================================================
# Task Heads
# ============================================================================
class OculusDetectionHead(nn.Module):
"""Head for bounding box detection."""
def __init__(self, config: OculusConfig):
super().__init__()
hidden_dim = config.lm_hidden_size
num_classes = config.num_detection_classes
self.cls_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, num_classes)
)
self.box_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 4)
)
def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
cls_logits = self.cls_head(vision_tokens)
box_coords = self.box_head(vision_tokens).sigmoid()
return cls_logits, box_coords
class OculusPointHead(nn.Module):
"""Head for point detection (object counting)."""
def __init__(self, config: OculusConfig):
super().__init__()
hidden_dim = config.lm_hidden_size
num_classes = config.num_detection_classes
self.point_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 2)
)
self.cls_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, num_classes)
)
self.conf_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 4),
nn.GELU(),
nn.Linear(hidden_dim // 4, 1)
)
def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
points = self.point_head(vision_tokens).sigmoid()
cls_logits = self.cls_head(vision_tokens)
confidence = self.conf_head(vision_tokens).sigmoid()
return points, cls_logits, confidence
class OculusSegmentationHead(nn.Module):
"""Head for polygon/mask segmentation."""
def __init__(self, config: OculusConfig):
super().__init__()
hidden_dim = config.lm_hidden_size
num_classes = config.num_segmentation_classes
self.mask_head = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 14 * 14 * num_classes)
)
self.num_classes = num_classes
def forward(self, vision_tokens: torch.Tensor) -> torch.Tensor:
batch_size = vision_tokens.shape[0]
pooled = vision_tokens.mean(dim=1)
mask_logits = self.mask_head(pooled)
mask_logits = mask_logits.reshape(batch_size, self.num_classes, 14, 14)
return mask_logits
class OculusOCRHead(nn.Module):
"""Head for OCR text detection and recognition."""
def __init__(self, config: OculusConfig):
super().__init__()
hidden_dim = config.lm_hidden_size
self.text_detector = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, 5)
)
def forward(self, vision_tokens: torch.Tensor) -> torch.Tensor:
return self.text_detector(vision_tokens)
class OculusUIHead(nn.Module):
"""Head for UI element detection."""
def __init__(self, config: OculusConfig):
super().__init__()
hidden_dim = config.lm_hidden_size
num_classes = config.ui_element_classes
self.element_cls = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, num_classes)
)
self.element_box = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.GELU(),
nn.Linear(hidden_dim // 2, 4)
)
def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
cls_logits = self.element_cls(vision_tokens)
box_coords = self.element_box(vision_tokens).sigmoid()
return cls_logits, box_coords
# ============================================================================
# Main Model
# ============================================================================
class OculusForConditionalGeneration(PreTrainedModel):
"""
Oculus: Hybrid-Reasoning Vision-Language Model
Oceanir-Oculus OO1 Architecture
Features:
- Reasoning via Thinking Traces
- Perceptive Tool Calling + Focus (Zoom & Crop)
- Structured Outputs (JSON, Box, Point)
- Complex OCR
- Desktop UI Understanding
Small models that outperform systems 10x larger on visual reasoning.
"""
config_class = OculusConfig
base_model_prefix = "oculus"
def __init__(self, config: OculusConfig):
super().__init__(config)
self.config = config
# Vision encoder
self.vision_encoder = OculusVisionEncoder(config)
# Vision adapter for dimension matching
self.vision_adapter = nn.Linear(config.vision_hidden_size, config.fused_vision_dim)
# Projector
self.projector = OculusProjector(config)
# Language model
self.language_model = OculusLanguageModel(config)
# Task-specific heads
self.detection_head = OculusDetectionHead(config)
self.point_head = OculusPointHead(config)
self.segmentation_head = OculusSegmentationHead(config)
self.ocr_head = OculusOCRHead(config)
self.ui_head = OculusUIHead(config)
# Special tokens
self.thinking_token = config.thinking_token
self.thinking_end_token = config.thinking_end_token
self.focus_token = config.focus_token
self.focus_end_token = config.focus_end_token
self.json_token = config.json_token
self.json_end_token = config.json_end_token
self.box_token = config.box_token
self.box_end_token = config.box_end_token
self.point_token = config.point_token
self.point_end_token = config.point_end_token
def encode_image(self, image: Union[Image.Image, str, np.ndarray, torch.Tensor]) -> torch.Tensor:
"""Encode image to vision tokens."""
if isinstance(image, str):
image = Image.open(image).convert('RGB')
if isinstance(image, Image.Image):
image = np.array(image.resize((self.config.image_size, self.config.image_size)))
if isinstance(image, np.ndarray):
image = torch.from_numpy(image).float()
if image.dim() == 3:
image = image.permute(2, 0, 1).unsqueeze(0)
image = image / 255.0
device = next(self.parameters()).device
image = image.to(device)
# Encode with vision encoder
vision_features = self.vision_encoder(image)
# Adapt dimensions
vision_features = self.vision_adapter(vision_features)
# Project to language space
vision_tokens = self.projector(vision_features)
return vision_tokens
def _crop_region(self, image: Image.Image, bbox: Tuple[int, int, int, int]) -> Image.Image:
"""Crop image to specified region for focus/zoom."""
x1, y1, x2, y2 = bbox
return image.crop((x1, y1, x2, y2))
def _generate_thinking_trace(self, prompt: str, context: str = "") -> str:
"""Generate structured thinking trace."""
if self.config.thinking_style == "structured":
return f"{self.thinking_token}Analyzing: {prompt[:50]}...{self.thinking_end_token}"
elif self.config.thinking_style == "verbose":
return f"{self.thinking_token}Let me think step by step: {prompt}{self.thinking_end_token}"
else:
return ""
def generate(
self,
image: Union[Image.Image, str, np.ndarray],
prompt: str = "Describe this image",
mode: str = "text",
think: bool = False,
focus: bool = False,
max_new_tokens: Optional[int] = None,
temperature: float = 0.7,
**kwargs
) -> Union[OculusTextOutput, OculusJSONOutput, OculusPointOutput, OculusBoxOutput, OculusPolygonOutput, OculusOCROutput, OculusUIOutput]:
"""
Generate output from image.
Args:
image: Input image
prompt: Text prompt/question
mode: "text", "json", "point", "box", "polygon", "ocr", "ui"
think: Enable reasoning traces
focus: Enable zoom/crop for fine-grained perception
"""
if isinstance(image, str):
image = Image.open(image).convert('RGB')
elif isinstance(image, np.ndarray):
image = Image.fromarray(image).convert('RGB')
vision_tokens = self.encode_image(image)
thinking_trace = None
if think and self.config.reasoning_enabled:
thinking_trace = self._generate_thinking_trace(prompt)
if mode == "text":
return self._generate_text(image, prompt, vision_tokens, thinking_trace, max_new_tokens, **kwargs)
elif mode == "json":
return self._generate_json(image, prompt, vision_tokens, thinking_trace, **kwargs)
elif mode == "point":
return self._generate_points(vision_tokens, thinking_trace, **kwargs)
elif mode == "box":
return self._generate_boxes(vision_tokens, thinking_trace, **kwargs)
elif mode == "polygon":
return self._generate_polygons(vision_tokens, thinking_trace, **kwargs)
elif mode == "ocr":
return self._generate_ocr(vision_tokens, thinking_trace, **kwargs)
elif mode == "ui":
return self._generate_ui(vision_tokens, thinking_trace, **kwargs)
else:
raise ValueError(f"Unknown mode: {mode}")
def _generate_text(self, image, prompt, vision_tokens, thinking_trace, max_new_tokens, **kwargs) -> OculusTextOutput:
"""Generate text output."""
# Placeholder - full implementation would do autoregressive generation
text = f"[Generated response for: {prompt[:50]}...]"
if thinking_trace:
text = f"{thinking_trace} {text}"
return OculusTextOutput(
text=text,
thinking_trace=thinking_trace,
vision_tokens=vision_tokens
)
def _generate_json(self, image, prompt, vision_tokens, thinking_trace, **kwargs) -> OculusJSONOutput:
"""Generate structured JSON output."""
json_data = {
"prompt": prompt,
"response": "generated",
"objects": []
}
return OculusJSONOutput(
json_data=json_data,
text=f"{self.json_token}{json.dumps(json_data)}{self.json_end_token}",
thinking_trace=thinking_trace,
vision_tokens=vision_tokens
)
def _generate_points(self, vision_tokens, thinking_trace, threshold=0.5, **kwargs) -> OculusPointOutput:
"""Generate point detections."""
points, cls_logits, confidence = self.point_head(vision_tokens)
mask = confidence.squeeze(-1) > threshold
filtered_points = []
filtered_labels = []
filtered_conf = []
for i in range(vision_tokens.shape[0]):
token_mask = mask[i]
pts = points[i][token_mask].detach().cpu().numpy().tolist()
confs = confidence[i][token_mask].squeeze(-1).detach().cpu().numpy().tolist()
cls_ids = cls_logits[i][token_mask].argmax(dim=-1).detach().cpu().numpy().tolist()
filtered_points.extend([tuple(p) for p in pts])
filtered_conf.extend(confs)
filtered_labels.extend([str(c) for c in cls_ids])
return OculusPointOutput(
points=filtered_points,
labels=filtered_labels,
confidences=filtered_conf,
thinking_trace=thinking_trace,
vision_tokens=vision_tokens
)
def _generate_boxes(self, vision_tokens, thinking_trace, threshold=0.3, **kwargs) -> OculusBoxOutput:
"""Generate bounding box detections."""
cls_logits, box_coords = self.detection_head(vision_tokens)
confidence = F.softmax(cls_logits, dim=-1).max(dim=-1).values
filtered_boxes = []
filtered_labels = []
filtered_conf = []
for i in range(vision_tokens.shape[0]):
mask = confidence[i] > threshold
boxes = box_coords[i][mask].detach().cpu().numpy()
confs = confidence[i][mask].detach().cpu().numpy().tolist()
cls_ids = cls_logits[i][mask].argmax(dim=-1).detach().cpu().numpy().tolist()
filtered_boxes.extend([tuple(b) for b in boxes])
filtered_conf.extend(confs)
filtered_labels.extend([str(c) for c in cls_ids])
return OculusBoxOutput(
boxes=filtered_boxes,
labels=filtered_labels,
confidences=filtered_conf,
thinking_trace=thinking_trace,
vision_tokens=vision_tokens
)
def _generate_polygons(self, vision_tokens, thinking_trace, **kwargs) -> OculusPolygonOutput:
"""Generate polygon/mask segmentation."""
mask_logits = self.segmentation_head(vision_tokens)
mask = mask_logits.argmax(dim=1).detach().cpu().numpy()
polygons = []
labels = []
unique_classes = np.unique(mask[0])
for cls_id in unique_classes:
if cls_id == 0:
continue
labels.append(str(cls_id))
polygons.append([(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)])
return OculusPolygonOutput(
polygons=polygons,
labels=labels,
mask=mask[0],
thinking_trace=thinking_trace,
vision_tokens=vision_tokens
)
def _generate_ocr(self, vision_tokens, thinking_trace, **kwargs) -> OculusOCROutput:
"""Generate OCR output."""
detections = self.ocr_head(vision_tokens)
text_blocks = []
for i in range(detections.shape[1]):
det = detections[0, i].detach().cpu().numpy()
if det[4] > self.config.ocr_confidence_threshold:
text_blocks.append({
"text": "[detected]",
"bbox": det[:4].tolist(),
"confidence": float(det[4])
})
return OculusOCROutput(
text_blocks=text_blocks,
full_text=" ".join([b["text"] for b in text_blocks]),
thinking_trace=thinking_trace,
vision_tokens=vision_tokens
)
def _generate_ui(self, vision_tokens, thinking_trace, threshold=0.5, **kwargs) -> OculusUIOutput:
"""Generate UI element detections."""
cls_logits, box_coords = self.ui_head(vision_tokens)
confidence = F.softmax(cls_logits, dim=-1).max(dim=-1).values
UI_TYPES = ["button", "text_field", "checkbox", "radio", "dropdown", "link", "image", "icon", "label", "container"]
elements = []
for i in range(vision_tokens.shape[1]):
if confidence[0, i] > threshold:
cls_id = cls_logits[0, i].argmax().item()
elements.append({
"type": UI_TYPES[cls_id % len(UI_TYPES)],
"bbox": box_coords[0, i].detach().cpu().numpy().tolist(),
"confidence": float(confidence[0, i])
})
return OculusUIOutput(
elements=elements,
thinking_trace=thinking_trace,
vision_tokens=vision_tokens
)
# Convenience methods
def ask(self, image, question: str, think: bool = False, focus: bool = False) -> str:
"""Ask a question about an image."""
output = self.generate(image, question, mode="text", think=think, focus=focus)
return output.text
def caption(self, image) -> str:
"""Generate a caption for an image."""
output = self.generate(image, "Describe this image", mode="text")
return output.text
def detect(self, image) -> List[Dict]:
"""Detect objects in an image."""
output = self.generate(image, mode="box")
return [{"label": l, "box": b, "confidence": c}
for l, b, c in zip(output.labels, output.boxes, output.confidences)]
def segment(self, image) -> np.ndarray:
"""Segment an image."""
output = self.generate(image, mode="polygon")
return output.mask
def ocr(self, image) -> str:
"""Extract text from an image."""
output = self.generate(image, mode="ocr")
return output.full_text
def detect_ui(self, image) -> List[Dict]:
"""Detect UI elements in a screenshot."""
output = self.generate(image, mode="ui")
return output.elements
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
"""Load model from pretrained weights."""
path = Path(pretrained_model_name_or_path)
config_path = path / "config.json"
if config_path.exists():
with open(config_path) as f:
config_dict = json.load(f)
config = OculusConfig(**config_dict)
else:
config = OculusConfig()
model = cls(config)
# Load trained components
projector_path = path / "trained_components" / "projector.npz"
if projector_path.exists():
model.projector = OculusProjector.from_pretrained(path / "trained_components", config)
heads_path = path / "trained_components" / "heads.pth"
if heads_path.exists():
heads_state = torch.load(heads_path, map_location="cpu")
model.detection_head.load_state_dict(heads_state.get("detection", {}), strict=False)
model.point_head.load_state_dict(heads_state.get("point", {}), strict=False)
model.segmentation_head.load_state_dict(heads_state.get("segmentation", {}), strict=False)
model.ocr_head.load_state_dict(heads_state.get("ocr", {}), strict=False)
model.ui_head.load_state_dict(heads_state.get("ui", {}), strict=False)
print(f" ✓ Loaded heads from {heads_path}")
return model
def save_pretrained(self, save_directory: str):
"""Save model to directory."""
path = Path(save_directory)
path.mkdir(parents=True, exist_ok=True)
self.config.save_pretrained(path)
# Save projector
trained_path = path / "trained_components"
trained_path.mkdir(exist_ok=True)
projector_state = self.projector.state_dict()
np_weights = {}
for k, v in projector_state.items():
parts = k.split(".")
layer = parts[0]
param = ".".join(parts[1:])
if layer not in np_weights:
np_weights[layer] = {}
np_weights[layer][param] = v.cpu().numpy()
np.savez(trained_path / "projector.npz", **{k: v for k, v in np_weights.items()})
# Save heads
torch.save({
"detection": self.detection_head.state_dict(),
"point": self.point_head.state_dict(),
"segmentation": self.segmentation_head.state_dict(),
"ocr": self.ocr_head.state_dict(),
"ui": self.ui_head.state_dict(),
}, trained_path / "heads.pth")
print(f"✓ Saved model to {path}")
OculusForConditionalGeneration.register_for_auto_class("AutoModelForVision2Seq")