sparse-cafm / src /util /torch_helpers.py
leharris3's picture
Minimal HF Space deployment with gradio 5.x fix
0917e8d
from signal import siginterrupt
import torch
import numpy as np
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from typing import List
# https://discuss.pytorch.org/t/pytorch-tensor-to-device-for-a-list-of-dict/66283
def move_to(obj, device):
if torch.is_tensor(obj):
return obj.to(device)
elif isinstance(obj, dict):
res = {}
for k, v in obj.items():
res[k] = move_to(v, device)
return res
elif isinstance(obj, list):
res = []
for v in obj:
res.append(move_to(v, device))
return res
else:
raise TypeError("Invalid type for move_to")
def apply_color_palette(map_like: np.ndarray) -> np.ndarray:
"""
Applies a color palette to a 2D input array (H, W) and returns
an RGB image (H, W, C).
Parameters:
-----------
map_tensor : np.ndarray
A NumPy array representing
some feature map or heatmap data.
Returns:
--------
colored_image : np.ndarray
An array of shape (H, W, C), where each pixel has RGB values
in the [0, 1] range.
"""
map_like = map_like.astype(np.float32)
min_val, max_val = np.min(map_like), np.max(map_like)
# [-1, 1] -> [0, 1]
# NOTE: this re-normalization step is needed for models like ControlNet
normalized_map = (map_like - min_val) / (max_val - min_val)
# normalized_map = map_like
cmap = cm.get_cmap("viridis")
colored_map = cmap(normalized_map)
# HACK:
if len(colored_map.shape) == 4:
colored_image = colored_map[:, :, :, 0]
else:
colored_image = colored_map[..., :3]
return colored_image
def convert_to_img_like(*args: torch.Tensor) -> List[np.ndarray]:
"""
Convert one or more tensors from any range to [0, 1].
Cast to int, move to CPU, and return them all as NumPy arrays.
"""
results = []
for x in args:
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
results.append(apply_color_palette(x))
return results
def grayscale_to_2d(grayscale_like: torch.Tensor) -> torch.Tensor:
"""
Converts a 'grayscale' tensor of shape (128, 128, 3)
to shape (128, 128) by averaging across the last dimension.
"""
tensor_2d = torch.mean(grayscale_like, axis=-1)
return tensor_2d