import matplotlib.pyplot as plt # from pytorch_lightning import seed_everything import torch import torch.nn.functional as F import torchvision.transforms.functional as TF import torchvision.transforms as T import cv2 from PIL import Image, ImageOps import numpy as np from sklearn.decomposition import PCA import os, glob import math from pathlib import Path from contextlib import nullcontext def _remove_axes(ax): ax.xaxis.set_major_formatter(plt.NullFormatter()) ax.yaxis.set_major_formatter(plt.NullFormatter()) ax.set_xticks([]) ax.set_yticks([]) def remove_axes(axes): if len(axes.shape) == 2: for ax1 in axes: for ax in ax1: _remove_axes(ax) else: for ax in axes: _remove_axes(ax) def pca(image_feats_list, dim=3, fit_pca=None, use_torch_pca=True, max_samples=None): device = image_feats_list[0].device def flatten(tensor, target_size=None): if len(tensor.shape) == 2: return tensor.detach().cpu() if target_size is not None and fit_pca is None: tensor = F.interpolate(tensor, (target_size, target_size), mode="bilinear") B, C, H, W = tensor.shape return tensor.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu() if len(image_feats_list) > 1 and fit_pca is None: if len(image_feats_list[0].shape) == 2: target_size = None else: target_size = image_feats_list[0].shape[2] else: target_size = None flattened_feats = [] for feats in image_feats_list: flattened_feats.append(flatten(feats, target_size)) x = torch.cat(flattened_feats, dim=0) # Subsample the data if max_samples is set and the number of samples exceeds max_samples if max_samples is not None and x.shape[0] > max_samples: indices = torch.randperm(x.shape[0])[:max_samples] x = x[indices] if fit_pca is None: if use_torch_pca: fit_pca = TorchPCA(n_components=dim).fit(x) else: fit_pca = PCA(n_components=dim).fit(x) reduced_feats = [] for feats in image_feats_list: x_red = fit_pca.transform(flatten(feats)) if isinstance(x_red, np.ndarray): x_red = torch.from_numpy(x_red) x_red -= x_red.min(dim=0, keepdim=True).values x_red /= x_red.max(dim=0, keepdim=True).values if len(feats.shape) == 2: reduced_feats.append(x_red) # 1D else: B, C, H, W = feats.shape reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device)) # 3D return reduced_feats, fit_pca @torch.no_grad() def plot_feats(image, lr, hr, save_name='feats.png'): assert len(image.shape) == len(lr.shape) == len(hr.shape) == 3 # seed_everything(0) [lr_feats_pca, hr_feats_pca], _ = pca([lr.unsqueeze(0), hr.unsqueeze(0)]) fig, ax = plt.subplots(1, 3, figsize=(12, 4)) ax[0].imshow(image.permute(1, 2, 0).detach().cpu()) ax[0].set_title("Image") ax[1].imshow(lr_feats_pca[0].permute(1, 2, 0).detach().cpu()) ax[1].set_title("Original Features") ax[2].imshow(hr_feats_pca[0].permute(1, 2, 0).detach().cpu()) ax[2].set_title("Upsampled Features") remove_axes(ax) plt.savefig(save_name, bbox_inches='tight', pad_inches=0.1) @torch.no_grad() def plot_feats_batch(images, lr, hr, save_name='feats_batch.png'): # images: (B, 3, H, W) # lr: (B, dim, h', w') # hr: (B, dim, H, W) b_i, b_lr, b_hr = images.shape[0], lr.shape[0], hr.shape[0] if not (b_i == b_lr == b_hr): raise ValueError(f"Batch size mismatch: images={b_i}, lr={b_lr}, hr={b_hr}") B = min(b_i, 4) [lr_pca, hr_pca], _ = pca([lr, hr]) fig, axes = plt.subplots(B, 3, figsize=(3 * 3, B * 2)) for i in range(B): img = images[i].permute(1, 2, 0).cpu() lf = lr_pca[i].permute(1, 2, 0).cpu() hf = hr_pca[i].permute(1, 2, 0).cpu() axes[i,0].imshow(img); axes[i,0].set_title(f"Image {i}") axes[i,1].imshow(lf); axes[i,1].set_title("Original Feats") axes[i,2].imshow(hf); axes[i,2].set_title("Upsampled Feats") remove_axes(axes[i]) plt.tight_layout() plt.savefig(save_name, bbox_inches='tight', pad_inches=0.1) plt.close(fig) def visualize_dissim_maps(dissim_maps: torch.Tensor, cmap: str = "jet"): """ dissim_maps: [B, H, W] """ import matplotlib.pyplot as plt dissim_maps = dissim_maps.cpu().numpy() Bm1 = dissim_maps.shape[0] fig, axes = plt.subplots(1, Bm1, figsize=(4*Bm1, 4)) if Bm1 == 1: axes = [axes] for i, ax in enumerate(axes): im = ax.imshow(dissim_maps[i], cmap=cmap, vmin=0, vmax=1) # dissimilarity: 0~2 ax.set_title(f"View {i+1} vs Ref") ax.axis("off") fig.colorbar(im, ax=axes, fraction=0.015, pad=0.04) plt.show() def visualize_dissim_maps_minmax(dissim_maps, cmap='jet'): """ dissim_maps: [B, H, W] """ dissim_maps = dissim_maps.cpu().numpy() B = dissim_maps.shape[0] fig, axes = plt.subplots(1, B, figsize=(4*B, 4)) if B == 1: axes = [axes] ims = [] for i, ax in enumerate(axes): smap = dissim_maps[i] smap_min = smap.min() smap_max = smap.max() smap_norm = (smap - smap_min) / (smap_max - smap_min + 1e-8) # Min-max normalization im = ax.imshow(smap_norm, cmap=cmap, vmin=0, vmax=1) ax.set_title(f"View {i+1}") ax.axis("off") ims.append(im) # Adjust fraction to change vertical length of the colorbar # Adjust aspect to control thickness vs length ratio cbar = fig.colorbar( ims[0], ax=axes, fraction=0.015, # Smaller value → shorter vertical colorbar pad=0.04, aspect=10 # Smaller value → shorter & thicker bar ) cbar.ax.tick_params(labelsize=8) plt.show() class ToTensorWithoutScaling: """Convert PIL image or numpy array to a PyTorch tensor without scaling the values.""" def __call__(self, pic): # Convert the PIL Image or numpy array to a tensor (without scaling). return TF.pil_to_tensor(pic).long() class TorchPCA(object): def __init__(self, n_components): self.n_components = n_components def fit(self, X): self.mean_ = X.mean(dim=0) unbiased = X - self.mean_.unsqueeze(0) U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4) self.components_ = V.T self.singular_values_ = S return self def transform(self, X): t0 = X - self.mean_.unsqueeze(0) projected = t0 @ self.components_.T return projected import math from typing import List, Tuple from PIL import Image import torch import torch.nn.functional as F import torchvision.transforms.functional as TF def load_and_preprocess_images( image_path_list: List[str], mode: str = "crop", k: int = 14, target_size: int = None, # default: 37*k (518 when k=14) pad_value: float = 1.0, # padding color in [0,1] ) -> torch.Tensor: """ Load & preprocess images for model input. Args: image_path_list: list of image file paths. mode: "crop" or "pad". - "crop": set width to target_size (multiple of k), keep AR, then center-crop height to target_size if taller. - "pad": scale long side to target_size (keep AR), then pad to square target_size x target_size. k: multiple base (patch size). Both H and W are kept multiples of k. target_size: desired side length; default is 37*k (518 when k=14). pad_value: padding value in [0,1] (before normalization). Returns: Tensor of shape (N, 3, H, W). If inputs differ in shape, they are padded to common (maxH, maxW). """ if not image_path_list: raise ValueError("At least 1 image is required") if mode not in {"crop", "pad"}: raise ValueError("Mode must be either 'crop' or 'pad'") if target_size is None: target_size = 37 * k # 518 for k=14 # ensure target_size is multiple of k target_size = (target_size // k) * k target_size = max(k, target_size) imgs = [] shapes = set() for path in image_path_list: try: with Image.open(path) as im0: im = im0 # alpha → white if im.mode == "RGBA": bg = Image.new("RGBA", im.size, (255, 255, 255, 255)) im = Image.alpha_composite(bg, im) im = im.convert("RGB") w, h = im.size if mode == "pad": if w >= h: new_w = target_size new_h = int(h * (new_w / max(1, w))) else: new_h = target_size new_w = int(w * (new_h / max(1, h))) # snap to multiples of k (floor) and clamp new_w = max(k, (new_w // k) * k) new_h = max(k, (new_h // k) * k) else: # "crop" new_w = target_size new_h = int(h * (new_w / max(1, w))) new_h = max(k, (new_h // k) * k) # floor to multiple of k # resize with good downscale filter im = im.resize((new_w, new_h), Image.Resampling.LANCZOS) x = TF.to_tensor(im) # [3, H, W] in [0,1] if mode == "crop" and new_h > target_size: # center-crop height to target_size start_y = (new_h - target_size) // 2 x = x[:, start_y:start_y + target_size, :] new_h = target_size # keep book-keeping in sync if mode == "pad": # pad to square target_size x target_size hp = target_size - new_h wp = target_size - new_w if hp > 0 or wp > 0: top = hp // 2 bottom = hp - top left = wp // 2 right = wp - left x = F.pad(x, (left, right, top, bottom), value=pad_value) # collect H, W = x.shape[-2], x.shape[-1] # final safety: keep multiples of k assert H % k == 0 and W % k == 0, f"Not k-multiple: {(H, W)}" imgs.append(x) shapes.add((H, W)) except Exception as e: print(f"skip {path}: {e}") if not imgs: return torch.empty(0, 3, target_size, target_size) # unify shapes if needed if len(shapes) > 1: print(f"Warning: Found images with different shapes: {shapes}") maxH = max(h for h, _ in shapes) maxW = max(w for _, w in shapes) # ensure the common canvas is also multiples of k (it already is if each is) assert maxH % k == 0 and maxW % k == 0 padded = [] for x in imgs: hp = maxH - x.shape[-2] wp = maxW - x.shape[-1] if hp > 0 or wp > 0: top = hp // 2 bottom = hp - top left = wp // 2 right = wp - left x = F.pad(x, (left, right, top, bottom), value=pad_value) padded.append(x) imgs = padded return torch.stack(imgs) # [N,3,H,W] def to_multiple_of_k(img, k=14, mode="crop", k_down_steps=0): """ Adjust a PIL image so that both width and height are multiples of k. Args: k (int): multiple base (e.g., 14). mode (str): "pad" -> pad to ceil multiple; "crop" -> center-crop to floor multiple. k_down_steps (int): when mode="crop", go 'steps' multiples below the floor multiple. 0 -> floor (default, previous behavior) 1 -> floor - k 2 -> floor - 2k ... """ w, h = img.size if mode == "pad": nw = math.ceil(w / k) * k nh = math.ceil(h / k) * k return ImageOps.expand(img, border=(0, 0, nw - w, nh - h), fill=0) # crop mode with optional extra-down steps fw = (w // k) * k fh = (h // k) * k # go additional steps down nw = max(k, fw - k * max(0, int(k_down_steps))) nh = max(k, fh - k * max(0, int(k_down_steps))) # if original is smaller than requested multiple, fall back to pad-to-k if nw <= 0 or nh <= 0: nw = max(k, nw); nh = max(k, nh) return ImageOps.expand(img, border=(0, 0, nw - w, nh - h), fill=0) left = (w - nw) // 2 top = (h - nh) // 2 return img.crop((left, top, left + nw, top + nh)) def load_image_batch( dir_path, size=224, recursive=False, resize=False, k=None, k_mode="crop", k_down_steps=0, ): """ Load images from a directory as tensors, with optional resizing and optional alignment to multiples of k. When k_mode='crop', you can go to a smaller multiple via k_down_steps. Args: resize (bool): If True, first Resize/CenterCrop to [size, size]. k (int|None): If set, enforce H,W to be multiples of k. k_mode (str): 'pad' or 'crop'. k_down_steps (int): extra multiples to step down when k_mode='crop'. Default 0 keeps previous behavior. Returns: (batch_or_list, kept_paths) """ allowed = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'} it = Path(dir_path).rglob('*') if recursive else Path(dir_path).glob('*') paths = [str(p) for p in it if p.is_file() and p.suffix.lower() in allowed] paths.sort(key=lambda s: s.lower()) ops = [] if resize: ops += [T.Resize(size, T.InterpolationMode.BILINEAR), T.CenterCrop(size)] if k is not None: ops += [T.Lambda(lambda im: to_multiple_of_k(im, k=k, mode=k_mode, k_down_steps=k_down_steps))] ops += [T.ToTensor()] transform = T.Compose(ops) imgs, kept = [], [] for p in paths: try: with Image.open(p) as im: imgs.append(transform(im.convert("RGB"))) kept.append(p) except Exception as e: print(f"skip {p}: {e}") if not imgs: if resize and k is None: return torch.empty(0, 3, size, size), kept return [], kept try: return torch.stack(imgs), kept except RuntimeError: return imgs, kept def _align_spatial(x, size_hw): """Resize x to (H,W) with bilinear (no align_corners).""" H, W = size_hw if x.shape[-2:] == (H, W): return x return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False) def _apply_mask_weights(x, mask): """ x: [B, ..., H, W]; mask: [B,1,H,W] or None. Returns (x, weights) with weights summed per-batch later. """ if mask is None: return x, None # broadcast while mask.dim() < x.dim(): mask = mask return x * mask, mask @torch.no_grad() def cosine_map_and_score(f1, f2, mask=None, resize_to='f1', eps=1e-8, reduce='mean'): """ Pixel-wise cosine similarity over channels (C) -> heatmap + aggregated score. f1,f2: [B,C,H,W]; mask: [B,1,H,W] (1=valid). resize_to: 'f1'|'f2'|None (spatial align if needed). reduce: 'mean'|'median' Returns: (sim_map [B,1,H,W], score [B]) """ assert f1.dim()==4 and f2.dim()==4, "inputs must be [B,C,H,W]" if resize_to == 'f1': f2 = _align_spatial(f2, f1.shape[-2:]) elif resize_to == 'f2': f1 = _align_spatial(f1, f2.shape[-2:]) elif resize_to is None: assert f1.shape[-2:] == f2.shape[-2:], "spatial mismatch and resize_to=None" # L2-normalize across channels f1n = F.normalize(f1, p=2, dim=1, eps=eps) f2n = F.normalize(f2, p=2, dim=1, eps=eps) # cosine per-pixel = sum_C(f1n*f2n) sim = (f1n * f2n).sum(dim=1, keepdim=True) # [B,1,H,W], in [-1,1] if mask is not None: mask = mask.to(sim.dtype) sim = sim * mask # aggregate if mask is None: if reduce == 'mean': score = sim.mean(dim=(-2, -1)).squeeze(1) # [B] else: score = sim.flatten(2).median(dim=-1).values.squeeze(1) # [B] else: denom = mask.sum(dim=(-2, -1)).clamp_min(1e-6) # [B,1] if reduce == 'mean': score = (sim.sum(dim=(-2, -1)) / denom).squeeze(1) else: # masked median: fallback to mean over valid (simple) score = (sim.sum(dim=(-2, -1)) / denom).squeeze(1) return sim, score @torch.no_grad() def linear_cka(f1, f2, mask=None, resize_to='f1', eps=1e-8): """ Linear CKA per image (Kornblith19). Range ~[0,1], higher = more similar. f1,f2: [B,C,H,W]. We treat spatial positions as samples (N=H*W). Returns: scores [B] """ assert f1.dim()==4 and f2.dim()==4 if resize_to == 'f1': f2 = _align_spatial(f2, f1.shape[-2:]) elif resize_to == 'f2': f1 = _align_spatial(f1, f2.shape[-2:]) elif resize_to is None: assert f1.shape[-2:] == f2.shape[-2:] B, C1, H, W = f1.shape _, C2, _, _ = f2.shape assert C1 == C2, "channel dim must match for linear CKA" # reshape to X: [B,N,C], N=H*W X = f1.permute(0,2,3,1).reshape(B, -1, C1).contiguous() Y = f2.permute(0,2,3,1).reshape(B, -1, C1).contiguous() if mask is not None: # mask: [B,1,H,W] -> [B,N,1], select valid rows M = mask.reshape(B, -1, 1).bool() X = [x[m.expand_as(x)].view(-1, C1) for x,m in zip(X, M)] Y = [y[m.expand_as(y)].view(-1, C1) for y,m in zip(Y, M)] else: X = [X[i] for i in range(B)] Y = [Y[i] for i in range(B)] scores = [] for x, y in zip(X, Y): # center columns x = x - x.mean(dim=0, keepdim=True) y = y - y.mean(dim=0, keepdim=True) # (C,C) cross-cov (up to scale): X^T Y XtY = x.T @ y # [C,C] XtX = x.T @ x YtY = y.T @ y num = (XtY ** 2).sum() denom = (XtX ** 2).sum().sqrt() * (YtY ** 2).sum().sqrt() scores.append((num / denom.clamp_min(eps)).item()) return torch.tensor(scores, dtype=torch.float32) @torch.no_grad() def norm_mse(f1, f2, mask=None, resize_to='f1', eps=1e-8): """ MSE on L2-normalized features across channels (scale-invariant). Lower = more similar. Returns scalar [B]. """ if resize_to == 'f1': f2 = _align_spatial(f2, f1.shape[-2:]) elif resize_to == 'f2': f1 = _align_spatial(f1, f2.shape[-2:]) elif resize_to is None: assert f1.shape[-2:] == f2.shape[-2:] f1n = F.normalize(f1, p=2, dim=1, eps=eps) f2n = F.normalize(f2, p=2, dim=1, eps=eps) diff2 = (f1n - f2n) ** 2 if mask is not None: mask = mask.to(diff2.dtype) diff2 = diff2 * mask denom = mask.numel() / mask.shape[1] if mask.sum() == 0 else mask.sum() return (diff2.sum() / denom).reshape(1) return diff2.mean(dim=(1,2,3)) @torch.no_grad() def save_similarity_maps( sim_map: torch.Tensor, # [B,1,Hs,Ws] imgs = None, # [B,3,Hi,Wi] in [0,1] (optional, for overlay) out_dir: str = "sim_vis", prefix: str = "sim", vmin: float = -1.0, vmax: float = 1.0, cmap: str = "jet", alpha: float = 0.5, upsample_to_img: bool = True, # NEW: output sizing heatmap_size = None, # (H_out, W_out) heatmap_scale = None, # e.g., 2.0 → 2x match_img_size = False, # heatmap도 원본 이미지 크기로 ): """ Saves heatmap and (optionally) image overlay, with controllable output size. Priority of output size (for heatmap image): 1) match_img_size=True & imgs provided -> use that image's size 2) heatmap_size=(H_out, W_out) 3) heatmap_scale (relative to sim_map size) 4) default: sim_map size """ assert sim_map.dim() == 4 and sim_map.shape[1] == 1, "sim_map must be [B,1,H,W]" os.makedirs(out_dir, exist_ok=True) cm = plt.get_cmap(cmap) B, _, Hs, Ws = sim_map.shape for i in range(B): m = sim_map[i, 0] # [Hs,Ws] # normalize to [0,1] m01 = (m.clamp(vmin, vmax) - vmin) / max(1e-8, (vmax - vmin)) # ---- decide target size for HEATMAP file ---- if match_img_size and imgs is not None: Ht, Wt = imgs[i].shape[-2:] elif heatmap_size is not None: Ht, Wt = heatmap_size elif heatmap_scale is not None: Ht, Wt = int(round(Hs * heatmap_scale)), int(round(Ws * heatmap_scale)) else: Ht, Wt = Hs, Ws # resize map (for heatmap) if (Ht, Wt) != (Hs, Ws): m_for_heat = F.interpolate( m01.unsqueeze(0).unsqueeze(0), size=(Ht, Wt), mode="bilinear", align_corners=False )[0,0] else: m_for_heat = m01 # colorize and save heatmap with PIL (정확한 픽셀 크기 보장) m_rgb = (cm(m_for_heat.cpu().numpy())[...,:3] * 255).astype(np.uint8) # HxWx3 heat_path = os.path.join(out_dir, f"{prefix}_{i:03d}.png") Image.fromarray(m_rgb).save(heat_path) # ---- overlay ---- if imgs is not None: img = imgs[i] # [3,Hi,Wi] in [0,1] Hi, Wi = img.shape[-2:] # decide overlay target size if upsample_to_img: Ho, Wo = Hi, Wi else: Ho, Wo = (Ht, Wt) # use heatmap size if (m01.shape[-2:] != (Ho, Wo)): m_for_overlay = F.interpolate( m01.unsqueeze(0).unsqueeze(0), size=(Ho, Wo), mode="bilinear", align_corners=False )[0,0] else: m_for_overlay = m01 # colorize & blend m_rgb = cm(m_for_overlay.cpu().numpy())[...,:3] # [0,1] img_np = img.permute(1,2,0).cpu().clamp(0,1).numpy() # if sizes mismatch, resize image to (Ho,Wo) if img_np.shape[:2] != (Ho, Wo): img_np = np.array(Image.fromarray((img_np*255).astype(np.uint8)).resize((Wo,Ho), Image.BILINEAR)) / 255.0 overlay = np.clip((1 - alpha) * img_np + alpha * m_rgb, 0, 1) overlay_path = os.path.join(out_dir, f"{prefix}_{i:03d}_overlay.png") Image.fromarray((overlay * 255).astype(np.uint8)).save(overlay_path) print(f"Saved to: {out_dir}") @torch.no_grad() def save_similarity_maps_normalized( sim_map: torch.Tensor, # [B,1,Hs,Ws] imgs = None, # [B,3,Hi,Wi] in [0,1] (optional, for overlay) out_dir: str = "sim_vis", prefix: str = "sim", # --- normalization --- norm: str = "minmax", # 'minmax' | 'zscore' | 'range' vmin: float = -1.0, vmax: float = 1.0, # used when norm='range' # --- coloring/overlay --- cmap: str = "jet", alpha: float = 0.5, upsample_to_img: bool = True, # --- output size control (heatmap file) --- heatmap_size = None, # (H_out, W_out) heatmap_scale = None, # e.g., 2.0 match_img_size = False, # match heatmap to image size # --- optional extra outputs --- save_gray = False, # save normalized grayscale PNG save_npy = False, # save normalized numpy array ): """ Save similarity maps with explicit normalization and controllable output size. Heatmap size priority: 1) match_img_size=True & imgs provided -> use that image's size 2) heatmap_size 3) heatmap_scale 4) original sim_map size """ assert sim_map.dim() == 4 and sim_map.shape[1] == 1, "sim_map must be [B,1,H,W]" os.makedirs(out_dir, exist_ok=True) cm = plt.get_cmap(cmap) B, _, Hs, Ws = sim_map.shape for i in range(B): m = sim_map[i, 0] # [Hs,Ws] # ---- normalization -> m01 in [0,1] ---- if norm == "minmax": m01 = (m - m.min()) / (m.max() - m.min() + 1e-8) elif norm == "zscore": std = m.std(unbiased=False).clamp_min(1e-8) m01 = ((m - m.mean()) / std).sigmoid() elif norm == "range": m01 = (m.clamp(vmin, vmax) - vmin) / max(1e-8, (vmax - vmin)) else: raise ValueError("norm must be 'minmax', 'zscore', or 'range'") # ---- decide HEATMAP output size ---- if match_img_size and imgs is not None: Ht, Wt = imgs[i].shape[-2:] elif heatmap_size is not None: Ht, Wt = heatmap_size elif heatmap_scale is not None: Ht, Wt = int(round(Hs * heatmap_scale)), int(round(Ws * heatmap_scale)) else: Ht, Wt = Hs, Ws # resize normalized map for heatmap save if (Ht, Wt) != (Hs, Ws): m_heat = F.interpolate( m01.unsqueeze(0).unsqueeze(0), size=(Ht, Wt), mode="bilinear", align_corners=False )[0,0] else: m_heat = m01 # colorize & save heatmap with PIL (exact pixel size) m_rgb = (cm(m_heat.cpu().numpy())[...,:3] * 255).astype(np.uint8) heat_path = os.path.join(out_dir, f"{prefix}_{i:03d}_norm.png") Image.fromarray(m_rgb).save(heat_path) # optional grayscale + npy (use the same resized map) if save_gray: gray_path = os.path.join(out_dir, f"{prefix}_{i:03d}_norm_gray.png") Image.fromarray((m_heat.cpu().numpy() * 255).astype(np.uint8)).save(gray_path) if save_npy: npy_path = os.path.join(out_dir, f"{prefix}_{i:03d}_norm.npy") np.save(npy_path, m_heat.cpu().numpy()) # ---- overlay ---- if imgs is not None: img = imgs[i] # [3,Hi,Wi] Hi, Wi = img.shape[-2:] # overlay target size if upsample_to_img: Ho, Wo = Hi, Wi else: Ho, Wo = (Ht, Wt) if (m01.shape[-2:] != (Ho, Wo)): m_overlay = F.interpolate( m01.unsqueeze(0).unsqueeze(0), size=(Ho, Wo), mode="bilinear", align_corners=False )[0,0] else: m_overlay = m01 m_rgb = cm(m_overlay.cpu().numpy())[...,:3] img_np = img.permute(1,2,0).cpu().clamp(0,1).numpy() if img_np.shape[:2] != (Ho, Wo): img_np = np.array(Image.fromarray((img_np*255).astype(np.uint8)).resize((Wo,Ho), Image.BILINEAR)) / 255.0 overlay = np.clip((1 - alpha) * img_np + alpha * m_rgb, 0, 1) overlay_path = os.path.join(out_dir, f"{prefix}_{i:03d}_norm_overlay.png") Image.fromarray((overlay * 255).astype(np.uint8)).save(overlay_path) print(f"Saved to: {out_dir}") ADE20K_150_CATEGORIES = [ {"color": [120, 120, 120], "id": 0, "isthing": 0, "name": "wall"}, {"color": [180, 120, 120], "id": 1, "isthing": 0, "name": "building"}, {"color": [6, 230, 230], "id": 2, "isthing": 0, "name": "sky"}, {"color": [80, 50, 50], "id": 3, "isthing": 0, "name": "floor"}, {"color": [4, 200, 3], "id": 4, "isthing": 0, "name": "tree"}, {"color": [120, 120, 80], "id": 5, "isthing": 0, "name": "ceiling"}, {"color": [140, 140, 140], "id": 6, "isthing": 0, "name": "road, route"}, {"color": [204, 5, 255], "id": 7, "isthing": 1, "name": "bed"}, {"color": [230, 230, 230], "id": 8, "isthing": 1, "name": "window "}, {"color": [4, 250, 7], "id": 9, "isthing": 0, "name": "grass"}, {"color": [224, 5, 255], "id": 10, "isthing": 1, "name": "cabinet"}, {"color": [235, 255, 7], "id": 11, "isthing": 0, "name": "sidewalk, pavement"}, {"color": [150, 5, 61], "id": 12, "isthing": 1, "name": "person"}, {"color": [120, 120, 70], "id": 13, "isthing": 0, "name": "earth, ground"}, {"color": [8, 255, 51], "id": 14, "isthing": 1, "name": "door"}, {"color": [255, 6, 82], "id": 15, "isthing": 1, "name": "table"}, {"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "mountain, mount"}, {"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "plant"}, {"color": [255, 51, 7], "id": 18, "isthing": 1, "name": "curtain"}, {"color": [204, 70, 3], "id": 19, "isthing": 1, "name": "chair"}, {"color": [0, 102, 200], "id": 20, "isthing": 1, "name": "car"}, {"color": [61, 230, 250], "id": 21, "isthing": 0, "name": "water"}, {"color": [255, 6, 51], "id": 22, "isthing": 1, "name": "painting, picture"}, {"color": [11, 102, 255], "id": 23, "isthing": 1, "name": "sofa"}, {"color": [255, 7, 71], "id": 24, "isthing": 1, "name": "shelf"}, {"color": [255, 9, 224], "id": 25, "isthing": 0, "name": "house"}, {"color": [9, 7, 230], "id": 26, "isthing": 0, "name": "sea"}, {"color": [220, 220, 220], "id": 27, "isthing": 1, "name": "mirror"}, {"color": [255, 9, 92], "id": 28, "isthing": 0, "name": "rug"}, {"color": [112, 9, 255], "id": 29, "isthing": 0, "name": "field"}, {"color": [8, 255, 214], "id": 30, "isthing": 1, "name": "armchair"}, {"color": [7, 255, 224], "id": 31, "isthing": 1, "name": "seat"}, {"color": [255, 184, 6], "id": 32, "isthing": 1, "name": "fence"}, {"color": [10, 255, 71], "id": 33, "isthing": 1, "name": "desk"}, {"color": [255, 41, 10], "id": 34, "isthing": 0, "name": "rock, stone"}, {"color": [7, 255, 255], "id": 35, "isthing": 1, "name": "wardrobe, closet, press"}, {"color": [224, 255, 8], "id": 36, "isthing": 1, "name": "lamp"}, {"color": [102, 8, 255], "id": 37, "isthing": 1, "name": "tub"}, {"color": [255, 61, 6], "id": 38, "isthing": 1, "name": "rail"}, {"color": [255, 194, 7], "id": 39, "isthing": 1, "name": "cushion"}, {"color": [255, 122, 8], "id": 40, "isthing": 0, "name": "base, pedestal, stand"}, {"color": [0, 255, 20], "id": 41, "isthing": 1, "name": "box"}, {"color": [255, 8, 41], "id": 42, "isthing": 1, "name": "column, pillar"}, {"color": [255, 5, 153], "id": 43, "isthing": 1, "name": "signboard, sign"}, { "color": [6, 51, 255], "id": 44, "isthing": 1, "name": "chest of drawers, chest, bureau, dresser", }, {"color": [235, 12, 255], "id": 45, "isthing": 1, "name": "counter"}, {"color": [160, 150, 20], "id": 46, "isthing": 0, "name": "sand"}, {"color": [0, 163, 255], "id": 47, "isthing": 1, "name": "sink"}, {"color": [140, 140, 140], "id": 48, "isthing": 0, "name": "skyscraper"}, {"color": [250, 10, 15], "id": 49, "isthing": 1, "name": "fireplace"}, {"color": [20, 255, 0], "id": 50, "isthing": 1, "name": "refrigerator, icebox"}, {"color": [31, 255, 0], "id": 51, "isthing": 0, "name": "grandstand, covered stand"}, {"color": [255, 31, 0], "id": 52, "isthing": 0, "name": "path"}, {"color": [255, 224, 0], "id": 53, "isthing": 1, "name": "stairs"}, {"color": [153, 255, 0], "id": 54, "isthing": 0, "name": "runway"}, {"color": [0, 0, 255], "id": 55, "isthing": 1, "name": "case, display case, showcase, vitrine"}, { "color": [255, 71, 0], "id": 56, "isthing": 1, "name": "pool table, billiard table, snooker table", }, {"color": [0, 235, 255], "id": 57, "isthing": 1, "name": "pillow"}, {"color": [0, 173, 255], "id": 58, "isthing": 1, "name": "screen door, screen"}, {"color": [31, 0, 255], "id": 59, "isthing": 0, "name": "stairway, staircase"}, {"color": [11, 200, 200], "id": 60, "isthing": 0, "name": "river"}, {"color": [255, 82, 0], "id": 61, "isthing": 0, "name": "bridge, span"}, {"color": [0, 255, 245], "id": 62, "isthing": 1, "name": "bookcase"}, {"color": [0, 61, 255], "id": 63, "isthing": 0, "name": "blind, screen"}, {"color": [0, 255, 112], "id": 64, "isthing": 1, "name": "coffee table"}, { "color": [0, 255, 133], "id": 65, "isthing": 1, "name": "toilet, can, commode, crapper, pot, potty, stool, throne", }, {"color": [255, 0, 0], "id": 66, "isthing": 1, "name": "flower"}, {"color": [255, 163, 0], "id": 67, "isthing": 1, "name": "book"}, {"color": [255, 102, 0], "id": 68, "isthing": 0, "name": "hill"}, {"color": [194, 255, 0], "id": 69, "isthing": 1, "name": "bench"}, {"color": [0, 143, 255], "id": 70, "isthing": 1, "name": "countertop"}, {"color": [51, 255, 0], "id": 71, "isthing": 1, "name": "stove"}, {"color": [0, 82, 255], "id": 72, "isthing": 1, "name": "palm, palm tree"}, {"color": [0, 255, 41], "id": 73, "isthing": 1, "name": "kitchen island"}, {"color": [0, 255, 173], "id": 74, "isthing": 1, "name": "computer"}, {"color": [10, 0, 255], "id": 75, "isthing": 1, "name": "swivel chair"}, {"color": [173, 255, 0], "id": 76, "isthing": 1, "name": "boat"}, {"color": [0, 255, 153], "id": 77, "isthing": 0, "name": "bar"}, {"color": [255, 92, 0], "id": 78, "isthing": 1, "name": "arcade machine"}, {"color": [255, 0, 255], "id": 79, "isthing": 0, "name": "hovel, hut, hutch, shack, shanty"}, {"color": [255, 0, 245], "id": 80, "isthing": 1, "name": "bus"}, {"color": [255, 0, 102], "id": 81, "isthing": 1, "name": "towel"}, {"color": [255, 173, 0], "id": 82, "isthing": 1, "name": "light"}, {"color": [255, 0, 20], "id": 83, "isthing": 1, "name": "truck"}, {"color": [255, 184, 184], "id": 84, "isthing": 0, "name": "tower"}, {"color": [0, 31, 255], "id": 85, "isthing": 1, "name": "chandelier"}, {"color": [0, 255, 61], "id": 86, "isthing": 1, "name": "awning, sunshade, sunblind"}, {"color": [0, 71, 255], "id": 87, "isthing": 1, "name": "street lamp"}, {"color": [255, 0, 204], "id": 88, "isthing": 1, "name": "booth"}, {"color": [0, 255, 194], "id": 89, "isthing": 1, "name": "tv"}, {"color": [0, 255, 82], "id": 90, "isthing": 1, "name": "plane"}, {"color": [0, 10, 255], "id": 91, "isthing": 0, "name": "dirt track"}, {"color": [0, 112, 255], "id": 92, "isthing": 1, "name": "clothes"}, {"color": [51, 0, 255], "id": 93, "isthing": 1, "name": "pole"}, {"color": [0, 194, 255], "id": 94, "isthing": 0, "name": "land, ground, soil"}, { "color": [0, 122, 255], "id": 95, "isthing": 1, "name": "bannister, banister, balustrade, balusters, handrail", }, { "color": [0, 255, 163], "id": 96, "isthing": 0, "name": "escalator, moving staircase, moving stairway", }, { "color": [255, 153, 0], "id": 97, "isthing": 1, "name": "ottoman, pouf, pouffe, puff, hassock", }, {"color": [0, 255, 10], "id": 98, "isthing": 1, "name": "bottle"}, {"color": [255, 112, 0], "id": 99, "isthing": 0, "name": "buffet, counter, sideboard"}, { "color": [143, 255, 0], "id": 100, "isthing": 0, "name": "poster, posting, placard, notice, bill, card", }, {"color": [82, 0, 255], "id": 101, "isthing": 0, "name": "stage"}, {"color": [163, 255, 0], "id": 102, "isthing": 1, "name": "van"}, {"color": [255, 235, 0], "id": 103, "isthing": 1, "name": "ship"}, {"color": [8, 184, 170], "id": 104, "isthing": 1, "name": "fountain"}, { "color": [133, 0, 255], "id": 105, "isthing": 0, "name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter", }, {"color": [0, 255, 92], "id": 106, "isthing": 0, "name": "canopy"}, { "color": [184, 0, 255], "id": 107, "isthing": 1, "name": "washer, automatic washer, washing machine", }, {"color": [255, 0, 31], "id": 108, "isthing": 1, "name": "plaything, toy"}, {"color": [0, 184, 255], "id": 109, "isthing": 0, "name": "pool"}, {"color": [0, 214, 255], "id": 110, "isthing": 1, "name": "stool"}, {"color": [255, 0, 112], "id": 111, "isthing": 1, "name": "barrel, cask"}, {"color": [92, 255, 0], "id": 112, "isthing": 1, "name": "basket, handbasket"}, {"color": [0, 224, 255], "id": 113, "isthing": 0, "name": "falls"}, {"color": [112, 224, 255], "id": 114, "isthing": 0, "name": "tent"}, {"color": [70, 184, 160], "id": 115, "isthing": 1, "name": "bag"}, {"color": [163, 0, 255], "id": 116, "isthing": 1, "name": "minibike, motorbike"}, {"color": [153, 0, 255], "id": 117, "isthing": 0, "name": "cradle"}, {"color": [71, 255, 0], "id": 118, "isthing": 1, "name": "oven"}, {"color": [255, 0, 163], "id": 119, "isthing": 1, "name": "ball"}, {"color": [255, 204, 0], "id": 120, "isthing": 1, "name": "food, solid food"}, {"color": [255, 0, 143], "id": 121, "isthing": 1, "name": "step, stair"}, {"color": [0, 255, 235], "id": 122, "isthing": 0, "name": "tank, storage tank"}, {"color": [133, 255, 0], "id": 123, "isthing": 1, "name": "trade name"}, {"color": [255, 0, 235], "id": 124, "isthing": 1, "name": "microwave"}, {"color": [245, 0, 255], "id": 125, "isthing": 1, "name": "pot"}, {"color": [255, 0, 122], "id": 126, "isthing": 1, "name": "animal"}, {"color": [255, 245, 0], "id": 127, "isthing": 1, "name": "bicycle"}, {"color": [10, 190, 212], "id": 128, "isthing": 0, "name": "lake"}, {"color": [214, 255, 0], "id": 129, "isthing": 1, "name": "dishwasher"}, {"color": [0, 204, 255], "id": 130, "isthing": 1, "name": "screen"}, {"color": [20, 0, 255], "id": 131, "isthing": 0, "name": "blanket, cover"}, {"color": [255, 255, 0], "id": 132, "isthing": 1, "name": "sculpture"}, {"color": [0, 153, 255], "id": 133, "isthing": 1, "name": "hood, exhaust hood"}, {"color": [0, 41, 255], "id": 134, "isthing": 1, "name": "sconce"}, {"color": [0, 255, 204], "id": 135, "isthing": 1, "name": "vase"}, {"color": [41, 0, 255], "id": 136, "isthing": 1, "name": "traffic light"}, {"color": [41, 255, 0], "id": 137, "isthing": 1, "name": "tray"}, {"color": [173, 0, 255], "id": 138, "isthing": 1, "name": "trash can"}, {"color": [0, 245, 255], "id": 139, "isthing": 1, "name": "fan"}, {"color": [71, 0, 255], "id": 140, "isthing": 0, "name": "pier"}, {"color": [122, 0, 255], "id": 141, "isthing": 0, "name": "crt screen"}, {"color": [0, 255, 184], "id": 142, "isthing": 1, "name": "plate"}, {"color": [0, 92, 255], "id": 143, "isthing": 1, "name": "monitor"}, {"color": [184, 255, 0], "id": 144, "isthing": 1, "name": "bulletin board"}, {"color": [0, 133, 255], "id": 145, "isthing": 0, "name": "shower"}, {"color": [255, 214, 0], "id": 146, "isthing": 1, "name": "radiator"}, {"color": [25, 194, 194], "id": 147, "isthing": 1, "name": "glass, drinking glass"}, {"color": [102, 255, 0], "id": 148, "isthing": 1, "name": "clock"}, {"color": [92, 0, 255], "id": 149, "isthing": 1, "name": "flag"}, ]