| """ |
| Visual and Text Backbone Encoders for MR-JEPA. |
| |
| Visual: DINOv3-L/16 via timm (primary), DINOv2-L/14 via transformers (ablation) |
| Text: Qwen3-Embedding-0.6B (1024-dim causal LM used as encoder) |
| |
| Both backbones are frozen in Phase 1 and partially unfrozen in Phase 2. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from typing import Optional, Dict, Any |
|
|
| from ..configs.model_config import VisualBackboneConfig, TextEncoderConfig |
|
|
|
|
| class VisualBackbone(nn.Module): |
| """ |
| Dense visual feature extractor using DINOv3/v2 or SigLIP2. |
| |
| Outputs patch-level tokens (excluding CLS and register tokens). |
| DINOv3-L at 256px: 256 patch tokens × 1024 dim. |
| DINOv2-L at 518px: 1369 patch tokens × 1024 dim. |
| """ |
|
|
| def __init__(self, config: VisualBackboneConfig): |
| super().__init__() |
| self.config = config |
| self.backbone = None |
| self.hidden_size = config.hidden_size |
| self._use_timm = False |
| self._build_backbone() |
|
|
| if config.freeze: |
| self.freeze_all() |
|
|
| def _build_backbone(self): |
| """Load pretrained backbone.""" |
| if self.config.backbone_type == "dinov3": |
| |
| import timm |
| |
| timm_id = self.config.model_name |
| if timm_id.startswith("timm/"): |
| timm_id = timm_id[5:] |
| self.backbone = timm.create_model( |
| timm_id, |
| pretrained=True, |
| num_classes=0, |
| ) |
| self._use_timm = True |
| |
| self._skip_tokens = 1 + self.config.num_register_tokens |
|
|
| elif self.config.backbone_type == "dinov2": |
| |
| from transformers import AutoModel, AutoImageProcessor |
| self.backbone = AutoModel.from_pretrained( |
| self.config.model_name, |
| torch_dtype=torch.float32, |
| ) |
| self.processor = AutoImageProcessor.from_pretrained( |
| self.config.model_name |
| ) |
| self._use_timm = False |
| self._skip_tokens = 1 + self.config.num_register_tokens |
|
|
| elif self.config.backbone_type == "siglip2": |
| from transformers import SiglipVisionModel, SiglipImageProcessor |
| self.backbone = SiglipVisionModel.from_pretrained( |
| self.config.model_name, |
| torch_dtype=torch.float32, |
| ) |
| self.processor = SiglipImageProcessor.from_pretrained( |
| self.config.model_name |
| ) |
| self._use_timm = False |
| self._skip_tokens = 0 |
|
|
| def freeze_all(self): |
| """Freeze all backbone parameters.""" |
| for param in self.backbone.parameters(): |
| param.requires_grad = False |
|
|
| def unfreeze_last_n_layers(self, n: int): |
| """Unfreeze the last N transformer layers (Phase 2).""" |
| if self._use_timm: |
| |
| layers = self.backbone.blocks |
| elif hasattr(self.backbone, 'encoder'): |
| |
| layers = self.backbone.encoder.layer |
| elif hasattr(self.backbone, 'vision_model'): |
| |
| layers = self.backbone.vision_model.encoder.layers |
| else: |
| raise ValueError(f"Unknown backbone structure for {self.config.model_name}") |
|
|
| total_layers = len(layers) |
| for i, layer in enumerate(layers): |
| if i >= total_layers - n: |
| for param in layer.parameters(): |
| param.requires_grad = True |
|
|
| def get_transform(self): |
| """Return image preprocessing transform.""" |
| if self._use_timm: |
| from timm.data import resolve_data_config, create_transform |
| data_cfg = resolve_data_config(self.backbone.pretrained_cfg) |
| return create_transform(**data_cfg, is_training=False) |
| else: |
| return self.processor |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor, |
| return_cls: bool = False, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Extract dense patch tokens from images. |
| |
| Returns: |
| dict with: |
| 'patch_tokens': [B, num_patches, hidden_size] |
| 'cls_token': [B, hidden_size] (if return_cls=True) |
| """ |
| if self._use_timm: |
| hidden_states = self.backbone.forward_features(pixel_values) |
| else: |
| outputs = self.backbone(pixel_values=pixel_values) |
| hidden_states = outputs.last_hidden_state |
|
|
| result = {} |
| result['patch_tokens'] = hidden_states[:, self._skip_tokens:] |
|
|
| if return_cls: |
| result['cls_token'] = hidden_states[:, 0] |
|
|
| return result |
|
|
|
|
| class TextEncoder(nn.Module): |
| """ |
| Text encoder for questions, options, and optional context. |
| |
| Uses Qwen3-Embedding-0.6B (causal LM architecture as encoder). |
| Outputs: |
| - Token-level representations for cross-attention |
| - Mean-pooled representation for global text understanding |
| """ |
|
|
| def __init__(self, config: TextEncoderConfig): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self._build_encoder() |
|
|
| if config.freeze: |
| self.freeze_all() |
|
|
| def _build_encoder(self): |
| """Load pretrained text encoder.""" |
| from transformers import AutoModel, AutoTokenizer |
|
|
| self.encoder = AutoModel.from_pretrained( |
| self.config.model_name, |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True, |
| ) |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| self.config.model_name, |
| trust_remote_code=True, |
| ) |
|
|
| def freeze_all(self): |
| for param in self.encoder.parameters(): |
| param.requires_grad = False |
|
|
| def unfreeze_last_n_layers(self, n: int): |
| """Unfreeze the last N transformer layers (Phase 2).""" |
| |
| if hasattr(self.encoder, 'model') and hasattr(self.encoder.model, 'layers'): |
| layers = self.encoder.model.layers |
| elif hasattr(self.encoder, 'encoder') and hasattr(self.encoder.encoder, 'layer'): |
| |
| layers = self.encoder.encoder.layer |
| else: |
| raise ValueError(f"Unknown encoder structure for {self.config.model_name}") |
|
|
| total_layers = len(layers) |
| for i, layer in enumerate(layers): |
| if i >= total_layers - n: |
| for param in layer.parameters(): |
| param.requires_grad = True |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Encode text (question + options). |
| |
| Returns: |
| dict with: |
| 'token_embeddings': [B, seq_len, hidden_size] |
| 'cls_embedding': [B, hidden_size] (mean pooling for Qwen3-Embedding) |
| 'attention_mask': [B, seq_len] |
| """ |
| outputs = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| ) |
| token_embs = outputs.last_hidden_state |
|
|
| |
| mask_expanded = attention_mask.unsqueeze(-1).float() |
| pooled = (token_embs * mask_expanded).sum(1) / mask_expanded.sum(1).clamp(min=1) |
|
|
| return { |
| 'token_embeddings': token_embs, |
| 'cls_embedding': pooled, |
| 'attention_mask': attention_mask, |
| } |
|
|