Image Segmentation
Transformers
Safetensors
PyTorch
English
tren
feature-extraction
vision
image-feature-extraction
region-tokens
dinov3
custom_code
Instructions to use aryaaan12/T-REN with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use aryaaan12/T-REN with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-segmentation", model="aryaaan12/T-REN", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("aryaaan12/T-REN", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 5,778 Bytes
010834a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | import os
import math
import itertools
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from matplotlib import pyplot as plt
from sklearn.decomposition import PCA
class CenterPadding(torch.nn.Module):
def __init__(self, multiple):
super().__init__()
self.multiple = multiple
def _get_pad(self, size):
new_size = math.ceil(size / self.multiple) * self.multiple
pad_size = new_size - size
pad_size_left = pad_size // 2
pad_size_right = pad_size - pad_size_left
return pad_size_left, pad_size_right
@torch.inference_mode()
def forward(self, x):
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
output = F.pad(x, pads)
return output
def upsample_features(image_features, new_h, new_w, padded_h, padded_w, upsampling_method='bilinear'):
if upsampling_method == 'bilinear':
upsampled_feature = torch.nn.functional.interpolate(image_features,
size=[padded_h, padded_w], mode='bilinear')
upsampled_feature = T.CenterCrop((new_h, new_w))(upsampled_feature)
else:
raise ValueError(f'{upsampling_method} is not a valid upsampling method.')
return upsampled_feature
def visualize_features(features, image, save_path):
image_height, image_width = image.shape[1], image.shape[2]
pca = PCA(n_components=3)
reshaped_features = features.permute(1, 2, 0).reshape(image_height * image_width, -1).float().numpy()
pca.fit(reshaped_features)
pca_features = pca.transform(reshaped_features)
pca_features = (pca_features - pca_features.min(axis = -1)[..., None]) / \
(pca_features.max(axis = -1)[..., None] - pca_features.min(axis = -1)[..., None])
vis_features = pca_features.reshape(image_height, image_width, 3)
plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(image.permute(1, 2, 0).numpy())
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(vis_features)
plt.axis('off')
plt.savefig(save_path)
plt.clf()
def visualize_cosine_similarity(features, images, save_dir, grid_size=64):
os.makedirs(save_dir, exist_ok=True)
features = F.normalize(features, p=2, dim=1).flatten(-2)
batch_size, _, num_tokens = features.shape
for batch_idx in range(batch_size):
similarity_map = features[batch_idx].t().mm(features[batch_idx])
for token_idx in range(num_tokens):
token_similarity_map = similarity_map[token_idx]
token_similarity_map = token_similarity_map.reshape(grid_size, grid_size)
row = token_idx // grid_size
col = token_idx % grid_size
plt.figure()
plt.subplot(1, 2, 1)
plt.imshow(images[batch_idx].cpu().permute(1, 2, 0).float().numpy())
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(token_similarity_map.float().detach().cpu().numpy())
plt.plot(col, row, 'rx', markersize=3, markeredgewidth=2, label='Query token')
plt.axis('off')
os.makedirs(f'{save_dir}/batch-{batch_idx}', exist_ok=True)
plt.savefig(f'{save_dir}/batch-{batch_idx}/token-{token_idx}.jpg')
plt.clf()
plt.close()
def visualize_regions(regions, image, save_dir):
os.makedirs(save_dir, exist_ok=True)
for idx, mask in enumerate(regions):
plt.imshow(mask[:, :, None] * image.permute(1, 2, 0).numpy())
plt.axis('off')
plt.savefig(os.path.join(save_dir, f'{idx}.jpg'))
plt.clf()
plt.imshow(image.permute(1, 2, 0).numpy())
plt.axis('off')
plt.savefig(os.path.join(save_dir, 'image.jpg'))
plt.clf()
def visualize_attn_weights(attn_weights, images, patch_size, grid_points=None, attn_aggregation='max', save_dir='attn_vis'):
batch_size, num_heads, num_q, _ = attn_weights.shape
h, w = images.shape[-2:]
for batch_idx in range(images.shape[0]):
batch_dir = f'{save_dir}/batch-{batch_idx}'
os.makedirs(batch_dir, exist_ok=True)
plt.imshow(images[batch_idx].permute(1, 2, 0).detach().cpu().numpy())
plt.axis('off')
plt.savefig(f'{batch_dir}/image.jpg')
plt.clf()
attn_weights = attn_weights.view(batch_size, num_heads, num_q, h // patch_size, w // patch_size)
for q_idx in range(num_q):
attn_map = F.sigmoid(attn_weights[batch_idx, :, q_idx]).detach().cpu().numpy()
if attn_aggregation == 'max':
combined_attn_map = np.max(attn_map, axis=0)
elif attn_aggregation == 'mean':
combined_attn_map = np.mean(attn_map, axis=0)
plt.imshow(combined_attn_map)
plt.axis('off')
if grid_points is not None:
plt.scatter([grid_points[batch_idx][q_idx][1] / patch_size], [grid_points[batch_idx][q_idx][0] / patch_size],
marker='o', s=20, c='red')
plt.savefig(f'{batch_dir}/query-{q_idx}.jpg')
plt.close()
def pad_or_truncate_tokens(tokens, pad_length, pad_value):
current_length, dim_size = tokens.shape
if current_length > pad_length:
return tokens[:pad_length]
if current_length < pad_length:
padding = torch.full((pad_length - current_length, dim_size), pad_value,
dtype=tokens.dtype, device=tokens.device)
return torch.cat([tokens, padding], dim=0)
def print_log(log_str, save_dir=None):
print(log_str)
if save_dir is not None:
log_file = os.path.join(save_dir, 'log.txt')
with open(log_file, 'a') as f:
f.write(log_str + '\n') |