Image Segmentation
Transformers
Safetensors
PyTorch
English
tren
feature-extraction
vision
image-feature-extraction
region-tokens
dinov3
custom_code
Instructions to use aryaaan12/T-REN with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use aryaaan12/T-REN with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-segmentation", model="aryaaan12/T-REN", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("aryaaan12/T-REN", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Upload modeling_tren.py with huggingface_hub
Browse files- modeling_tren.py +37 -0
modeling_tren.py
CHANGED
|
@@ -122,6 +122,43 @@ class TRENModel(PreTrainedModel):
|
|
| 122 |
object.__setattr__(self, "_image_encoder", image_encoder)
|
| 123 |
object.__setattr__(self, "_text_encoder", text_encoder)
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
def forward(
|
| 126 |
self,
|
| 127 |
pixel_values: torch.Tensor,
|
|
|
|
| 122 |
object.__setattr__(self, "_image_encoder", image_encoder)
|
| 123 |
object.__setattr__(self, "_text_encoder", text_encoder)
|
| 124 |
|
| 125 |
+
def adapt_to_resolution(self, image_resolution: int) -> None:
|
| 126 |
+
"""
|
| 127 |
+
Interpolate the RegionEncoder's positional embeddings to a new spatial
|
| 128 |
+
resolution. Call this after from_pretrained() when running inference at
|
| 129 |
+
a resolution different from the training resolution (512px by default).
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
image_resolution: Target image resolution in pixels (e.g. 384).
|
| 133 |
+
|
| 134 |
+
Example::
|
| 135 |
+
|
| 136 |
+
model = AutoModel.from_pretrained("aryaaan12/T-REN", trust_remote_code=True)
|
| 137 |
+
model.load_backbone("/path/to/dinov3/weights/")
|
| 138 |
+
model.adapt_to_resolution(384) # eval at 384px instead of 512px
|
| 139 |
+
"""
|
| 140 |
+
if image_resolution == self.config.image_resolution:
|
| 141 |
+
return
|
| 142 |
+
|
| 143 |
+
saved_state = self.region_encoder.state_dict()
|
| 144 |
+
device = next(self.region_encoder.parameters()).device
|
| 145 |
+
ps = self.config.patch_size
|
| 146 |
+
num_patches = (image_resolution // ps) ** 2
|
| 147 |
+
C = self.region_encoder.feature_embeddings.shape[-1]
|
| 148 |
+
|
| 149 |
+
self.region_encoder.feature_embeddings = torch.nn.Parameter(
|
| 150 |
+
torch.zeros(num_patches, C, device=device)
|
| 151 |
+
)
|
| 152 |
+
self.region_encoder.load_state_dict_resolution_agnostic(saved_state)
|
| 153 |
+
self.region_encoder.to(device)
|
| 154 |
+
|
| 155 |
+
# Reset grid so it is rebuilt at the new resolution on the next forward().
|
| 156 |
+
object.__setattr__(self, "_grid_points", None)
|
| 157 |
+
|
| 158 |
+
logger.info(
|
| 159 |
+
f"Adapted positional embeddings: {self.config.image_resolution}px → {image_resolution}px"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
def forward(
|
| 163 |
self,
|
| 164 |
pixel_values: torch.Tensor,
|