Image Segmentation
Transformers
Safetensors
PyTorch
English
tren
feature-extraction
vision
image-feature-extraction
region-tokens
dinov3
custom_code
Instructions to use aryaaan12/T-REN with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use aryaaan12/T-REN with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-segmentation", model="aryaaan12/T-REN", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("aryaaan12/T-REN", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Upload modeling_tren.py with huggingface_hub
Browse files- modeling_tren.py +13 -9
modeling_tren.py
CHANGED
|
@@ -89,15 +89,10 @@ class TRENModel(PreTrainedModel):
|
|
| 89 |
# RegionEncoder: the trained T-REN head. HF saves/loads these weights.
|
| 90 |
self.region_encoder = RegionEncoder(cfg)
|
| 91 |
|
| 92 |
-
#
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
grid_points = torch.tensor([(y, x) for y in coords for x in coords])
|
| 97 |
-
|
| 98 |
-
# Store grid_points and lazy backbone refs without registering them as
|
| 99 |
-
# nn.Module submodules (so they are excluded from HF save/load).
|
| 100 |
-
object.__setattr__(self, "_grid_points", grid_points)
|
| 101 |
object.__setattr__(self, "_image_encoder", None)
|
| 102 |
object.__setattr__(self, "_text_encoder", None)
|
| 103 |
|
|
@@ -159,6 +154,15 @@ class TRENModel(PreTrainedModel):
|
|
| 159 |
)
|
| 160 |
|
| 161 |
device = pixel_values.device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
prompts = [self._grid_points.to(device) for _ in range(pixel_values.shape[0])]
|
| 163 |
|
| 164 |
with torch.no_grad():
|
|
|
|
| 89 |
# RegionEncoder: the trained T-REN head. HF saves/loads these weights.
|
| 90 |
self.region_encoder = RegionEncoder(cfg)
|
| 91 |
|
| 92 |
+
# Lazy placeholders — not registered as nn.Module submodules so they
|
| 93 |
+
# are excluded from HF save/load. _grid_points is computed on first
|
| 94 |
+
# forward() call to avoid meta-device issues during from_pretrained().
|
| 95 |
+
object.__setattr__(self, "_grid_points", None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
object.__setattr__(self, "_image_encoder", None)
|
| 97 |
object.__setattr__(self, "_text_encoder", None)
|
| 98 |
|
|
|
|
| 154 |
)
|
| 155 |
|
| 156 |
device = pixel_values.device
|
| 157 |
+
|
| 158 |
+
# Build grid on first call (avoids meta-device issues during from_pretrained).
|
| 159 |
+
if self._grid_points is None:
|
| 160 |
+
res = self.config.image_resolution
|
| 161 |
+
ps = self.config.patch_size
|
| 162 |
+
coords = np.linspace(1, res - 2, res // ps, dtype=int)
|
| 163 |
+
object.__setattr__(self, "_grid_points",
|
| 164 |
+
torch.tensor([(y, x) for y in coords for x in coords]))
|
| 165 |
+
|
| 166 |
prompts = [self._grid_points.to(device) for _ in range(pixel_values.shape[0])]
|
| 167 |
|
| 168 |
with torch.no_grad():
|