kobiakor15 commited on
Commit
95c5fe2
·
verified ·
1 Parent(s): a028cbf

Upload oculus_unified_model/modeling_oculus.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. oculus_unified_model/modeling_oculus.py +357 -379
oculus_unified_model/modeling_oculus.py CHANGED
@@ -3,10 +3,13 @@ Oculus Unified Model
3
 
4
  HuggingFace-compatible vision-language model with:
5
  - Multi-encoder vision (DINOv3 + SigLIP2)
6
- - Trained projector for vision-to-language
7
- - Optional reasoning with thinking traces
8
- - Multiple output modes (Text, Point, Box, Polygon)
9
- - Focus/Zoom tool calling for fine-grained perception
 
 
 
10
  """
11
 
12
  import os
@@ -27,9 +30,7 @@ from transformers import (
27
  AutoModel,
28
  AutoTokenizer,
29
  AutoModelForCausalLM,
30
- GenerationConfig,
31
  )
32
- from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
33
  from PIL import Image
34
 
35
  from .configuration_oculus import OculusConfig
@@ -55,6 +56,12 @@ class OculusTextOutput(OculusOutput):
55
  pass
56
 
57
 
 
 
 
 
 
 
58
  @dataclass
59
  class OculusPointOutput(OculusOutput):
60
  """Output for point detection mode (counting objects)."""
@@ -63,10 +70,10 @@ class OculusPointOutput(OculusOutput):
63
  confidences: Optional[List[float]] = None
64
 
65
 
66
- @dataclass
67
  class OculusBoxOutput(OculusOutput):
68
  """Output for bounding box detection mode."""
69
- boxes: Optional[List[Tuple[float, float, float, float]]] = None # x1, y1, x2, y2
70
  labels: Optional[List[str]] = None
71
  confidences: Optional[List[float]] = None
72
 
@@ -79,6 +86,19 @@ class OculusPolygonOutput(OculusOutput):
79
  mask: Optional[np.ndarray] = None
80
 
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  # ============================================================================
83
  # Vision Encoder (DINOv3 + SigLIP2)
84
  # ============================================================================
@@ -86,30 +106,29 @@ class OculusPolygonOutput(OculusOutput):
86
  class OculusVisionEncoder(nn.Module):
87
  """
88
  Dual vision encoder combining DINOv3 and SigLIP2.
89
-
90
  DINOv3: Excellent at semantic understanding, object boundaries
91
  SigLIP2: Strong at text/language alignment
92
  """
93
-
94
  def __init__(self, config: OculusConfig):
95
  super().__init__()
96
  self.config = config
97
-
98
- # Will be loaded lazily
99
  self.dinov3 = None
100
  self.dinov3_processor = None
101
  self.siglip = None
102
  self.siglip_processor = None
103
-
104
  self._loaded = False
105
-
106
  def load_encoders(self, device: str = "cpu"):
107
  """Load vision encoders from HuggingFace."""
108
  if self._loaded:
109
  return
110
-
111
  print("[Oculus] Loading vision encoders...")
112
-
113
  # DINOv3
114
  try:
115
  self.dinov3_processor = AutoImageProcessor.from_pretrained(
@@ -121,10 +140,10 @@ class OculusVisionEncoder(nn.Module):
121
  print(f" ✓ DINOv3: {self.config.dinov3_model_id}")
122
  except Exception as e:
123
  warnings.warn(f"Failed to load DINOv3: {e}")
124
- self.dinov3_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
125
- self.dinov3 = AutoModel.from_pretrained("facebook/dinov2-base").eval().to(device)
126
- print(" ✓ DINOv2-base (fallback)")
127
-
128
  # SigLIP2
129
  try:
130
  self.siglip_processor = AutoImageProcessor.from_pretrained(
@@ -133,58 +152,52 @@ class OculusVisionEncoder(nn.Module):
133
  self.siglip = AutoModel.from_pretrained(
134
  self.config.siglip_model_id
135
  ).eval().to(device)
136
- print(f" ✓ SigLIP: {self.config.siglip_model_id}")
137
  except Exception as e:
138
- warnings.warn(f"Failed to load SigLIP: {e}")
139
  from transformers import SiglipVisionModel
140
  self.siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
141
  self.siglip = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224").eval().to(device)
142
  print(" ✓ SigLIP-base (fallback)")
143
-
144
  self._loaded = True
145
-
146
  @torch.no_grad()
147
  def forward(self, image: Union[Image.Image, torch.Tensor, np.ndarray]) -> torch.Tensor:
148
- """
149
- Encode image with both vision encoders and fuse features.
150
-
151
- Returns:
152
- Fused vision features [batch, fused_dim]
153
- """
154
  if not self._loaded:
155
  self.load_encoders()
156
-
157
- # Handle different input types
158
  if isinstance(image, np.ndarray):
159
  image = Image.fromarray(image)
160
  elif isinstance(image, torch.Tensor):
161
  image = Image.fromarray(image.cpu().numpy().astype(np.uint8))
162
-
163
  if isinstance(image, Image.Image):
164
  image = image.convert('RGB')
165
-
166
  device = next(self.dinov3.parameters()).device
167
-
168
  # DINOv3 encoding
169
  d_inputs = self.dinov3_processor(images=image, return_tensors="pt")
170
  d_inputs = {k: v.to(device) for k, v in d_inputs.items()}
171
  d_out = self.dinov3(**d_inputs)
172
  d_pooled = d_out.pooler_output if hasattr(d_out, 'pooler_output') and d_out.pooler_output is not None else d_out.last_hidden_state[:, 0]
173
-
174
- # SigLIP encoding
175
  s_inputs = self.siglip_processor(images=image, return_tensors="pt")
176
  s_inputs = {k: v.to(device) for k, v in s_inputs.items()}
177
-
178
  if hasattr(self.siglip, 'vision_model'):
179
  s_hidden = self.siglip.vision_model.embeddings(s_inputs['pixel_values'])
180
  s_pooled = s_hidden.mean(dim=1)
181
  else:
182
  s_out = self.siglip(**s_inputs)
183
  s_pooled = s_out.pooler_output if hasattr(s_out, 'pooler_output') else s_out.last_hidden_state[:, 0]
184
-
185
  # Fuse features
186
  fused = torch.cat([d_pooled, s_pooled], dim=-1)
187
-
188
  return fused
189
 
190
 
@@ -193,143 +206,121 @@ class OculusVisionEncoder(nn.Module):
193
  # ============================================================================
194
 
195
  class OculusProjector(nn.Module):
196
- """
197
- Projects fused vision features to language model token space.
198
-
199
- Converts [batch, fused_dim] → [batch, num_tokens, lm_hidden_size]
200
- """
201
-
202
  def __init__(self, config: OculusConfig):
203
  super().__init__()
204
  self.config = config
205
-
206
  fused_dim = config.fused_vision_dim
207
  hidden_dim = config.projector_hidden_dim
208
  num_tokens = config.num_vision_tokens
209
  embed_dim = config.lm_hidden_size
210
-
211
  self.fc1 = nn.Linear(fused_dim, hidden_dim)
212
  self.act1 = nn.GELU()
213
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
214
  self.act2 = nn.GELU()
215
  self.fc3 = nn.Linear(hidden_dim, num_tokens * embed_dim)
216
  self.norm = nn.LayerNorm(embed_dim)
217
-
218
  self.num_tokens = num_tokens
219
  self.embed_dim = embed_dim
220
-
221
  def forward(self, x: torch.Tensor) -> torch.Tensor:
222
- """
223
- Project vision features to token embeddings.
224
-
225
- Args:
226
- x: Vision features [batch, fused_dim]
227
-
228
- Returns:
229
- Vision tokens [batch, num_tokens, embed_dim]
230
- """
231
  batch_size = x.shape[0]
232
-
233
  h = self.fc1(x)
234
  h = self.act1(h)
235
  h = self.fc2(h)
236
  h = self.act2(h)
237
  h = self.fc3(h)
238
-
239
  h = h.reshape(batch_size, self.num_tokens, self.embed_dim)
240
  h = self.norm(h)
241
-
242
  return h
243
-
244
  @classmethod
245
  def from_pretrained(cls, path: str, config: OculusConfig):
246
  """Load projector from saved weights."""
247
  projector = cls(config)
248
-
249
  weights_path = Path(path) / "projector.npz"
250
  if weights_path.exists():
251
- import numpy as np
252
  weights = np.load(weights_path, allow_pickle=True)
253
-
254
  state_dict = {}
255
  for key in weights.files:
256
  layer_dict = weights[key].item()
257
  for param_name, param_val in layer_dict.items():
258
  full_key = f"{key}.{param_name}"
259
- # Convert from MLX array if needed
260
  if hasattr(param_val, 'tolist'):
261
  param_val = np.array(param_val.tolist())
262
  state_dict[full_key] = torch.from_numpy(np.array(param_val))
263
-
264
  projector.load_state_dict(state_dict, strict=False)
265
  print(f" ✓ Loaded projector from {path}")
266
-
267
  return projector
268
 
269
 
270
  # ============================================================================
271
- # Detection/Segmentation Heads
272
  # ============================================================================
273
 
274
  class OculusDetectionHead(nn.Module):
275
  """Head for bounding box detection."""
276
-
277
  def __init__(self, config: OculusConfig):
278
  super().__init__()
279
  hidden_dim = config.lm_hidden_size
280
  num_classes = config.num_detection_classes
281
-
282
  self.cls_head = nn.Sequential(
283
  nn.Linear(hidden_dim, hidden_dim // 2),
284
  nn.GELU(),
285
  nn.Linear(hidden_dim // 2, num_classes)
286
  )
287
-
288
  self.box_head = nn.Sequential(
289
  nn.Linear(hidden_dim, hidden_dim // 2),
290
  nn.GELU(),
291
- nn.Linear(hidden_dim // 2, 4) # x1, y1, x2, y2
292
  )
293
-
294
  def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
295
- """
296
- Predict boxes and classes from vision tokens.
297
-
298
- Returns:
299
- cls_logits: [batch, num_tokens, num_classes]
300
- box_coords: [batch, num_tokens, 4]
301
- """
302
  cls_logits = self.cls_head(vision_tokens)
303
- box_coords = self.box_head(vision_tokens).sigmoid() # Normalize to [0, 1]
304
  return cls_logits, box_coords
305
 
306
 
307
  class OculusPointHead(nn.Module):
308
  """Head for point detection (object counting)."""
309
-
310
  def __init__(self, config: OculusConfig):
311
  super().__init__()
312
  hidden_dim = config.lm_hidden_size
313
  num_classes = config.num_detection_classes
314
-
315
  self.point_head = nn.Sequential(
316
  nn.Linear(hidden_dim, hidden_dim // 2),
317
  nn.GELU(),
318
- nn.Linear(hidden_dim // 2, 2) # x, y
319
  )
320
-
321
  self.cls_head = nn.Sequential(
322
  nn.Linear(hidden_dim, hidden_dim // 2),
323
  nn.GELU(),
324
  nn.Linear(hidden_dim // 2, num_classes)
325
  )
326
-
327
  self.conf_head = nn.Sequential(
328
  nn.Linear(hidden_dim, hidden_dim // 4),
329
  nn.GELU(),
330
  nn.Linear(hidden_dim // 4, 1)
331
  )
332
-
333
  def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
334
  points = self.point_head(vision_tokens).sigmoid()
335
  cls_logits = self.cls_head(vision_tokens)
@@ -339,21 +330,20 @@ class OculusPointHead(nn.Module):
339
 
340
  class OculusSegmentationHead(nn.Module):
341
  """Head for polygon/mask segmentation."""
342
-
343
  def __init__(self, config: OculusConfig):
344
  super().__init__()
345
  hidden_dim = config.lm_hidden_size
346
  num_classes = config.num_segmentation_classes
347
-
348
- # Predict mask logits
349
  self.mask_head = nn.Sequential(
350
  nn.Linear(hidden_dim, hidden_dim),
351
  nn.GELU(),
352
- nn.Linear(hidden_dim, 14 * 14 * num_classes) # Output spatial mask
353
  )
354
-
355
  self.num_classes = num_classes
356
-
357
  def forward(self, vision_tokens: torch.Tensor) -> torch.Tensor:
358
  batch_size = vision_tokens.shape[0]
359
  pooled = vision_tokens.mean(dim=1)
@@ -362,6 +352,49 @@ class OculusSegmentationHead(nn.Module):
362
  return mask_logits
363
 
364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  # ============================================================================
366
  # Main Model
367
  # ============================================================================
@@ -369,188 +402,109 @@ class OculusSegmentationHead(nn.Module):
369
  class OculusForConditionalGeneration(PreTrainedModel):
370
  """
