Spaces:
Running on Zero
Running on Zero
| import matplotlib.pyplot as plt | |
| # from pytorch_lightning import seed_everything | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms.functional as TF | |
| import torchvision.transforms as T | |
| import cv2 | |
| from PIL import Image, ImageOps | |
| import numpy as np | |
| from sklearn.decomposition import PCA | |
| import os, glob | |
| import math | |
| from pathlib import Path | |
| from contextlib import nullcontext | |
| def _remove_axes(ax): | |
| ax.xaxis.set_major_formatter(plt.NullFormatter()) | |
| ax.yaxis.set_major_formatter(plt.NullFormatter()) | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| def remove_axes(axes): | |
| if len(axes.shape) == 2: | |
| for ax1 in axes: | |
| for ax in ax1: | |
| _remove_axes(ax) | |
| else: | |
| for ax in axes: | |
| _remove_axes(ax) | |
| def pca(image_feats_list, dim=3, fit_pca=None, use_torch_pca=True, max_samples=None): | |
| device = image_feats_list[0].device | |
| def flatten(tensor, target_size=None): | |
| if len(tensor.shape) == 2: | |
| return tensor.detach().cpu() | |
| if target_size is not None and fit_pca is None: | |
| tensor = F.interpolate(tensor, (target_size, target_size), mode="bilinear") | |
| B, C, H, W = tensor.shape | |
| return tensor.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu() | |
| if len(image_feats_list) > 1 and fit_pca is None: | |
| if len(image_feats_list[0].shape) == 2: | |
| target_size = None | |
| else: | |
| target_size = image_feats_list[0].shape[2] | |
| else: | |
| target_size = None | |
| flattened_feats = [] | |
| for feats in image_feats_list: | |
| flattened_feats.append(flatten(feats, target_size)) | |
| x = torch.cat(flattened_feats, dim=0) | |
| # Subsample the data if max_samples is set and the number of samples exceeds max_samples | |
| if max_samples is not None and x.shape[0] > max_samples: | |
| indices = torch.randperm(x.shape[0])[:max_samples] | |
| x = x[indices] | |
| if fit_pca is None: | |
| if use_torch_pca: | |
| fit_pca = TorchPCA(n_components=dim).fit(x) | |
| else: | |
| fit_pca = PCA(n_components=dim).fit(x) | |
| reduced_feats = [] | |
| for feats in image_feats_list: | |
| x_red = fit_pca.transform(flatten(feats)) | |
| if isinstance(x_red, np.ndarray): | |
| x_red = torch.from_numpy(x_red) | |
| x_red -= x_red.min(dim=0, keepdim=True).values | |
| x_red /= x_red.max(dim=0, keepdim=True).values | |
| if len(feats.shape) == 2: | |
| reduced_feats.append(x_red) # 1D | |
| else: | |
| B, C, H, W = feats.shape | |
| reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device)) # 3D | |
| return reduced_feats, fit_pca | |
| def plot_feats(image, lr, hr, save_name='feats.png'): | |
| assert len(image.shape) == len(lr.shape) == len(hr.shape) == 3 | |
| # seed_everything(0) | |
| [lr_feats_pca, hr_feats_pca], _ = pca([lr.unsqueeze(0), hr.unsqueeze(0)]) | |
| fig, ax = plt.subplots(1, 3, figsize=(12, 4)) | |
| ax[0].imshow(image.permute(1, 2, 0).detach().cpu()) | |
| ax[0].set_title("Image") | |
| ax[1].imshow(lr_feats_pca[0].permute(1, 2, 0).detach().cpu()) | |
| ax[1].set_title("Original Features") | |
| ax[2].imshow(hr_feats_pca[0].permute(1, 2, 0).detach().cpu()) | |
| ax[2].set_title("Upsampled Features") | |
| remove_axes(ax) | |
| plt.savefig(save_name, bbox_inches='tight', pad_inches=0.1) | |
| def plot_feats_batch(images, lr, hr, save_name='feats_batch.png'): | |
| # images: (B, 3, H, W) | |
| # lr: (B, dim, h', w') | |
| # hr: (B, dim, H, W) | |
| b_i, b_lr, b_hr = images.shape[0], lr.shape[0], hr.shape[0] | |
| if not (b_i == b_lr == b_hr): | |
| raise ValueError(f"Batch size mismatch: images={b_i}, lr={b_lr}, hr={b_hr}") | |
| B = min(b_i, 4) | |
| [lr_pca, hr_pca], _ = pca([lr, hr]) | |
| fig, axes = plt.subplots(B, 3, figsize=(3 * 3, B * 2)) | |
| for i in range(B): | |
| img = images[i].permute(1, 2, 0).cpu() | |
| lf = lr_pca[i].permute(1, 2, 0).cpu() | |
| hf = hr_pca[i].permute(1, 2, 0).cpu() | |
| axes[i,0].imshow(img); axes[i,0].set_title(f"Image {i}") | |
| axes[i,1].imshow(lf); axes[i,1].set_title("Original Feats") | |
| axes[i,2].imshow(hf); axes[i,2].set_title("Upsampled Feats") | |
| remove_axes(axes[i]) | |
| plt.tight_layout() | |
| plt.savefig(save_name, bbox_inches='tight', pad_inches=0.1) | |
| plt.close(fig) | |
| def visualize_dissim_maps(dissim_maps: torch.Tensor, cmap: str = "jet"): | |
| """ | |
| dissim_maps: [B, H, W] | |
| """ | |
| import matplotlib.pyplot as plt | |
| dissim_maps = dissim_maps.cpu().numpy() | |
| Bm1 = dissim_maps.shape[0] | |
| fig, axes = plt.subplots(1, Bm1, figsize=(4*Bm1, 4)) | |
| if Bm1 == 1: | |
| axes = [axes] | |
| for i, ax in enumerate(axes): | |
| im = ax.imshow(dissim_maps[i], cmap=cmap, vmin=0, vmax=1) # dissimilarity: 0~2 | |
| ax.set_title(f"View {i+1} vs Ref") | |
| ax.axis("off") | |
| fig.colorbar(im, ax=axes, fraction=0.015, pad=0.04) | |
| plt.show() | |
| def visualize_dissim_maps_minmax(dissim_maps, cmap='jet'): | |
| """ | |
| dissim_maps: [B, H, W] | |
| """ | |
| dissim_maps = dissim_maps.cpu().numpy() | |
| B = dissim_maps.shape[0] | |
| fig, axes = plt.subplots(1, B, figsize=(4*B, 4)) | |
| if B == 1: | |
| axes = [axes] | |
| ims = [] | |
| for i, ax in enumerate(axes): | |
| smap = dissim_maps[i] | |
| smap_min = smap.min() | |
| smap_max = smap.max() | |
| smap_norm = (smap - smap_min) / (smap_max - smap_min + 1e-8) # Min-max normalization | |
| im = ax.imshow(smap_norm, cmap=cmap, vmin=0, vmax=1) | |
| ax.set_title(f"View {i+1}") | |
| ax.axis("off") | |
| ims.append(im) | |
| # Adjust fraction to change vertical length of the colorbar | |
| # Adjust aspect to control thickness vs length ratio | |
| cbar = fig.colorbar( | |
| ims[0], | |
| ax=axes, | |
| fraction=0.015, # Smaller value → shorter vertical colorbar | |
| pad=0.04, | |
| aspect=10 # Smaller value → shorter & thicker bar | |
| ) | |
| cbar.ax.tick_params(labelsize=8) | |
| plt.show() | |
| class ToTensorWithoutScaling: | |
| """Convert PIL image or numpy array to a PyTorch tensor without scaling the values.""" | |
| def __call__(self, pic): | |
| # Convert the PIL Image or numpy array to a tensor (without scaling). | |
| return TF.pil_to_tensor(pic).long() | |
| class TorchPCA(object): | |
| def __init__(self, n_components): | |
| self.n_components = n_components | |
| def fit(self, X): | |
| self.mean_ = X.mean(dim=0) | |
| unbiased = X - self.mean_.unsqueeze(0) | |
| U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4) | |
| self.components_ = V.T | |
| self.singular_values_ = S | |
| return self | |
| def transform(self, X): | |
| t0 = X - self.mean_.unsqueeze(0) | |
| projected = t0 @ self.components_.T | |
| return projected | |
| import math | |
| from typing import List, Tuple | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| import torchvision.transforms.functional as TF | |
| def load_and_preprocess_images( | |
| image_path_list: List[str], | |
| mode: str = "crop", | |
| k: int = 14, | |
| target_size: int = None, # default: 37*k (518 when k=14) | |
| pad_value: float = 1.0, # padding color in [0,1] | |
| ) -> torch.Tensor: | |
| """ | |
| Load & preprocess images for model input. | |
| Args: | |
| image_path_list: list of image file paths. | |
| mode: "crop" or "pad". | |
| - "crop": set width to target_size (multiple of k), keep AR, then center-crop height to target_size if taller. | |
| - "pad": scale long side to target_size (keep AR), then pad to square target_size x target_size. | |
| k: multiple base (patch size). Both H and W are kept multiples of k. | |
| target_size: desired side length; default is 37*k (518 when k=14). | |
| pad_value: padding value in [0,1] (before normalization). | |
| Returns: | |
| Tensor of shape (N, 3, H, W). If inputs differ in shape, they are padded to common (maxH, maxW). | |
| """ | |
| if not image_path_list: | |
| raise ValueError("At least 1 image is required") | |
| if mode not in {"crop", "pad"}: | |
| raise ValueError("Mode must be either 'crop' or 'pad'") | |
| if target_size is None: | |
| target_size = 37 * k # 518 for k=14 | |
| # ensure target_size is multiple of k | |
| target_size = (target_size // k) * k | |
| target_size = max(k, target_size) | |
| imgs = [] | |
| shapes = set() | |
| for path in image_path_list: | |
| try: | |
| with Image.open(path) as im0: | |
| im = im0 | |
| # alpha → white | |
| if im.mode == "RGBA": | |
| bg = Image.new("RGBA", im.size, (255, 255, 255, 255)) | |
| im = Image.alpha_composite(bg, im) | |
| im = im.convert("RGB") | |
| w, h = im.size | |
| if mode == "pad": | |
| if w >= h: | |
| new_w = target_size | |
| new_h = int(h * (new_w / max(1, w))) | |
| else: | |
| new_h = target_size | |
| new_w = int(w * (new_h / max(1, h))) | |
| # snap to multiples of k (floor) and clamp | |
| new_w = max(k, (new_w // k) * k) | |
| new_h = max(k, (new_h // k) * k) | |
| else: # "crop" | |
| new_w = target_size | |
| new_h = int(h * (new_w / max(1, w))) | |
| new_h = max(k, (new_h // k) * k) # floor to multiple of k | |
| # resize with good downscale filter | |
| im = im.resize((new_w, new_h), Image.Resampling.LANCZOS) | |
| x = TF.to_tensor(im) # [3, H, W] in [0,1] | |
| if mode == "crop" and new_h > target_size: | |
| # center-crop height to target_size | |
| start_y = (new_h - target_size) // 2 | |
| x = x[:, start_y:start_y + target_size, :] | |
| new_h = target_size # keep book-keeping in sync | |
| if mode == "pad": | |
| # pad to square target_size x target_size | |
| hp = target_size - new_h | |
| wp = target_size - new_w | |
| if hp > 0 or wp > 0: | |
| top = hp // 2 | |
| bottom = hp - top | |
| left = wp // 2 | |
| right = wp - left | |
| x = F.pad(x, (left, right, top, bottom), value=pad_value) | |
| # collect | |
| H, W = x.shape[-2], x.shape[-1] | |
| # final safety: keep multiples of k | |
| assert H % k == 0 and W % k == 0, f"Not k-multiple: {(H, W)}" | |
| imgs.append(x) | |
| shapes.add((H, W)) | |
| except Exception as e: | |
| print(f"skip {path}: {e}") | |
| if not imgs: | |
| return torch.empty(0, 3, target_size, target_size) | |
| # unify shapes if needed | |
| if len(shapes) > 1: | |
| print(f"Warning: Found images with different shapes: {shapes}") | |
| maxH = max(h for h, _ in shapes) | |
| maxW = max(w for _, w in shapes) | |
| # ensure the common canvas is also multiples of k (it already is if each is) | |
| assert maxH % k == 0 and maxW % k == 0 | |
| padded = [] | |
| for x in imgs: | |
| hp = maxH - x.shape[-2] | |
| wp = maxW - x.shape[-1] | |
| if hp > 0 or wp > 0: | |
| top = hp // 2 | |
| bottom = hp - top | |
| left = wp // 2 | |
| right = wp - left | |
| x = F.pad(x, (left, right, top, bottom), value=pad_value) | |
| padded.append(x) | |
| imgs = padded | |
| return torch.stack(imgs) # [N,3,H,W] | |
| def to_multiple_of_k(img, k=14, mode="crop", k_down_steps=0): | |
| """ | |
| Adjust a PIL image so that both width and height are multiples of k. | |
| Args: | |
| k (int): multiple base (e.g., 14). | |
| mode (str): "pad" -> pad to ceil multiple; "crop" -> center-crop to floor multiple. | |
| k_down_steps (int): when mode="crop", go 'steps' multiples below the floor multiple. | |
| 0 -> floor (default, previous behavior) | |
| 1 -> floor - k | |
| 2 -> floor - 2k | |
| ... | |
| """ | |
| w, h = img.size | |
| if mode == "pad": | |
| nw = math.ceil(w / k) * k | |
| nh = math.ceil(h / k) * k | |
| return ImageOps.expand(img, border=(0, 0, nw - w, nh - h), fill=0) | |
| # crop mode with optional extra-down steps | |
| fw = (w // k) * k | |
| fh = (h // k) * k | |
| # go additional steps down | |
| nw = max(k, fw - k * max(0, int(k_down_steps))) | |
| nh = max(k, fh - k * max(0, int(k_down_steps))) | |
| # if original is smaller than requested multiple, fall back to pad-to-k | |
| if nw <= 0 or nh <= 0: | |
| nw = max(k, nw); nh = max(k, nh) | |
| return ImageOps.expand(img, border=(0, 0, nw - w, nh - h), fill=0) | |
| left = (w - nw) // 2 | |
| top = (h - nh) // 2 | |
| return img.crop((left, top, left + nw, top + nh)) | |
| def load_image_batch( | |
| dir_path, | |
| size=224, | |
| recursive=False, | |
| resize=False, | |
| k=None, | |
| k_mode="crop", | |
| k_down_steps=0, | |
| ): | |
| """ | |
| Load images from a directory as tensors, with optional resizing and | |
| optional alignment to multiples of k. When k_mode='crop', you can | |
| go to a smaller multiple via k_down_steps. | |
| Args: | |
| resize (bool): If True, first Resize/CenterCrop to [size, size]. | |
| k (int|None): If set, enforce H,W to be multiples of k. | |
| k_mode (str): 'pad' or 'crop'. | |
| k_down_steps (int): extra multiples to step down when k_mode='crop'. | |
| Default 0 keeps previous behavior. | |
| Returns: | |
| (batch_or_list, kept_paths) | |
| """ | |
| allowed = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'} | |
| it = Path(dir_path).rglob('*') if recursive else Path(dir_path).glob('*') | |
| paths = [str(p) for p in it if p.is_file() and p.suffix.lower() in allowed] | |
| paths.sort(key=lambda s: s.lower()) | |
| ops = [] | |
| if resize: | |
| ops += [T.Resize(size, T.InterpolationMode.BILINEAR), | |
| T.CenterCrop(size)] | |
| if k is not None: | |
| ops += [T.Lambda(lambda im: to_multiple_of_k(im, k=k, mode=k_mode, k_down_steps=k_down_steps))] | |
| ops += [T.ToTensor()] | |
| transform = T.Compose(ops) | |
| imgs, kept = [], [] | |
| for p in paths: | |
| try: | |
| with Image.open(p) as im: | |
| imgs.append(transform(im.convert("RGB"))) | |
| kept.append(p) | |
| except Exception as e: | |
| print(f"skip {p}: {e}") | |
| if not imgs: | |
| if resize and k is None: | |
| return torch.empty(0, 3, size, size), kept | |
| return [], kept | |
| try: | |
| return torch.stack(imgs), kept | |
| except RuntimeError: | |
| return imgs, kept | |
| def _align_spatial(x, size_hw): | |
| """Resize x to (H,W) with bilinear (no align_corners).""" | |
| H, W = size_hw | |
| if x.shape[-2:] == (H, W): | |
| return x | |
| return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False) | |
| def _apply_mask_weights(x, mask): | |
| """ | |
| x: [B, ..., H, W]; mask: [B,1,H,W] or None. | |
| Returns (x, weights) with weights summed per-batch later. | |
| """ | |
| if mask is None: | |
| return x, None | |
| # broadcast | |
| while mask.dim() < x.dim(): | |
| mask = mask | |
| return x * mask, mask | |
| def cosine_map_and_score(f1, f2, mask=None, resize_to='f1', eps=1e-8, reduce='mean'): | |
| """ | |
| Pixel-wise cosine similarity over channels (C) -> heatmap + aggregated score. | |
| f1,f2: [B,C,H,W]; mask: [B,1,H,W] (1=valid). | |
| resize_to: 'f1'|'f2'|None (spatial align if needed). | |
| reduce: 'mean'|'median' | |
| Returns: (sim_map [B,1,H,W], score [B]) | |
| """ | |
| assert f1.dim()==4 and f2.dim()==4, "inputs must be [B,C,H,W]" | |
| if resize_to == 'f1': | |
| f2 = _align_spatial(f2, f1.shape[-2:]) | |
| elif resize_to == 'f2': | |
| f1 = _align_spatial(f1, f2.shape[-2:]) | |
| elif resize_to is None: | |
| assert f1.shape[-2:] == f2.shape[-2:], "spatial mismatch and resize_to=None" | |
| # L2-normalize across channels | |
| f1n = F.normalize(f1, p=2, dim=1, eps=eps) | |
| f2n = F.normalize(f2, p=2, dim=1, eps=eps) | |
| # cosine per-pixel = sum_C(f1n*f2n) | |
| sim = (f1n * f2n).sum(dim=1, keepdim=True) # [B,1,H,W], in [-1,1] | |
| if mask is not None: | |
| mask = mask.to(sim.dtype) | |
| sim = sim * mask | |
| # aggregate | |
| if mask is None: | |
| if reduce == 'mean': | |
| score = sim.mean(dim=(-2, -1)).squeeze(1) # [B] | |
| else: | |
| score = sim.flatten(2).median(dim=-1).values.squeeze(1) # [B] | |
| else: | |
| denom = mask.sum(dim=(-2, -1)).clamp_min(1e-6) # [B,1] | |
| if reduce == 'mean': | |
| score = (sim.sum(dim=(-2, -1)) / denom).squeeze(1) | |
| else: | |
| # masked median: fallback to mean over valid (simple) | |
| score = (sim.sum(dim=(-2, -1)) / denom).squeeze(1) | |
| return sim, score | |
| def linear_cka(f1, f2, mask=None, resize_to='f1', eps=1e-8): | |
| """ | |
| Linear CKA per image (Kornblith19). Range ~[0,1], higher = more similar. | |
| f1,f2: [B,C,H,W]. We treat spatial positions as samples (N=H*W). | |
| Returns: scores [B] | |
| """ | |
| assert f1.dim()==4 and f2.dim()==4 | |
| if resize_to == 'f1': | |
| f2 = _align_spatial(f2, f1.shape[-2:]) | |
| elif resize_to == 'f2': | |
| f1 = _align_spatial(f1, f2.shape[-2:]) | |
| elif resize_to is None: | |
| assert f1.shape[-2:] == f2.shape[-2:] | |
| B, C1, H, W = f1.shape | |
| _, C2, _, _ = f2.shape | |
| assert C1 == C2, "channel dim must match for linear CKA" | |
| # reshape to X: [B,N,C], N=H*W | |
| X = f1.permute(0,2,3,1).reshape(B, -1, C1).contiguous() | |
| Y = f2.permute(0,2,3,1).reshape(B, -1, C1).contiguous() | |
| if mask is not None: | |
| # mask: [B,1,H,W] -> [B,N,1], select valid rows | |
| M = mask.reshape(B, -1, 1).bool() | |
| X = [x[m.expand_as(x)].view(-1, C1) for x,m in zip(X, M)] | |
| Y = [y[m.expand_as(y)].view(-1, C1) for y,m in zip(Y, M)] | |
| else: | |
| X = [X[i] for i in range(B)] | |
| Y = [Y[i] for i in range(B)] | |
| scores = [] | |
| for x, y in zip(X, Y): | |
| # center columns | |
| x = x - x.mean(dim=0, keepdim=True) | |
| y = y - y.mean(dim=0, keepdim=True) | |
| # (C,C) cross-cov (up to scale): X^T Y | |
| XtY = x.T @ y # [C,C] | |
| XtX = x.T @ x | |
| YtY = y.T @ y | |
| num = (XtY ** 2).sum() | |
| denom = (XtX ** 2).sum().sqrt() * (YtY ** 2).sum().sqrt() | |
| scores.append((num / denom.clamp_min(eps)).item()) | |
| return torch.tensor(scores, dtype=torch.float32) | |
| def norm_mse(f1, f2, mask=None, resize_to='f1', eps=1e-8): | |
| """ | |
| MSE on L2-normalized features across channels (scale-invariant). | |
| Lower = more similar. Returns scalar [B]. | |
| """ | |
| if resize_to == 'f1': | |
| f2 = _align_spatial(f2, f1.shape[-2:]) | |
| elif resize_to == 'f2': | |
| f1 = _align_spatial(f1, f2.shape[-2:]) | |
| elif resize_to is None: | |
| assert f1.shape[-2:] == f2.shape[-2:] | |
| f1n = F.normalize(f1, p=2, dim=1, eps=eps) | |
| f2n = F.normalize(f2, p=2, dim=1, eps=eps) | |
| diff2 = (f1n - f2n) ** 2 | |
| if mask is not None: | |
| mask = mask.to(diff2.dtype) | |
| diff2 = diff2 * mask | |
| denom = mask.numel() / mask.shape[1] if mask.sum() == 0 else mask.sum() | |
| return (diff2.sum() / denom).reshape(1) | |
| return diff2.mean(dim=(1,2,3)) | |
| def save_similarity_maps( | |
| sim_map: torch.Tensor, # [B,1,Hs,Ws] | |
| imgs = None, # [B,3,Hi,Wi] in [0,1] (optional, for overlay) | |
| out_dir: str = "sim_vis", | |
| prefix: str = "sim", | |
| vmin: float = -1.0, vmax: float = 1.0, | |
| cmap: str = "jet", | |
| alpha: float = 0.5, | |
| upsample_to_img: bool = True, | |
| # NEW: output sizing | |
| heatmap_size = None, # (H_out, W_out) | |
| heatmap_scale = None, # e.g., 2.0 → 2x | |
| match_img_size = False, # heatmap도 원본 이미지 크기로 | |
| ): | |
| """ | |
| Saves heatmap and (optionally) image overlay, with controllable output size. | |
| Priority of output size (for heatmap image): | |
| 1) match_img_size=True & imgs provided -> use that image's size | |
| 2) heatmap_size=(H_out, W_out) | |
| 3) heatmap_scale (relative to sim_map size) | |
| 4) default: sim_map size | |
| """ | |
| assert sim_map.dim() == 4 and sim_map.shape[1] == 1, "sim_map must be [B,1,H,W]" | |
| os.makedirs(out_dir, exist_ok=True) | |
| cm = plt.get_cmap(cmap) | |
| B, _, Hs, Ws = sim_map.shape | |
| for i in range(B): | |
| m = sim_map[i, 0] # [Hs,Ws] | |
| # normalize to [0,1] | |
| m01 = (m.clamp(vmin, vmax) - vmin) / max(1e-8, (vmax - vmin)) | |
| # ---- decide target size for HEATMAP file ---- | |
| if match_img_size and imgs is not None: | |
| Ht, Wt = imgs[i].shape[-2:] | |
| elif heatmap_size is not None: | |
| Ht, Wt = heatmap_size | |
| elif heatmap_scale is not None: | |
| Ht, Wt = int(round(Hs * heatmap_scale)), int(round(Ws * heatmap_scale)) | |
| else: | |
| Ht, Wt = Hs, Ws | |
| # resize map (for heatmap) | |
| if (Ht, Wt) != (Hs, Ws): | |
| m_for_heat = F.interpolate( | |
| m01.unsqueeze(0).unsqueeze(0), size=(Ht, Wt), | |
| mode="bilinear", align_corners=False | |
| )[0,0] | |
| else: | |
| m_for_heat = m01 | |
| # colorize and save heatmap with PIL (정확한 픽셀 크기 보장) | |
| m_rgb = (cm(m_for_heat.cpu().numpy())[...,:3] * 255).astype(np.uint8) # HxWx3 | |
| heat_path = os.path.join(out_dir, f"{prefix}_{i:03d}.png") | |
| Image.fromarray(m_rgb).save(heat_path) | |
| # ---- overlay ---- | |
| if imgs is not None: | |
| img = imgs[i] # [3,Hi,Wi] in [0,1] | |
| Hi, Wi = img.shape[-2:] | |
| # decide overlay target size | |
| if upsample_to_img: | |
| Ho, Wo = Hi, Wi | |
| else: | |
| Ho, Wo = (Ht, Wt) # use heatmap size | |
| if (m01.shape[-2:] != (Ho, Wo)): | |
| m_for_overlay = F.interpolate( | |
| m01.unsqueeze(0).unsqueeze(0), size=(Ho, Wo), | |
| mode="bilinear", align_corners=False | |
| )[0,0] | |
| else: | |
| m_for_overlay = m01 | |
| # colorize & blend | |
| m_rgb = cm(m_for_overlay.cpu().numpy())[...,:3] # [0,1] | |
| img_np = img.permute(1,2,0).cpu().clamp(0,1).numpy() | |
| # if sizes mismatch, resize image to (Ho,Wo) | |
| if img_np.shape[:2] != (Ho, Wo): | |
| img_np = np.array(Image.fromarray((img_np*255).astype(np.uint8)).resize((Wo,Ho), Image.BILINEAR)) / 255.0 | |
| overlay = np.clip((1 - alpha) * img_np + alpha * m_rgb, 0, 1) | |
| overlay_path = os.path.join(out_dir, f"{prefix}_{i:03d}_overlay.png") | |
| Image.fromarray((overlay * 255).astype(np.uint8)).save(overlay_path) | |
| print(f"Saved to: {out_dir}") | |
| def save_similarity_maps_normalized( | |
| sim_map: torch.Tensor, # [B,1,Hs,Ws] | |
| imgs = None, # [B,3,Hi,Wi] in [0,1] (optional, for overlay) | |
| out_dir: str = "sim_vis", | |
| prefix: str = "sim", | |
| # --- normalization --- | |
| norm: str = "minmax", # 'minmax' | 'zscore' | 'range' | |
| vmin: float = -1.0, vmax: float = 1.0, # used when norm='range' | |
| # --- coloring/overlay --- | |
| cmap: str = "jet", | |
| alpha: float = 0.5, | |
| upsample_to_img: bool = True, | |
| # --- output size control (heatmap file) --- | |
| heatmap_size = None, # (H_out, W_out) | |
| heatmap_scale = None, # e.g., 2.0 | |
| match_img_size = False, # match heatmap to image size | |
| # --- optional extra outputs --- | |
| save_gray = False, # save normalized grayscale PNG | |
| save_npy = False, # save normalized numpy array | |
| ): | |
| """ | |
| Save similarity maps with explicit normalization and controllable output size. | |
| Heatmap size priority: | |
| 1) match_img_size=True & imgs provided -> use that image's size | |
| 2) heatmap_size | |
| 3) heatmap_scale | |
| 4) original sim_map size | |
| """ | |
| assert sim_map.dim() == 4 and sim_map.shape[1] == 1, "sim_map must be [B,1,H,W]" | |
| os.makedirs(out_dir, exist_ok=True) | |
| cm = plt.get_cmap(cmap) | |
| B, _, Hs, Ws = sim_map.shape | |
| for i in range(B): | |
| m = sim_map[i, 0] # [Hs,Ws] | |
| # ---- normalization -> m01 in [0,1] ---- | |
| if norm == "minmax": | |
| m01 = (m - m.min()) / (m.max() - m.min() + 1e-8) | |
| elif norm == "zscore": | |
| std = m.std(unbiased=False).clamp_min(1e-8) | |
| m01 = ((m - m.mean()) / std).sigmoid() | |
| elif norm == "range": | |
| m01 = (m.clamp(vmin, vmax) - vmin) / max(1e-8, (vmax - vmin)) | |
| else: | |
| raise ValueError("norm must be 'minmax', 'zscore', or 'range'") | |
| # ---- decide HEATMAP output size ---- | |
| if match_img_size and imgs is not None: | |
| Ht, Wt = imgs[i].shape[-2:] | |
| elif heatmap_size is not None: | |
| Ht, Wt = heatmap_size | |
| elif heatmap_scale is not None: | |
| Ht, Wt = int(round(Hs * heatmap_scale)), int(round(Ws * heatmap_scale)) | |
| else: | |
| Ht, Wt = Hs, Ws | |
| # resize normalized map for heatmap save | |
| if (Ht, Wt) != (Hs, Ws): | |
| m_heat = F.interpolate( | |
| m01.unsqueeze(0).unsqueeze(0), size=(Ht, Wt), | |
| mode="bilinear", align_corners=False | |
| )[0,0] | |
| else: | |
| m_heat = m01 | |
| # colorize & save heatmap with PIL (exact pixel size) | |
| m_rgb = (cm(m_heat.cpu().numpy())[...,:3] * 255).astype(np.uint8) | |
| heat_path = os.path.join(out_dir, f"{prefix}_{i:03d}_norm.png") | |
| Image.fromarray(m_rgb).save(heat_path) | |
| # optional grayscale + npy (use the same resized map) | |
| if save_gray: | |
| gray_path = os.path.join(out_dir, f"{prefix}_{i:03d}_norm_gray.png") | |
| Image.fromarray((m_heat.cpu().numpy() * 255).astype(np.uint8)).save(gray_path) | |
| if save_npy: | |
| npy_path = os.path.join(out_dir, f"{prefix}_{i:03d}_norm.npy") | |
| np.save(npy_path, m_heat.cpu().numpy()) | |
| # ---- overlay ---- | |
| if imgs is not None: | |
| img = imgs[i] # [3,Hi,Wi] | |
| Hi, Wi = img.shape[-2:] | |
| # overlay target size | |
| if upsample_to_img: | |
| Ho, Wo = Hi, Wi | |
| else: | |
| Ho, Wo = (Ht, Wt) | |
| if (m01.shape[-2:] != (Ho, Wo)): | |
| m_overlay = F.interpolate( | |
| m01.unsqueeze(0).unsqueeze(0), size=(Ho, Wo), | |
| mode="bilinear", align_corners=False | |
| )[0,0] | |
| else: | |
| m_overlay = m01 | |
| m_rgb = cm(m_overlay.cpu().numpy())[...,:3] | |
| img_np = img.permute(1,2,0).cpu().clamp(0,1).numpy() | |
| if img_np.shape[:2] != (Ho, Wo): | |
| img_np = np.array(Image.fromarray((img_np*255).astype(np.uint8)).resize((Wo,Ho), Image.BILINEAR)) / 255.0 | |
| overlay = np.clip((1 - alpha) * img_np + alpha * m_rgb, 0, 1) | |
| overlay_path = os.path.join(out_dir, f"{prefix}_{i:03d}_norm_overlay.png") | |
| Image.fromarray((overlay * 255).astype(np.uint8)).save(overlay_path) | |
| print(f"Saved to: {out_dir}") | |
| ADE20K_150_CATEGORIES = [ | |
| {"color": [120, 120, 120], "id": 0, "isthing": 0, "name": "wall"}, | |
| {"color": [180, 120, 120], "id": 1, "isthing": 0, "name": "building"}, | |
| {"color": [6, 230, 230], "id": 2, "isthing": 0, "name": "sky"}, | |
| {"color": [80, 50, 50], "id": 3, "isthing": 0, "name": "floor"}, | |
| {"color": [4, 200, 3], "id": 4, "isthing": 0, "name": "tree"}, | |
| {"color": [120, 120, 80], "id": 5, "isthing": 0, "name": "ceiling"}, | |
| {"color": [140, 140, 140], "id": 6, "isthing": 0, "name": "road, route"}, | |
| {"color": [204, 5, 255], "id": 7, "isthing": 1, "name": "bed"}, | |
| {"color": [230, 230, 230], "id": 8, "isthing": 1, "name": "window "}, | |
| {"color": [4, 250, 7], "id": 9, "isthing": 0, "name": "grass"}, | |
| {"color": [224, 5, 255], "id": 10, "isthing": 1, "name": "cabinet"}, | |
| {"color": [235, 255, 7], "id": 11, "isthing": 0, "name": "sidewalk, pavement"}, | |
| {"color": [150, 5, 61], "id": 12, "isthing": 1, "name": "person"}, | |
| {"color": [120, 120, 70], "id": 13, "isthing": 0, "name": "earth, ground"}, | |
| {"color": [8, 255, 51], "id": 14, "isthing": 1, "name": "door"}, | |
| {"color": [255, 6, 82], "id": 15, "isthing": 1, "name": "table"}, | |
| {"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "mountain, mount"}, | |
| {"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "plant"}, | |
| {"color": [255, 51, 7], "id": 18, "isthing": 1, "name": "curtain"}, | |
| {"color": [204, 70, 3], "id": 19, "isthing": 1, "name": "chair"}, | |
| {"color": [0, 102, 200], "id": 20, "isthing": 1, "name": "car"}, | |
| {"color": [61, 230, 250], "id": 21, "isthing": 0, "name": "water"}, | |
| {"color": [255, 6, 51], "id": 22, "isthing": 1, "name": "painting, picture"}, | |
| {"color": [11, 102, 255], "id": 23, "isthing": 1, "name": "sofa"}, | |
| {"color": [255, 7, 71], "id": 24, "isthing": 1, "name": "shelf"}, | |
| {"color": [255, 9, 224], "id": 25, "isthing": 0, "name": "house"}, | |
| {"color": [9, 7, 230], "id": 26, "isthing": 0, "name": "sea"}, | |
| {"color": [220, 220, 220], "id": 27, "isthing": 1, "name": "mirror"}, | |
| {"color": [255, 9, 92], "id": 28, "isthing": 0, "name": "rug"}, | |
| {"color": [112, 9, 255], "id": 29, "isthing": 0, "name": "field"}, | |
| {"color": [8, 255, 214], "id": 30, "isthing": 1, "name": "armchair"}, | |
| {"color": [7, 255, 224], "id": 31, "isthing": 1, "name": "seat"}, | |
| {"color": [255, 184, 6], "id": 32, "isthing": 1, "name": "fence"}, | |
| {"color": [10, 255, 71], "id": 33, "isthing": 1, "name": "desk"}, | |
| {"color": [255, 41, 10], "id": 34, "isthing": 0, "name": "rock, stone"}, | |
| {"color": [7, 255, 255], "id": 35, "isthing": 1, "name": "wardrobe, closet, press"}, | |
| {"color": [224, 255, 8], "id": 36, "isthing": 1, "name": "lamp"}, | |
| {"color": [102, 8, 255], "id": 37, "isthing": 1, "name": "tub"}, | |
| {"color": [255, 61, 6], "id": 38, "isthing": 1, "name": "rail"}, | |
| {"color": [255, 194, 7], "id": 39, "isthing": 1, "name": "cushion"}, | |
| {"color": [255, 122, 8], "id": 40, "isthing": 0, "name": "base, pedestal, stand"}, | |
| {"color": [0, 255, 20], "id": 41, "isthing": 1, "name": "box"}, | |
| {"color": [255, 8, 41], "id": 42, "isthing": 1, "name": "column, pillar"}, | |
| {"color": [255, 5, 153], "id": 43, "isthing": 1, "name": "signboard, sign"}, | |
| { | |
| "color": [6, 51, 255], | |
| "id": 44, | |
| "isthing": 1, | |
| "name": "chest of drawers, chest, bureau, dresser", | |
| }, | |
| {"color": [235, 12, 255], "id": 45, "isthing": 1, "name": "counter"}, | |
| {"color": [160, 150, 20], "id": 46, "isthing": 0, "name": "sand"}, | |
| {"color": [0, 163, 255], "id": 47, "isthing": 1, "name": "sink"}, | |
| {"color": [140, 140, 140], "id": 48, "isthing": 0, "name": "skyscraper"}, | |
| {"color": [250, 10, 15], "id": 49, "isthing": 1, "name": "fireplace"}, | |
| {"color": [20, 255, 0], "id": 50, "isthing": 1, "name": "refrigerator, icebox"}, | |
| {"color": [31, 255, 0], "id": 51, "isthing": 0, "name": "grandstand, covered stand"}, | |
| {"color": [255, 31, 0], "id": 52, "isthing": 0, "name": "path"}, | |
| {"color": [255, 224, 0], "id": 53, "isthing": 1, "name": "stairs"}, | |
| {"color": [153, 255, 0], "id": 54, "isthing": 0, "name": "runway"}, | |
| {"color": [0, 0, 255], "id": 55, "isthing": 1, "name": "case, display case, showcase, vitrine"}, | |
| { | |
| "color": [255, 71, 0], | |
| "id": 56, | |
| "isthing": 1, | |
| "name": "pool table, billiard table, snooker table", | |
| }, | |
| {"color": [0, 235, 255], "id": 57, "isthing": 1, "name": "pillow"}, | |
| {"color": [0, 173, 255], "id": 58, "isthing": 1, "name": "screen door, screen"}, | |
| {"color": [31, 0, 255], "id": 59, "isthing": 0, "name": "stairway, staircase"}, | |
| {"color": [11, 200, 200], "id": 60, "isthing": 0, "name": "river"}, | |
| {"color": [255, 82, 0], "id": 61, "isthing": 0, "name": "bridge, span"}, | |
| {"color": [0, 255, 245], "id": 62, "isthing": 1, "name": "bookcase"}, | |
| {"color": [0, 61, 255], "id": 63, "isthing": 0, "name": "blind, screen"}, | |
| {"color": [0, 255, 112], "id": 64, "isthing": 1, "name": "coffee table"}, | |
| { | |
| "color": [0, 255, 133], | |
| "id": 65, | |
| "isthing": 1, | |
| "name": "toilet, can, commode, crapper, pot, potty, stool, throne", | |
| }, | |
| {"color": [255, 0, 0], "id": 66, "isthing": 1, "name": "flower"}, | |
| {"color": [255, 163, 0], "id": 67, "isthing": 1, "name": "book"}, | |
| {"color": [255, 102, 0], "id": 68, "isthing": 0, "name": "hill"}, | |
| {"color": [194, 255, 0], "id": 69, "isthing": 1, "name": "bench"}, | |
| {"color": [0, 143, 255], "id": 70, "isthing": 1, "name": "countertop"}, | |
| {"color": [51, 255, 0], "id": 71, "isthing": 1, "name": "stove"}, | |
| {"color": [0, 82, 255], "id": 72, "isthing": 1, "name": "palm, palm tree"}, | |
| {"color": [0, 255, 41], "id": 73, "isthing": 1, "name": "kitchen island"}, | |
| {"color": [0, 255, 173], "id": 74, "isthing": 1, "name": "computer"}, | |
| {"color": [10, 0, 255], "id": 75, "isthing": 1, "name": "swivel chair"}, | |
| {"color": [173, 255, 0], "id": 76, "isthing": 1, "name": "boat"}, | |
| {"color": [0, 255, 153], "id": 77, "isthing": 0, "name": "bar"}, | |
| {"color": [255, 92, 0], "id": 78, "isthing": 1, "name": "arcade machine"}, | |
| {"color": [255, 0, 255], "id": 79, "isthing": 0, "name": "hovel, hut, hutch, shack, shanty"}, | |
| {"color": [255, 0, 245], "id": 80, "isthing": 1, "name": "bus"}, | |
| {"color": [255, 0, 102], "id": 81, "isthing": 1, "name": "towel"}, | |
| {"color": [255, 173, 0], "id": 82, "isthing": 1, "name": "light"}, | |
| {"color": [255, 0, 20], "id": 83, "isthing": 1, "name": "truck"}, | |
| {"color": [255, 184, 184], "id": 84, "isthing": 0, "name": "tower"}, | |
| {"color": [0, 31, 255], "id": 85, "isthing": 1, "name": "chandelier"}, | |
| {"color": [0, 255, 61], "id": 86, "isthing": 1, "name": "awning, sunshade, sunblind"}, | |
| {"color": [0, 71, 255], "id": 87, "isthing": 1, "name": "street lamp"}, | |
| {"color": [255, 0, 204], "id": 88, "isthing": 1, "name": "booth"}, | |
| {"color": [0, 255, 194], "id": 89, "isthing": 1, "name": "tv"}, | |
| {"color": [0, 255, 82], "id": 90, "isthing": 1, "name": "plane"}, | |
| {"color": [0, 10, 255], "id": 91, "isthing": 0, "name": "dirt track"}, | |
| {"color": [0, 112, 255], "id": 92, "isthing": 1, "name": "clothes"}, | |
| {"color": [51, 0, 255], "id": 93, "isthing": 1, "name": "pole"}, | |
| {"color": [0, 194, 255], "id": 94, "isthing": 0, "name": "land, ground, soil"}, | |
| { | |
| "color": [0, 122, 255], | |
| "id": 95, | |
| "isthing": 1, | |
| "name": "bannister, banister, balustrade, balusters, handrail", | |
| }, | |
| { | |
| "color": [0, 255, 163], | |
| "id": 96, | |
| "isthing": 0, | |
| "name": "escalator, moving staircase, moving stairway", | |
| }, | |
| { | |
| "color": [255, 153, 0], | |
| "id": 97, | |
| "isthing": 1, | |
| "name": "ottoman, pouf, pouffe, puff, hassock", | |
| }, | |
| {"color": [0, 255, 10], "id": 98, "isthing": 1, "name": "bottle"}, | |
| {"color": [255, 112, 0], "id": 99, "isthing": 0, "name": "buffet, counter, sideboard"}, | |
| { | |
| "color": [143, 255, 0], | |
| "id": 100, | |
| "isthing": 0, | |
| "name": "poster, posting, placard, notice, bill, card", | |
| }, | |
| {"color": [82, 0, 255], "id": 101, "isthing": 0, "name": "stage"}, | |
| {"color": [163, 255, 0], "id": 102, "isthing": 1, "name": "van"}, | |
| {"color": [255, 235, 0], "id": 103, "isthing": 1, "name": "ship"}, | |
| {"color": [8, 184, 170], "id": 104, "isthing": 1, "name": "fountain"}, | |
| { | |
| "color": [133, 0, 255], | |
| "id": 105, | |
| "isthing": 0, | |
| "name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter", | |
| }, | |
| {"color": [0, 255, 92], "id": 106, "isthing": 0, "name": "canopy"}, | |
| { | |
| "color": [184, 0, 255], | |
| "id": 107, | |
| "isthing": 1, | |
| "name": "washer, automatic washer, washing machine", | |
| }, | |
| {"color": [255, 0, 31], "id": 108, "isthing": 1, "name": "plaything, toy"}, | |
| {"color": [0, 184, 255], "id": 109, "isthing": 0, "name": "pool"}, | |
| {"color": [0, 214, 255], "id": 110, "isthing": 1, "name": "stool"}, | |
| {"color": [255, 0, 112], "id": 111, "isthing": 1, "name": "barrel, cask"}, | |
| {"color": [92, 255, 0], "id": 112, "isthing": 1, "name": "basket, handbasket"}, | |
| {"color": [0, 224, 255], "id": 113, "isthing": 0, "name": "falls"}, | |
| {"color": [112, 224, 255], "id": 114, "isthing": 0, "name": "tent"}, | |
| {"color": [70, 184, 160], "id": 115, "isthing": 1, "name": "bag"}, | |
| {"color": [163, 0, 255], "id": 116, "isthing": 1, "name": "minibike, motorbike"}, | |
| {"color": [153, 0, 255], "id": 117, "isthing": 0, "name": "cradle"}, | |
| {"color": [71, 255, 0], "id": 118, "isthing": 1, "name": "oven"}, | |
| {"color": [255, 0, 163], "id": 119, "isthing": 1, "name": "ball"}, | |
| {"color": [255, 204, 0], "id": 120, "isthing": 1, "name": "food, solid food"}, | |
| {"color": [255, 0, 143], "id": 121, "isthing": 1, "name": "step, stair"}, | |
| {"color": [0, 255, 235], "id": 122, "isthing": 0, "name": "tank, storage tank"}, | |
| {"color": [133, 255, 0], "id": 123, "isthing": 1, "name": "trade name"}, | |
| {"color": [255, 0, 235], "id": 124, "isthing": 1, "name": "microwave"}, | |
| {"color": [245, 0, 255], "id": 125, "isthing": 1, "name": "pot"}, | |
| {"color": [255, 0, 122], "id": 126, "isthing": 1, "name": "animal"}, | |
| {"color": [255, 245, 0], "id": 127, "isthing": 1, "name": "bicycle"}, | |
| {"color": [10, 190, 212], "id": 128, "isthing": 0, "name": "lake"}, | |
| {"color": [214, 255, 0], "id": 129, "isthing": 1, "name": "dishwasher"}, | |
| {"color": [0, 204, 255], "id": 130, "isthing": 1, "name": "screen"}, | |
| {"color": [20, 0, 255], "id": 131, "isthing": 0, "name": "blanket, cover"}, | |
| {"color": [255, 255, 0], "id": 132, "isthing": 1, "name": "sculpture"}, | |
| {"color": [0, 153, 255], "id": 133, "isthing": 1, "name": "hood, exhaust hood"}, | |
| {"color": [0, 41, 255], "id": 134, "isthing": 1, "name": "sconce"}, | |
| {"color": [0, 255, 204], "id": 135, "isthing": 1, "name": "vase"}, | |
| {"color": [41, 0, 255], "id": 136, "isthing": 1, "name": "traffic light"}, | |
| {"color": [41, 255, 0], "id": 137, "isthing": 1, "name": "tray"}, | |
| {"color": [173, 0, 255], "id": 138, "isthing": 1, "name": "trash can"}, | |
| {"color": [0, 245, 255], "id": 139, "isthing": 1, "name": "fan"}, | |
| {"color": [71, 0, 255], "id": 140, "isthing": 0, "name": "pier"}, | |
| {"color": [122, 0, 255], "id": 141, "isthing": 0, "name": "crt screen"}, | |
| {"color": [0, 255, 184], "id": 142, "isthing": 1, "name": "plate"}, | |
| {"color": [0, 92, 255], "id": 143, "isthing": 1, "name": "monitor"}, | |
| {"color": [184, 255, 0], "id": 144, "isthing": 1, "name": "bulletin board"}, | |
| {"color": [0, 133, 255], "id": 145, "isthing": 0, "name": "shower"}, | |
| {"color": [255, 214, 0], "id": 146, "isthing": 1, "name": "radiator"}, | |
| {"color": [25, 194, 194], "id": 147, "isthing": 1, "name": "glass, drinking glass"}, | |
| {"color": [102, 255, 0], "id": 148, "isthing": 1, "name": "clock"}, | |
| {"color": [92, 0, 255], "id": 149, "isthing": 1, "name": "flag"}, | |
| ] |