Image Segmentation
Transformers
Safetensors
PyTorch
English
tren
feature-extraction
vision
image-feature-extraction
region-tokens
dinov3
custom_code
Instructions to use aryaaan12/T-REN with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use aryaaan12/T-REN with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-segmentation", model="aryaaan12/T-REN", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("aryaaan12/T-REN", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import os | |
| import math | |
| import itertools | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms as T | |
| from matplotlib import pyplot as plt | |
| from sklearn.decomposition import PCA | |
| class CenterPadding(torch.nn.Module): | |
| def __init__(self, multiple): | |
| super().__init__() | |
| self.multiple = multiple | |
| def _get_pad(self, size): | |
| new_size = math.ceil(size / self.multiple) * self.multiple | |
| pad_size = new_size - size | |
| pad_size_left = pad_size // 2 | |
| pad_size_right = pad_size - pad_size_left | |
| return pad_size_left, pad_size_right | |
| def forward(self, x): | |
| pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) | |
| output = F.pad(x, pads) | |
| return output | |
| def upsample_features(image_features, new_h, new_w, padded_h, padded_w, upsampling_method='bilinear'): | |
| if upsampling_method == 'bilinear': | |
| upsampled_feature = torch.nn.functional.interpolate(image_features, | |
| size=[padded_h, padded_w], mode='bilinear') | |
| upsampled_feature = T.CenterCrop((new_h, new_w))(upsampled_feature) | |
| else: | |
| raise ValueError(f'{upsampling_method} is not a valid upsampling method.') | |
| return upsampled_feature | |
| def visualize_features(features, image, save_path): | |
| image_height, image_width = image.shape[1], image.shape[2] | |
| pca = PCA(n_components=3) | |
| reshaped_features = features.permute(1, 2, 0).reshape(image_height * image_width, -1).float().numpy() | |
| pca.fit(reshaped_features) | |
| pca_features = pca.transform(reshaped_features) | |
| pca_features = (pca_features - pca_features.min(axis = -1)[..., None]) / \ | |
| (pca_features.max(axis = -1)[..., None] - pca_features.min(axis = -1)[..., None]) | |
| vis_features = pca_features.reshape(image_height, image_width, 3) | |
| plt.figure() | |
| plt.subplot(1, 2, 1) | |
| plt.imshow(image.permute(1, 2, 0).numpy()) | |
| plt.axis('off') | |
| plt.subplot(1, 2, 2) | |
| plt.imshow(vis_features) | |
| plt.axis('off') | |
| plt.savefig(save_path) | |
| plt.clf() | |
| def visualize_cosine_similarity(features, images, save_dir, grid_size=64): | |
| os.makedirs(save_dir, exist_ok=True) | |
| features = F.normalize(features, p=2, dim=1).flatten(-2) | |
| batch_size, _, num_tokens = features.shape | |
| for batch_idx in range(batch_size): | |
| similarity_map = features[batch_idx].t().mm(features[batch_idx]) | |
| for token_idx in range(num_tokens): | |
| token_similarity_map = similarity_map[token_idx] | |
| token_similarity_map = token_similarity_map.reshape(grid_size, grid_size) | |
| row = token_idx // grid_size | |
| col = token_idx % grid_size | |
| plt.figure() | |
| plt.subplot(1, 2, 1) | |
| plt.imshow(images[batch_idx].cpu().permute(1, 2, 0).float().numpy()) | |
| plt.axis('off') | |
| plt.subplot(1, 2, 2) | |
| plt.imshow(token_similarity_map.float().detach().cpu().numpy()) | |
| plt.plot(col, row, 'rx', markersize=3, markeredgewidth=2, label='Query token') | |
| plt.axis('off') | |
| os.makedirs(f'{save_dir}/batch-{batch_idx}', exist_ok=True) | |
| plt.savefig(f'{save_dir}/batch-{batch_idx}/token-{token_idx}.jpg') | |
| plt.clf() | |
| plt.close() | |
| def visualize_regions(regions, image, save_dir): | |
| os.makedirs(save_dir, exist_ok=True) | |
| for idx, mask in enumerate(regions): | |
| plt.imshow(mask[:, :, None] * image.permute(1, 2, 0).numpy()) | |
| plt.axis('off') | |
| plt.savefig(os.path.join(save_dir, f'{idx}.jpg')) | |
| plt.clf() | |
| plt.imshow(image.permute(1, 2, 0).numpy()) | |
| plt.axis('off') | |
| plt.savefig(os.path.join(save_dir, 'image.jpg')) | |
| plt.clf() | |
| def visualize_attn_weights(attn_weights, images, patch_size, grid_points=None, attn_aggregation='max', save_dir='attn_vis'): | |
| batch_size, num_heads, num_q, _ = attn_weights.shape | |
| h, w = images.shape[-2:] | |
| for batch_idx in range(images.shape[0]): | |
| batch_dir = f'{save_dir}/batch-{batch_idx}' | |
| os.makedirs(batch_dir, exist_ok=True) | |
| plt.imshow(images[batch_idx].permute(1, 2, 0).detach().cpu().numpy()) | |
| plt.axis('off') | |
| plt.savefig(f'{batch_dir}/image.jpg') | |
| plt.clf() | |
| attn_weights = attn_weights.view(batch_size, num_heads, num_q, h // patch_size, w // patch_size) | |
| for q_idx in range(num_q): | |
| attn_map = F.sigmoid(attn_weights[batch_idx, :, q_idx]).detach().cpu().numpy() | |
| if attn_aggregation == 'max': | |
| combined_attn_map = np.max(attn_map, axis=0) | |
| elif attn_aggregation == 'mean': | |
| combined_attn_map = np.mean(attn_map, axis=0) | |
| plt.imshow(combined_attn_map) | |
| plt.axis('off') | |
| if grid_points is not None: | |
| plt.scatter([grid_points[batch_idx][q_idx][1] / patch_size], [grid_points[batch_idx][q_idx][0] / patch_size], | |
| marker='o', s=20, c='red') | |
| plt.savefig(f'{batch_dir}/query-{q_idx}.jpg') | |
| plt.close() | |
| def pad_or_truncate_tokens(tokens, pad_length, pad_value): | |
| current_length, dim_size = tokens.shape | |
| if current_length > pad_length: | |
| return tokens[:pad_length] | |
| if current_length < pad_length: | |
| padding = torch.full((pad_length - current_length, dim_size), pad_value, | |
| dtype=tokens.dtype, device=tokens.device) | |
| return torch.cat([tokens, padding], dim=0) | |
| def print_log(log_str, save_dir=None): | |
| print(log_str) | |
| if save_dir is not None: | |
| log_file = os.path.join(save_dir, 'log.txt') | |
| with open(log_file, 'a') as f: | |
| f.write(log_str + '\n') |