File size: 2,368 Bytes
0917e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from signal import siginterrupt
import torch
import numpy as np
import matplotlib.cm as cm
import matplotlib.colors as mcolors

from typing import List


# https://discuss.pytorch.org/t/pytorch-tensor-to-device-for-a-list-of-dict/66283
def move_to(obj, device):
    if torch.is_tensor(obj):
        return obj.to(device)
    elif isinstance(obj, dict):
        res = {}
        for k, v in obj.items():
            res[k] = move_to(v, device)
        return res
    elif isinstance(obj, list):
        res = []
        for v in obj:
            res.append(move_to(v, device))
        return res
    else:
        raise TypeError("Invalid type for move_to")


def apply_color_palette(map_like: np.ndarray) -> np.ndarray:
    """
    Applies a color palette to a 2D input array (H, W) and returns
    an RGB image (H, W, C).

    Parameters:
    -----------
    map_tensor : np.ndarray
        A  NumPy array representing
        some feature map or heatmap data.

    Returns:
    --------
    colored_image : np.ndarray
        An array of shape (H, W, C), where each pixel has RGB values
        in the [0, 1] range.
    """
    
    map_like = map_like.astype(np.float32)
    min_val, max_val = np.min(map_like), np.max(map_like)
    
    # [-1, 1] -> [0, 1]
    # NOTE: this re-normalization step is needed for models like ControlNet
    normalized_map = (map_like - min_val) / (max_val - min_val)
    
    # normalized_map = map_like
    
    cmap = cm.get_cmap("viridis")
    colored_map = cmap(normalized_map)
    
    # HACK:
    if len(colored_map.shape) == 4:
        colored_image = colored_map[:, :, :, 0]
    else:
        colored_image = colored_map[..., :3]
    return colored_image


def convert_to_img_like(*args: torch.Tensor) -> List[np.ndarray]:
    """
    Convert one or more tensors from any range to [0, 1].
    Cast to int, move to CPU, and return them all as NumPy arrays.
    """
    results = []
    for x in args:
        if isinstance(x, torch.Tensor):
            x = x.detach().cpu().numpy()
        results.append(apply_color_palette(x))
    return results


def grayscale_to_2d(grayscale_like: torch.Tensor) -> torch.Tensor:
    """
    Converts a 'grayscale' tensor of shape (128, 128, 3)
    to shape (128, 128) by averaging across the last dimension.
    """
    tensor_2d = torch.mean(grayscale_like, axis=-1)
    return tensor_2d