aryaaan12 commited on
Commit
18c1533
·
verified ·
1 Parent(s): 2b883a1

Upload modeling_tren.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_tren.py +37 -0
modeling_tren.py CHANGED
@@ -122,6 +122,43 @@ class TRENModel(PreTrainedModel):
122
  object.__setattr__(self, "_image_encoder", image_encoder)
123
  object.__setattr__(self, "_text_encoder", text_encoder)
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def forward(
126
  self,
127
  pixel_values: torch.Tensor,
 
122
  object.__setattr__(self, "_image_encoder", image_encoder)
123
  object.__setattr__(self, "_text_encoder", text_encoder)
124
 
125
+ def adapt_to_resolution(self, image_resolution: int) -> None:
126
+ """
127
+ Interpolate the RegionEncoder's positional embeddings to a new spatial
128
+ resolution. Call this after from_pretrained() when running inference at
129
+ a resolution different from the training resolution (512px by default).
130
+
131
+ Args:
132
+ image_resolution: Target image resolution in pixels (e.g. 384).
133
+
134
+ Example::
135
+
136
+ model = AutoModel.from_pretrained("aryaaan12/T-REN", trust_remote_code=True)
137
+ model.load_backbone("/path/to/dinov3/weights/")
138
+ model.adapt_to_resolution(384) # eval at 384px instead of 512px
139
+ """
140
+ if image_resolution == self.config.image_resolution:
141
+ return
142
+
143
+ saved_state = self.region_encoder.state_dict()
144
+ device = next(self.region_encoder.parameters()).device
145
+ ps = self.config.patch_size
146
+ num_patches = (image_resolution // ps) ** 2
147
+ C = self.region_encoder.feature_embeddings.shape[-1]
148
+
149
+ self.region_encoder.feature_embeddings = torch.nn.Parameter(
150
+ torch.zeros(num_patches, C, device=device)
151
+ )
152
+ self.region_encoder.load_state_dict_resolution_agnostic(saved_state)
153
+ self.region_encoder.to(device)
154
+
155
+ # Reset grid so it is rebuilt at the new resolution on the next forward().
156
+ object.__setattr__(self, "_grid_points", None)
157
+
158
+ logger.info(
159
+ f"Adapted positional embeddings: {self.config.image_resolution}px → {image_resolution}px"
160
+ )
161
+
162
  def forward(
163
  self,
164
  pixel_values: torch.Tensor,