kobiakor15 commited on
Commit
4b92f99
·
verified ·
1 Parent(s): 7cefab8

Upload oculus_unified_model/modeling_oculus.py with huggingface_hub

Browse files
oculus_unified_model/modeling_oculus.py ADDED
@@ -0,0 +1,842 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Oculus Unified Model
3
+
4
+ HuggingFace-compatible vision-language model with:
5
+ - Multi-encoder vision (DINOv3 + SigLIP2)
6
+ - Trained projector for vision-to-language
7
+ - Optional reasoning with thinking traces
8
+ - Multiple output modes (Text, Point, Box, Polygon)
9
+ - Focus/Zoom tool calling for fine-grained perception
10
+ """
11
+
12
+ import os
13
+ import json
14
+ import warnings
15
+ from dataclasses import dataclass
16
+ from pathlib import Path
17
+ from typing import Optional, Tuple, List, Dict, Any, Union
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from transformers import (
24
+ PreTrainedModel,
25
+ PretrainedConfig,
26
+ AutoImageProcessor,
27
+ AutoModel,
28
+ AutoTokenizer,
29
+ AutoModelForCausalLM,
30
+ GenerationConfig,
31
+ )
32
+ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
33
+ from PIL import Image
34
+
35
+ from .configuration_oculus import OculusConfig
36
+
37
+
38
+ # ============================================================================
39
+ # Output Data Classes
40
+ # ============================================================================
41
+
42
+ @dataclass
43
+ class OculusOutput:
44
+ """Base output class for Oculus model."""
45
+ text: Optional[str] = None
46
+ thinking_trace: Optional[str] = None
47
+ logits: Optional[torch.Tensor] = None
48
+ hidden_states: Optional[torch.Tensor] = None
49
+ vision_tokens: Optional[torch.Tensor] = None
50
+
51
+
52
+ @dataclass
53
+ class OculusTextOutput(OculusOutput):
54
+ """Output for text/caption mode."""
55
+ pass
56
+
57
+
58
+ @dataclass
59
+ class OculusPointOutput(OculusOutput):
60
+ """Output for point detection mode (counting objects)."""
61
+ points: Optional[List[Tuple[float, float]]] = None
62
+ labels: Optional[List[str]] = None
63
+ confidences: Optional[List[float]] = None
64
+
65
+
66
+ @dataclass
67
+ class OculusBoxOutput(OculusOutput):
68
+ """Output for bounding box detection mode."""
69
+ boxes: Optional[List[Tuple[float, float, float, float]]] = None # x1, y1, x2, y2
70
+ labels: Optional[List[str]] = None
71
+ confidences: Optional[List[float]] = None
72
+
73
+
74
+ @dataclass
75
+ class OculusPolygonOutput(OculusOutput):
76
+ """Output for polygon/segmentation mode."""
77
+ polygons: Optional[List[List[Tuple[float, float]]]] = None
78
+ labels: Optional[List[str]] = None
79
+ mask: Optional[np.ndarray] = None
80
+
81
+
82
+ # ============================================================================
83
+ # Vision Encoder (DINOv3 + SigLIP2)
84
+ # ============================================================================
85
+
86
+ class OculusVisionEncoder(nn.Module):
87
+ """
88
+ Dual vision encoder combining DINOv3 and SigLIP2.
89
+
90
+ DINOv3: Excellent at semantic understanding, object boundaries
91
+ SigLIP2: Strong at text/language alignment
92
+ """
93
+
94
+ def __init__(self, config: OculusConfig):
95
+ super().__init__()
96
+ self.config = config
97
+
98
+ # Will be loaded lazily
99
+ self.dinov3 = None
100
+ self.dinov3_processor = None
101
+ self.siglip = None
102
+ self.siglip_processor = None
103
+
104
+ self._loaded = False
105
+
106
+ def load_encoders(self, device: str = "cpu"):
107
+ """Load vision encoders from HuggingFace."""
108
+ if self._loaded:
109
+ return
110
+
111
+ print("[Oculus] Loading vision encoders...")
112
+
113
+ # DINOv3
114
+ try:
115
+ self.dinov3_processor = AutoImageProcessor.from_pretrained(
116
+ self.config.dinov3_model_id
117
+ )
118
+ self.dinov3 = AutoModel.from_pretrained(
119
+ self.config.dinov3_model_id
120
+ ).eval().to(device)
121
+ print(f" ✓ DINOv3: {self.config.dinov3_model_id}")
122
+ except Exception as e:
123
+ warnings.warn(f"Failed to load DINOv3: {e}")
124
+ self.dinov3_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
125
+ self.dinov3 = AutoModel.from_pretrained("facebook/dinov2-base").eval().to(device)
126
+ print(" ✓ DINOv2-base (fallback)")
127
+
128
+ # SigLIP2
129
+ try:
130
+ self.siglip_processor = AutoImageProcessor.from_pretrained(
131
+ self.config.siglip_model_id
132
+ )
133
+ self.siglip = AutoModel.from_pretrained(
134
+ self.config.siglip_model_id
135
+ ).eval().to(device)
136
+ print(f" ✓ SigLIP: {self.config.siglip_model_id}")
137
+ except Exception as e:
138
+ warnings.warn(f"Failed to load SigLIP: {e}")
139
+ from transformers import SiglipVisionModel
140
+ self.siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
141
+ self.siglip = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224").eval().to(device)
142
+ print(" ✓ SigLIP-base (fallback)")
143
+
144
+ self._loaded = True
145
+
146
+ @torch.no_grad()
147
+ def forward(self, image: Union[Image.Image, torch.Tensor, np.ndarray]) -> torch.Tensor:
148
+ """
149
+ Encode image with both vision encoders and fuse features.
150
+
151
+ Returns:
152
+ Fused vision features [batch, fused_dim]
153
+ """
154
+ if not self._loaded:
155
+ self.load_encoders()
156
+
157
+ # Handle different input types
158
+ if isinstance(image, np.ndarray):
159
+ image = Image.fromarray(image)
160
+ elif isinstance(image, torch.Tensor):
161
+ image = Image.fromarray(image.cpu().numpy().astype(np.uint8))
162
+
163
+ if isinstance(image, Image.Image):
164
+ image = image.convert('RGB')
165
+
166
+ device = next(self.dinov3.parameters()).device
167
+
168
+ # DINOv3 encoding
169
+ d_inputs = self.dinov3_processor(images=image, return_tensors="pt")
170
+ d_inputs = {k: v.to(device) for k, v in d_inputs.items()}
171
+ d_out = self.dinov3(**d_inputs)
172
+ d_pooled = d_out.pooler_output if hasattr(d_out, 'pooler_output') and d_out.pooler_output is not None else d_out.last_hidden_state[:, 0]
173
+
174
+ # SigLIP encoding
175
+ s_inputs = self.siglip_processor(images=image, return_tensors="pt")
176
+ s_inputs = {k: v.to(device) for k, v in s_inputs.items()}
177
+
178
+ if hasattr(self.siglip, 'vision_model'):
179
+ s_hidden = self.siglip.vision_model.embeddings(s_inputs['pixel_values'])
180
+ s_pooled = s_hidden.mean(dim=1)
181
+ else:
182
+ s_out = self.siglip(**s_inputs)
183
+ s_pooled = s_out.pooler_output if hasattr(s_out, 'pooler_output') else s_out.last_hidden_state[:, 0]
184
+
185
+ # Fuse features
186
+ fused = torch.cat([d_pooled, s_pooled], dim=-1)
187
+
188
+ return fused
189
+
190
+
191
+ # ============================================================================
192
+ # Vision Projector
193
+ # ============================================================================
194
+
195
+ class OculusProjector(nn.Module):
196
+ """
197
+ Projects fused vision features to language model token space.
198
+
199
+ Converts [batch, fused_dim] → [batch, num_tokens, lm_hidden_size]
200
+ """
201
+
202
+ def __init__(self, config: OculusConfig):
203
+ super().__init__()
204
+ self.config = config
205
+
206
+ fused_dim = config.fused_vision_dim
207
+ hidden_dim = config.projector_hidden_dim
208
+ num_tokens = config.num_vision_tokens
209
+ embed_dim = config.lm_hidden_size
210
+
211
+ self.fc1 = nn.Linear(fused_dim, hidden_dim)
212
+ self.act1 = nn.GELU()
213
+ self.fc2 = nn.Linear(hidden_dim, hidden_dim)
214
+ self.act2 = nn.GELU()
215
+ self.fc3 = nn.Linear(hidden_dim, num_tokens * embed_dim)
216
+ self.norm = nn.LayerNorm(embed_dim)
217
+
218
+ self.num_tokens = num_tokens
219
+ self.embed_dim = embed_dim
220
+
221
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
222
+ """
223
+ Project vision features to token embeddings.
224
+
225
+ Args:
226
+ x: Vision features [batch, fused_dim]
227
+
228
+ Returns:
229
+ Vision tokens [batch, num_tokens, embed_dim]
230
+ """
231
+ batch_size = x.shape[0]
232
+
233
+ h = self.fc1(x)
234
+ h = self.act1(h)
235
+ h = self.fc2(h)
236
+ h = self.act2(h)
237
+ h = self.fc3(h)
238
+
239
+ h = h.reshape(batch_size, self.num_tokens, self.embed_dim)
240
+ h = self.norm(h)
241
+
242
+ return h
243
+
244
+ @classmethod
245
+ def from_pretrained(cls, path: str, config: OculusConfig):
246
+ """Load projector from saved weights."""
247
+ projector = cls(config)
248
+
249
+ weights_path = Path(path) / "projector.npz"
250
+ if weights_path.exists():
251
+ import numpy as np
252
+ weights = np.load(weights_path, allow_pickle=True)
253
+
254
+ state_dict = {}
255
+ for key in weights.files:
256
+ layer_dict = weights[key].item()
257
+ for param_name, param_val in layer_dict.items():
258
+ full_key = f"{key}.{param_name}"
259
+ # Convert from MLX array if needed
260
+ if hasattr(param_val, 'tolist'):
261
+ param_val = np.array(param_val.tolist())
262
+ state_dict[full_key] = torch.from_numpy(np.array(param_val))
263
+
264
+ projector.load_state_dict(state_dict, strict=False)
265
+ print(f" ✓ Loaded projector from {path}")
266
+
267
+ return projector
268
+
269
+
270
+ # ============================================================================
271
+ # Detection/Segmentation Heads
272
+ # ============================================================================
273
+
274
+ class OculusDetectionHead(nn.Module):
275
+ """Head for bounding box detection."""
276
+
277
+ def __init__(self, config: OculusConfig):
278
+ super().__init__()
279
+ hidden_dim = config.lm_hidden_size
280
+ num_classes = config.num_detection_classes
281
+
282
+ self.cls_head = nn.Sequential(
283
+ nn.Linear(hidden_dim, hidden_dim // 2),
284
+ nn.GELU(),
285
+ nn.Linear(hidden_dim // 2, num_classes)
286
+ )
287
+
288
+ self.box_head = nn.Sequential(
289
+ nn.Linear(hidden_dim, hidden_dim // 2),
290
+ nn.GELU(),
291
+ nn.Linear(hidden_dim // 2, 4) # x1, y1, x2, y2
292
+ )
293
+
294
+ def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
295
+ """
296
+ Predict boxes and classes from vision tokens.
297
+
298
+ Returns:
299
+ cls_logits: [batch, num_tokens, num_classes]
300
+ box_coords: [batch, num_tokens, 4]
301
+ """
302
+ cls_logits = self.cls_head(vision_tokens)
303
+ box_coords = self.box_head(vision_tokens).sigmoid() # Normalize to [0, 1]
304
+ return cls_logits, box_coords
305
+
306
+
307
+ class OculusPointHead(nn.Module):
308
+ """Head for point detection (object counting)."""
309
+
310
+ def __init__(self, config: OculusConfig):
311
+ super().__init__()
312
+ hidden_dim = config.lm_hidden_size
313
+ num_classes = config.num_detection_classes
314
+
315
+ self.point_head = nn.Sequential(
316
+ nn.Linear(hidden_dim, hidden_dim // 2),
317
+ nn.GELU(),
318
+ nn.Linear(hidden_dim // 2, 2) # x, y
319
+ )
320
+
321
+ self.cls_head = nn.Sequential(
322
+ nn.Linear(hidden_dim, hidden_dim // 2),
323
+ nn.GELU(),
324
+ nn.Linear(hidden_dim // 2, num_classes)
325
+ )
326
+
327
+ self.conf_head = nn.Sequential(
328
+ nn.Linear(hidden_dim, hidden_dim // 4),
329
+ nn.GELU(),
330
+ nn.Linear(hidden_dim // 4, 1)
331
+ )
332
+
333
+ def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
334
+ points = self.point_head(vision_tokens).sigmoid()
335
+ cls_logits = self.cls_head(vision_tokens)
336
+ confidence = self.conf_head(vision_tokens).sigmoid()
337
+ return points, cls_logits, confidence
338
+
339
+
340
+ class OculusSegmentationHead(nn.Module):
341
+ """Head for polygon/mask segmentation."""
342
+
343
+ def __init__(self, config: OculusConfig):
344
+ super().__init__()
345
+ hidden_dim = config.lm_hidden_size
346
+ num_classes = config.num_segmentation_classes
347
+
348
+ # Predict mask logits
349
+ self.mask_head = nn.Sequential(
350
+ nn.Linear(hidden_dim, hidden_dim),
351
+ nn.GELU(),
352
+ nn.Linear(hidden_dim, 14 * 14 * num_classes) # Output spatial mask
353
+ )
354
+
355
+ self.num_classes = num_classes
356
+
357
+ def forward(self, vision_tokens: torch.Tensor) -> torch.Tensor:
358
+ batch_size = vision_tokens.shape[0]
359
+ pooled = vision_tokens.mean(dim=1)
360
+ mask_logits = self.mask_head(pooled)
361
+ mask_logits = mask_logits.reshape(batch_size, self.num_classes, 14, 14)
362
+ return mask_logits
363
+
364
+
365
+ # ============================================================================
366
+ # Main Model
367
+ # ============================================================================
368
+
369
+ class OculusForConditionalGeneration(PreTrainedModel):
370
+ """
371
+ Oculus: Unified Vision-Language Model
372
+
373
+ Features:
374
+ - Multi-encoder vision (DINOv3 + SigLIP2)
375
+ - Optional reasoning with thinking traces
376
+ - Multiple output modes: Text, Point, Box, Polygon
377
+ - Focus/Zoom tool calling for fine-grained perception
378
+
379
+ Usage:
380
+ ```python
381
+ from oculus_unified_model import OculusForConditionalGeneration
382
+
383
+ model = OculusForConditionalGeneration.from_pretrained("OceanirAI/oculus-0.2")
384
+
385
+ # Caption mode
386
+ output = model.generate(image, mode="text", prompt="Describe this image")
387
+
388
+ # VQA mode
389
+ output = model.generate(image, mode="text", prompt="What color is the cat?")
390
+
391
+ # With reasoning
392
+ output = model.generate(image, mode="text", prompt="Count the people", think=True)
393
+
394
+ # Detection mode
395
+ output = model.generate(image, mode="box", prompt="Find all cars")
396
+
397
+ # Point mode (counting)
398
+ output = model.generate(image, mode="point", prompt="Count the birds")
399
+
400
+ # Segmentation mode
401
+ output = model.generate(image, mode="polygon", prompt="Segment the road")
402
+ ```
403
+ """
404
+
405
+ config_class = OculusConfig
406
+ base_model_prefix = "oculus"
407
+
408
+ def __init__(self, config: OculusConfig):
409
+ super().__init__(config)
410
+ self.config = config
411
+
412
+ # Vision encoder
413
+ self.vision_encoder = OculusVisionEncoder(config)
414
+
415
+ # Vision adapter (handles dimension mismatch if needed)
416
+ self.vision_adapter = None
417
+ self._actual_vision_dim = None
418
+
419
+ # Projector
420
+ self.projector = OculusProjector(config)
421
+
422
+ # Task-specific heads
423
+ self.detection_head = OculusDetectionHead(config)
424
+ self.point_head = OculusPointHead(config)
425
+ self.segmentation_head = OculusSegmentationHead(config)
426
+
427
+ # Language model (loaded lazily)
428
+ self.lm_tokenizer = None
429
+ self.lm_model = None
430
+ self._lm_loaded = False
431
+
432
+ # Special tokens for reasoning
433
+ self.thinking_token = config.thinking_token
434
+ self.thinking_end_token = config.thinking_end_token
435
+ self.focus_token = config.focus_token
436
+ self.focus_end_token = config.focus_end_token
437
+
438
+ def load_language_model(self, device: str = "cpu"):
439
+ """Load language model for text generation."""
440
+ if self._lm_loaded:
441
+ return
442
+
443
+ print("[Oculus] Loading language model...")
444
+
445
+ try:
446
+ # Try BLIP for now (works well for captioning/VQA)
447
+ from transformers import BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
448
+
449
+ self.lm_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
450
+ self.lm_caption_model = BlipForConditionalGeneration.from_pretrained(
451
+ "Salesforce/blip-image-captioning-base"
452
+ ).to(device)
453
+
454
+ self.lm_vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
455
+ self.lm_vqa_model = BlipForQuestionAnswering.from_pretrained(
456
+ "Salesforce/blip-vqa-base"
457
+ ).to(device)
458
+
459
+ print(" ✓ BLIP (captioning + VQA)")
460
+ self._lm_loaded = True
461
+
462
+ except Exception as e:
463
+ warnings.warn(f"Failed to load language model: {e}")
464
+
465
+ def encode_image(self, image: Union[Image.Image, str, np.ndarray]) -> torch.Tensor:
466
+ """
467
+ Encode image to vision tokens.
468
+
469
+ Args:
470
+ image: PIL Image, file path, or numpy array
471
+
472
+ Returns:
473
+ Vision tokens [1, num_tokens, embed_dim]
474
+ """
475
+ # Load image if path
476
+ if isinstance(image, str):
477
+ image = Image.open(image)
478
+
479
+ # Encode with vision encoders
480
+ vision_features = self.vision_encoder(image)
481
+
482
+ # Check if we need an adapter for dimension mismatch
483
+ actual_dim = vision_features.shape[-1]
484
+ expected_dim = self.config.fused_vision_dim
485
+
486
+ if actual_dim != expected_dim:
487
+ if self.vision_adapter is None or self._actual_vision_dim != actual_dim:
488
+ # Create adapter layer
489
+ print(f" [Adapter] Creating vision adapter: {actual_dim} -> {expected_dim}")
490
+ self.vision_adapter = nn.Linear(actual_dim, expected_dim)
491
+ self._actual_vision_dim = actual_dim
492
+ # Initialize with small weights
493
+ nn.init.xavier_uniform_(self.vision_adapter.weight)
494
+ nn.init.zeros_(self.vision_adapter.bias)
495
+
496
+ vision_features = self.vision_adapter(vision_features)
497
+
498
+ # Project to language space
499
+ vision_tokens = self.projector(vision_features)
500
+
501
+ return vision_tokens
502
+
503
+ def _generate_thinking_trace(
504
+ self,
505
+ image: Image.Image,
506
+ prompt: str,
507
+ max_tokens: int = 256
508
+ ) -> str:
509
+ """
510
+ Generate a thinking/reasoning trace before answering.
511
+
512
+ This enables multi-step reasoning for complex tasks.
513
+ """
514
+ thinking_prompt = f"""Let me think about this step by step:
515
+ 1. First, I'll analyze what I see in the image.
516
+ 2. Then, I'll consider the question: "{prompt}"
517
+ 3. Finally, I'll formulate my answer.
518
+
519
+ Observation: """
520
+
521
+ # Generate reasoning (simplified for now)
522
+ if self._lm_loaded and hasattr(self, 'lm_caption_model'):
523
+ inputs = self.lm_processor(image, thinking_prompt, return_tensors="pt")
524
+ inputs = {k: v.to(self.lm_caption_model.device) for k, v in inputs.items()}
525
+
526
+ with torch.no_grad():
527
+ out = self.lm_caption_model.generate(
528
+ **inputs,
529
+ max_new_tokens=max_tokens,
530
+ do_sample=True,
531
+ temperature=0.7
532
+ )
533
+ thinking = self.lm_processor.decode(out[0], skip_special_tokens=True)
534
+ else:
535
+ thinking = "I observe the image and analyze its contents."
536
+
537
+ return thinking
538
+
539
+ def _detect_focus_regions(
540
+ self,
541
+ image: Image.Image,
542
+ prompt: str
543
+ ) -> List[Tuple[int, int, int, int]]:
544
+ """
545
+ Detect regions that need closer inspection (Focus/Zoom system).
546
+
547
+ Returns list of (x1, y1, x2, y2) crop regions.
548
+ """
549
+ # Simplified: return full image as single region
550
+ # In full implementation, would use attention maps to find regions of interest
551
+ w, h = image.size
552
+ return [(0, 0, w, h)]
553
+
554
+ def generate(
555
+ self,
556
+ image: Union[Image.Image, str, np.ndarray],
557
+ prompt: str = "Describe this image",
558
+ mode: str = "text",
559
+ think: bool = False,
560
+ focus: bool = False,
561
+ max_new_tokens: Optional[int] = None,
562
+ temperature: float = 0.7,
563
+ return_thinking: bool = True,
564
+ **kwargs
565
+ ) -> Union[OculusTextOutput, OculusPointOutput, OculusBoxOutput, OculusPolygonOutput]:
566
+ """
567
+ Generate output from image.
568
+
569
+ Args:
570
+ image: Input image (PIL, path, or array)
571
+ prompt: Text prompt/question
572
+ mode: Output mode ("text", "point", "box", "polygon")
573
+ think: Enable reasoning traces
574
+ focus: Enable zoom/crop for fine-grained perception
575
+ max_new_tokens: Maximum tokens to generate
576
+ temperature: Sampling temperature
577
+ return_thinking: Include thinking trace in output
578
+
579
+ Returns:
580
+ Mode-specific output dataclass
581
+ """
582
+ # Load models if needed
583
+ self.vision_encoder.load_encoders()
584
+ if mode == "text":
585
+ self.load_language_model()
586
+
587
+ # Load image
588
+ if isinstance(image, str):
589
+ image = Image.open(image).convert('RGB')
590
+ elif isinstance(image, np.ndarray):
591
+ image = Image.fromarray(image).convert('RGB')
592
+
593
+ # Encode image
594
+ vision_tokens = self.encode_image(image)
595
+
596
+ # Generate thinking trace if enabled
597
+ thinking_trace = None
598
+ if think and self.config.reasoning_enabled:
599
+ thinking_trace = self._generate_thinking_trace(image, prompt)
600
+
601
+ # Focus system: zoom/crop if needed
602
+ if focus and self.config.enable_focus:
603
+ focus_regions = self._detect_focus_regions(image, prompt)
604
+ # Could re-encode cropped regions here
605
+
606
+ # Mode-specific generation
607
+ if mode == "text":
608
+ return self._generate_text(image, prompt, vision_tokens, thinking_trace, max_new_tokens, **kwargs)
609
+ elif mode == "point":
610
+ return self._generate_points(vision_tokens, thinking_trace, **kwargs)
611
+ elif mode == "box":
612
+ return self._generate_boxes(vision_tokens, thinking_trace, **kwargs)
613
+ elif mode == "polygon":
614
+ return self._generate_polygons(vision_tokens, thinking_trace, **kwargs)
615
+ else:
616
+ raise ValueError(f"Unknown mode: {mode}")
617
+
618
+ def _generate_text(
619
+ self,
620
+ image: Image.Image,
621
+ prompt: str,
622
+ vision_tokens: torch.Tensor,
623
+ thinking_trace: Optional[str],
624
+ max_new_tokens: Optional[int],
625
+ **kwargs
626
+ ) -> OculusTextOutput:
627
+ """Generate text output (caption or VQA)."""
628
+
629
+ device = vision_tokens.device if vision_tokens.is_cuda else "cpu"
630
+ max_tokens = max_new_tokens or self.config.max_new_tokens
631
+
632
+ # Determine if this is a question
633
+ is_question = any(q in prompt.lower() for q in ["what", "where", "who", "how", "why", "is", "are", "does", "do", "can", "?"])
634
+
635
+ if is_question and hasattr(self, 'lm_vqa_model'):
636
+ # VQA mode
637
+ inputs = self.lm_vqa_processor(image, prompt, return_tensors="pt")
638
+ inputs = {k: v.to(device) for k, v in inputs.items()}
639
+
640
+ with torch.no_grad():
641
+ out = self.lm_vqa_model.generate(**inputs, max_new_tokens=50)
642
+ text = self.lm_vqa_processor.decode(out[0], skip_special_tokens=True)
643
+ else:
644
+ # Caption mode
645
+ inputs = self.lm_processor(image, prompt, return_tensors="pt")
646
+ inputs = {k: v.to(device) for k, v in inputs.items()}
647
+
648
+ with torch.no_grad():
649
+ out = self.lm_caption_model.generate(**inputs, max_new_tokens=max_tokens)
650
+ text = self.lm_processor.decode(out[0], skip_special_tokens=True)
651
+
652
+ return OculusTextOutput(
653
+ text=text,
654
+ thinking_trace=thinking_trace,
655
+ vision_tokens=vision_tokens
656
+ )
657
+
658
+ def _generate_points(
659
+ self,
660
+ vision_tokens: torch.Tensor,
661
+ thinking_trace: Optional[str],
662
+ threshold: float = 0.5,
663
+ **kwargs
664
+ ) -> OculusPointOutput:
665
+ """Generate point detections."""
666
+
667
+ points, cls_logits, confidence = self.point_head(vision_tokens)
668
+
669
+ # Filter by confidence
670
+ mask = confidence.squeeze(-1) > threshold
671
+
672
+ filtered_points = []
673
+ filtered_labels = []
674
+ filtered_conf = []
675
+
676
+ for i in range(vision_tokens.shape[0]):
677
+ token_mask = mask[i]
678
+ pts = points[i][token_mask].detach().cpu().numpy().tolist()
679
+ confs = confidence[i][token_mask].squeeze(-1).detach().cpu().numpy().tolist()
680
+ cls_ids = cls_logits[i][token_mask].argmax(dim=-1).detach().cpu().numpy().tolist()
681
+
682
+ filtered_points.extend([tuple(p) for p in pts])
683
+ filtered_conf.extend(confs)
684
+ filtered_labels.extend([str(c) for c in cls_ids])
685
+
686
+ return OculusPointOutput(
687
+ points=filtered_points,
688
+ labels=filtered_labels,
689
+ confidences=filtered_conf,
690
+ thinking_trace=thinking_trace,
691
+ vision_tokens=vision_tokens
692
+ )
693
+
694
+ def _generate_boxes(
695
+ self,
696
+ vision_tokens: torch.Tensor,
697
+ thinking_trace: Optional[str],
698
+ threshold: float = 0.3,
699
+ **kwargs
700
+ ) -> OculusBoxOutput:
701
+ """Generate bounding box detections."""
702
+
703
+ cls_logits, box_coords = self.detection_head(vision_tokens)
704
+
705
+ # Get confidence from class logits
706
+ confidence = F.softmax(cls_logits, dim=-1).max(dim=-1).values
707
+
708
+ filtered_boxes = []
709
+ filtered_labels = []
710
+ filtered_conf = []
711
+
712
+ for i in range(vision_tokens.shape[0]):
713
+ mask = confidence[i] > threshold
714
+ boxes = box_coords[i][mask].detach().cpu().numpy()
715
+ confs = confidence[i][mask].detach().cpu().numpy().tolist()
716
+ cls_ids = cls_logits[i][mask].argmax(dim=-1).detach().cpu().numpy().tolist()
717
+
718
+ filtered_boxes.extend([tuple(b) for b in boxes])
719
+ filtered_conf.extend(confs)
720
+ filtered_labels.extend([str(c) for c in cls_ids])
721
+
722
+ return OculusBoxOutput(
723
+ boxes=filtered_boxes,
724
+ labels=filtered_labels,
725
+ confidences=filtered_conf,
726
+ thinking_trace=thinking_trace,
727
+ vision_tokens=vision_tokens
728
+ )
729
+
730
+ def _generate_polygons(
731
+ self,
732
+ vision_tokens: torch.Tensor,
733
+ thinking_trace: Optional[str],
734
+ **kwargs
735
+ ) -> OculusPolygonOutput:
736
+ """Generate polygon/mask segmentation."""
737
+
738
+ mask_logits = self.segmentation_head(vision_tokens)
739
+
740
+ # Get predicted mask
741
+ mask = mask_logits.argmax(dim=1).detach().cpu().numpy()
742
+
743
+ # Convert to polygons (simplified)
744
+ # In full implementation, would use cv2.findContours
745
+ polygons = []
746
+ labels = []
747
+
748
+ unique_classes = np.unique(mask[0])
749
+ for cls_id in unique_classes:
750
+ if cls_id == 0: # Skip background
751
+ continue
752
+ labels.append(str(cls_id))
753
+ # Placeholder polygon
754
+ polygons.append([(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)])
755
+
756
+ return OculusPolygonOutput(
757
+ polygons=polygons,
758
+ labels=labels,
759
+ mask=mask[0],
760
+ thinking_trace=thinking_trace,
761
+ vision_tokens=vision_tokens
762
+ )
763
+
764
+ @classmethod
765
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
766
+ """
767
+ Load model from pretrained weights.
768
+
769
+ Args:
770
+ pretrained_model_name_or_path: HuggingFace repo ID or local path
771
+ """
772
+ path = Path(pretrained_model_name_or_path)
773
+
774
+ # Load config
775
+ config_path = path / "config.json"
776
+ if config_path.exists():
777
+ import json
778
+ with open(config_path) as f:
779
+ proj_config = json.load(f)
780
+
781
+ # Create config with correct dimensions from projector
782
+ config = OculusConfig(
783
+ dinov3_hidden_size=proj_config.get("fused_dim", 2048) - 768, # Infer from fused
784
+ siglip_hidden_size=768,
785
+ projector_hidden_dim=proj_config.get("hidden_dim", 2048),
786
+ num_vision_tokens=proj_config.get("num_tokens", 64),
787
+ lm_hidden_size=proj_config.get("embed_dim", 1536),
788
+ )
789
+ else:
790
+ config = OculusConfig()
791
+
792
+ # Create model
793
+ model = cls(config)
794
+
795
+ # Load projector weights
796
+ projector_path = path / "projector.npz"
797
+ if projector_path.exists():
798
+ model.projector = OculusProjector.from_pretrained(path, config)
799
+
800
+ # Load detection/segmentation heads if available
801
+ heads_path = path / "heads.pth"
802
+ if heads_path.exists():
803
+ heads_state = torch.load(heads_path, map_location="cpu")
804
+ model.detection_head.load_state_dict(heads_state.get("detection", {}), strict=False)
805
+ model.point_head.load_state_dict(heads_state.get("point", {}), strict=False)
806
+ model.segmentation_head.load_state_dict(heads_state.get("segmentation", {}), strict=False)
807
+
808
+ return model
809
+
810
+ def save_pretrained(self, save_directory: str):
811
+ """Save model to directory."""
812
+ path = Path(save_directory)
813
+ path.mkdir(parents=True, exist_ok=True)
814
+
815
+ # Save config
816
+ self.config.save_pretrained(path)
817
+
818
+ # Save projector
819
+ projector_state = self.projector.state_dict()
820
+ # Convert to numpy for MLX compatibility
821
+ np_weights = {}
822
+ for k, v in projector_state.items():
823
+ parts = k.split(".")
824
+ layer = parts[0]
825
+ param = ".".join(parts[1:])
826
+ if layer not in np_weights:
827
+ np_weights[layer] = {}
828
+ np_weights[layer][param] = v.cpu().numpy()
829
+ np.savez(path / "projector.npz", **{k: v for k, v in np_weights.items()})
830
+
831
+ # Save heads
832
+ torch.save({
833
+ "detection": self.detection_head.state_dict(),
834
+ "point": self.point_head.state_dict(),
835
+ "segmentation": self.segmentation_head.state_dict(),
836
+ }, path / "heads.pth")
837
+
838
+ print(f"✓ Saved model to {path}")
839
+
840
+
841
+ # Register for auto-loading
842
+ OculusForConditionalGeneration.register_for_auto_class("AutoModelForVision2Seq")