PR-IQA / submodules /loftup /utils.py
2cu1001's picture
Upload 349 files
52d0a0e verified
import matplotlib.pyplot as plt
# from pytorch_lightning import seed_everything
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torchvision.transforms as T
import cv2
from PIL import Image, ImageOps
import numpy as np
from sklearn.decomposition import PCA
import os, glob
import math
from pathlib import Path
from contextlib import nullcontext
def _remove_axes(ax):
ax.xaxis.set_major_formatter(plt.NullFormatter())
ax.yaxis.set_major_formatter(plt.NullFormatter())
ax.set_xticks([])
ax.set_yticks([])
def remove_axes(axes):
if len(axes.shape) == 2:
for ax1 in axes:
for ax in ax1:
_remove_axes(ax)
else:
for ax in axes:
_remove_axes(ax)
def pca(image_feats_list, dim=3, fit_pca=None, use_torch_pca=True, max_samples=None):
device = image_feats_list[0].device
def flatten(tensor, target_size=None):
if len(tensor.shape) == 2:
return tensor.detach().cpu()
if target_size is not None and fit_pca is None:
tensor = F.interpolate(tensor, (target_size, target_size), mode="bilinear")
B, C, H, W = tensor.shape
return tensor.permute(1, 0, 2, 3).reshape(C, B * H * W).permute(1, 0).detach().cpu()
if len(image_feats_list) > 1 and fit_pca is None:
if len(image_feats_list[0].shape) == 2:
target_size = None
else:
target_size = image_feats_list[0].shape[2]
else:
target_size = None
flattened_feats = []
for feats in image_feats_list:
flattened_feats.append(flatten(feats, target_size))
x = torch.cat(flattened_feats, dim=0)
# Subsample the data if max_samples is set and the number of samples exceeds max_samples
if max_samples is not None and x.shape[0] > max_samples:
indices = torch.randperm(x.shape[0])[:max_samples]
x = x[indices]
if fit_pca is None:
if use_torch_pca:
fit_pca = TorchPCA(n_components=dim).fit(x)
else:
fit_pca = PCA(n_components=dim).fit(x)
reduced_feats = []
for feats in image_feats_list:
x_red = fit_pca.transform(flatten(feats))
if isinstance(x_red, np.ndarray):
x_red = torch.from_numpy(x_red)
x_red -= x_red.min(dim=0, keepdim=True).values
x_red /= x_red.max(dim=0, keepdim=True).values
if len(feats.shape) == 2:
reduced_feats.append(x_red) # 1D
else:
B, C, H, W = feats.shape
reduced_feats.append(x_red.reshape(B, H, W, dim).permute(0, 3, 1, 2).to(device)) # 3D
return reduced_feats, fit_pca
@torch.no_grad()
def plot_feats(image, lr, hr, save_name='feats.png'):
assert len(image.shape) == len(lr.shape) == len(hr.shape) == 3
# seed_everything(0)
[lr_feats_pca, hr_feats_pca], _ = pca([lr.unsqueeze(0), hr.unsqueeze(0)])
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(image.permute(1, 2, 0).detach().cpu())
ax[0].set_title("Image")
ax[1].imshow(lr_feats_pca[0].permute(1, 2, 0).detach().cpu())
ax[1].set_title("Original Features")
ax[2].imshow(hr_feats_pca[0].permute(1, 2, 0).detach().cpu())
ax[2].set_title("Upsampled Features")
remove_axes(ax)
plt.savefig(save_name, bbox_inches='tight', pad_inches=0.1)
@torch.no_grad()
def plot_feats_batch(images, lr, hr, save_name='feats_batch.png'):
# images: (B, 3, H, W)
# lr: (B, dim, h', w')
# hr: (B, dim, H, W)
b_i, b_lr, b_hr = images.shape[0], lr.shape[0], hr.shape[0]
if not (b_i == b_lr == b_hr):
raise ValueError(f"Batch size mismatch: images={b_i}, lr={b_lr}, hr={b_hr}")
B = min(b_i, 4)
[lr_pca, hr_pca], _ = pca([lr, hr])
fig, axes = plt.subplots(B, 3, figsize=(3 * 3, B * 2))
for i in range(B):
img = images[i].permute(1, 2, 0).cpu()
lf = lr_pca[i].permute(1, 2, 0).cpu()
hf = hr_pca[i].permute(1, 2, 0).cpu()
axes[i,0].imshow(img); axes[i,0].set_title(f"Image {i}")
axes[i,1].imshow(lf); axes[i,1].set_title("Original Feats")
axes[i,2].imshow(hf); axes[i,2].set_title("Upsampled Feats")
remove_axes(axes[i])
plt.tight_layout()
plt.savefig(save_name, bbox_inches='tight', pad_inches=0.1)
plt.close(fig)
def visualize_dissim_maps(dissim_maps: torch.Tensor, cmap: str = "jet"):
"""
dissim_maps: [B, H, W]
"""
import matplotlib.pyplot as plt
dissim_maps = dissim_maps.cpu().numpy()
Bm1 = dissim_maps.shape[0]
fig, axes = plt.subplots(1, Bm1, figsize=(4*Bm1, 4))
if Bm1 == 1:
axes = [axes]
for i, ax in enumerate(axes):
im = ax.imshow(dissim_maps[i], cmap=cmap, vmin=0, vmax=1) # dissimilarity: 0~2
ax.set_title(f"View {i+1} vs Ref")
ax.axis("off")
fig.colorbar(im, ax=axes, fraction=0.015, pad=0.04)
plt.show()
def visualize_dissim_maps_minmax(dissim_maps, cmap='jet'):
"""
dissim_maps: [B, H, W]
"""
dissim_maps = dissim_maps.cpu().numpy()
B = dissim_maps.shape[0]
fig, axes = plt.subplots(1, B, figsize=(4*B, 4))
if B == 1:
axes = [axes]
ims = []
for i, ax in enumerate(axes):
smap = dissim_maps[i]
smap_min = smap.min()
smap_max = smap.max()
smap_norm = (smap - smap_min) / (smap_max - smap_min + 1e-8) # Min-max normalization
im = ax.imshow(smap_norm, cmap=cmap, vmin=0, vmax=1)
ax.set_title(f"View {i+1}")
ax.axis("off")
ims.append(im)
# Adjust fraction to change vertical length of the colorbar
# Adjust aspect to control thickness vs length ratio
cbar = fig.colorbar(
ims[0],
ax=axes,
fraction=0.015, # Smaller value → shorter vertical colorbar
pad=0.04,
aspect=10 # Smaller value → shorter & thicker bar
)
cbar.ax.tick_params(labelsize=8)
plt.show()
class ToTensorWithoutScaling:
"""Convert PIL image or numpy array to a PyTorch tensor without scaling the values."""
def __call__(self, pic):
# Convert the PIL Image or numpy array to a tensor (without scaling).
return TF.pil_to_tensor(pic).long()
class TorchPCA(object):
def __init__(self, n_components):
self.n_components = n_components
def fit(self, X):
self.mean_ = X.mean(dim=0)
unbiased = X - self.mean_.unsqueeze(0)
U, S, V = torch.pca_lowrank(unbiased, q=self.n_components, center=False, niter=4)
self.components_ = V.T
self.singular_values_ = S
return self
def transform(self, X):
t0 = X - self.mean_.unsqueeze(0)
projected = t0 @ self.components_.T
return projected
import math
from typing import List, Tuple
from PIL import Image
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
def load_and_preprocess_images(
image_path_list: List[str],
mode: str = "crop",
k: int = 14,
target_size: int = None, # default: 37*k (518 when k=14)
pad_value: float = 1.0, # padding color in [0,1]
) -> torch.Tensor:
"""
Load & preprocess images for model input.
Args:
image_path_list: list of image file paths.
mode: "crop" or "pad".
- "crop": set width to target_size (multiple of k), keep AR, then center-crop height to target_size if taller.
- "pad": scale long side to target_size (keep AR), then pad to square target_size x target_size.
k: multiple base (patch size). Both H and W are kept multiples of k.
target_size: desired side length; default is 37*k (518 when k=14).
pad_value: padding value in [0,1] (before normalization).
Returns:
Tensor of shape (N, 3, H, W). If inputs differ in shape, they are padded to common (maxH, maxW).
"""
if not image_path_list:
raise ValueError("At least 1 image is required")
if mode not in {"crop", "pad"}:
raise ValueError("Mode must be either 'crop' or 'pad'")
if target_size is None:
target_size = 37 * k # 518 for k=14
# ensure target_size is multiple of k
target_size = (target_size // k) * k
target_size = max(k, target_size)
imgs = []
shapes = set()
for path in image_path_list:
try:
with Image.open(path) as im0:
im = im0
# alpha → white
if im.mode == "RGBA":
bg = Image.new("RGBA", im.size, (255, 255, 255, 255))
im = Image.alpha_composite(bg, im)
im = im.convert("RGB")
w, h = im.size
if mode == "pad":
if w >= h:
new_w = target_size
new_h = int(h * (new_w / max(1, w)))
else:
new_h = target_size
new_w = int(w * (new_h / max(1, h)))
# snap to multiples of k (floor) and clamp
new_w = max(k, (new_w // k) * k)
new_h = max(k, (new_h // k) * k)
else: # "crop"
new_w = target_size
new_h = int(h * (new_w / max(1, w)))
new_h = max(k, (new_h // k) * k) # floor to multiple of k
# resize with good downscale filter
im = im.resize((new_w, new_h), Image.Resampling.LANCZOS)
x = TF.to_tensor(im) # [3, H, W] in [0,1]
if mode == "crop" and new_h > target_size:
# center-crop height to target_size
start_y = (new_h - target_size) // 2
x = x[:, start_y:start_y + target_size, :]
new_h = target_size # keep book-keeping in sync
if mode == "pad":
# pad to square target_size x target_size
hp = target_size - new_h
wp = target_size - new_w
if hp > 0 or wp > 0:
top = hp // 2
bottom = hp - top
left = wp // 2
right = wp - left
x = F.pad(x, (left, right, top, bottom), value=pad_value)
# collect
H, W = x.shape[-2], x.shape[-1]
# final safety: keep multiples of k
assert H % k == 0 and W % k == 0, f"Not k-multiple: {(H, W)}"
imgs.append(x)
shapes.add((H, W))
except Exception as e:
print(f"skip {path}: {e}")
if not imgs:
return torch.empty(0, 3, target_size, target_size)
# unify shapes if needed
if len(shapes) > 1:
print(f"Warning: Found images with different shapes: {shapes}")
maxH = max(h for h, _ in shapes)
maxW = max(w for _, w in shapes)
# ensure the common canvas is also multiples of k (it already is if each is)
assert maxH % k == 0 and maxW % k == 0
padded = []
for x in imgs:
hp = maxH - x.shape[-2]
wp = maxW - x.shape[-1]
if hp > 0 or wp > 0:
top = hp // 2
bottom = hp - top
left = wp // 2
right = wp - left
x = F.pad(x, (left, right, top, bottom), value=pad_value)
padded.append(x)
imgs = padded
return torch.stack(imgs) # [N,3,H,W]
def to_multiple_of_k(img, k=14, mode="crop", k_down_steps=0):
"""
Adjust a PIL image so that both width and height are multiples of k.
Args:
k (int): multiple base (e.g., 14).
mode (str): "pad" -> pad to ceil multiple; "crop" -> center-crop to floor multiple.
k_down_steps (int): when mode="crop", go 'steps' multiples below the floor multiple.
0 -> floor (default, previous behavior)
1 -> floor - k
2 -> floor - 2k
...
"""
w, h = img.size
if mode == "pad":
nw = math.ceil(w / k) * k
nh = math.ceil(h / k) * k
return ImageOps.expand(img, border=(0, 0, nw - w, nh - h), fill=0)
# crop mode with optional extra-down steps
fw = (w // k) * k
fh = (h // k) * k
# go additional steps down
nw = max(k, fw - k * max(0, int(k_down_steps)))
nh = max(k, fh - k * max(0, int(k_down_steps)))
# if original is smaller than requested multiple, fall back to pad-to-k
if nw <= 0 or nh <= 0:
nw = max(k, nw); nh = max(k, nh)
return ImageOps.expand(img, border=(0, 0, nw - w, nh - h), fill=0)
left = (w - nw) // 2
top = (h - nh) // 2
return img.crop((left, top, left + nw, top + nh))
def load_image_batch(
dir_path,
size=224,
recursive=False,
resize=False,
k=None,
k_mode="crop",
k_down_steps=0,
):
"""
Load images from a directory as tensors, with optional resizing and
optional alignment to multiples of k. When k_mode='crop', you can
go to a smaller multiple via k_down_steps.
Args:
resize (bool): If True, first Resize/CenterCrop to [size, size].
k (int|None): If set, enforce H,W to be multiples of k.
k_mode (str): 'pad' or 'crop'.
k_down_steps (int): extra multiples to step down when k_mode='crop'.
Default 0 keeps previous behavior.
Returns:
(batch_or_list, kept_paths)
"""
allowed = {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}
it = Path(dir_path).rglob('*') if recursive else Path(dir_path).glob('*')
paths = [str(p) for p in it if p.is_file() and p.suffix.lower() in allowed]
paths.sort(key=lambda s: s.lower())
ops = []
if resize:
ops += [T.Resize(size, T.InterpolationMode.BILINEAR),
T.CenterCrop(size)]
if k is not None:
ops += [T.Lambda(lambda im: to_multiple_of_k(im, k=k, mode=k_mode, k_down_steps=k_down_steps))]
ops += [T.ToTensor()]
transform = T.Compose(ops)
imgs, kept = [], []
for p in paths:
try:
with Image.open(p) as im:
imgs.append(transform(im.convert("RGB")))
kept.append(p)
except Exception as e:
print(f"skip {p}: {e}")
if not imgs:
if resize and k is None:
return torch.empty(0, 3, size, size), kept
return [], kept
try:
return torch.stack(imgs), kept
except RuntimeError:
return imgs, kept
def _align_spatial(x, size_hw):
"""Resize x to (H,W) with bilinear (no align_corners)."""
H, W = size_hw
if x.shape[-2:] == (H, W):
return x
return F.interpolate(x, size=(H, W), mode='bilinear', align_corners=False)
def _apply_mask_weights(x, mask):
"""
x: [B, ..., H, W]; mask: [B,1,H,W] or None.
Returns (x, weights) with weights summed per-batch later.
"""
if mask is None:
return x, None
# broadcast
while mask.dim() < x.dim():
mask = mask
return x * mask, mask
@torch.no_grad()
def cosine_map_and_score(f1, f2, mask=None, resize_to='f1', eps=1e-8, reduce='mean'):
"""
Pixel-wise cosine similarity over channels (C) -> heatmap + aggregated score.
f1,f2: [B,C,H,W]; mask: [B,1,H,W] (1=valid).
resize_to: 'f1'|'f2'|None (spatial align if needed).
reduce: 'mean'|'median'
Returns: (sim_map [B,1,H,W], score [B])
"""
assert f1.dim()==4 and f2.dim()==4, "inputs must be [B,C,H,W]"
if resize_to == 'f1':
f2 = _align_spatial(f2, f1.shape[-2:])
elif resize_to == 'f2':
f1 = _align_spatial(f1, f2.shape[-2:])
elif resize_to is None:
assert f1.shape[-2:] == f2.shape[-2:], "spatial mismatch and resize_to=None"
# L2-normalize across channels
f1n = F.normalize(f1, p=2, dim=1, eps=eps)
f2n = F.normalize(f2, p=2, dim=1, eps=eps)
# cosine per-pixel = sum_C(f1n*f2n)
sim = (f1n * f2n).sum(dim=1, keepdim=True) # [B,1,H,W], in [-1,1]
if mask is not None:
mask = mask.to(sim.dtype)
sim = sim * mask
# aggregate
if mask is None:
if reduce == 'mean':
score = sim.mean(dim=(-2, -1)).squeeze(1) # [B]
else:
score = sim.flatten(2).median(dim=-1).values.squeeze(1) # [B]
else:
denom = mask.sum(dim=(-2, -1)).clamp_min(1e-6) # [B,1]
if reduce == 'mean':
score = (sim.sum(dim=(-2, -1)) / denom).squeeze(1)
else:
# masked median: fallback to mean over valid (simple)
score = (sim.sum(dim=(-2, -1)) / denom).squeeze(1)
return sim, score
@torch.no_grad()
def linear_cka(f1, f2, mask=None, resize_to='f1', eps=1e-8):
"""
Linear CKA per image (Kornblith19). Range ~[0,1], higher = more similar.
f1,f2: [B,C,H,W]. We treat spatial positions as samples (N=H*W).
Returns: scores [B]
"""
assert f1.dim()==4 and f2.dim()==4
if resize_to == 'f1':
f2 = _align_spatial(f2, f1.shape[-2:])
elif resize_to == 'f2':
f1 = _align_spatial(f1, f2.shape[-2:])
elif resize_to is None:
assert f1.shape[-2:] == f2.shape[-2:]
B, C1, H, W = f1.shape
_, C2, _, _ = f2.shape
assert C1 == C2, "channel dim must match for linear CKA"
# reshape to X: [B,N,C], N=H*W
X = f1.permute(0,2,3,1).reshape(B, -1, C1).contiguous()
Y = f2.permute(0,2,3,1).reshape(B, -1, C1).contiguous()
if mask is not None:
# mask: [B,1,H,W] -> [B,N,1], select valid rows
M = mask.reshape(B, -1, 1).bool()
X = [x[m.expand_as(x)].view(-1, C1) for x,m in zip(X, M)]
Y = [y[m.expand_as(y)].view(-1, C1) for y,m in zip(Y, M)]
else:
X = [X[i] for i in range(B)]
Y = [Y[i] for i in range(B)]
scores = []
for x, y in zip(X, Y):
# center columns
x = x - x.mean(dim=0, keepdim=True)
y = y - y.mean(dim=0, keepdim=True)
# (C,C) cross-cov (up to scale): X^T Y
XtY = x.T @ y # [C,C]
XtX = x.T @ x
YtY = y.T @ y
num = (XtY ** 2).sum()
denom = (XtX ** 2).sum().sqrt() * (YtY ** 2).sum().sqrt()
scores.append((num / denom.clamp_min(eps)).item())
return torch.tensor(scores, dtype=torch.float32)
@torch.no_grad()
def norm_mse(f1, f2, mask=None, resize_to='f1', eps=1e-8):
"""
MSE on L2-normalized features across channels (scale-invariant).
Lower = more similar. Returns scalar [B].
"""
if resize_to == 'f1':
f2 = _align_spatial(f2, f1.shape[-2:])
elif resize_to == 'f2':
f1 = _align_spatial(f1, f2.shape[-2:])
elif resize_to is None:
assert f1.shape[-2:] == f2.shape[-2:]
f1n = F.normalize(f1, p=2, dim=1, eps=eps)
f2n = F.normalize(f2, p=2, dim=1, eps=eps)
diff2 = (f1n - f2n) ** 2
if mask is not None:
mask = mask.to(diff2.dtype)
diff2 = diff2 * mask
denom = mask.numel() / mask.shape[1] if mask.sum() == 0 else mask.sum()
return (diff2.sum() / denom).reshape(1)
return diff2.mean(dim=(1,2,3))
@torch.no_grad()
def save_similarity_maps(
sim_map: torch.Tensor, # [B,1,Hs,Ws]
imgs = None, # [B,3,Hi,Wi] in [0,1] (optional, for overlay)
out_dir: str = "sim_vis",
prefix: str = "sim",
vmin: float = -1.0, vmax: float = 1.0,
cmap: str = "jet",
alpha: float = 0.5,
upsample_to_img: bool = True,
# NEW: output sizing
heatmap_size = None, # (H_out, W_out)
heatmap_scale = None, # e.g., 2.0 → 2x
match_img_size = False, # heatmap도 원본 이미지 크기로
):
"""
Saves heatmap and (optionally) image overlay, with controllable output size.
Priority of output size (for heatmap image):
1) match_img_size=True & imgs provided -> use that image's size
2) heatmap_size=(H_out, W_out)
3) heatmap_scale (relative to sim_map size)
4) default: sim_map size
"""
assert sim_map.dim() == 4 and sim_map.shape[1] == 1, "sim_map must be [B,1,H,W]"
os.makedirs(out_dir, exist_ok=True)
cm = plt.get_cmap(cmap)
B, _, Hs, Ws = sim_map.shape
for i in range(B):
m = sim_map[i, 0] # [Hs,Ws]
# normalize to [0,1]
m01 = (m.clamp(vmin, vmax) - vmin) / max(1e-8, (vmax - vmin))
# ---- decide target size for HEATMAP file ----
if match_img_size and imgs is not None:
Ht, Wt = imgs[i].shape[-2:]
elif heatmap_size is not None:
Ht, Wt = heatmap_size
elif heatmap_scale is not None:
Ht, Wt = int(round(Hs * heatmap_scale)), int(round(Ws * heatmap_scale))
else:
Ht, Wt = Hs, Ws
# resize map (for heatmap)
if (Ht, Wt) != (Hs, Ws):
m_for_heat = F.interpolate(
m01.unsqueeze(0).unsqueeze(0), size=(Ht, Wt),
mode="bilinear", align_corners=False
)[0,0]
else:
m_for_heat = m01
# colorize and save heatmap with PIL (정확한 픽셀 크기 보장)
m_rgb = (cm(m_for_heat.cpu().numpy())[...,:3] * 255).astype(np.uint8) # HxWx3
heat_path = os.path.join(out_dir, f"{prefix}_{i:03d}.png")
Image.fromarray(m_rgb).save(heat_path)
# ---- overlay ----
if imgs is not None:
img = imgs[i] # [3,Hi,Wi] in [0,1]
Hi, Wi = img.shape[-2:]
# decide overlay target size
if upsample_to_img:
Ho, Wo = Hi, Wi
else:
Ho, Wo = (Ht, Wt) # use heatmap size
if (m01.shape[-2:] != (Ho, Wo)):
m_for_overlay = F.interpolate(
m01.unsqueeze(0).unsqueeze(0), size=(Ho, Wo),
mode="bilinear", align_corners=False
)[0,0]
else:
m_for_overlay = m01
# colorize & blend
m_rgb = cm(m_for_overlay.cpu().numpy())[...,:3] # [0,1]
img_np = img.permute(1,2,0).cpu().clamp(0,1).numpy()
# if sizes mismatch, resize image to (Ho,Wo)
if img_np.shape[:2] != (Ho, Wo):
img_np = np.array(Image.fromarray((img_np*255).astype(np.uint8)).resize((Wo,Ho), Image.BILINEAR)) / 255.0
overlay = np.clip((1 - alpha) * img_np + alpha * m_rgb, 0, 1)
overlay_path = os.path.join(out_dir, f"{prefix}_{i:03d}_overlay.png")
Image.fromarray((overlay * 255).astype(np.uint8)).save(overlay_path)
print(f"Saved to: {out_dir}")
@torch.no_grad()
def save_similarity_maps_normalized(
sim_map: torch.Tensor, # [B,1,Hs,Ws]
imgs = None, # [B,3,Hi,Wi] in [0,1] (optional, for overlay)
out_dir: str = "sim_vis",
prefix: str = "sim",
# --- normalization ---
norm: str = "minmax", # 'minmax' | 'zscore' | 'range'
vmin: float = -1.0, vmax: float = 1.0, # used when norm='range'
# --- coloring/overlay ---
cmap: str = "jet",
alpha: float = 0.5,
upsample_to_img: bool = True,
# --- output size control (heatmap file) ---
heatmap_size = None, # (H_out, W_out)
heatmap_scale = None, # e.g., 2.0
match_img_size = False, # match heatmap to image size
# --- optional extra outputs ---
save_gray = False, # save normalized grayscale PNG
save_npy = False, # save normalized numpy array
):
"""
Save similarity maps with explicit normalization and controllable output size.
Heatmap size priority:
1) match_img_size=True & imgs provided -> use that image's size
2) heatmap_size
3) heatmap_scale
4) original sim_map size
"""
assert sim_map.dim() == 4 and sim_map.shape[1] == 1, "sim_map must be [B,1,H,W]"
os.makedirs(out_dir, exist_ok=True)
cm = plt.get_cmap(cmap)
B, _, Hs, Ws = sim_map.shape
for i in range(B):
m = sim_map[i, 0] # [Hs,Ws]
# ---- normalization -> m01 in [0,1] ----
if norm == "minmax":
m01 = (m - m.min()) / (m.max() - m.min() + 1e-8)
elif norm == "zscore":
std = m.std(unbiased=False).clamp_min(1e-8)
m01 = ((m - m.mean()) / std).sigmoid()
elif norm == "range":
m01 = (m.clamp(vmin, vmax) - vmin) / max(1e-8, (vmax - vmin))
else:
raise ValueError("norm must be 'minmax', 'zscore', or 'range'")
# ---- decide HEATMAP output size ----
if match_img_size and imgs is not None:
Ht, Wt = imgs[i].shape[-2:]
elif heatmap_size is not None:
Ht, Wt = heatmap_size
elif heatmap_scale is not None:
Ht, Wt = int(round(Hs * heatmap_scale)), int(round(Ws * heatmap_scale))
else:
Ht, Wt = Hs, Ws
# resize normalized map for heatmap save
if (Ht, Wt) != (Hs, Ws):
m_heat = F.interpolate(
m01.unsqueeze(0).unsqueeze(0), size=(Ht, Wt),
mode="bilinear", align_corners=False
)[0,0]
else:
m_heat = m01
# colorize & save heatmap with PIL (exact pixel size)
m_rgb = (cm(m_heat.cpu().numpy())[...,:3] * 255).astype(np.uint8)
heat_path = os.path.join(out_dir, f"{prefix}_{i:03d}_norm.png")
Image.fromarray(m_rgb).save(heat_path)
# optional grayscale + npy (use the same resized map)
if save_gray:
gray_path = os.path.join(out_dir, f"{prefix}_{i:03d}_norm_gray.png")
Image.fromarray((m_heat.cpu().numpy() * 255).astype(np.uint8)).save(gray_path)
if save_npy:
npy_path = os.path.join(out_dir, f"{prefix}_{i:03d}_norm.npy")
np.save(npy_path, m_heat.cpu().numpy())
# ---- overlay ----
if imgs is not None:
img = imgs[i] # [3,Hi,Wi]
Hi, Wi = img.shape[-2:]
# overlay target size
if upsample_to_img:
Ho, Wo = Hi, Wi
else:
Ho, Wo = (Ht, Wt)
if (m01.shape[-2:] != (Ho, Wo)):
m_overlay = F.interpolate(
m01.unsqueeze(0).unsqueeze(0), size=(Ho, Wo),
mode="bilinear", align_corners=False
)[0,0]
else:
m_overlay = m01
m_rgb = cm(m_overlay.cpu().numpy())[...,:3]
img_np = img.permute(1,2,0).cpu().clamp(0,1).numpy()
if img_np.shape[:2] != (Ho, Wo):
img_np = np.array(Image.fromarray((img_np*255).astype(np.uint8)).resize((Wo,Ho), Image.BILINEAR)) / 255.0
overlay = np.clip((1 - alpha) * img_np + alpha * m_rgb, 0, 1)
overlay_path = os.path.join(out_dir, f"{prefix}_{i:03d}_norm_overlay.png")
Image.fromarray((overlay * 255).astype(np.uint8)).save(overlay_path)
print(f"Saved to: {out_dir}")
ADE20K_150_CATEGORIES = [
{"color": [120, 120, 120], "id": 0, "isthing": 0, "name": "wall"},
{"color": [180, 120, 120], "id": 1, "isthing": 0, "name": "building"},
{"color": [6, 230, 230], "id": 2, "isthing": 0, "name": "sky"},
{"color": [80, 50, 50], "id": 3, "isthing": 0, "name": "floor"},
{"color": [4, 200, 3], "id": 4, "isthing": 0, "name": "tree"},
{"color": [120, 120, 80], "id": 5, "isthing": 0, "name": "ceiling"},
{"color": [140, 140, 140], "id": 6, "isthing": 0, "name": "road, route"},
{"color": [204, 5, 255], "id": 7, "isthing": 1, "name": "bed"},
{"color": [230, 230, 230], "id": 8, "isthing": 1, "name": "window "},
{"color": [4, 250, 7], "id": 9, "isthing": 0, "name": "grass"},
{"color": [224, 5, 255], "id": 10, "isthing": 1, "name": "cabinet"},
{"color": [235, 255, 7], "id": 11, "isthing": 0, "name": "sidewalk, pavement"},
{"color": [150, 5, 61], "id": 12, "isthing": 1, "name": "person"},
{"color": [120, 120, 70], "id": 13, "isthing": 0, "name": "earth, ground"},
{"color": [8, 255, 51], "id": 14, "isthing": 1, "name": "door"},
{"color": [255, 6, 82], "id": 15, "isthing": 1, "name": "table"},
{"color": [143, 255, 140], "id": 16, "isthing": 0, "name": "mountain, mount"},
{"color": [204, 255, 4], "id": 17, "isthing": 0, "name": "plant"},
{"color": [255, 51, 7], "id": 18, "isthing": 1, "name": "curtain"},
{"color": [204, 70, 3], "id": 19, "isthing": 1, "name": "chair"},
{"color": [0, 102, 200], "id": 20, "isthing": 1, "name": "car"},
{"color": [61, 230, 250], "id": 21, "isthing": 0, "name": "water"},
{"color": [255, 6, 51], "id": 22, "isthing": 1, "name": "painting, picture"},
{"color": [11, 102, 255], "id": 23, "isthing": 1, "name": "sofa"},
{"color": [255, 7, 71], "id": 24, "isthing": 1, "name": "shelf"},
{"color": [255, 9, 224], "id": 25, "isthing": 0, "name": "house"},
{"color": [9, 7, 230], "id": 26, "isthing": 0, "name": "sea"},
{"color": [220, 220, 220], "id": 27, "isthing": 1, "name": "mirror"},
{"color": [255, 9, 92], "id": 28, "isthing": 0, "name": "rug"},
{"color": [112, 9, 255], "id": 29, "isthing": 0, "name": "field"},
{"color": [8, 255, 214], "id": 30, "isthing": 1, "name": "armchair"},
{"color": [7, 255, 224], "id": 31, "isthing": 1, "name": "seat"},
{"color": [255, 184, 6], "id": 32, "isthing": 1, "name": "fence"},
{"color": [10, 255, 71], "id": 33, "isthing": 1, "name": "desk"},
{"color": [255, 41, 10], "id": 34, "isthing": 0, "name": "rock, stone"},
{"color": [7, 255, 255], "id": 35, "isthing": 1, "name": "wardrobe, closet, press"},
{"color": [224, 255, 8], "id": 36, "isthing": 1, "name": "lamp"},
{"color": [102, 8, 255], "id": 37, "isthing": 1, "name": "tub"},
{"color": [255, 61, 6], "id": 38, "isthing": 1, "name": "rail"},
{"color": [255, 194, 7], "id": 39, "isthing": 1, "name": "cushion"},
{"color": [255, 122, 8], "id": 40, "isthing": 0, "name": "base, pedestal, stand"},
{"color": [0, 255, 20], "id": 41, "isthing": 1, "name": "box"},
{"color": [255, 8, 41], "id": 42, "isthing": 1, "name": "column, pillar"},
{"color": [255, 5, 153], "id": 43, "isthing": 1, "name": "signboard, sign"},
{
"color": [6, 51, 255],
"id": 44,
"isthing": 1,
"name": "chest of drawers, chest, bureau, dresser",
},
{"color": [235, 12, 255], "id": 45, "isthing": 1, "name": "counter"},
{"color": [160, 150, 20], "id": 46, "isthing": 0, "name": "sand"},
{"color": [0, 163, 255], "id": 47, "isthing": 1, "name": "sink"},
{"color": [140, 140, 140], "id": 48, "isthing": 0, "name": "skyscraper"},
{"color": [250, 10, 15], "id": 49, "isthing": 1, "name": "fireplace"},
{"color": [20, 255, 0], "id": 50, "isthing": 1, "name": "refrigerator, icebox"},
{"color": [31, 255, 0], "id": 51, "isthing": 0, "name": "grandstand, covered stand"},
{"color": [255, 31, 0], "id": 52, "isthing": 0, "name": "path"},
{"color": [255, 224, 0], "id": 53, "isthing": 1, "name": "stairs"},
{"color": [153, 255, 0], "id": 54, "isthing": 0, "name": "runway"},
{"color": [0, 0, 255], "id": 55, "isthing": 1, "name": "case, display case, showcase, vitrine"},
{
"color": [255, 71, 0],
"id": 56,
"isthing": 1,
"name": "pool table, billiard table, snooker table",
},
{"color": [0, 235, 255], "id": 57, "isthing": 1, "name": "pillow"},
{"color": [0, 173, 255], "id": 58, "isthing": 1, "name": "screen door, screen"},
{"color": [31, 0, 255], "id": 59, "isthing": 0, "name": "stairway, staircase"},
{"color": [11, 200, 200], "id": 60, "isthing": 0, "name": "river"},
{"color": [255, 82, 0], "id": 61, "isthing": 0, "name": "bridge, span"},
{"color": [0, 255, 245], "id": 62, "isthing": 1, "name": "bookcase"},
{"color": [0, 61, 255], "id": 63, "isthing": 0, "name": "blind, screen"},
{"color": [0, 255, 112], "id": 64, "isthing": 1, "name": "coffee table"},
{
"color": [0, 255, 133],
"id": 65,
"isthing": 1,
"name": "toilet, can, commode, crapper, pot, potty, stool, throne",
},
{"color": [255, 0, 0], "id": 66, "isthing": 1, "name": "flower"},
{"color": [255, 163, 0], "id": 67, "isthing": 1, "name": "book"},
{"color": [255, 102, 0], "id": 68, "isthing": 0, "name": "hill"},
{"color": [194, 255, 0], "id": 69, "isthing": 1, "name": "bench"},
{"color": [0, 143, 255], "id": 70, "isthing": 1, "name": "countertop"},
{"color": [51, 255, 0], "id": 71, "isthing": 1, "name": "stove"},
{"color": [0, 82, 255], "id": 72, "isthing": 1, "name": "palm, palm tree"},
{"color": [0, 255, 41], "id": 73, "isthing": 1, "name": "kitchen island"},
{"color": [0, 255, 173], "id": 74, "isthing": 1, "name": "computer"},
{"color": [10, 0, 255], "id": 75, "isthing": 1, "name": "swivel chair"},
{"color": [173, 255, 0], "id": 76, "isthing": 1, "name": "boat"},
{"color": [0, 255, 153], "id": 77, "isthing": 0, "name": "bar"},
{"color": [255, 92, 0], "id": 78, "isthing": 1, "name": "arcade machine"},
{"color": [255, 0, 255], "id": 79, "isthing": 0, "name": "hovel, hut, hutch, shack, shanty"},
{"color": [255, 0, 245], "id": 80, "isthing": 1, "name": "bus"},
{"color": [255, 0, 102], "id": 81, "isthing": 1, "name": "towel"},
{"color": [255, 173, 0], "id": 82, "isthing": 1, "name": "light"},
{"color": [255, 0, 20], "id": 83, "isthing": 1, "name": "truck"},
{"color": [255, 184, 184], "id": 84, "isthing": 0, "name": "tower"},
{"color": [0, 31, 255], "id": 85, "isthing": 1, "name": "chandelier"},
{"color": [0, 255, 61], "id": 86, "isthing": 1, "name": "awning, sunshade, sunblind"},
{"color": [0, 71, 255], "id": 87, "isthing": 1, "name": "street lamp"},
{"color": [255, 0, 204], "id": 88, "isthing": 1, "name": "booth"},
{"color": [0, 255, 194], "id": 89, "isthing": 1, "name": "tv"},
{"color": [0, 255, 82], "id": 90, "isthing": 1, "name": "plane"},
{"color": [0, 10, 255], "id": 91, "isthing": 0, "name": "dirt track"},
{"color": [0, 112, 255], "id": 92, "isthing": 1, "name": "clothes"},
{"color": [51, 0, 255], "id": 93, "isthing": 1, "name": "pole"},
{"color": [0, 194, 255], "id": 94, "isthing": 0, "name": "land, ground, soil"},
{
"color": [0, 122, 255],
"id": 95,
"isthing": 1,
"name": "bannister, banister, balustrade, balusters, handrail",
},
{
"color": [0, 255, 163],
"id": 96,
"isthing": 0,
"name": "escalator, moving staircase, moving stairway",
},
{
"color": [255, 153, 0],
"id": 97,
"isthing": 1,
"name": "ottoman, pouf, pouffe, puff, hassock",
},
{"color": [0, 255, 10], "id": 98, "isthing": 1, "name": "bottle"},
{"color": [255, 112, 0], "id": 99, "isthing": 0, "name": "buffet, counter, sideboard"},
{
"color": [143, 255, 0],
"id": 100,
"isthing": 0,
"name": "poster, posting, placard, notice, bill, card",
},
{"color": [82, 0, 255], "id": 101, "isthing": 0, "name": "stage"},
{"color": [163, 255, 0], "id": 102, "isthing": 1, "name": "van"},
{"color": [255, 235, 0], "id": 103, "isthing": 1, "name": "ship"},
{"color": [8, 184, 170], "id": 104, "isthing": 1, "name": "fountain"},
{
"color": [133, 0, 255],
"id": 105,
"isthing": 0,
"name": "conveyer belt, conveyor belt, conveyer, conveyor, transporter",
},
{"color": [0, 255, 92], "id": 106, "isthing": 0, "name": "canopy"},
{
"color": [184, 0, 255],
"id": 107,
"isthing": 1,
"name": "washer, automatic washer, washing machine",
},
{"color": [255, 0, 31], "id": 108, "isthing": 1, "name": "plaything, toy"},
{"color": [0, 184, 255], "id": 109, "isthing": 0, "name": "pool"},
{"color": [0, 214, 255], "id": 110, "isthing": 1, "name": "stool"},
{"color": [255, 0, 112], "id": 111, "isthing": 1, "name": "barrel, cask"},
{"color": [92, 255, 0], "id": 112, "isthing": 1, "name": "basket, handbasket"},
{"color": [0, 224, 255], "id": 113, "isthing": 0, "name": "falls"},
{"color": [112, 224, 255], "id": 114, "isthing": 0, "name": "tent"},
{"color": [70, 184, 160], "id": 115, "isthing": 1, "name": "bag"},
{"color": [163, 0, 255], "id": 116, "isthing": 1, "name": "minibike, motorbike"},
{"color": [153, 0, 255], "id": 117, "isthing": 0, "name": "cradle"},
{"color": [71, 255, 0], "id": 118, "isthing": 1, "name": "oven"},
{"color": [255, 0, 163], "id": 119, "isthing": 1, "name": "ball"},
{"color": [255, 204, 0], "id": 120, "isthing": 1, "name": "food, solid food"},
{"color": [255, 0, 143], "id": 121, "isthing": 1, "name": "step, stair"},
{"color": [0, 255, 235], "id": 122, "isthing": 0, "name": "tank, storage tank"},
{"color": [133, 255, 0], "id": 123, "isthing": 1, "name": "trade name"},
{"color": [255, 0, 235], "id": 124, "isthing": 1, "name": "microwave"},
{"color": [245, 0, 255], "id": 125, "isthing": 1, "name": "pot"},
{"color": [255, 0, 122], "id": 126, "isthing": 1, "name": "animal"},
{"color": [255, 245, 0], "id": 127, "isthing": 1, "name": "bicycle"},
{"color": [10, 190, 212], "id": 128, "isthing": 0, "name": "lake"},
{"color": [214, 255, 0], "id": 129, "isthing": 1, "name": "dishwasher"},
{"color": [0, 204, 255], "id": 130, "isthing": 1, "name": "screen"},
{"color": [20, 0, 255], "id": 131, "isthing": 0, "name": "blanket, cover"},
{"color": [255, 255, 0], "id": 132, "isthing": 1, "name": "sculpture"},
{"color": [0, 153, 255], "id": 133, "isthing": 1, "name": "hood, exhaust hood"},
{"color": [0, 41, 255], "id": 134, "isthing": 1, "name": "sconce"},
{"color": [0, 255, 204], "id": 135, "isthing": 1, "name": "vase"},
{"color": [41, 0, 255], "id": 136, "isthing": 1, "name": "traffic light"},
{"color": [41, 255, 0], "id": 137, "isthing": 1, "name": "tray"},
{"color": [173, 0, 255], "id": 138, "isthing": 1, "name": "trash can"},
{"color": [0, 245, 255], "id": 139, "isthing": 1, "name": "fan"},
{"color": [71, 0, 255], "id": 140, "isthing": 0, "name": "pier"},
{"color": [122, 0, 255], "id": 141, "isthing": 0, "name": "crt screen"},
{"color": [0, 255, 184], "id": 142, "isthing": 1, "name": "plate"},
{"color": [0, 92, 255], "id": 143, "isthing": 1, "name": "monitor"},
{"color": [184, 255, 0], "id": 144, "isthing": 1, "name": "bulletin board"},
{"color": [0, 133, 255], "id": 145, "isthing": 0, "name": "shower"},
{"color": [255, 214, 0], "id": 146, "isthing": 1, "name": "radiator"},
{"color": [25, 194, 194], "id": 147, "isthing": 1, "name": "glass, drinking glass"},
{"color": [102, 255, 0], "id": 148, "isthing": 1, "name": "clock"},
{"color": [92, 0, 255], "id": 149, "isthing": 1, "name": "flag"},
]