Other
English
File size: 5,484 Bytes
26225c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
import numpy as np
from colorhash import ColorHash
from plotly.colors import sample_colorscale, get_colorscale


__all__ = [
    'to_float_rgb', 'to_byte_rgb', 'rgb_to_plotly_rgb', 'int_to_plotly_rgb',
    'hex_to_tensor', 'feats_to_plotly_rgb', 'identity_PCA']


def to_float_rgb(rgb):
    rgb = rgb.float()
    if rgb.max() > 1:
        rgb = rgb / 255
    rgb = rgb.clamp(min=0, max=1)
    return rgb


def to_byte_rgb(rgb):
    if rgb.is_floating_point() and rgb.max() <= 1:
        rgb = rgb * 255
    rgb = rgb.clamp(min=0, max=255).byte()
    return rgb


def rgb_to_plotly_rgb(rgb, alpha=None):
    """Convert torch.Tensor of float RGB values in [0, 1] to
    plotly-friendly RGB format. If alpha is provided, the output will be
    expressed in RGBA format.
    """
    assert isinstance(rgb, torch.Tensor)
    assert rgb.dim() <= 2
    if rgb.dim() == 1:
        rgb = rgb.unsqueeze(0)
    if rgb.dtype in [torch.uint8, torch.int, torch.long]:
        rgb = rgb.long().numpy()
    elif rgb.is_floating_point() and rgb.max() <= 1:
        rgb = (rgb * 255).long().numpy()
    else:
        raise ValueError(
            f'Not sure how to deal with RGB of dtype={rgb.dtype} and '
            f'max={rgb.max()}')

    if alpha is None:
        return np.array([x for x in rgb])

    if isinstance(alpha, (int, float)):
        alpha = np.array([alpha] * rgb.shape[0])
    elif isinstance(alpha, torch.Tensor):
        alpha = alpha.numpy()
    assert isinstance(alpha, np.ndarray)
    assert alpha.ndim == 1
    assert alpha.shape[0] == rgb.shape[0]

    return np.array([
        [x[0], x[1], x[1], a] for x, a in zip(rgb, alpha)])


def int_to_plotly_rgb(x):
    """Convert 1D torch.Tensor of int into plotly-friendly RGB format.
    This operation is deterministic on the int values.
    """
    assert isinstance(x, torch.Tensor)
    assert x.dim() == 1
    assert not x.is_floating_point()
    x = x.cpu().long().numpy()
    palette = np.array([ColorHash(i).rgb for i in range(x.max() + 1)])
    return palette[x]


def hex_to_tensor(h):
    h = h.lstrip('#')
    rgb = tuple(int(h[i:i + 2], 16) for i in (0, 2, 4))
    return to_float_rgb(torch.tensor(rgb))


def feats_to_plotly_rgb(feats, normalize=False, colorscale='Agsunset'):
    """Convert features of the format M x N with N>=1 to an M x 3
    tensor with values in [0, 1 for RGB visualization].
    """
    is_normalized = False
    is_plotly_rgb_string_format = False

    if feats.dim() == 1:
        feats = feats.unsqueeze(1)
    elif feats.dim() > 2:
        raise NotImplementedError

    if feats.shape[1] == 3:
        color = feats

    elif feats.shape[1] == 1:
        # If only 1 feature is found convert to a 3-channel
        # repetition for grayscale visualization or to plotly RGB string
        # format if a colorscale was provided
        if colorscale is None:
            color = feats.repeat_interleave(3, 1)
        else:
            colorscale = get_colorscale(colorscale)
            feats = min_max_normalize(feats).squeeze().numpy()
            color = np.array(sample_colorscale(colorscale, feats))
            is_plotly_rgb_string_format = True

    elif feats.shape[1] == 2:
        # If 2 features are found, add an extra channel.
        color = torch.cat([feats, torch.ones(feats.shape[0], 1)], 1)

    elif feats.shape[1] > 3:
        # If more than 3 features or more are found, project features to
        # a 3-dimensional space using N-simplex PCA. Heuristics for
        # clamping:
        #   - most features live in [0, 1]
        #   - most n-simplex PCA features live in [-0.5, 0.6]
        color = identity_PCA(feats, dim=3, normalize=normalize)
        color = (torch.clamp(color, -0.5, 0.6) + 0.5) / 1.1
        is_normalized = True

    if normalize and not is_normalized and not is_plotly_rgb_string_format:
        color = min_max_normalize(color)

    # Convert to RGB-255 plotly-friendly numpy format
    if not is_plotly_rgb_string_format:
        color = rgb_to_plotly_rgb(color)

    return color


def min_max_normalize(x):
    """Normalize an array of floats in a unit-hypercube of shared scale.
    Typically useful for visualizing float features with colors
    """
    # Unit-normalize the features in a hypercube of shared scale
    # for nicer visualizations
    high = x.max(dim=0).values.float()
    low = x.min(dim=0).values.float()
    x_normalized = (x.float() - low) / (high - low)
    x_normalized[x_normalized.isnan() | x_normalized.isinf()] = 0
    return x_normalized


def identity_PCA(x, dim=3, normalize=False):
    """Reduce dimension of x based on PCA on the union of the n-simplex.
    This is a way of reducing the dimension of x while treating all
    input dimensions with the same importance, independently of the
    input distribution in x.
    """
    assert x.dim() == 2, f"Expected x.dim()=2, got x.dim()={x.dim()} instead"

    # Create z the union of the N-simplex
    input_dim = x.shape[1]
    z = torch.eye(input_dim)

    # PCA on z
    z_offset = z.mean(axis=0)
    z_centered = z - z_offset
    cov_matrix = z_centered.T.mm(z_centered) / len(z_centered)
    _, eigenvectors = torch.linalg.eigh(cov_matrix)

    # Normalize x if need be
    if normalize:
        high = x.max(dim=0).values
        low = x.min(dim=0).values
        x = (x - low) / (high - low)
        x[x.isnan() | x.isinf()] = 0

    # Apply the PCA on x
    x_reduced = (x - z_offset).mm(eigenvectors[:, -dim:])

    return x_reduced