aryaaan12 commited on
Commit
010834a
·
verified ·
1 Parent(s): b19f6dc

Upload task_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. task_utils.py +151 -0
task_utils.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import itertools
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ import torchvision.transforms as T
8
+ from matplotlib import pyplot as plt
9
+ from sklearn.decomposition import PCA
10
+
11
+
12
+ class CenterPadding(torch.nn.Module):
13
+ def __init__(self, multiple):
14
+ super().__init__()
15
+ self.multiple = multiple
16
+
17
+ def _get_pad(self, size):
18
+ new_size = math.ceil(size / self.multiple) * self.multiple
19
+ pad_size = new_size - size
20
+ pad_size_left = pad_size // 2
21
+ pad_size_right = pad_size - pad_size_left
22
+ return pad_size_left, pad_size_right
23
+
24
+ @torch.inference_mode()
25
+ def forward(self, x):
26
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
27
+ output = F.pad(x, pads)
28
+ return output
29
+
30
+
31
+ def upsample_features(image_features, new_h, new_w, padded_h, padded_w, upsampling_method='bilinear'):
32
+ if upsampling_method == 'bilinear':
33
+ upsampled_feature = torch.nn.functional.interpolate(image_features,
34
+ size=[padded_h, padded_w], mode='bilinear')
35
+ upsampled_feature = T.CenterCrop((new_h, new_w))(upsampled_feature)
36
+ else:
37
+ raise ValueError(f'{upsampling_method} is not a valid upsampling method.')
38
+ return upsampled_feature
39
+
40
+
41
+ def visualize_features(features, image, save_path):
42
+ image_height, image_width = image.shape[1], image.shape[2]
43
+
44
+ pca = PCA(n_components=3)
45
+ reshaped_features = features.permute(1, 2, 0).reshape(image_height * image_width, -1).float().numpy()
46
+ pca.fit(reshaped_features)
47
+ pca_features = pca.transform(reshaped_features)
48
+ pca_features = (pca_features - pca_features.min(axis = -1)[..., None]) / \
49
+ (pca_features.max(axis = -1)[..., None] - pca_features.min(axis = -1)[..., None])
50
+ vis_features = pca_features.reshape(image_height, image_width, 3)
51
+
52
+ plt.figure()
53
+ plt.subplot(1, 2, 1)
54
+ plt.imshow(image.permute(1, 2, 0).numpy())
55
+ plt.axis('off')
56
+ plt.subplot(1, 2, 2)
57
+ plt.imshow(vis_features)
58
+ plt.axis('off')
59
+ plt.savefig(save_path)
60
+ plt.clf()
61
+
62
+
63
+ def visualize_cosine_similarity(features, images, save_dir, grid_size=64):
64
+ os.makedirs(save_dir, exist_ok=True)
65
+ features = F.normalize(features, p=2, dim=1).flatten(-2)
66
+ batch_size, _, num_tokens = features.shape
67
+
68
+ for batch_idx in range(batch_size):
69
+ similarity_map = features[batch_idx].t().mm(features[batch_idx])
70
+ for token_idx in range(num_tokens):
71
+ token_similarity_map = similarity_map[token_idx]
72
+ token_similarity_map = token_similarity_map.reshape(grid_size, grid_size)
73
+
74
+ row = token_idx // grid_size
75
+ col = token_idx % grid_size
76
+
77
+ plt.figure()
78
+ plt.subplot(1, 2, 1)
79
+ plt.imshow(images[batch_idx].cpu().permute(1, 2, 0).float().numpy())
80
+ plt.axis('off')
81
+ plt.subplot(1, 2, 2)
82
+ plt.imshow(token_similarity_map.float().detach().cpu().numpy())
83
+ plt.plot(col, row, 'rx', markersize=3, markeredgewidth=2, label='Query token')
84
+ plt.axis('off')
85
+ os.makedirs(f'{save_dir}/batch-{batch_idx}', exist_ok=True)
86
+ plt.savefig(f'{save_dir}/batch-{batch_idx}/token-{token_idx}.jpg')
87
+ plt.clf()
88
+ plt.close()
89
+
90
+
91
+ def visualize_regions(regions, image, save_dir):
92
+ os.makedirs(save_dir, exist_ok=True)
93
+ for idx, mask in enumerate(regions):
94
+ plt.imshow(mask[:, :, None] * image.permute(1, 2, 0).numpy())
95
+ plt.axis('off')
96
+ plt.savefig(os.path.join(save_dir, f'{idx}.jpg'))
97
+ plt.clf()
98
+
99
+ plt.imshow(image.permute(1, 2, 0).numpy())
100
+ plt.axis('off')
101
+ plt.savefig(os.path.join(save_dir, 'image.jpg'))
102
+ plt.clf()
103
+
104
+
105
+ def visualize_attn_weights(attn_weights, images, patch_size, grid_points=None, attn_aggregation='max', save_dir='attn_vis'):
106
+ batch_size, num_heads, num_q, _ = attn_weights.shape
107
+ h, w = images.shape[-2:]
108
+
109
+ for batch_idx in range(images.shape[0]):
110
+ batch_dir = f'{save_dir}/batch-{batch_idx}'
111
+ os.makedirs(batch_dir, exist_ok=True)
112
+ plt.imshow(images[batch_idx].permute(1, 2, 0).detach().cpu().numpy())
113
+ plt.axis('off')
114
+ plt.savefig(f'{batch_dir}/image.jpg')
115
+ plt.clf()
116
+
117
+ attn_weights = attn_weights.view(batch_size, num_heads, num_q, h // patch_size, w // patch_size)
118
+
119
+ for q_idx in range(num_q):
120
+ attn_map = F.sigmoid(attn_weights[batch_idx, :, q_idx]).detach().cpu().numpy()
121
+ if attn_aggregation == 'max':
122
+ combined_attn_map = np.max(attn_map, axis=0)
123
+ elif attn_aggregation == 'mean':
124
+ combined_attn_map = np.mean(attn_map, axis=0)
125
+ plt.imshow(combined_attn_map)
126
+ plt.axis('off')
127
+ if grid_points is not None:
128
+ plt.scatter([grid_points[batch_idx][q_idx][1] / patch_size], [grid_points[batch_idx][q_idx][0] / patch_size],
129
+ marker='o', s=20, c='red')
130
+ plt.savefig(f'{batch_dir}/query-{q_idx}.jpg')
131
+ plt.close()
132
+
133
+
134
+ def pad_or_truncate_tokens(tokens, pad_length, pad_value):
135
+ current_length, dim_size = tokens.shape
136
+
137
+ if current_length > pad_length:
138
+ return tokens[:pad_length]
139
+
140
+ if current_length < pad_length:
141
+ padding = torch.full((pad_length - current_length, dim_size), pad_value,
142
+ dtype=tokens.dtype, device=tokens.device)
143
+ return torch.cat([tokens, padding], dim=0)
144
+
145
+
146
+ def print_log(log_str, save_dir=None):
147
+ print(log_str)
148
+ if save_dir is not None:
149
+ log_file = os.path.join(save_dir, 'log.txt')
150
+ with open(log_file, 'a') as f:
151
+ f.write(log_str + '\n')