""" T-REN HuggingFace model wrapper. Usage: from transformers import AutoModel model = AutoModel.from_pretrained("savyak2/T-REN", trust_remote_code=True) model.load_backbone("/path/to/dinov3/weights/") # Or in one shot: model = AutoModel.from_pretrained( "savyak2/T-REN", trust_remote_code=True, dinov3_weights_dir="/path/to/dinov3/weights/", ) outputs = model(pixel_values) # pixel_values: (B, 3, H, W) float in [0, 1] """ import numpy as np import torch from transformers import PreTrainedModel from transformers.utils import logging try: from .configuration_tren import TRENConfig from .model import FeatureExtractor, RegionEncoder, TextEncoder except ImportError: from configuration_tren import TRENConfig from model import FeatureExtractor, RegionEncoder, TextEncoder logger = logging.get_logger(__name__) DINOV3_BACKBONE_FILENAME = "dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth" DINOV3_HEAD_FILENAME = "dinov3_vitl16_dinotxt_vision_head_and_text_encoder-a442d8f5.pth" def _build_cfg_dict(config: TRENConfig, dinov3_weights_dir: str = None) -> dict: """Convert TRENConfig into the dict format expected by existing model classes.""" return { "pretrained": { "feature_extractor": "dinov3_vitl16", "text_encoder": "dinov3_vitl16", }, "architecture": { "patch_size": config.patch_size, "hidden_dim": config.hidden_dim, "text_embed_dim": config.text_embed_dim, "num_decoder_layers": config.num_decoder_layers, "num_attention_heads": config.num_attention_heads, }, "parameters": { "image_resolution": config.image_resolution, "num_multiscale_regions": config.num_multiscale_regions, "merging_iou_threshold": config.merging_iou_threshold, "merging_similarity_threshold": config.merging_similarity_threshold, }, # save_dir + exp_name join to give the directory containing DINOv3 weights. # e.g. os.path.join("/path/to/dir", "", "filename.pth") -> "/path/to/dir/filename.pth" "logging": { "save_dir": dinov3_weights_dir or "", "exp_name": "", }, } class TRENModel(PreTrainedModel): """ T-REN: Text-aligned Region Encoder Network. Takes raw images and returns dense region tokens aligned to a shared vision-language embedding space (DINOv3 / DINOtxt). The trainable RegionEncoder weights are stored in this HF repo and loaded automatically. The DINOv3 ViT-L/16 backbone (~2 GB) must be provided separately via load_backbone(). DINOv3 weights needed in the same directory: - dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth - dinov3_vitl16_dinotxt_vision_head_and_text_encoder-a442d8f5.pth """ config_class = TRENConfig base_model_prefix = "region_encoder" def __init__(self, config: TRENConfig, dinov3_weights_dir: str = None): super().__init__(config) cfg = _build_cfg_dict(config) # RegionEncoder: the trained T-REN head. HF saves/loads these weights. self.region_encoder = RegionEncoder(cfg) # Lazy placeholders — not registered as nn.Module submodules so they # are excluded from HF save/load. _grid_points is computed on first # forward() call to avoid meta-device issues during from_pretrained(). object.__setattr__(self, "_grid_points", None) object.__setattr__(self, "_image_encoder", None) object.__setattr__(self, "_text_encoder", None) self.post_init() if dinov3_weights_dir is not None: self.load_backbone(dinov3_weights_dir) def load_backbone(self, dinov3_weights_dir: str) -> None: """ Load the frozen DINOv3 image and text encoder backbones. Args: dinov3_weights_dir: Directory containing both DINOv3 weight files: - dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth - dinov3_vitl16_dinotxt_vision_head_and_text_encoder-a442d8f5.pth """ device = next(self.region_encoder.parameters()).device cfg = _build_cfg_dict(self.config, dinov3_weights_dir) logger.info("Loading DINOv3 image encoder...") image_encoder = FeatureExtractor(cfg, device=str(device)).eval() logger.info("Loading DINOv3 text encoder...") text_encoder = TextEncoder(cfg, device=str(device)).eval() object.__setattr__(self, "_image_encoder", image_encoder) object.__setattr__(self, "_text_encoder", text_encoder) def adapt_to_resolution(self, image_resolution: int) -> None: """ Interpolate the RegionEncoder's positional embeddings to a new spatial resolution. Call this after from_pretrained() when running inference at a resolution different from the training resolution (512px by default). Args: image_resolution: Target image resolution in pixels (e.g. 384). Example:: model = AutoModel.from_pretrained("aryaaan12/T-REN", trust_remote_code=True) model.load_backbone("/path/to/dinov3/weights/") model.adapt_to_resolution(384) # eval at 384px instead of 512px """ if image_resolution == self.config.image_resolution: return saved_state = self.region_encoder.state_dict() device = next(self.region_encoder.parameters()).device ps = self.config.patch_size num_patches = (image_resolution // ps) ** 2 C = self.region_encoder.feature_embeddings.shape[-1] self.region_encoder.feature_embeddings = torch.nn.Parameter( torch.zeros(num_patches, C, device=device) ) self.region_encoder.load_state_dict_resolution_agnostic(saved_state) self.region_encoder.to(device) # Reset grid so it is rebuilt at the new resolution on the next forward(). object.__setattr__(self, "_grid_points", None) logger.info( f"Adapted positional embeddings: {self.config.image_resolution}px → {image_resolution}px" ) def forward( self, pixel_values: torch.Tensor, texts: list = None, aggregate_tokens: bool = True, ) -> dict: """ Encode an image into region tokens. Args: pixel_values: Float tensor of shape (B, 3, H, W) in [0, 1]. texts: Optional list of text strings. When provided, text embeddings are returned alongside region tokens for similarity scoring. aggregate_tokens: Merge overlapping region tokens by mask IoU and embedding cosine similarity (recommended for downstream use). Returns: dict with keys: pred_tokens – (B, N, D) raw region feature tokens. region_masks – (B, N, fH, fW) attention-derived region masks. text_aligned_tokens – (B, N, D) tokens in the DINOtxt embedding space. class_tokens – (B, D) image-level DINOv3 class tokens. text_encodings – (T, D) text embeddings, only if texts is provided. """ if self._image_encoder is None: raise RuntimeError( "DINOv3 backbone not loaded. " "Call model.load_backbone(dinov3_weights_dir=...) first, " "or pass dinov3_weights_dir= to from_pretrained()." ) device = pixel_values.device # Build grid on first call (avoids meta-device issues during from_pretrained). if self._grid_points is None: res = self.config.image_resolution ps = self.config.patch_size coords = np.linspace(1, res - 2, res // ps, dtype=int) object.__setattr__(self, "_grid_points", torch.tensor([(y, x) for y in coords for x in coords])) prompts = [self._grid_points.to(device) for _ in range(pixel_values.shape[0])] with torch.no_grad(): backbone_out = self._image_encoder(pixel_values) feature_maps = backbone_out["feature_maps"].to(device) class_tokens = backbone_out["text_aligned_class_tokens"].to(device) outputs = self.region_encoder(feature_maps, prompts, aggregate_tokens=aggregate_tokens) outputs["class_tokens"] = class_tokens if texts is not None: if self._text_encoder is None: raise RuntimeError("Text encoder not loaded. Call load_backbone() first.") outputs["text_encodings"] = self._text_encoder(texts) return outputs