Spaces:
Runtime error
Runtime error
File size: 4,676 Bytes
c75adab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
def generate_position_tensor(h, w, box_xy, raw_h, raw_w):
box_x = box_xy[2] - box_xy[0]
box_y = box_xy[3] - box_xy[1]
x_tensor = torch.arange(w, dtype=torch.float32) / w
x_tensor = x_tensor.repeat(h, 1)
x_tensor = (x_tensor*box_x+box_xy[0])/(raw_w-1)
y_tensor = torch.arange(h, dtype=torch.float32) / h
y_tensor = y_tensor.unsqueeze(1).repeat(1, w)
y_tensor = (y_tensor*box_y+box_xy[1])/(raw_h-1)
tensor_3d = torch.stack([x_tensor, y_tensor], dim=2)
return tensor_3d
class MaskExtractor(nn.Module):
def __init__(self, config, mm_hidden_size, depth=2):
super(MaskExtractor, self).__init__()
self.mask_pooling = MaskPooling()
modules = [nn.Linear(mm_hidden_size, config.hidden_size)]
for _ in range(1, depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
self.feat_linear = nn.Sequential(*modules)
self.pos_linear = nn.Linear(2, mm_hidden_size)
def forward(self, feats, masks, box_params, mask_num):
query_feats = []
if masks is None: #infer
return None
# masks = torch.zeros((1, 1, 336, 336)).to(feats.device).float()
num_imgs = len(masks)
image_idx = 0
for idx in range(num_imgs):
if masks[idx]==None:
continue
for mask_idx in range(len(masks[idx])):
mask = masks[idx][mask_idx].unsqueeze(0).unsqueeze(0).float()
box_param = box_params[idx][mask_idx]
box_xy, raw_h, raw_w = box_param
if len(mask[0])==0:
print('mask error')
mask = torch.zeros((1, 1, 336, 336)).to(feats.device).float()
feat = feats[image_idx].unsqueeze(0)
pos_emb = generate_position_tensor(feat.shape[1], feat.shape[2], box_xy, raw_h, raw_w).unsqueeze(0).to(feat)
pos_emb = self.pos_linear(pos_emb)
feat = feat+pos_emb
image_idx+=1
# h, w = feat.shape[1:3]
feat = feat.permute(0,3,1,2)
feat = feat.to(mask.dtype)
mask_feat_raw = self.mask_pooling(feat, mask, mask_token_num=mask_num) # [n, 1024]
query_feats.append(mask_feat_raw)
if len(query_feats)==0:
return None
mask_feats = torch.cat(query_feats, dim=0)
mask_feats = mask_feats.to(feats[0].dtype)
mask_feats_linear = self.feat_linear(mask_feats)
return mask_feats_linear
def kmeans_fast(tokens, num_clusters=10, num_iterations=5):
n, d = tokens.shape
centroids = tokens[torch.randperm(n)[:num_clusters]]
for _ in range(num_iterations):
tokens_expand = tokens.unsqueeze(1) # [n, 1, d]
centroids_expand = centroids.unsqueeze(0) # [1, num_clusters, d]
distances = torch.sum((tokens_expand - centroids_expand) ** 2, dim=2) # [n, num_clusters]
labels = torch.argmin(distances, dim=1) # [n]
new_centroids = torch.stack([tokens[labels == i].mean(dim=0) if tokens[labels == i].size(0) > 0 else centroids[i] for i in range(num_clusters)])
if torch.allclose(centroids, new_centroids, atol=1e-6):
break
centroids = new_centroids
return centroids
class MaskPooling(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, mask, mask_token_num=10):
if not x.shape[-2:] == mask.shape[-2:]:
# reshape mask to x
x = F.interpolate(x, size=mask.shape[-2:], mode='bilinear', align_corners=False)
# mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False)
if not x.device == mask.device:
mask = mask.to(x.device)
# b, c, h ,w = x.shape
# b, q, h, w = mask.shape
mask = (mask > 0).to(mask.dtype)
mask = mask.permute(1,0,2,3)
denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8
mask_emb = x * mask
mask = torch.any(mask_emb != 0, dim=(0, 1))
mask_emb = mask_emb[:,:, mask]
mask_embedding = mask_emb[0].permute(1,0)
if len(mask_embedding)>mask_token_num: #FIXME
mask_embedding = kmeans_fast(mask_embedding, mask_token_num)
return mask_embedding
def build_region_encoder(config, mm_hidden_size):
return MaskExtractor(config, mm_hidden_size)
|