""" Visual and Text Backbone Encoders for MR-JEPA. Visual: DINOv3-L/16 via timm (primary), DINOv2-L/14 via transformers (ablation) Text: Qwen3-Embedding-0.6B (1024-dim causal LM used as encoder) Both backbones are frozen in Phase 1 and partially unfrozen in Phase 2. """ import torch import torch.nn as nn from typing import Optional, Dict, Any from ..configs.model_config import VisualBackboneConfig, TextEncoderConfig class VisualBackbone(nn.Module): """ Dense visual feature extractor using DINOv3/v2 or SigLIP2. Outputs patch-level tokens (excluding CLS and register tokens). DINOv3-L at 256px: 256 patch tokens × 1024 dim. DINOv2-L at 518px: 1369 patch tokens × 1024 dim. """ def __init__(self, config: VisualBackboneConfig): super().__init__() self.config = config self.backbone = None self.hidden_size = config.hidden_size self._use_timm = False self._build_backbone() if config.freeze: self.freeze_all() def _build_backbone(self): """Load pretrained backbone.""" if self.config.backbone_type == "dinov3": # DINOv3 loaded via timm import timm # model_name format: "timm/vit_large_patch16_dinov3.lvd1689m" → extract timm id timm_id = self.config.model_name if timm_id.startswith("timm/"): timm_id = timm_id[5:] self.backbone = timm.create_model( timm_id, pretrained=True, num_classes=0, ) self._use_timm = True # DINOv3 via timm: forward_features returns [B, CLS + 4_reg + patches, D] self._skip_tokens = 1 + self.config.num_register_tokens # CLS + regs elif self.config.backbone_type == "dinov2": # DINOv2 loaded via transformers from transformers import AutoModel, AutoImageProcessor self.backbone = AutoModel.from_pretrained( self.config.model_name, torch_dtype=torch.float32, ) self.processor = AutoImageProcessor.from_pretrained( self.config.model_name ) self._use_timm = False self._skip_tokens = 1 + self.config.num_register_tokens elif self.config.backbone_type == "siglip2": from transformers import SiglipVisionModel, SiglipImageProcessor self.backbone = SiglipVisionModel.from_pretrained( self.config.model_name, torch_dtype=torch.float32, ) self.processor = SiglipImageProcessor.from_pretrained( self.config.model_name ) self._use_timm = False self._skip_tokens = 0 # SigLIP has no CLS or register tokens def freeze_all(self): """Freeze all backbone parameters.""" for param in self.backbone.parameters(): param.requires_grad = False def unfreeze_last_n_layers(self, n: int): """Unfreeze the last N transformer layers (Phase 2).""" if self._use_timm: # timm ViT: backbone.blocks is a nn.Sequential of Block modules layers = self.backbone.blocks elif hasattr(self.backbone, 'encoder'): # DINOv2 via transformers layers = self.backbone.encoder.layer elif hasattr(self.backbone, 'vision_model'): # SigLIP layers = self.backbone.vision_model.encoder.layers else: raise ValueError(f"Unknown backbone structure for {self.config.model_name}") total_layers = len(layers) for i, layer in enumerate(layers): if i >= total_layers - n: for param in layer.parameters(): param.requires_grad = True def get_transform(self): """Return image preprocessing transform.""" if self._use_timm: from timm.data import resolve_data_config, create_transform data_cfg = resolve_data_config(self.backbone.pretrained_cfg) return create_transform(**data_cfg, is_training=False) else: return self.processor def forward( self, pixel_values: torch.Tensor, # [B, C, H, W] return_cls: bool = False, ) -> Dict[str, torch.Tensor]: """ Extract dense patch tokens from images. Returns: dict with: 'patch_tokens': [B, num_patches, hidden_size] 'cls_token': [B, hidden_size] (if return_cls=True) """ if self._use_timm: hidden_states = self.backbone.forward_features(pixel_values) # [B, N, D] else: outputs = self.backbone(pixel_values=pixel_values) hidden_states = outputs.last_hidden_state result = {} result['patch_tokens'] = hidden_states[:, self._skip_tokens:] if return_cls: result['cls_token'] = hidden_states[:, 0] return result class TextEncoder(nn.Module): """ Text encoder for questions, options, and optional context. Uses Qwen3-Embedding-0.6B (causal LM architecture as encoder). Outputs: - Token-level representations for cross-attention - Mean-pooled representation for global text understanding """ def __init__(self, config: TextEncoderConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self._build_encoder() if config.freeze: self.freeze_all() def _build_encoder(self): """Load pretrained text encoder.""" from transformers import AutoModel, AutoTokenizer self.encoder = AutoModel.from_pretrained( self.config.model_name, torch_dtype=torch.bfloat16, trust_remote_code=True, ) self.tokenizer = AutoTokenizer.from_pretrained( self.config.model_name, trust_remote_code=True, ) def freeze_all(self): for param in self.encoder.parameters(): param.requires_grad = False def unfreeze_last_n_layers(self, n: int): """Unfreeze the last N transformer layers (Phase 2).""" # Qwen3 architecture: encoder.model.layers[i] if hasattr(self.encoder, 'model') and hasattr(self.encoder.model, 'layers'): layers = self.encoder.model.layers elif hasattr(self.encoder, 'encoder') and hasattr(self.encoder.encoder, 'layer'): # Fallback for BERT/DeBERTa-style models layers = self.encoder.encoder.layer else: raise ValueError(f"Unknown encoder structure for {self.config.model_name}") total_layers = len(layers) for i, layer in enumerate(layers): if i >= total_layers - n: for param in layer.parameters(): param.requires_grad = True def forward( self, input_ids: torch.Tensor, # [B, seq_len] attention_mask: torch.Tensor, # [B, seq_len] ) -> Dict[str, torch.Tensor]: """ Encode text (question + options). Returns: dict with: 'token_embeddings': [B, seq_len, hidden_size] 'cls_embedding': [B, hidden_size] (mean pooling for Qwen3-Embedding) 'attention_mask': [B, seq_len] """ outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, ) token_embs = outputs.last_hidden_state # Mean pooling over non-padding tokens (Qwen3-Embedding has no CLS token) mask_expanded = attention_mask.unsqueeze(-1).float() pooled = (token_embs * mask_expanded).sum(1) / mask_expanded.sum(1).clamp(min=1) return { 'token_embeddings': token_embs, 'cls_embedding': pooled, 'attention_mask': attention_mask, }