Spaces:
Running
Running
File size: 2,368 Bytes
0917e8d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | from signal import siginterrupt
import torch
import numpy as np
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from typing import List
# https://discuss.pytorch.org/t/pytorch-tensor-to-device-for-a-list-of-dict/66283
def move_to(obj, device):
if torch.is_tensor(obj):
return obj.to(device)
elif isinstance(obj, dict):
res = {}
for k, v in obj.items():
res[k] = move_to(v, device)
return res
elif isinstance(obj, list):
res = []
for v in obj:
res.append(move_to(v, device))
return res
else:
raise TypeError("Invalid type for move_to")
def apply_color_palette(map_like: np.ndarray) -> np.ndarray:
"""
Applies a color palette to a 2D input array (H, W) and returns
an RGB image (H, W, C).
Parameters:
-----------
map_tensor : np.ndarray
A NumPy array representing
some feature map or heatmap data.
Returns:
--------
colored_image : np.ndarray
An array of shape (H, W, C), where each pixel has RGB values
in the [0, 1] range.
"""
map_like = map_like.astype(np.float32)
min_val, max_val = np.min(map_like), np.max(map_like)
# [-1, 1] -> [0, 1]
# NOTE: this re-normalization step is needed for models like ControlNet
normalized_map = (map_like - min_val) / (max_val - min_val)
# normalized_map = map_like
cmap = cm.get_cmap("viridis")
colored_map = cmap(normalized_map)
# HACK:
if len(colored_map.shape) == 4:
colored_image = colored_map[:, :, :, 0]
else:
colored_image = colored_map[..., :3]
return colored_image
def convert_to_img_like(*args: torch.Tensor) -> List[np.ndarray]:
"""
Convert one or more tensors from any range to [0, 1].
Cast to int, move to CPU, and return them all as NumPy arrays.
"""
results = []
for x in args:
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
results.append(apply_color_palette(x))
return results
def grayscale_to_2d(grayscale_like: torch.Tensor) -> torch.Tensor:
"""
Converts a 'grayscale' tensor of shape (128, 128, 3)
to shape (128, 128) by averaging across the last dimension.
"""
tensor_2d = torch.mean(grayscale_like, axis=-1)
return tensor_2d
|