Image Segmentation
Transformers
Safetensors
PyTorch
English
tren
feature-extraction
vision
image-feature-extraction
region-tokens
dinov3
custom_code
Instructions to use aryaaan12/T-REN with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use aryaaan12/T-REN with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-segmentation", model="aryaaan12/T-REN", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("aryaaan12/T-REN", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """ | |
| T-REN HuggingFace model wrapper. | |
| Usage: | |
| from transformers import AutoModel | |
| model = AutoModel.from_pretrained("savyak2/T-REN", trust_remote_code=True) | |
| model.load_backbone("/path/to/dinov3/weights/") | |
| # Or in one shot: | |
| model = AutoModel.from_pretrained( | |
| "savyak2/T-REN", | |
| trust_remote_code=True, | |
| dinov3_weights_dir="/path/to/dinov3/weights/", | |
| ) | |
| outputs = model(pixel_values) # pixel_values: (B, 3, H, W) float in [0, 1] | |
| """ | |
| import numpy as np | |
| import torch | |
| from transformers import PreTrainedModel | |
| from transformers.utils import logging | |
| try: | |
| from .configuration_tren import TRENConfig | |
| from .model import FeatureExtractor, RegionEncoder, TextEncoder | |
| except ImportError: | |
| from configuration_tren import TRENConfig | |
| from model import FeatureExtractor, RegionEncoder, TextEncoder | |
| logger = logging.get_logger(__name__) | |
| DINOV3_BACKBONE_FILENAME = "dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth" | |
| DINOV3_HEAD_FILENAME = "dinov3_vitl16_dinotxt_vision_head_and_text_encoder-a442d8f5.pth" | |
| def _build_cfg_dict(config: TRENConfig, dinov3_weights_dir: str = None) -> dict: | |
| """Convert TRENConfig into the dict format expected by existing model classes.""" | |
| return { | |
| "pretrained": { | |
| "feature_extractor": "dinov3_vitl16", | |
| "text_encoder": "dinov3_vitl16", | |
| }, | |
| "architecture": { | |
| "patch_size": config.patch_size, | |
| "hidden_dim": config.hidden_dim, | |
| "text_embed_dim": config.text_embed_dim, | |
| "num_decoder_layers": config.num_decoder_layers, | |
| "num_attention_heads": config.num_attention_heads, | |
| }, | |
| "parameters": { | |
| "image_resolution": config.image_resolution, | |
| "num_multiscale_regions": config.num_multiscale_regions, | |
| "merging_iou_threshold": config.merging_iou_threshold, | |
| "merging_similarity_threshold": config.merging_similarity_threshold, | |
| }, | |
| # save_dir + exp_name join to give the directory containing DINOv3 weights. | |
| # e.g. os.path.join("/path/to/dir", "", "filename.pth") -> "/path/to/dir/filename.pth" | |
| "logging": { | |
| "save_dir": dinov3_weights_dir or "", | |
| "exp_name": "", | |
| }, | |
| } | |
| class TRENModel(PreTrainedModel): | |
| """ | |
| T-REN: Text-aligned Region Encoder Network. | |
| Takes raw images and returns dense region tokens aligned to a shared | |
| vision-language embedding space (DINOv3 / DINOtxt). | |
| The trainable RegionEncoder weights are stored in this HF repo and loaded | |
| automatically. The DINOv3 ViT-L/16 backbone (~2 GB) must be provided | |
| separately via load_backbone(). | |
| DINOv3 weights needed in the same directory: | |
| - dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth | |
| - dinov3_vitl16_dinotxt_vision_head_and_text_encoder-a442d8f5.pth | |
| """ | |
| config_class = TRENConfig | |
| base_model_prefix = "region_encoder" | |
| def __init__(self, config: TRENConfig, dinov3_weights_dir: str = None): | |
| super().__init__(config) | |
| cfg = _build_cfg_dict(config) | |
| # RegionEncoder: the trained T-REN head. HF saves/loads these weights. | |
| self.region_encoder = RegionEncoder(cfg) | |
| # Lazy placeholders — not registered as nn.Module submodules so they | |
| # are excluded from HF save/load. _grid_points is computed on first | |
| # forward() call to avoid meta-device issues during from_pretrained(). | |
| object.__setattr__(self, "_grid_points", None) | |
| object.__setattr__(self, "_image_encoder", None) | |
| object.__setattr__(self, "_text_encoder", None) | |
| self.post_init() | |
| if dinov3_weights_dir is not None: | |
| self.load_backbone(dinov3_weights_dir) | |
| def load_backbone(self, dinov3_weights_dir: str) -> None: | |
| """ | |
| Load the frozen DINOv3 image and text encoder backbones. | |
| Args: | |
| dinov3_weights_dir: Directory containing both DINOv3 weight files: | |
| - dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth | |
| - dinov3_vitl16_dinotxt_vision_head_and_text_encoder-a442d8f5.pth | |
| """ | |
| device = next(self.region_encoder.parameters()).device | |
| cfg = _build_cfg_dict(self.config, dinov3_weights_dir) | |
| logger.info("Loading DINOv3 image encoder...") | |
| image_encoder = FeatureExtractor(cfg, device=str(device)).eval() | |
| logger.info("Loading DINOv3 text encoder...") | |
| text_encoder = TextEncoder(cfg, device=str(device)).eval() | |
| object.__setattr__(self, "_image_encoder", image_encoder) | |
| object.__setattr__(self, "_text_encoder", text_encoder) | |
| def adapt_to_resolution(self, image_resolution: int) -> None: | |
| """ | |
| Interpolate the RegionEncoder's positional embeddings to a new spatial | |
| resolution. Call this after from_pretrained() when running inference at | |
| a resolution different from the training resolution (512px by default). | |
| Args: | |
| image_resolution: Target image resolution in pixels (e.g. 384). | |
| Example:: | |
| model = AutoModel.from_pretrained("aryaaan12/T-REN", trust_remote_code=True) | |
| model.load_backbone("/path/to/dinov3/weights/") | |
| model.adapt_to_resolution(384) # eval at 384px instead of 512px | |
| """ | |
| if image_resolution == self.config.image_resolution: | |
| return | |
| saved_state = self.region_encoder.state_dict() | |
| device = next(self.region_encoder.parameters()).device | |
| ps = self.config.patch_size | |
| num_patches = (image_resolution // ps) ** 2 | |
| C = self.region_encoder.feature_embeddings.shape[-1] | |
| self.region_encoder.feature_embeddings = torch.nn.Parameter( | |
| torch.zeros(num_patches, C, device=device) | |
| ) | |
| self.region_encoder.load_state_dict_resolution_agnostic(saved_state) | |
| self.region_encoder.to(device) | |
| # Reset grid so it is rebuilt at the new resolution on the next forward(). | |
| object.__setattr__(self, "_grid_points", None) | |
| logger.info( | |
| f"Adapted positional embeddings: {self.config.image_resolution}px → {image_resolution}px" | |
| ) | |
| def forward( | |
| self, | |
| pixel_values: torch.Tensor, | |
| texts: list = None, | |
| aggregate_tokens: bool = True, | |
| ) -> dict: | |
| """ | |
| Encode an image into region tokens. | |
| Args: | |
| pixel_values: Float tensor of shape (B, 3, H, W) in [0, 1]. | |
| texts: Optional list of text strings. When provided, text embeddings | |
| are returned alongside region tokens for similarity scoring. | |
| aggregate_tokens: Merge overlapping region tokens by mask IoU and | |
| embedding cosine similarity (recommended for downstream use). | |
| Returns: | |
| dict with keys: | |
| pred_tokens – (B, N, D) raw region feature tokens. | |
| region_masks – (B, N, fH, fW) attention-derived region masks. | |
| text_aligned_tokens – (B, N, D) tokens in the DINOtxt embedding space. | |
| class_tokens – (B, D) image-level DINOv3 class tokens. | |
| text_encodings – (T, D) text embeddings, only if texts is provided. | |
| """ | |
| if self._image_encoder is None: | |
| raise RuntimeError( | |
| "DINOv3 backbone not loaded. " | |
| "Call model.load_backbone(dinov3_weights_dir=...) first, " | |
| "or pass dinov3_weights_dir= to from_pretrained()." | |
| ) | |
| device = pixel_values.device | |
| # Build grid on first call (avoids meta-device issues during from_pretrained). | |
| if self._grid_points is None: | |
| res = self.config.image_resolution | |
| ps = self.config.patch_size | |
| coords = np.linspace(1, res - 2, res // ps, dtype=int) | |
| object.__setattr__(self, "_grid_points", | |
| torch.tensor([(y, x) for y in coords for x in coords])) | |
| prompts = [self._grid_points.to(device) for _ in range(pixel_values.shape[0])] | |
| with torch.no_grad(): | |
| backbone_out = self._image_encoder(pixel_values) | |
| feature_maps = backbone_out["feature_maps"].to(device) | |
| class_tokens = backbone_out["text_aligned_class_tokens"].to(device) | |
| outputs = self.region_encoder(feature_maps, prompts, aggregate_tokens=aggregate_tokens) | |
| outputs["class_tokens"] = class_tokens | |
| if texts is not None: | |
| if self._text_encoder is None: | |
| raise RuntimeError("Text encoder not loaded. Call load_backbone() first.") | |
| outputs["text_encodings"] = self._text_encoder(texts) | |
| return outputs | |