MR-JEPA / mr_jepa /models /backbones.py
JorgeAV's picture
fix: backbones.py — DINOv3 via timm + Qwen3-Embedding-0.6B, add get_transform(), proper layer unfreezing for both architectures
ad6ce96 verified
"""
Visual and Text Backbone Encoders for MR-JEPA.
Visual: DINOv3-L/16 via timm (primary), DINOv2-L/14 via transformers (ablation)
Text: Qwen3-Embedding-0.6B (1024-dim causal LM used as encoder)
Both backbones are frozen in Phase 1 and partially unfrozen in Phase 2.
"""
import torch
import torch.nn as nn
from typing import Optional, Dict, Any
from ..configs.model_config import VisualBackboneConfig, TextEncoderConfig
class VisualBackbone(nn.Module):
"""
Dense visual feature extractor using DINOv3/v2 or SigLIP2.
Outputs patch-level tokens (excluding CLS and register tokens).
DINOv3-L at 256px: 256 patch tokens × 1024 dim.
DINOv2-L at 518px: 1369 patch tokens × 1024 dim.
"""
def __init__(self, config: VisualBackboneConfig):
super().__init__()
self.config = config
self.backbone = None
self.hidden_size = config.hidden_size
self._use_timm = False
self._build_backbone()
if config.freeze:
self.freeze_all()
def _build_backbone(self):
"""Load pretrained backbone."""
if self.config.backbone_type == "dinov3":
# DINOv3 loaded via timm
import timm
# model_name format: "timm/vit_large_patch16_dinov3.lvd1689m" → extract timm id
timm_id = self.config.model_name
if timm_id.startswith("timm/"):
timm_id = timm_id[5:]
self.backbone = timm.create_model(
timm_id,
pretrained=True,
num_classes=0,
)
self._use_timm = True
# DINOv3 via timm: forward_features returns [B, CLS + 4_reg + patches, D]
self._skip_tokens = 1 + self.config.num_register_tokens # CLS + regs
elif self.config.backbone_type == "dinov2":
# DINOv2 loaded via transformers
from transformers import AutoModel, AutoImageProcessor
self.backbone = AutoModel.from_pretrained(
self.config.model_name,
torch_dtype=torch.float32,
)
self.processor = AutoImageProcessor.from_pretrained(
self.config.model_name
)
self._use_timm = False
self._skip_tokens = 1 + self.config.num_register_tokens
elif self.config.backbone_type == "siglip2":
from transformers import SiglipVisionModel, SiglipImageProcessor
self.backbone = SiglipVisionModel.from_pretrained(
self.config.model_name,
torch_dtype=torch.float32,
)
self.processor = SiglipImageProcessor.from_pretrained(
self.config.model_name
)
self._use_timm = False
self._skip_tokens = 0 # SigLIP has no CLS or register tokens
def freeze_all(self):
"""Freeze all backbone parameters."""
for param in self.backbone.parameters():
param.requires_grad = False
def unfreeze_last_n_layers(self, n: int):
"""Unfreeze the last N transformer layers (Phase 2)."""
if self._use_timm:
# timm ViT: backbone.blocks is a nn.Sequential of Block modules
layers = self.backbone.blocks
elif hasattr(self.backbone, 'encoder'):
# DINOv2 via transformers
layers = self.backbone.encoder.layer
elif hasattr(self.backbone, 'vision_model'):
# SigLIP
layers = self.backbone.vision_model.encoder.layers
else:
raise ValueError(f"Unknown backbone structure for {self.config.model_name}")
total_layers = len(layers)
for i, layer in enumerate(layers):
if i >= total_layers - n:
for param in layer.parameters():
param.requires_grad = True
def get_transform(self):
"""Return image preprocessing transform."""
if self._use_timm:
from timm.data import resolve_data_config, create_transform
data_cfg = resolve_data_config(self.backbone.pretrained_cfg)
return create_transform(**data_cfg, is_training=False)
else:
return self.processor
def forward(
self,
pixel_values: torch.Tensor, # [B, C, H, W]
return_cls: bool = False,
) -> Dict[str, torch.Tensor]:
"""
Extract dense patch tokens from images.
Returns:
dict with:
'patch_tokens': [B, num_patches, hidden_size]
'cls_token': [B, hidden_size] (if return_cls=True)
"""
if self._use_timm:
hidden_states = self.backbone.forward_features(pixel_values) # [B, N, D]
else:
outputs = self.backbone(pixel_values=pixel_values)
hidden_states = outputs.last_hidden_state
result = {}
result['patch_tokens'] = hidden_states[:, self._skip_tokens:]
if return_cls:
result['cls_token'] = hidden_states[:, 0]
return result
class TextEncoder(nn.Module):
"""
Text encoder for questions, options, and optional context.
Uses Qwen3-Embedding-0.6B (causal LM architecture as encoder).
Outputs:
- Token-level representations for cross-attention
- Mean-pooled representation for global text understanding
"""
def __init__(self, config: TextEncoderConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self._build_encoder()
if config.freeze:
self.freeze_all()
def _build_encoder(self):
"""Load pretrained text encoder."""
from transformers import AutoModel, AutoTokenizer
self.encoder = AutoModel.from_pretrained(
self.config.model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.model_name,
trust_remote_code=True,
)
def freeze_all(self):
for param in self.encoder.parameters():
param.requires_grad = False
def unfreeze_last_n_layers(self, n: int):
"""Unfreeze the last N transformer layers (Phase 2)."""
# Qwen3 architecture: encoder.model.layers[i]
if hasattr(self.encoder, 'model') and hasattr(self.encoder.model, 'layers'):
layers = self.encoder.model.layers
elif hasattr(self.encoder, 'encoder') and hasattr(self.encoder.encoder, 'layer'):
# Fallback for BERT/DeBERTa-style models
layers = self.encoder.encoder.layer
else:
raise ValueError(f"Unknown encoder structure for {self.config.model_name}")
total_layers = len(layers)
for i, layer in enumerate(layers):
if i >= total_layers - n:
for param in layer.parameters():
param.requires_grad = True
def forward(
self,
input_ids: torch.Tensor, # [B, seq_len]
attention_mask: torch.Tensor, # [B, seq_len]
) -> Dict[str, torch.Tensor]:
"""
Encode text (question + options).
Returns:
dict with:
'token_embeddings': [B, seq_len, hidden_size]
'cls_embedding': [B, hidden_size] (mean pooling for Qwen3-Embedding)
'attention_mask': [B, seq_len]
"""
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
)
token_embs = outputs.last_hidden_state
# Mean pooling over non-padding tokens (Qwen3-Embedding has no CLS token)
mask_expanded = attention_mask.unsqueeze(-1).float()
pooled = (token_embs * mask_expanded).sum(1) / mask_expanded.sum(1).clamp(min=1)
return {
'token_embeddings': token_embs,
'cls_embedding': pooled,
'attention_mask': attention_mask,
}