aryaaan12 commited on
Commit
e2b3733
·
verified ·
1 Parent(s): e98924d

Upload modeling_tren.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_tren.py +177 -0
modeling_tren.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ T-REN HuggingFace model wrapper.
3
+
4
+ Usage:
5
+ from transformers import AutoModel
6
+ model = AutoModel.from_pretrained("savyak2/T-REN", trust_remote_code=True)
7
+ model.load_backbone("/path/to/dinov3/weights/")
8
+
9
+ # Or in one shot:
10
+ model = AutoModel.from_pretrained(
11
+ "savyak2/T-REN",
12
+ trust_remote_code=True,
13
+ dinov3_weights_dir="/path/to/dinov3/weights/",
14
+ )
15
+ outputs = model(pixel_values) # pixel_values: (B, 3, H, W) float in [0, 1]
16
+ """
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import PreTrainedModel
21
+ from transformers.utils import logging
22
+
23
+ try:
24
+ from .configuration_tren import TRENConfig
25
+ from .model import FeatureExtractor, RegionEncoder, TextEncoder
26
+ except ImportError:
27
+ from configuration_tren import TRENConfig
28
+ from model import FeatureExtractor, RegionEncoder, TextEncoder
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+ DINOV3_BACKBONE_FILENAME = "dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth"
33
+ DINOV3_HEAD_FILENAME = "dinov3_vitl16_dinotxt_vision_head_and_text_encoder-a442d8f5.pth"
34
+
35
+
36
+ def _build_cfg_dict(config: TRENConfig, dinov3_weights_dir: str = None) -> dict:
37
+ """Convert TRENConfig into the dict format expected by existing model classes."""
38
+ return {
39
+ "pretrained": {
40
+ "feature_extractor": "dinov3_vitl16",
41
+ "text_encoder": "dinov3_vitl16",
42
+ },
43
+ "architecture": {
44
+ "patch_size": config.patch_size,
45
+ "hidden_dim": config.hidden_dim,
46
+ "text_embed_dim": config.text_embed_dim,
47
+ "num_decoder_layers": config.num_decoder_layers,
48
+ "num_attention_heads": config.num_attention_heads,
49
+ },
50
+ "parameters": {
51
+ "image_resolution": config.image_resolution,
52
+ "num_multiscale_regions": config.num_multiscale_regions,
53
+ "merging_iou_threshold": config.merging_iou_threshold,
54
+ "merging_similarity_threshold": config.merging_similarity_threshold,
55
+ },
56
+ # save_dir + exp_name join to give the directory containing DINOv3 weights.
57
+ # e.g. os.path.join("/path/to/dir", "", "filename.pth") -> "/path/to/dir/filename.pth"
58
+ "logging": {
59
+ "save_dir": dinov3_weights_dir or "",
60
+ "exp_name": "",
61
+ },
62
+ }
63
+
64
+
65
+ class TRENModel(PreTrainedModel):
66
+ """
67
+ T-REN: Text-aligned Region Encoder Network.
68
+
69
+ Takes raw images and returns dense region tokens aligned to a shared
70
+ vision-language embedding space (DINOv3 / DINOtxt).
71
+
72
+ The trainable RegionEncoder weights are stored in this HF repo and loaded
73
+ automatically. The DINOv3 ViT-L/16 backbone (~2 GB) must be provided
74
+ separately via load_backbone().
75
+
76
+ DINOv3 weights needed in the same directory:
77
+ - dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth
78
+ - dinov3_vitl16_dinotxt_vision_head_and_text_encoder-a442d8f5.pth
79
+ """
80
+
81
+ config_class = TRENConfig
82
+ base_model_prefix = "region_encoder"
83
+
84
+ def __init__(self, config: TRENConfig, dinov3_weights_dir: str = None):
85
+ super().__init__(config)
86
+
87
+ cfg = _build_cfg_dict(config)
88
+
89
+ # RegionEncoder: the trained T-REN head. HF saves/loads these weights.
90
+ self.region_encoder = RegionEncoder(cfg)
91
+
92
+ # Dense grid of point prompts covering the full image at patch stride.
93
+ res = config.image_resolution
94
+ ps = config.patch_size
95
+ coords = np.linspace(1, res - 2, res // ps, dtype=int)
96
+ grid_points = torch.tensor([(y, x) for y in coords for x in coords])
97
+
98
+ # Store grid_points and lazy backbone refs without registering them as
99
+ # nn.Module submodules (so they are excluded from HF save/load).
100
+ object.__setattr__(self, "_grid_points", grid_points)
101
+ object.__setattr__(self, "_image_encoder", None)
102
+ object.__setattr__(self, "_text_encoder", None)
103
+
104
+ self.post_init()
105
+
106
+ if dinov3_weights_dir is not None:
107
+ self.load_backbone(dinov3_weights_dir)
108
+
109
+ def load_backbone(self, dinov3_weights_dir: str) -> None:
110
+ """
111
+ Load the frozen DINOv3 image and text encoder backbones.
112
+
113
+ Args:
114
+ dinov3_weights_dir: Directory containing both DINOv3 weight files:
115
+ - dinov3_vitl16_pretrain_lvd1689m-8aa4cbdd.pth
116
+ - dinov3_vitl16_dinotxt_vision_head_and_text_encoder-a442d8f5.pth
117
+ """
118
+ device = next(self.region_encoder.parameters()).device
119
+ cfg = _build_cfg_dict(self.config, dinov3_weights_dir)
120
+
121
+ logger.info("Loading DINOv3 image encoder...")
122
+ image_encoder = FeatureExtractor(cfg, device=str(device)).eval()
123
+
124
+ logger.info("Loading DINOv3 text encoder...")
125
+ text_encoder = TextEncoder(cfg, device=str(device)).eval()
126
+
127
+ object.__setattr__(self, "_image_encoder", image_encoder)
128
+ object.__setattr__(self, "_text_encoder", text_encoder)
129
+
130
+ def forward(
131
+ self,
132
+ pixel_values: torch.Tensor,
133
+ texts: list = None,
134
+ aggregate_tokens: bool = True,
135
+ ) -> dict:
136
+ """
137
+ Encode an image into region tokens.
138
+
139
+ Args:
140
+ pixel_values: Float tensor of shape (B, 3, H, W) in [0, 1].
141
+ texts: Optional list of text strings. When provided, text embeddings
142
+ are returned alongside region tokens for similarity scoring.
143
+ aggregate_tokens: Merge overlapping region tokens by mask IoU and
144
+ embedding cosine similarity (recommended for downstream use).
145
+
146
+ Returns:
147
+ dict with keys:
148
+ pred_tokens – (B, N, D) raw region feature tokens.
149
+ region_masks – (B, N, fH, fW) attention-derived region masks.
150
+ text_aligned_tokens – (B, N, D) tokens in the DINOtxt embedding space.
151
+ class_tokens – (B, D) image-level DINOv3 class tokens.
152
+ text_encodings – (T, D) text embeddings, only if texts is provided.
153
+ """
154
+ if self._image_encoder is None:
155
+ raise RuntimeError(
156
+ "DINOv3 backbone not loaded. "
157
+ "Call model.load_backbone(dinov3_weights_dir=...) first, "
158
+ "or pass dinov3_weights_dir= to from_pretrained()."
159
+ )
160
+
161
+ device = pixel_values.device
162
+ prompts = [self._grid_points.to(device) for _ in range(pixel_values.shape[0])]
163
+
164
+ with torch.no_grad():
165
+ backbone_out = self._image_encoder(pixel_values)
166
+ feature_maps = backbone_out["feature_maps"].to(device)
167
+ class_tokens = backbone_out["text_aligned_class_tokens"].to(device)
168
+
169
+ outputs = self.region_encoder(feature_maps, prompts, aggregate_tokens=aggregate_tokens)
170
+ outputs["class_tokens"] = class_tokens
171
+
172
+ if texts is not None:
173
+ if self._text_encoder is None:
174
+ raise RuntimeError("Text encoder not loaded. Call load_backbone() first.")
175
+ outputs["text_encodings"] = self._text_encoder(texts)
176
+
177
+ return outputs