Upload cluster_visualize.py
Browse files- cluster_visualize.py +176 -0
cluster_visualize.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# --------------------------------------------------------
|
| 2 |
+
# This file is a modified version of https://github.com/ma-xu/Context-Cluster/blob/main/cluster_visualize.py
|
| 3 |
+
# It is modified in order to make it compatible with Gradio.
|
| 4 |
+
# --------------------------------------------------------
|
| 5 |
+
|
| 6 |
+
import context_cluster.models as models
|
| 7 |
+
import timm
|
| 8 |
+
import os
|
| 9 |
+
import torch
|
| 10 |
+
import argparse
|
| 11 |
+
import cv2
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import torchvision.transforms.functional as TransF
|
| 15 |
+
from torchvision import transforms
|
| 16 |
+
from einops import rearrange
|
| 17 |
+
import random
|
| 18 |
+
from timm.models import load_checkpoint
|
| 19 |
+
from torchvision.utils import draw_segmentation_masks
|
| 20 |
+
|
| 21 |
+
object_categories = []
|
| 22 |
+
with open("./context_cluster/imagenet1k_id_to_label.txt", "r") as f:
|
| 23 |
+
for line in f:
|
| 24 |
+
_, val = line.strip().split(":")
|
| 25 |
+
object_categories.append(val)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PredictionArgs:
|
| 29 |
+
def __init__(self,
|
| 30 |
+
model,
|
| 31 |
+
checkpoint,
|
| 32 |
+
image,
|
| 33 |
+
shape=224,
|
| 34 |
+
stage=0,
|
| 35 |
+
block=0,
|
| 36 |
+
head=1,
|
| 37 |
+
resize_img=False,
|
| 38 |
+
alpha=0.5):
|
| 39 |
+
"""
|
| 40 |
+
This class contains all the arguments required for model prediction.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
model: `str` denoting the name of model. ex. 'coc_tiny', 'coc_small', 'coc_medium'.
|
| 44 |
+
checkpoint: `str` denoting the path of model checkpoint.
|
| 45 |
+
image: `np.array` denoting the path of image.
|
| 46 |
+
shape: `int` denoting the dimension of square image.
|
| 47 |
+
stage: `int` denoting index of visualized stage, 0-3.
|
| 48 |
+
block: `int` denoting index of visualized stage, -1 is the last block ,2,3,4,1.
|
| 49 |
+
head: `int` denoting index of visualized head, 0-3 or 0-7.
|
| 50 |
+
resize_img: Boolean denoting whether to resize img to feature-map size.
|
| 51 |
+
alpha: `float` denoting transparency, 0-1.
|
| 52 |
+
"""
|
| 53 |
+
self.model = model
|
| 54 |
+
self.checkpoint = checkpoint
|
| 55 |
+
self.image = image
|
| 56 |
+
self.shape = shape
|
| 57 |
+
self.stage = stage
|
| 58 |
+
self.block = block
|
| 59 |
+
self.head = head
|
| 60 |
+
self.resize_img = resize_img
|
| 61 |
+
self.alpha = alpha
|
| 62 |
+
assert self.model in timm.list_models(), "Please use a timm pre-trined model, see timm.list_models()"
|
| 63 |
+
|
| 64 |
+
# Preprocessing
|
| 65 |
+
def _preprocess(raw_image):
|
| 66 |
+
raw_image = cv2.resize(raw_image, (224,) * 2)
|
| 67 |
+
image = transforms.Compose(
|
| 68 |
+
[
|
| 69 |
+
transforms.ToTensor(),
|
| 70 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 71 |
+
]
|
| 72 |
+
)(raw_image[..., ::-1].copy())
|
| 73 |
+
return image, raw_image
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def pairwise_cos_sim(x1: torch.Tensor, x2: torch.Tensor):
|
| 77 |
+
"""
|
| 78 |
+
return pair-wise similarity matrix between two tensors
|
| 79 |
+
:param x1: [B,M,D]
|
| 80 |
+
:param x2: [B,N,D]
|
| 81 |
+
:return: similarity matrix [B,M,N]
|
| 82 |
+
"""
|
| 83 |
+
x1 = F.normalize(x1, dim=-1)
|
| 84 |
+
x2 = F.normalize(x2, dim=-1)
|
| 85 |
+
sim = torch.matmul(x1, x2.permute(0, 2, 1))
|
| 86 |
+
return sim
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# forward hook function
|
| 90 |
+
def get_attention_score(self, input, output):
|
| 91 |
+
x = input[0] # input tensor in a tuple
|
| 92 |
+
value = self.v(x)
|
| 93 |
+
x = self.f(x)
|
| 94 |
+
x = rearrange(x, "b (e c) w h -> (b e) c w h", e=self.heads)
|
| 95 |
+
value = rearrange(value, "b (e c) w h -> (b e) c w h", e=self.heads)
|
| 96 |
+
if self.fold_w > 1 and self.fold_h > 1:
|
| 97 |
+
b0, c0, w0, h0 = x.shape
|
| 98 |
+
assert w0 % self.fold_w == 0 and h0 % self.fold_h == 0, \
|
| 99 |
+
f"Ensure the feature map size ({w0}*{h0}) can be divided by fold {self.fold_w}*{self.fold_h}"
|
| 100 |
+
x = rearrange(x, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=self.fold_w,
|
| 101 |
+
f2=self.fold_h) # [bs*blocks,c,ks[0],ks[1]]
|
| 102 |
+
value = rearrange(value, "b c (f1 w) (f2 h) -> (b f1 f2) c w h", f1=self.fold_w, f2=self.fold_h)
|
| 103 |
+
b, c, w, h = x.shape
|
| 104 |
+
centers = self.centers_proposal(x) # [b,c,C_W,C_H], we set M = C_W*C_H and N = w*h
|
| 105 |
+
value_centers = rearrange(self.centers_proposal(value), 'b c w h -> b (w h) c') # [b,C_W,C_H,c]
|
| 106 |
+
b, c, ww, hh = centers.shape
|
| 107 |
+
sim = torch.sigmoid(self.sim_beta +
|
| 108 |
+
self.sim_alpha * pairwise_cos_sim(
|
| 109 |
+
centers.reshape(b, c, -1).permute(0, 2, 1),
|
| 110 |
+
x.reshape(b, c, -1).permute(0, 2,1)
|
| 111 |
+
)
|
| 112 |
+
) # [B,M,N]
|
| 113 |
+
# sololy assign each point to one center
|
| 114 |
+
sim_max, sim_max_idx = sim.max(dim=1, keepdim=True)
|
| 115 |
+
mask = torch.zeros_like(sim) # binary #[B,M,N]
|
| 116 |
+
mask.scatter_(1, sim_max_idx, 1.) # binary #[B,M,N]
|
| 117 |
+
# changed, for plotting mask.
|
| 118 |
+
mask = mask.reshape(mask.shape[0], mask.shape[1], w, h) # [(head*fold*fold),m, w,h]
|
| 119 |
+
mask = rearrange(mask, "(h0 f1 f2) m w h -> h0 (f1 f2) m w h",
|
| 120 |
+
h0=self.heads, f1=self.fold_w, f2=self.fold_h) # [head, (fold*fold),m, w,h]
|
| 121 |
+
mask_list = []
|
| 122 |
+
for i in range(self.fold_w):
|
| 123 |
+
for j in range(self.fold_h):
|
| 124 |
+
for k in range(mask.shape[2]):
|
| 125 |
+
temp = torch.zeros(self.heads, w * self.fold_w, h * self.fold_h)
|
| 126 |
+
temp[:, i * w:(i + 1) * w, j * h:(j + 1) * h] = mask[:, i * self.fold_w + j, k, :, :]
|
| 127 |
+
mask_list.append(temp.unsqueeze(dim=0)) # [1, heads, w, h]
|
| 128 |
+
|
| 129 |
+
mask2 = torch.concat(mask_list, dim=0) # [ n, heads, w, h]
|
| 130 |
+
global attention
|
| 131 |
+
attention = mask2.detach()
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def generate_visualization(args):
|
| 135 |
+
global attention
|
| 136 |
+
image, raw_image = _preprocess(args.image)
|
| 137 |
+
image = image.unsqueeze(dim=0)
|
| 138 |
+
model = timm.create_model(model_name=args.model, pretrained=True)
|
| 139 |
+
if args.checkpoint:
|
| 140 |
+
load_checkpoint(model, args.checkpoint, True)
|
| 141 |
+
print(f"\n\n==> Loaded checkpoint")
|
| 142 |
+
else:
|
| 143 |
+
raise Exception("Checkpoint doesn't exist at specified path: {}".format(args.checkpoint))
|
| 144 |
+
print(f"\n\n==> NO checkpoint is loaded")
|
| 145 |
+
model.network[args.stage * 2][args.block].token_mixer.register_forward_hook(get_attention_score)
|
| 146 |
+
out = model(image)
|
| 147 |
+
if type(out) is tuple:
|
| 148 |
+
out = out[0]
|
| 149 |
+
possibility = torch.softmax(out, dim=1).max() * 100
|
| 150 |
+
possibility = "{:.3f}".format(possibility)
|
| 151 |
+
value, index = torch.max(out, dim=1)
|
| 152 |
+
|
| 153 |
+
from torchvision.io import read_image
|
| 154 |
+
img = torch.tensor(raw_image).permute(2, 0, 1)
|
| 155 |
+
|
| 156 |
+
# process the attention map
|
| 157 |
+
attention = attention[:, args.head, :, :]
|
| 158 |
+
mask = attention.unsqueeze(dim=0)
|
| 159 |
+
mask = F.interpolate(mask, (img.shape[-2], img.shape[-1]))
|
| 160 |
+
mask = mask.squeeze(dim=0)
|
| 161 |
+
mask = mask > 0.5
|
| 162 |
+
# randomly selected some good colors.
|
| 163 |
+
colors = ["brown", "green", "deepskyblue", "blue", "darkgreen", "darkcyan", "coral", "aliceblue",
|
| 164 |
+
"white", "black", "beige", "red", "tomato", "yellowgreen", "violet", "mediumseagreen"]
|
| 165 |
+
if mask.shape[0] == 4:
|
| 166 |
+
colors = colors[0:4]
|
| 167 |
+
if mask.shape[0] > 4:
|
| 168 |
+
colors = colors * (mask.shape[0] // 16)
|
| 169 |
+
random.seed(123)
|
| 170 |
+
random.shuffle(colors)
|
| 171 |
+
|
| 172 |
+
img_with_masks = draw_segmentation_masks(img, masks=mask, alpha=args.alpha, colors=colors)
|
| 173 |
+
img_with_masks = img_with_masks.detach()
|
| 174 |
+
img_with_masks = TransF.to_pil_image(img_with_masks)
|
| 175 |
+
img_with_masks = np.asarray(img_with_masks)
|
| 176 |
+
return img_with_masks, possibility
|