File size: 4,413 Bytes
66c4f84
 
 
 
 
 
 
 
 
 
 
 
 
 
082cdb0
485629f
082cdb0
485629f
082cdb0
485629f
082cdb0
485629f
082cdb0
 
 
 
 
485629f
082cdb0
485629f
082cdb0
485629f
082cdb0
 
 
485629f
082cdb0
485629f
082cdb0
485629f
082cdb0
485629f
082cdb0
485629f
082cdb0
 
 
485629f
082cdb0
485629f
082cdb0
485629f
082cdb0
 
 
 
485629f
082cdb0
485629f
082cdb0
485629f
082cdb0
 
 
 
 
485629f
082cdb0
 
485629f
082cdb0
 
485629f
082cdb0
485629f
082cdb0
 
 
 
 
 
 
485629f
082cdb0
 
 
485629f
082cdb0
 
 
 
 
 
485629f
082cdb0
485629f
082cdb0
 
485629f
082cdb0
485629f
082cdb0
 
485629f
082cdb0
 
485629f
082cdb0
 
 
 
 
485629f
082cdb0
485629f
082cdb0
485629f
082cdb0
 
 
 
 
 
 
 
485629f
082cdb0
485629f
082cdb0
485629f
082cdb0
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
---
license: mit
language:
  - en
tags:
  - vision
  - image-segmentation
  - image-feature-extraction
  - region-tokens
  - dinov3
  - pytorch
library_name: transformers
---

# T-REN: Text-Aligned Region Encoder Network

**Authors**: [Savya Khosla](https://savya08.github.io/), [Sethuraman TV](https://github.com/sethuramanio), [Aryan Chadha](https://www.linkedin.com/in/aryan-chadha/), [Alex Schwing](https://www.alexander-schwing.de/), [Derek Hoiem](https://dhoiem.cs.illinois.edu/)

[![GitHub](https://img.shields.io/badge/GitHub-Code-black.svg)](https://github.com/savya08/T-REN)

T-REN (**T**ext-aligned **R**egion **E**ncoder **N**etwork) is an image encoder that produces region-level tokens aligned with text, built on top of [DINOv3](https://github.com/facebookresearch/dinov3) ViT-L/16. Compared to its patch-based backbone, T-REN delivers:

- **+5.9 mIoU** on ADE20K open-vocabulary segmentation
- **+18.4% recall** on COCO object-level text-image retrieval
- **+15.6% recall** on Ego4D video object localization (VQ2D)
- **+17.6% mIoU** on VSPW video scene parsing
- **24× fewer tokens** per image, **187× fewer** per video

---

## What's in this repo

This HuggingFace repo contains:
- `model.safetensors` — the trained `RegionEncoder` head weights (~1.2 GB)
- `configuration_tren.py`, `modeling_tren.py`, `model.py`, `task_utils.py` — source code for `trust_remote_code`

**The DINOv3 ViT-L/16 backbone is NOT included here** — it belongs to Facebook Research and must be obtained separately (see below).

---

## Quickstart

### Step 1 — Install dependencies

```bash
pip install transformers torch torchvision kornia
```

### Step 2 — Get the DINOv3 weights

T-REN's backbone is DINOv3 ViT-L/16 with a DINOtxt text-alignment head. You need two weight files from the [DINOv3 release](https://github.com/facebookresearch/dinov3):

| File | Description |
|------|-------------|
| `dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth` | DINOv3 ViT-L/16 backbone |
| `dinov3_vitl16_dinotxt_vision_head_and_text_encoder-a442d8f5.pth` | DINOtxt vision head + text encoder |

Place both files in the same directory, e.g. `/path/to/dinov3_weights/`.

### Step 3 — Load and run T-REN

```python
import torch
import torchvision.transforms as T
from PIL import Image
from transformers import AutoModel

# Load model (downloads T-REN weights from this repo automatically)
model = AutoModel.from_pretrained("aryaaan12/T-REN", trust_remote_code=True)

# Load the DINOv3 backbone from your local directory
model.load_backbone("/path/to/dinov3_weights/")

model.eval()

# Prepare an image — resize to 512x512, values in [0, 1]
transform = T.Compose([
    T.Resize((512, 512)),
    T.ToTensor(),
])
image = transform(Image.open("your_image.jpg").convert("RGB"))
image = image.unsqueeze(0)  # (1, 3, 512, 512)

# Run T-REN
with torch.no_grad():
    outputs = model(image)

# Outputs
region_tokens = outputs["text_aligned_tokens"]  # list of (N, 1024) per image
region_masks  = outputs["region_masks"]          # list of (N, 32, 32) per image
class_token   = outputs["class_tokens"]          # (1, 1024) image-level token
print(f"Number of region tokens: {len(region_tokens[0])}")
```

### Text-guided region matching

```python
import torch.nn.functional as F

texts = ["sky", "car", "building", "tree", "road"]

with torch.no_grad():
    outputs = model(image, texts=texts)

region_tokens = outputs["text_aligned_tokens"][0]   # (N, 1024)
text_tokens   = outputs["text_encodings"]           # (5, 1024)

# Cosine similarity: which text label fits each region best?
sim = F.normalize(region_tokens, dim=-1) @ F.normalize(text_tokens, dim=-1).T
best_labels = sim.argmax(dim=-1)
print([texts[i] for i in best_labels])
```

---

## Model details

| | |
|---|---|
| Architecture | RegionEncoder (cross-attention decoder) over DINOv3 ViT-L/16 features |
| Trainable parameters | 31.5M (RegionEncoder head only; backbone is frozen) |
| Input resolution | 512 × 512 |
| Output token dim | 1024 |
| Multiscale regions | 3 scales per prompt point |
| Text embedding space | DINOtxt (aligned with DINOv3 text encoder) |

---

## Citation

```bibtex
@misc{khosla2026tren,
      title={T-REN: Learning Text-Aligned Region Tokens Improves Dense Vision-Language Alignment and Scalability},
      author={Savya Khosla and Sethuraman T V and Aryan Chadha and Alexander Schwing and Derek Hoiem},
      year={2026},
}
```