371
  Oculus: Unified Vision-Language Model
372
-
373
- Features:
374
- - Multi-encoder vision (DINOv3 + SigLIP2)
375
- - Optional reasoning with thinking traces
376
- - Multiple output modes: Text, Point, Box, Polygon
377
- - Focus/Zoom tool calling for fine-grained perception
378
-
379
- Usage:
380
- ```python
381
- from oculus_unified_model import OculusForConditionalGeneration
382
-
383
- model = OculusForConditionalGeneration.from_pretrained("OceanirAI/oculus-0.2")
384
-
385
- # Caption mode
386
- output = model.generate(image, mode="text", prompt="Describe this image")
387
-
388
- # VQA mode
389
- output = model.generate(image, mode="text", prompt="What color is the cat?")
390
-
391
- # With reasoning
392
- output = model.generate(image, mode="text", prompt="Count the people", think=True)
393
-
394
- # Detection mode
395
- output = model.generate(image, mode="box", prompt="Find all cars")
396
-
397
- # Point mode (counting)
398
- output = model.generate(image, mode="point", prompt="Count the birds")
399
-
400
- # Segmentation mode
401
- output = model.generate(image, mode="polygon", prompt="Segment the road")
402
- ```
403
  """
404
-
405
  config_class = OculusConfig
406
  base_model_prefix = "oculus"
407
-
408
  def __init__(self, config: OculusConfig):
409
  super().__init__(config)
410
  self.config = config
411
-
412
  # Vision encoder
413
  self.vision_encoder = OculusVisionEncoder(config)
414
-
415
- # Vision adapter (handles dimension mismatch if needed)
416
  self.vision_adapter = None
417
  self._actual_vision_dim = None
418
-
419
  # Projector
420
  self.projector = OculusProjector(config)
421
-
422
  # Task-specific heads
423
  self.detection_head = OculusDetectionHead(config)
424
  self.point_head = OculusPointHead(config)
425
  self.segmentation_head = OculusSegmentationHead(config)
426
-
427
- # Language model (loaded lazily)
 
 
428
  self.lm_tokenizer = None
429
  self.lm_model = None
430
  self._lm_loaded = False
431
-
432
- # Special tokens for reasoning
433
  self.thinking_token = config.thinking_token
434
  self.thinking_end_token = config.thinking_end_token
435
  self.focus_token = config.focus_token
436
  self.focus_end_token = config.focus_end_token
437
-
 
 
438
  def load_language_model(self, device: str = "cpu"):
439
- """Load language model for text generation."""
440
  if self._lm_loaded:
441
  return
442
-
443
  print("[Oculus] Loading language model...")
444
-
445
  try:
446
- # Try BLIP for now (works well for captioning/VQA)
447
- from transformers import BlipProcessor, BlipForConditionalGeneration, BlipForQuestionAnswering
448
-
449
- self.lm_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
450
- self.lm_caption_model = BlipForConditionalGeneration.from_pretrained(
451
- "Salesforce/blip-image-captioning-base"
452
  ).to(device)
453
-
454
- self.lm_vqa_processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
455
- self.lm_vqa_model = BlipForQuestionAnswering.from_pretrained(
456
- "Salesforce/blip-vqa-base"
457
- ).to(device)
458
-
459
- print(" ✓ BLIP (captioning + VQA)")
460
  self._lm_loaded = True
461
-
462
  except Exception as e:
463
- warnings.warn(f"Failed to load language model: {e}")
464
-
465
  def encode_image(self, image: Union[Image.Image, str, np.ndarray]) -> torch.Tensor:
466
- """
467
- Encode image to vision tokens.
468
-
469
- Args:
470
- image: PIL Image, file path, or numpy array
471
-
472
- Returns:
473
- Vision tokens [1, num_tokens, embed_dim]
474
- """
475
- # Load image if path
476
  if isinstance(image, str):
477
  image = Image.open(image)
478
-
479
- # Encode with vision encoders
480
  vision_features = self.vision_encoder(image)
481
-
482
- # Check if we need an adapter for dimension mismatch
483
  actual_dim = vision_features.shape[-1]
484
  expected_dim = self.config.fused_vision_dim
485
-
486
  if actual_dim != expected_dim:
487
  if self.vision_adapter is None or self._actual_vision_dim != actual_dim:
488
- # Create adapter layer
489
  print(f" [Adapter] Creating vision adapter: {actual_dim} -> {expected_dim}")
490
  self.vision_adapter = nn.Linear(actual_dim, expected_dim)
491
  self._actual_vision_dim = actual_dim
492
- # Initialize with small weights
493
  nn.init.xavier_uniform_(self.vision_adapter.weight)
494
  nn.init.zeros_(self.vision_adapter.bias)
495
-
496
  vision_features = self.vision_adapter(vision_features)
497
-
498
- # Project to language space
499
  vision_tokens = self.projector(vision_features)
500
-
501
  return vision_tokens
502
-
503
- def _generate_thinking_trace(
504
- self,
505
- image: Image.Image,
506
- prompt: str,
507
- max_tokens: int = 256
508
- ) -> str:
509
- """
510
- Generate a thinking/reasoning trace before answering.
511
-
512
- This enables multi-step reasoning for complex tasks.
513
- """
514
- thinking_prompt = f"""Let me think about this step by step:
515
- 1. First, I'll analyze what I see in the image.
516
- 2. Then, I'll consider the question: "{prompt}"
517
- 3. Finally, I'll formulate my answer.
518
-
519
- Observation: """
520
-
521
- # Generate reasoning (simplified for now)
522
- if self._lm_loaded and hasattr(self, 'lm_caption_model'):
523
- inputs = self.lm_processor(image, thinking_prompt, return_tensors="pt")
524
- inputs = {k: v.to(self.lm_caption_model.device) for k, v in inputs.items()}
525
-
526
- with torch.no_grad():
527
- out = self.lm_caption_model.generate(
528
- **inputs,
529
- max_new_tokens=max_tokens,
530
- do_sample=True,
531
- temperature=0.7
532
- )
533
- thinking = self.lm_processor.decode(out[0], skip_special_tokens=True)
534
  else:
535
- thinking = "I observe the image and analyze its contents."
536
-
537
- return thinking
538
-
539
- def _detect_focus_regions(
540
- self,
541
- image: Image.Image,
542
- prompt: str
543
- ) -> List[Tuple[int, int, int, int]]:
544
- """
545
- Detect regions that need closer inspection (Focus/Zoom system).
546
-
547
- Returns list of (x1, y1, x2, y2) crop regions.
548
- """
549
- # Simplified: return full image as single region
550
- # In full implementation, would use attention maps to find regions of interest
551
- w, h = image.size
552
- return [(0, 0, w, h)]
553
-
554
  def generate(
555
  self,
556
  image: Union[Image.Image, str, np.ndarray],
@@ -560,129 +514,109 @@ Observation: """
560
  focus: bool = False,
