aryaaan12 commited on
Commit
2b883a1
·
verified ·
1 Parent(s): 66c4f84

Upload modeling_tren.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_tren.py +13 -9
modeling_tren.py CHANGED
@@ -89,15 +89,10 @@ class TRENModel(PreTrainedModel):
89
  # RegionEncoder: the trained T-REN head. HF saves/loads these weights.
90
  self.region_encoder = RegionEncoder(cfg)
91
 
92
- # Dense grid of point prompts covering the full image at patch stride.
93
- res = config.image_resolution
94
- ps = config.patch_size
95
- coords = np.linspace(1, res - 2, res // ps, dtype=int)
96
- grid_points = torch.tensor([(y, x) for y in coords for x in coords])
97
-
98
- # Store grid_points and lazy backbone refs without registering them as
99
- # nn.Module submodules (so they are excluded from HF save/load).
100
- object.__setattr__(self, "_grid_points", grid_points)
101
  object.__setattr__(self, "_image_encoder", None)
102
  object.__setattr__(self, "_text_encoder", None)
103
 
@@ -159,6 +154,15 @@ class TRENModel(PreTrainedModel):
159
  )
160
 
161
  device = pixel_values.device
 
 
 
 
 
 
 
 
 
162
  prompts = [self._grid_points.to(device) for _ in range(pixel_values.shape[0])]
163
 
164
  with torch.no_grad():
 
89
  # RegionEncoder: the trained T-REN head. HF saves/loads these weights.
90
  self.region_encoder = RegionEncoder(cfg)
91
 
92
+ # Lazy placeholders not registered as nn.Module submodules so they
93
+ # are excluded from HF save/load. _grid_points is computed on first
94
+ # forward() call to avoid meta-device issues during from_pretrained().
95
+ object.__setattr__(self, "_grid_points", None)
 
 
 
 
 
96
  object.__setattr__(self, "_image_encoder", None)
97
  object.__setattr__(self, "_text_encoder", None)
98
 
 
154
  )
155
 
156
  device = pixel_values.device
157
+
158
+ # Build grid on first call (avoids meta-device issues during from_pretrained).
159
+ if self._grid_points is None:
160
+ res = self.config.image_resolution
161
+ ps = self.config.patch_size
162
+ coords = np.linspace(1, res - 2, res // ps, dtype=int)
163
+ object.__setattr__(self, "_grid_points",
164
+ torch.tensor([(y, x) for y in coords for x in coords]))
165
+
166
  prompts = [self._grid_points.to(device) for _ in range(pixel_values.shape[0])]
167
 
168
  with torch.no_grad():