T-REN / modeling_tren.py
aryaaan12's picture
Upload modeling_tren.py with huggingface_hub
18c1533 verified
"""
T-REN HuggingFace model wrapper.
Usage:
from transformers import AutoModel
model = AutoModel.from_pretrained("savyak2/T-REN", trust_remote_code=True)
model.load_backbone("/path/to/dinov3/weights/")
# Or in one shot:
model = AutoModel.from_pretrained(
"savyak2/T-REN",
trust_remote_code=True,
dinov3_weights_dir="/path/to/dinov3/weights/",
)
outputs = model(pixel_values) # pixel_values: (B, 3, H, W) float in [0, 1]
"""
import numpy as np
import torch
from transformers import PreTrainedModel
from transformers.utils import logging
try:
from .configuration_tren import TRENConfig
from .model import FeatureExtractor, RegionEncoder, TextEncoder
except ImportError:
from configuration_tren import TRENConfig
from model import FeatureExtractor, RegionEncoder, TextEncoder
logger = logging.get_logger(__name__)
DINOV3_BACKBONE_FILENAME = "dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth"
DINOV3_HEAD_FILENAME = "dinov3_vitl16_dinotxt_vision_head_and_text_encoder-a442d8f5.pth"
def _build_cfg_dict(config: TRENConfig, dinov3_weights_dir: str = None) -> dict:
"""Convert TRENConfig into the dict format expected by existing model classes."""
return {
"pretrained": {
"feature_extractor": "dinov3_vitl16",
"text_encoder": "dinov3_vitl16",
},
"architecture": {
"patch_size": config.patch_size,
"hidden_dim": config.hidden_dim,
"text_embed_dim": config.text_embed_dim,
"num_decoder_layers": config.num_decoder_layers,
"num_attention_heads": config.num_attention_heads,
},
"parameters": {
"image_resolution": config.image_resolution,
"num_multiscale_regions": config.num_multiscale_regions,
"merging_iou_threshold": config.merging_iou_threshold,
"merging_similarity_threshold": config.merging_similarity_threshold,
},
# save_dir + exp_name join to give the directory containing DINOv3 weights.
# e.g. os.path.join("/path/to/dir", "", "filename.pth") -> "/path/to/dir/filename.pth"
"logging": {
"save_dir": dinov3_weights_dir or "",
"exp_name": "",
},
}
class TRENModel(PreTrainedModel):
"""
T-REN: Text-aligned Region Encoder Network.
Takes raw images and returns dense region tokens aligned to a shared
vision-language embedding space (DINOv3 / DINOtxt).
The trainable RegionEncoder weights are stored in this HF repo and loaded
automatically. The DINOv3 ViT-L/16 backbone (~2 GB) must be provided
separately via load_backbone().
DINOv3 weights needed in the same directory:
- dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth
- dinov3_vitl16_dinotxt_vision_head_and_text_encoder-a442d8f5.pth
"""
config_class = TRENConfig
base_model_prefix = "region_encoder"
def __init__(self, config: TRENConfig, dinov3_weights_dir: str = None):
super().__init__(config)
cfg = _build_cfg_dict(config)
# RegionEncoder: the trained T-REN head. HF saves/loads these weights.
self.region_encoder = RegionEncoder(cfg)
# Lazy placeholders — not registered as nn.Module submodules so they
# are excluded from HF save/load. _grid_points is computed on first
# forward() call to avoid meta-device issues during from_pretrained().
object.__setattr__(self, "_grid_points", None)
object.__setattr__(self, "_image_encoder", None)
object.__setattr__(self, "_text_encoder", None)
self.post_init()
if dinov3_weights_dir is not None:
self.load_backbone(dinov3_weights_dir)
def load_backbone(self, dinov3_weights_dir: str) -> None:
"""
Load the frozen DINOv3 image and text encoder backbones.
Args:
dinov3_weights_dir: Directory containing both DINOv3 weight files:
- dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth
- dinov3_vitl16_dinotxt_vision_head_and_text_encoder-a442d8f5.pth
"""
device = next(self.region_encoder.parameters()).device
cfg = _build_cfg_dict(self.config, dinov3_weights_dir)
logger.info("Loading DINOv3 image encoder...")
image_encoder = FeatureExtractor(cfg, device=str(device)).eval()
logger.info("Loading DINOv3 text encoder...")
text_encoder = TextEncoder(cfg, device=str(device)).eval()
object.__setattr__(self, "_image_encoder", image_encoder)
object.__setattr__(self, "_text_encoder", text_encoder)
def adapt_to_resolution(self, image_resolution: int) -> None:
"""
Interpolate the RegionEncoder's positional embeddings to a new spatial
resolution. Call this after from_pretrained() when running inference at
a resolution different from the training resolution (512px by default).
Args:
image_resolution: Target image resolution in pixels (e.g. 384).
Example::
model = AutoModel.from_pretrained("aryaaan12/T-REN", trust_remote_code=True)
model.load_backbone("/path/to/dinov3/weights/")
model.adapt_to_resolution(384) # eval at 384px instead of 512px
"""
if image_resolution == self.config.image_resolution:
return
saved_state = self.region_encoder.state_dict()
device = next(self.region_encoder.parameters()).device
ps = self.config.patch_size
num_patches = (image_resolution // ps) ** 2
C = self.region_encoder.feature_embeddings.shape[-1]
self.region_encoder.feature_embeddings = torch.nn.Parameter(
torch.zeros(num_patches, C, device=device)
)
self.region_encoder.load_state_dict_resolution_agnostic(saved_state)
self.region_encoder.to(device)
# Reset grid so it is rebuilt at the new resolution on the next forward().
object.__setattr__(self, "_grid_points", None)
logger.info(
f"Adapted positional embeddings: {self.config.image_resolution}px → {image_resolution}px"
)
def forward(
self,
pixel_values: torch.Tensor,
texts: list = None,
aggregate_tokens: bool = True,
) -> dict:
"""
Encode an image into region tokens.
Args:
pixel_values: Float tensor of shape (B, 3, H, W) in [0, 1].
texts: Optional list of text strings. When provided, text embeddings
are returned alongside region tokens for similarity scoring.
aggregate_tokens: Merge overlapping region tokens by mask IoU and
embedding cosine similarity (recommended for downstream use).
Returns:
dict with keys:
pred_tokens – (B, N, D) raw region feature tokens.
region_masks – (B, N, fH, fW) attention-derived region masks.
text_aligned_tokens – (B, N, D) tokens in the DINOtxt embedding space.
class_tokens – (B, D) image-level DINOv3 class tokens.
text_encodings – (T, D) text embeddings, only if texts is provided.
"""
if self._image_encoder is None:
raise RuntimeError(
"DINOv3 backbone not loaded. "
"Call model.load_backbone(dinov3_weights_dir=...) first, "
"or pass dinov3_weights_dir= to from_pretrained()."
)
device = pixel_values.device
# Build grid on first call (avoids meta-device issues during from_pretrained).
if self._grid_points is None:
res = self.config.image_resolution
ps = self.config.patch_size
coords = np.linspace(1, res - 2, res // ps, dtype=int)
object.__setattr__(self, "_grid_points",
torch.tensor([(y, x) for y in coords for x in coords]))
prompts = [self._grid_points.to(device) for _ in range(pixel_values.shape[0])]
with torch.no_grad():
backbone_out = self._image_encoder(pixel_values)
feature_maps = backbone_out["feature_maps"].to(device)
class_tokens = backbone_out["text_aligned_class_tokens"].to(device)
outputs = self.region_encoder(feature_maps, prompts, aggregate_tokens=aggregate_tokens)
outputs["class_tokens"] = class_tokens
if texts is not None:
if self._text_encoder is None:
raise RuntimeError("Text encoder not loaded. Call load_backbone() first.")
outputs["text_encodings"] = self._text_encoder(texts)
return outputs