561
  max_new_tokens: Optional[int] = None,
562
  temperature: float = 0.7,
563
- return_thinking: bool = True,
564
  **kwargs
565
- ) -> Union[OculusTextOutput, OculusPointOutput, OculusBoxOutput, OculusPolygonOutput]:
566
  """
567
  Generate output from image.
568
-
569
  Args:
570
- image: Input image (PIL, path, or array)
571
  prompt: Text prompt/question
572
- mode: Output mode ("text", "point", "box", "polygon")
573
  think: Enable reasoning traces
574
  focus: Enable zoom/crop for fine-grained perception
575
- max_new_tokens: Maximum tokens to generate
576
- temperature: Sampling temperature
577
- return_thinking: Include thinking trace in output
578
-
579
- Returns:
580
- Mode-specific output dataclass
581
  """
582
- # Load models if needed
583
  self.vision_encoder.load_encoders()
584
- if mode == "text":
585
- self.load_language_model()
586
-
587
- # Load image
588
  if isinstance(image, str):
589
  image = Image.open(image).convert('RGB')
590
  elif isinstance(image, np.ndarray):
591
  image = Image.fromarray(image).convert('RGB')
592
-
593
- # Encode image
594
  vision_tokens = self.encode_image(image)
595
-
596
- # Generate thinking trace if enabled
597
  thinking_trace = None
598
  if think and self.config.reasoning_enabled:
599
- thinking_trace = self._generate_thinking_trace(image, prompt)
600
-
601
- # Focus system: zoom/crop if needed
602
- if focus and self.config.enable_focus:
603
- focus_regions = self._detect_focus_regions(image, prompt)
604
- # Could re-encode cropped regions here
605
-
606
- # Mode-specific generation
607
  if mode == "text":
608
  return self._generate_text(image, prompt, vision_tokens, thinking_trace, max_new_tokens, **kwargs)
 
 
609
  elif mode == "point":
610
  return self._generate_points(vision_tokens, thinking_trace, **kwargs)
611
  elif mode == "box":
612
  return self._generate_boxes(vision_tokens, thinking_trace, **kwargs)
613
  elif mode == "polygon":
614
  return self._generate_polygons(vision_tokens, thinking_trace, **kwargs)
 
 
 
 
615
  else:
616
  raise ValueError(f"Unknown mode: {mode}")
617
-
618
- def _generate_text(
619
- self,
620
- image: Image.Image,
621
- prompt: str,
622
- vision_tokens: torch.Tensor,
623
- thinking_trace: Optional[str],
624
- max_new_tokens: Optional[int],
625
- **kwargs
626
- ) -> OculusTextOutput:
627
- """Generate text output (caption or VQA)."""
628
-
629
- device = vision_tokens.device if vision_tokens.is_cuda else "cpu"
630
- max_tokens = max_new_tokens or self.config.max_new_tokens
631
-
632
- # Determine if this is a question
633
- is_question = any(q in prompt.lower() for q in ["what", "where", "who", "how", "why", "is", "are", "does", "do", "can", "?"])
634
-
635
- if is_question and hasattr(self, 'lm_vqa_model'):
636
- # VQA mode
637
- inputs = self.lm_vqa_processor(image, prompt, return_tensors="pt")
638
- inputs = {k: v.to(device) for k, v in inputs.items()}
639
-
640
- with torch.no_grad():
641
- out = self.lm_vqa_model.generate(**inputs, max_new_tokens=50)
642
- text = self.lm_vqa_processor.decode(out[0], skip_special_tokens=True)
643
- else:
644
- # Caption mode
645
- inputs = self.lm_processor(image, prompt, return_tensors="pt")
646
- inputs = {k: v.to(device) for k, v in inputs.items()}
647
-
648
- with torch.no_grad():
649
- out = self.lm_caption_model.generate(**inputs, max_new_tokens=max_tokens)
650
- text = self.lm_processor.decode(out[0], skip_special_tokens=True)
651
-
652
  return OculusTextOutput(
653
  text=text,
654
  thinking_trace=thinking_trace,
655
  vision_tokens=vision_tokens
656
  )
657
-
658
- def _generate_points(
659
- self,
660
- vision_tokens: torch.Tensor,
661
- thinking_trace: Optional[str],
662
- threshold: float = 0.5,
663
- **kwargs
664
- ) -> OculusPointOutput:
 
 
 
665
  """Generate point detections."""
666
-
667
  points, cls_logits, confidence = self.point_head(vision_tokens)
668
-
669
- # Filter by confidence
670
  mask = confidence.squeeze(-1) > threshold
671
-
672
  filtered_points = []
673
  filtered_labels = []
674
  filtered_conf = []
675
-
676
  for i in range(vision_tokens.shape[0]):
677
  token_mask = mask[i]
678
  pts = points[i][token_mask].detach().cpu().numpy().tolist()
679
  confs = confidence[i][token_mask].squeeze(-1).detach().cpu().numpy().tolist()
680
  cls_ids = cls_logits[i][token_mask].argmax(dim=-1).detach().cpu().numpy().tolist()
681
-
682
  filtered_points.extend([tuple(p) for p in pts])
683
  filtered_conf.extend(confs)
684
  filtered_labels.extend([str(c) for c in cls_ids])
685
-
686
  return OculusPointOutput(
687
  points=filtered_points,
688
  labels=filtered_labels,
@@ -690,35 +624,26 @@ Observation: """
690
  thinking_trace=thinking_trace,
691
  vision_tokens=vision_tokens
692
  )
693
-
694
- def _generate_boxes(
695
- self,
696
- vision_tokens: torch.Tensor,
697
- thinking_trace: Optional[str],
698
- threshold: float = 0.3,
699
- **kwargs
700
- ) -> OculusBoxOutput:
701
  """Generate bounding box detections."""
702
-
703
  cls_logits, box_coords = self.detection_head(vision_tokens)
704
-
705
- # Get confidence from class logits
706
  confidence = F.softmax(cls_logits, dim=-1).max(dim=-1).values
707
-
708
  filtered_boxes = []
709
  filtered_labels = []
710
  filtered_conf = []
711
-
712
  for i in range(vision_tokens.shape[0]):
713
  mask = confidence[i] > threshold
714
  boxes = box_coords[i][mask].detach().cpu().numpy()
715
  confs = confidence[i][mask].detach().cpu().numpy().tolist()
716
  cls_ids = cls_logits[i][mask].argmax(dim=-1).detach().cpu().numpy().tolist()
717
-
718
  filtered_boxes.extend([tuple(b) for b in boxes])
719
  filtered_conf.extend(confs)
720
  filtered_labels.extend([str(c) for c in cls_ids])
721
-
722
  return OculusBoxOutput(
723
  boxes=filtered_boxes,
724
  labels=filtered_labels,
@@ -726,33 +651,22 @@ Observation: """
726
  thinking_trace=thinking_trace,
727
  vision_tokens=vision_tokens
728
  )
729
-
730
- def _generate_polygons(
731
- self,
732
- vision_tokens: torch.Tensor,
733
- thinking_trace: Optional[str],
734
- **kwargs
735
- ) -> OculusPolygonOutput:
736
  """Generate polygon/mask segmentation."""
737
-
738
  mask_logits = self.segmentation_head(vision_tokens)
739
-
740
- # Get predicted mask
741
  mask = mask_logits.argmax(dim=1).detach().cpu().numpy()
742
-
743
- # Convert to polygons (simplified)
744
- # In full implementation, would use cv2.findContours
745
  polygons = []
746
  labels = []
747
-
748
  unique_classes = np.unique(mask[0])
749
  for cls_id in unique_classes:
750
- if cls_id == 0: # Skip background
751
  continue
752
  labels.append(str(cls_id))
753
- # Placeholder polygon
754
  polygons.append([(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)])
755
-
756
  return OculusPolygonOutput(
757
  polygons=polygons,
758
  labels=labels,
@@ -760,64 +674,127 @@ Observation: """
760
  thinking_trace=thinking_trace,
761
  vision_tokens=vision_tokens
762
  )
763
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
764
  @classmethod
765
  def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
766
- """
767
- Load model from pretrained weights.
768
-
769
- Args:
770
- pretrained_model_name_or_path: HuggingFace repo ID or local path
771
- """
772
  path = Path(pretrained_model_name_or_path)
773
-
774
- # Load config
775
  config_path = path / "config.json"
776
  if config_path.exists():
777
- import json
778
  with open(config_path) as f:
779
- proj_config = json.load(f)
780
-
781
- # Create config with correct dimensions from projector
782
- config = OculusConfig(
783
- dinov3_hidden_size=proj_config.get("fused_dim", 2048) - 768, # Infer from fused
784
- siglip_hidden_size=768,
785
- projector_hidden_dim=proj_config.get("hidden_dim", 2048),
786
- num_vision_tokens=proj_config.get("num_tokens", 64),
787
- lm_hidden_size=proj_config.get("embed_dim", 1536),
788
- )
789
  else:
790
  config = OculusConfig()
791
-
792
- # Create model
793
  model = cls(config)
794
-
795
- # Load projector weights
796
- projector_path = path / "projector.npz"
797
  if projector_path.exists():
798
- model.projector = OculusProjector.from_pretrained(path, config)
799
-
800
- # Load detection/segmentation heads if available
801
- heads_path = path / "heads.pth"
802
  if heads_path.exists():
803
  heads_state = torch.load(heads_path, map_location="cpu")
804
  model.detection_head.load_state_dict(heads_state.get("detection", {}), strict=False)
805
  model.point_head.load_state_dict(heads_state.get("point", {}), strict=False)
806
  model.segmentation_head.load_state_dict(heads_state.get("segmentation", {}), strict=False)
807
-
 
 
 
808
  return model
809
-
810
  def save_pretrained(self, save_directory: str):
811
  """Save model to directory."""
812
  path = Path(save_directory)
813
  path.mkdir(parents=True, exist_ok=True)
814
-
815
- # Save config
816
  self.config.save_pretrained(path)
817
-
818
  # Save projector
 
 
 
819
  projector_state = self.projector.state_dict()
820
- # Convert to numpy for MLX compatibility
821
  np_weights = {}
822
  for k, v in projector_state.items():
823
  parts = k.split(".")
@@ -826,17 +803,18 @@ Observation: """
826
  if layer not in np_weights:
827
  np_weights[layer] = {}
828
  np_weights[layer][param] = v.cpu().numpy()
829
- np.savez(path / "projector.npz", **{k: v for k, v in np_weights.items()})
830
-
831
  # Save heads
832
  torch.save({
833
  "detection": self.detection_head.state_dict(),
834
  "point": self.point_head.state_dict(),
835
  "segmentation": self.segmentation_head.state_dict(),
836
- }, path / "heads.pth")
837
-
 
 
838
  print(f"✓ Saved model to {path}")
839
 
840
 
841
- # Register for auto-loading
842
  OculusForConditionalGeneration.register_for_auto_class("AutoModelForVision2Seq")
 
3
 
4
  HuggingFace-compatible vision-language model with:
5
  - Multi-encoder vision (DINOv3 + SigLIP2)
6
+ - LFM2.5-1.2B language model (Liquid AI)
7
+ - Isaac 0.2 features:
8
+ - Reasoning via Thinking Traces
9
+ - Perceptive Tool Calling + Focus (Zoom & Crop)
10
+ - Structured Outputs (JSON)
11
+ - Complex OCR
12
+ - Desktop UI Understanding
13
  """
14
 
15
  import os
 
30
  AutoModel,
31
  AutoTokenizer,
32
  AutoModelForCausalLM,
 
33
  )
 
34
  from PIL import Image
35
 
36
  from .configuration_oculus import OculusConfig
 
56
  pass
57
 
58
 
59
+ @dataclass
60
+ class OculusJSONOutput(OculusOutput):
61
+ """Output for structured JSON mode."""
62
+ json_data: Optional[Dict[str, Any]] = None
63
+
64
+
65
  @dataclass
66
  class OculusPointOutput(OculusOutput):
67
  """Output for point detection mode (counting objects)."""
 
70
  confidences: Optional[List[float]] = None
71
 
72
 
73
+ @dataclass
74
  class OculusBoxOutput(OculusOutput):
75
  """Output for bounding box detection mode."""
76
+ boxes: Optional[List[Tuple[float, float, float, float]]] = None
77
  labels: Optional[List[str]] = None
78
  confidences: Optional[List[float]] = None
79
 
 
86
  mask: Optional[np.ndarray] = None
87
 
88
 
89
+ @dataclass
90
+ class OculusOCROutput(OculusOutput):
91
+ """Output for OCR mode."""
92
+ text_blocks: Optional[List[Dict[str, Any]]] = None # [{text, bbox, confidence}]
93
+ full_text: Optional[str] = None
94
+
95
+
96
+ @dataclass
97
+ class OculusUIOutput(OculusOutput):
98
+ """Output for UI element detection."""
99
+ elements: Optional[List[Dict[str, Any]]] = None # [{type, text, bbox}]
100
+
101
+
102
  # ============================================================================
103
  # Vision Encoder (DINOv3 + SigLIP2)
104
  # ============================================================================
 
106
  class OculusVisionEncoder(nn.Module):
107
  """
108
  Dual vision encoder combining DINOv3 and SigLIP2.
109
+
110
  DINOv3: Excellent at semantic understanding, object boundaries
111
  SigLIP2: Strong at text/language alignment
112
  """
113
+
114
  def __init__(self, config: OculusConfig):
115
  super().__init__()
116
  self.config = config
117
+
 
118
  self.dinov3 = None
119
  self.dinov3_processor = None
120
  self.siglip = None
121
  self.siglip_processor = None
122
+
123
  self._loaded = False
124
+
125
  def load_encoders(self, device: str = "cpu"):
126
  """Load vision encoders from HuggingFace."""
127
  if self._loaded:
128
  return
129
+
130
  print("[Oculus] Loading vision encoders...")
131
+
132
  # DINOv3
133
  try:
134
  self.dinov3_processor = AutoImageProcessor.from_pretrained(
 
140
  print(f" ✓ DINOv3: {self.config.dinov3_model_id}")
141
  except Exception as e:
142
  warnings.warn(f"Failed to load DINOv3: {e}")
143
+ self.dinov3_processor = AutoImageProcessor.from_pretrained("facebook/dinov2-large")
144
+ self.dinov3 = AutoModel.from_pretrained("facebook/dinov2-large").eval().to(device)
145
+ print(" ✓ DINOv2-large (fallback)")
146
+
147
  # SigLIP2
148
  try:
149
  self.siglip_processor = AutoImageProcessor.from_pretrained(
 
152
  self.siglip = AutoModel.from_pretrained(
153
  self.config.siglip_model_id
154
  ).eval().to(device)
155
+ print(f" ✓ SigLIP2: {self.config.siglip_model_id}")
156
  except Exception as e:
157
+ warnings.warn(f"Failed to load SigLIP2: {e}")
158
  from transformers import SiglipVisionModel
159
  self.siglip_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")
160
  self.siglip = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224").eval().to(device)
161
  print(" ✓ SigLIP-base (fallback)")
162
+
163
  self._loaded = True
164
+
165
  @torch.no_grad()
166
  def forward(self, image: Union[Image.Image, torch.Tensor, np.ndarray]) -> torch.Tensor:
167
+ """Encode image with both vision encoders and fuse features."""
 
 
 
 
 
168
  if not self._loaded:
169
  self.load_encoders()
170
+
 
171
  if isinstance(image, np.ndarray):
172
  image = Image.fromarray(image)
173
  elif isinstance(image, torch.Tensor):
174
  image = Image.fromarray(image.cpu().numpy().astype(np.uint8))
175
+
176
  if isinstance(image, Image.Image):
177
  image = image.convert('RGB')
178
+
179
  device = next(self.dinov3.parameters()).device
180
+
181
  # DINOv3 encoding
182
  d_inputs = self.dinov3_processor(images=image, return_tensors="pt")
183
  d_inputs = {k: v.to(device) for k, v in d_inputs.items()}
184
  d_out = self.dinov3(**d_inputs)
185
  d_pooled = d_out.pooler_output if hasattr(d_out, 'pooler_output') and d_out.pooler_output is not None else d_out.last_hidden_state[:, 0]
186
+
187
+ # SigLIP2 encoding
188
  s_inputs = self.siglip_processor(images=image, return_tensors="pt")
189
  s_inputs = {k: v.to(device) for k, v in s_inputs.items()}
190
+
191
  if hasattr(self.siglip, 'vision_model'):
192
  s_hidden = self.siglip.vision_model.embeddings(s_inputs['pixel_values'])
193
  s_pooled = s_hidden.mean(dim=1)
194
  else:
195
  s_out = self.siglip(**s_inputs)
196
  s_pooled = s_out.pooler_output if hasattr(s_out, 'pooler_output') else s_out.last_hidden_state[:, 0]
197
+
198
  # Fuse features
199
  fused = torch.cat([d_pooled, s_pooled], dim=-1)
200
+
201
  return fused
202
 
203
 
 
206
  # ============================================================================
207
 
208
  class OculusProjector(nn.Module):
209
+ """Projects fused vision features to language model token space."""
210
+
 
 
 
 
211
  def __init__(self, config: OculusConfig):
212
  super().__init__()
213
  self.config = config
214
+
215
  fused_dim = config.fused_vision_dim
216
  hidden_dim = config.projector_hidden_dim
217
  num_tokens = config.num_vision_tokens
218
  embed_dim = config.lm_hidden_size
219
+
220
  self.fc1 = nn.Linear(fused_dim, hidden_dim)
221
  self.act1 = nn.GELU()
222
  self.fc2 = nn.Linear(hidden_dim, hidden_dim)
223
  self.act2 = nn.GELU()
224
  self.fc3 = nn.Linear(hidden_dim, num_tokens * embed_dim)
225
  self.norm = nn.LayerNorm(embed_dim)
226
+
227
  self.num_tokens = num_tokens
228
  self.embed_dim = embed_dim
229
+
230
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
231
  batch_size = x.shape[0]
232
+
233
  h = self.fc1(x)
234
  h = self.act1(h)
235
  h = self.fc2(h)
236
  h = self.act2(h)
237
  h = self.fc3(h)
238
+
239
  h = h.reshape(batch_size, self.num_tokens, self.embed_dim)
240
  h = self.norm(h)
241
+
242
  return h
243
+
244
  @classmethod
245
  def from_pretrained(cls, path: str, config: OculusConfig):
246
  """Load projector from saved weights."""
247
  projector = cls(config)
248
+
249
  weights_path = Path(path) / "projector.npz"
250
  if weights_path.exists():
 
251
  weights = np.load(weights_path, allow_pickle=True)
252
+
253
  state_dict = {}
254
  for key in weights.files:
255
  layer_dict = weights[key].item()
256
  for param_name, param_val in layer_dict.items():
257
  full_key = f"{key}.{param_name}"
 
258
  if hasattr(param_val, 'tolist'):
259
  param_val = np.array(param_val.tolist())
260
  state_dict[full_key] = torch.from_numpy(np.array(param_val))
261
+
262
  projector.load_state_dict(state_dict, strict=False)
263
  print(f" ✓ Loaded projector from {path}")
264
+
265
  return projector
266
 
267
 
268
  # ============================================================================
269
+ # Task Heads
270
  # ============================================================================
271
 
272
  class OculusDetectionHead(nn.Module):
273
  """Head for bounding box detection."""
274
+
275
  def __init__(self, config: OculusConfig):
276
  super().__init__()
277
  hidden_dim = config.lm_hidden_size
278
  num_classes = config.num_detection_classes
279
+
280
  self.cls_head = nn.Sequential(
281
  nn.Linear(hidden_dim, hidden_dim // 2),
282
  nn.GELU(),
283
  nn.Linear(hidden_dim // 2, num_classes)
284
  )
285
+
286
  self.box_head = nn.Sequential(
287
  nn.Linear(hidden_dim, hidden_dim // 2),
288
  nn.GELU(),
289
+ nn.Linear(hidden_dim // 2, 4)
290
  )
291
+
292
  def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
 
293
  cls_logits = self.cls_head(vision_tokens)
294
+ box_coords = self.box_head(vision_tokens).sigmoid()
295
  return cls_logits, box_coords
296
 
297
 
298
  class OculusPointHead(nn.Module):
299
  """Head for point detection (object counting)."""
300
+
301
  def __init__(self, config: OculusConfig):
302
  super().__init__()
303
  hidden_dim = config.lm_hidden_size
304
  num_classes = config.num_detection_classes
305
+
306
  self.point_head = nn.Sequential(
307
  nn.Linear(hidden_dim, hidden_dim // 2),
308
  nn.GELU(),
309
+ nn.Linear(hidden_dim // 2, 2)
310
  )
311
+
312
  self.cls_head = nn.Sequential(
313
  nn.Linear(hidden_dim, hidden_dim // 2),
314
  nn.GELU(),
315
  nn.Linear(hidden_dim // 2, num_classes)
316
  )
317
+
318
  self.conf_head = nn.Sequential(
319
  nn.Linear(hidden_dim, hidden_dim // 4),
320
  nn.GELU(),
321
  nn.Linear(hidden_dim // 4, 1)
322
  )
323
+
324
  def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
325
  points = self.point_head(vision_tokens).sigmoid()
326
  cls_logits = self.cls_head(vision_tokens)
 
330
 
331
  class OculusSegmentationHead(nn.Module):
332
  """Head for polygon/mask segmentation."""
333
+
334
  def __init__(self, config: OculusConfig):
335
  super().__init__()
336
  hidden_dim = config.lm_hidden_size
337
  num_classes = config.num_segmentation_classes
338
+
 
339
  self.mask_head = nn.Sequential(
340
  nn.Linear(hidden_dim, hidden_dim),
341
  nn.GELU(),
342
+ nn.Linear(hidden_dim, 14 * 14 * num_classes)
343
  )
344
+
345
  self.num_classes = num_classes
346
+
347
  def forward(self, vision_tokens: torch.Tensor) -> torch.Tensor:
348
  batch_size = vision_tokens.shape[0]
349
  pooled = vision_tokens.mean(dim=1)
 
352
  return mask_logits
353
 
354
 
355
+ class OculusOCRHead(nn.Module):
356
+ """Head for OCR text detection and recognition."""
357
+
358
+ def __init__(self, config: OculusConfig):
359
+ super().__init__()
360
+ hidden_dim = config.lm_hidden_size
361
+
362
+ self.text_detector = nn.Sequential(
363
+ nn.Linear(hidden_dim, hidden_dim),
364
+ nn.GELU(),
365
+ nn.Linear(hidden_dim, 5) # x, y, w, h, confidence
366
+ )
367
+
368
+ def forward(self, vision_tokens: torch.Tensor) -> torch.Tensor:
369
+ return self.text_detector(vision_tokens)
370
+
371
+
372
+ class OculusUIHead(nn.Module):
373
+ """Head for UI element detection."""
374
+
375
+ def __init__(self, config: OculusConfig):
376
+ super().__init__()
377
+ hidden_dim = config.lm_hidden_size
378
+ num_classes = config.ui_element_classes
379
+
380
+ self.element_cls = nn.Sequential(
381
+ nn.Linear(hidden_dim, hidden_dim // 2),
382
+ nn.GELU(),
383
+ nn.Linear(hidden_dim // 2, num_classes)
384
+ )
385
+
386
+ self.element_box = nn.Sequential(
387
+ nn.Linear(hidden_dim, hidden_dim // 2),
388
+ nn.GELU(),
389
+ nn.Linear(hidden_dim // 2, 4)
390
+ )
391
+
392
+ def forward(self, vision_tokens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
393
+ cls_logits = self.element_cls(vision_tokens)
394
+ box_coords = self.element_box(vision_tokens).sigmoid()
395
+ return cls_logits, box_coords
396
+
397
+
398
  # ============================================================================
399
  # Main Model
400
  # ============================================================================
 
402
  class OculusForConditionalGeneration(PreTrainedModel):
403
  """
404
  Oculus: Unified Vision-Language Model
405
+
406
+ Architecture: DINOv3 + SigLIP2 + LFM2.5-1.2B
407
+
408
+ Isaac 0.2 Features:
409
+ - Reasoning via Thinking Traces
410
+ - Perceptive Tool Calling + Focus (Zoom & Crop)
411
+ - Structured Outputs (JSON)
412
+ - Complex OCR
413
+ - Desktop UI Understanding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  """
415
+
416
  config_class = OculusConfig
417
  base_model_prefix = "oculus"
418
+
419
  def __init__(self, config: OculusConfig):
420
  super().__init__(config)
421
  self.config = config
422
+
423
  # Vision encoder
424
  self.vision_encoder = OculusVisionEncoder(config)
425
+
426
+ # Vision adapter
427
  self.vision_adapter = None
428
  self._actual_vision_dim = None
429
+
430
  # Projector
431
  self.projector = OculusProjector(config)
432
+
433
  # Task-specific heads
434
  self.detection_head = OculusDetectionHead(config)
435
  self.point_head = OculusPointHead(config)
436
  self.segmentation_head = OculusSegmentationHead(config)
437
+ self.ocr_head = OculusOCRHead(config)
438
+ self.ui_head = OculusUIHead(config)
439
+
440
+ # Language model (LFM2.5)
441
  self.lm_tokenizer = None
442
  self.lm_model = None
443
  self._lm_loaded = False
444
+
445
+ # Special tokens
446
  self.thinking_token = config.thinking_token
447
  self.thinking_end_token = config.thinking_end_token
448
  self.focus_token = config.focus_token
449
  self.focus_end_token = config.focus_end_token
450
+ self.json_token = config.json_token
451
+ self.json_end_token = config.json_end_token
452
+
453
  def load_language_model(self, device: str = "cpu"):
454
+ """Load LFM2.5 language model."""
455
  if self._lm_loaded:
456
  return
457
+
458
  print("[Oculus] Loading language model...")
459
+
460
  try:
461
+ self.lm_tokenizer = AutoTokenizer.from_pretrained(self.config.lm_model_id)
462
+ self.lm_model = AutoModelForCausalLM.from_pretrained(
463
+ self.config.lm_model_id
 
 
 
464
  ).to(device)
465
+ print(f" ✓ LFM2.5: {self.config.lm_model_id}")
 
 
 
 
 
 
466
  self._lm_loaded = True
 
467
  except Exception as e:
468
+ warnings.warn(f"Failed to load LFM2.5: {e}. Text generation unavailable.")
469
+
470
  def encode_image(self, image: Union[Image.Image, str, np.ndarray]) -> torch.Tensor:
471
+ """Encode image to vision tokens."""
 
 
 
 
 
 
 
 
 
472
  if isinstance(image, str):
473
  image = Image.open(image)
474
+
 
475
  vision_features = self.vision_encoder(image)
476
+
 
477
  actual_dim = vision_features.shape[-1]
478
  expected_dim = self.config.fused_vision_dim
479
+
480
  if actual_dim != expected_dim:
481
  if self.vision_adapter is None or self._actual_vision_dim != actual_dim:
 
482
  print(f" [Adapter] Creating vision adapter: {actual_dim} -> {expected_dim}")
483
  self.vision_adapter = nn.Linear(actual_dim, expected_dim)
484
  self._actual_vision_dim = actual_dim
 
485
  nn.init.xavier_uniform_(self.vision_adapter.weight)
486
  nn.init.zeros_(self.vision_adapter.bias)
487
+
488
  vision_features = self.vision_adapter(vision_features)
489
+
 
490
  vision_tokens = self.projector(vision_features)
491
+
492
  return vision_tokens
493
+
494
+ def _crop_region(self, image: Image.Image, bbox: Tuple[int, int, int, int]) -> Image.Image:
495
+ """Crop image to specified region for focus/zoom."""
496
+ x1, y1, x2, y2 = bbox
497
+ return image.crop((x1, y1, x2, y2))
498
+
499
+ def _generate_thinking_trace(self, prompt: str, context: str = "") -> str:
500
+ """Generate structured thinking trace."""
501
+ if self.config.thinking_style == "structured":
502
+ return f"Analyzing: {prompt[:50]}... | Observations: {context[:100]}"
503
+ elif self.config.thinking_style == "verbose":
504
+ return f"Let me think step by step about: {prompt}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
  else:
506
+ return ""
507
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
  def generate(
509
  self,
510
  image: Union[Image.Image, str, np.ndarray],
 
514
  focus: bool = False,
515
  max_new_tokens: Optional[int] = None,
516
  temperature: float = 0.7,
 
517
  **kwargs
518
+ ) -> Union[OculusTextOutput, OculusJSONOutput, OculusPointOutput, OculusBoxOutput, OculusPolygonOutput, OculusOCROutput, OculusUIOutput]:
519
  """
520
  Generate output from image.
521
+
522
  Args:
523
+ image: Input image
524
  prompt: Text prompt/question
525
+ mode: "text", "json", "point", "box", "polygon", "ocr", "ui"
526
  think: Enable reasoning traces
527
  focus: Enable zoom/crop for fine-grained perception
 
 
 
 
 
 
528
  """
 
529
  self.vision_encoder.load_encoders()
530
+
 
 
 
531
  if isinstance(image, str):
532
  image = Image.open(image).convert('RGB')
533
  elif isinstance(image, np.ndarray):
534
  image = Image.fromarray(image).convert('RGB')
535
+
 
536
  vision_tokens = self.encode_image(image)
537
+
 
538
  thinking_trace = None
539
  if think and self.config.reasoning_enabled:
540
+ thinking_trace = self._generate_thinking_trace(prompt)
541
+
 
 
 
 
 
 
542
  if mode == "text":
543
  return self._generate_text(image, prompt, vision_tokens, thinking_trace, max_new_tokens, **kwargs)
544
+ elif mode == "json":
545
+ return self._generate_json(image, prompt, vision_tokens, thinking_trace, **kwargs)
546
  elif mode == "point":
547
  return self._generate_points(vision_tokens, thinking_trace, **kwargs)
548
  elif mode == "box":
549
  return self._generate_boxes(vision_tokens, thinking_trace, **kwargs)
550
  elif mode == "polygon":
551
  return self._generate_polygons(vision_tokens, thinking_trace, **kwargs)
552
+ elif mode == "ocr":
553
+ return self._generate_ocr(vision_tokens, thinking_trace, **kwargs)
554
+ elif mode == "ui":
555
+ return self._generate_ui(vision_tokens, thinking_trace, **kwargs)
556
  else:
557
  raise ValueError(f"Unknown mode: {mode}")
558
+
559
+ def _generate_text(self, image, prompt, vision_tokens, thinking_trace, max_new_tokens, **kwargs) -> OculusTextOutput:
560
+ """Generate text output using LFM2.5."""
561
+ if not self._lm_loaded:
562
+ self.load_language_model()
563
+
564
+ if self.lm_model is None:
565
+ return OculusTextOutput(
566
+ text="[Language model not available]",
567
+ thinking_trace=thinking_trace,
568
+ vision_tokens=vision_tokens
569
+ )
570
+
571
+ # Simple text generation (full implementation would inject vision tokens)
572
+ inputs = self.lm_tokenizer(prompt, return_tensors="pt")
573
+ inputs = {k: v.to(self.lm_model.device) for k, v in inputs.items()}
574
+
575
+ with torch.no_grad():
576
+ outputs = self.lm_model.generate(
577
+ **inputs,
578
+ max_new_tokens=max_new_tokens or self.config.max_new_tokens,
579
+ temperature=self.config.temperature,
580
+ do_sample=True
581
+ )
582
+
583
+ text = self.lm_tokenizer.decode(outputs[0], skip_special_tokens=True)
584
+
 
 
 
 
 
 
 
 
585
  return OculusTextOutput(
586
  text=text,
587
  thinking_trace=thinking_trace,
588
  vision_tokens=vision_tokens
589
  )
590
+
591
+ def _generate_json(self, image, prompt, vision_tokens, thinking_trace, **kwargs) -> OculusJSONOutput:
592
+ """Generate structured JSON output."""
593
+ # Placeholder - would use constrained decoding
594
+ return OculusJSONOutput(
595
+ json_data={"prompt": prompt, "status": "generated"},
596
+ thinking_trace=thinking_trace,
597
+ vision_tokens=vision_tokens
598
+ )
599
+
600
+ def _generate_points(self, vision_tokens, thinking_trace, threshold=0.5, **kwargs) -> OculusPointOutput:
601
  """Generate point detections."""
 
602
  points, cls_logits, confidence = self.point_head(vision_tokens)
603
+
 
604
  mask = confidence.squeeze(-1) > threshold
605
+
606
  filtered_points = []
607
  filtered_labels = []
608
  filtered_conf = []
609
+
610
  for i in range(vision_tokens.shape[0]):
611
  token_mask = mask[i]
612
  pts = points[i][token_mask].detach().cpu().numpy().tolist()
613
  confs = confidence[i][token_mask].squeeze(-1).detach().cpu().numpy().tolist()
614
  cls_ids = cls_logits[i][token_mask].argmax(dim=-1).detach().cpu().numpy().tolist()
615
+
616
  filtered_points.extend([tuple(p) for p in pts])
617
  filtered_conf.extend(confs)
618
  filtered_labels.extend([str(c) for c in cls_ids])
619
+
620
  return OculusPointOutput(
621
  points=filtered_points,
622
  labels=filtered_labels,
 
624
  thinking_trace=thinking_trace,
625
  vision_tokens=vision_tokens
626
  )
627
+
628
+ def _generate_boxes(self, vision_tokens, thinking_trace, threshold=0.3, **kwargs) -> OculusBoxOutput:
 
 
 
 
 
 
629
  """Generate bounding box detections."""
 
630
  cls_logits, box_coords = self.detection_head(vision_tokens)
 
 
631
  confidence = F.softmax(cls_logits, dim=-1).max(dim=-1).values
632
+
633
  filtered_boxes = []
634
  filtered_labels = []
635
  filtered_conf = []
636
+
637
  for i in range(vision_tokens.shape[0]):
638
  mask = confidence[i] > threshold
639
  boxes = box_coords[i][mask].detach().cpu().numpy()
640
  confs = confidence[i][mask].detach().cpu().numpy().tolist()
641
  cls_ids = cls_logits[i][mask].argmax(dim=-1).detach().cpu().numpy().tolist()
642
+
643
  filtered_boxes.extend([tuple(b) for b in boxes])
644
  filtered_conf.extend(confs)
645
  filtered_labels.extend([str(c) for c in cls_ids])
646
+
647
  return OculusBoxOutput(
648
  boxes=filtered_boxes,
649
  labels=filtered_labels,
 
651
  thinking_trace=thinking_trace,
652
  vision_tokens=vision_tokens
653
  )
654
+
655
+ def _generate_polygons(self, vision_tokens, thinking_trace, **kwargs) -> OculusPolygonOutput:
 
 
 
 
 
656
  """Generate polygon/mask segmentation."""
 
657
  mask_logits = self.segmentation_head(vision_tokens)
 
 
658
  mask = mask_logits.argmax(dim=1).detach().cpu().numpy()
659
+
 
 
660
  polygons = []
661
  labels = []
662
+
663
  unique_classes = np.unique(mask[0])
664
  for cls_id in unique_classes:
665
+ if cls_id == 0:
666
  continue
667
  labels.append(str(cls_id))
 
668
  polygons.append([(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)])
669
+
670
  return OculusPolygonOutput(
671
  polygons=polygons,
672
  labels=labels,
 
674
  thinking_trace=thinking_trace,
675
  vision_tokens=vision_tokens
676
  )
677
+
678
+ def _generate_ocr(self, vision_tokens, thinking_trace, **kwargs) -> OculusOCROutput:
679
+ """Generate OCR output."""
680
+ detections = self.ocr_head(vision_tokens)
681
+
682
+ text_blocks = []
683
+ for i in range(detections.shape[1]):
684
+ det = detections[0, i].detach().cpu().numpy()
685
+ if det[4] > self.config.ocr_confidence_threshold:
686
+ text_blocks.append({
687
+ "text": "[detected]",
688
+ "bbox": det[:4].tolist(),
689
+ "confidence": float(det[4])
690
+ })
691
+
692
+ return OculusOCROutput(
693
+ text_blocks=text_blocks,
694
+ full_text=" ".join([b["text"] for b in text_blocks]),
695
+ thinking_trace=thinking_trace,
696
+ vision_tokens=vision_tokens
697
+ )
698
+
699
+ def _generate_ui(self, vision_tokens, thinking_trace, threshold=0.5, **kwargs) -> OculusUIOutput:
700
+ """Generate UI element detections."""
701
+ cls_logits, box_coords = self.ui_head(vision_tokens)
702
+ confidence = F.softmax(cls_logits, dim=-1).max(dim=-1).values
703
+
704
+ UI_TYPES = ["button", "text_field", "checkbox", "radio", "dropdown", "link", "image", "icon", "label", "container"]
705
+
706
+ elements = []
707
+ for i in range(vision_tokens.shape[1]):
708
+ if confidence[0, i] > threshold:
709
+ cls_id = cls_logits[0, i].argmax().item()
710
+ elements.append({
711
+ "type": UI_TYPES[cls_id % len(UI_TYPES)],
712
+ "bbox": box_coords[0, i].detach().cpu().numpy().tolist(),
713
+ "confidence": float(confidence[0, i])
714
+ })
715
+
716
+ return OculusUIOutput(
717
+ elements=elements,
718
+ thinking_trace=thinking_trace,
719
+ vision_tokens=vision_tokens
720
+ )
721
+
722
+ # Convenience methods
723
+ def ask(self, image, question: str, think: bool = False, focus: bool = False) -> str:
724
+ """Ask a question about an image."""
725
+ output = self.generate(image, question, mode="text", think=think, focus=focus)
726
+ return output.text
727
+
728
+ def caption(self, image) -> str:
729
+ """Generate a caption for an image."""
730
+ output = self.generate(image, "Describe this image", mode="text")
731
+ return output.text
732
+
733
+ def detect(self, image) -> List[Dict]:
734
+ """Detect objects in an image."""
735
+ output = self.generate(image, mode="box")
736
+ return [{"label": l, "box": b, "confidence": c}
737
+ for l, b, c in zip(output.labels, output.boxes, output.confidences)]
738
+
739
+ def segment(self, image) -> np.ndarray:
740
+ """Segment an image."""
741
+ output = self.generate(image, mode="polygon")
742
+ return output.mask
743
+
744
+ def ocr(self, image) -> str:
745
+ """Extract text from an image."""
746
+ output = self.generate(image, mode="ocr")
747
+ return output.full_text
748
+
749
+ def detect_ui(self, image) -> List[Dict]:
750
+ """Detect UI elements in a screenshot."""
751
+ output = self.generate(image, mode="ui")
752
+ return output.elements
753
+
754
  @classmethod
755
  def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
756
+ """Load model from pretrained weights."""
 
 
 
 
 
757
  path = Path(pretrained_model_name_or_path)
758
+
 
759
  config_path = path / "config.json"
760
  if config_path.exists():
 
761
  with open(config_path) as f:
762
+ config_dict = json.load(f)
763
+ config = OculusConfig(**config_dict)
 
 
 
 
 
 
 
 
764
  else:
765
  config = OculusConfig()
766
+
 
767
  model = cls(config)
768
+
769
+ # Load trained components
770
+ projector_path = path / "trained_components" / "projector.npz"
771
  if projector_path.exists():
772
+ model.projector = OculusProjector.from_pretrained(path / "trained_components", config)
773
+
774
+ heads_path = path / "trained_components" / "heads.pth"
 
775
  if heads_path.exists():
776
  heads_state = torch.load(heads_path, map_location="cpu")
777
  model.detection_head.load_state_dict(heads_state.get("detection", {}), strict=False)
778
  model.point_head.load_state_dict(heads_state.get("point", {}), strict=False)
779
  model.segmentation_head.load_state_dict(heads_state.get("segmentation", {}), strict=False)
780
+ model.ocr_head.load_state_dict(heads_state.get("ocr", {}), strict=False)
781
+ model.ui_head.load_state_dict(heads_state.get("ui", {}), strict=False)
782
+ print(f" ✓ Loaded heads from {heads_path}")
783
+
784
  return model
785
+
786
  def save_pretrained(self, save_directory: str):
787
  """Save model to directory."""
788
  path = Path(save_directory)
789
  path.mkdir(parents=True, exist_ok=True)
790
+
 
791
  self.config.save_pretrained(path)
792
+
793
  # Save projector
794
+ trained_path = path / "trained_components"
795
+ trained_path.mkdir(exist_ok=True)
796
+
797
  projector_state = self.projector.state_dict()
 
798
  np_weights = {}
799
  for k, v in projector_state.items():
800
  parts = k.split(".")
 
803
  if layer not in np_weights:
804
  np_weights[layer] = {}
805
  np_weights[layer][param] = v.cpu().numpy()
806
+ np.savez(trained_path / "projector.npz", **{k: v for k, v in np_weights.items()})
807
+
808
  # Save heads
809
  torch.save({
810
  "detection": self.detection_head.state_dict(),
811
  "point": self.point_head.state_dict(),
812
  "segmentation": self.segmentation_head.state_dict(),
813
+ "ocr": self.ocr_head.state_dict(),
814
+ "ui": self.ui_head.state_dict(),
815
+ }, trained_path / "heads.pth")
816
+
817
  print(f"✓ Saved model to {path}")
818
 
819
 
 
820
  OculusForConditionalGeneration.register_for_auto_class("AutoModelForVision2Seq")