File size: 5,484 Bytes
26225c5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | import torch
import numpy as np
from colorhash import ColorHash
from plotly.colors import sample_colorscale, get_colorscale
__all__ = [
'to_float_rgb', 'to_byte_rgb', 'rgb_to_plotly_rgb', 'int_to_plotly_rgb',
'hex_to_tensor', 'feats_to_plotly_rgb', 'identity_PCA']
def to_float_rgb(rgb):
rgb = rgb.float()
if rgb.max() > 1:
rgb = rgb / 255
rgb = rgb.clamp(min=0, max=1)
return rgb
def to_byte_rgb(rgb):
if rgb.is_floating_point() and rgb.max() <= 1:
rgb = rgb * 255
rgb = rgb.clamp(min=0, max=255).byte()
return rgb
def rgb_to_plotly_rgb(rgb, alpha=None):
"""Convert torch.Tensor of float RGB values in [0, 1] to
plotly-friendly RGB format. If alpha is provided, the output will be
expressed in RGBA format.
"""
assert isinstance(rgb, torch.Tensor)
assert rgb.dim() <= 2
if rgb.dim() == 1:
rgb = rgb.unsqueeze(0)
if rgb.dtype in [torch.uint8, torch.int, torch.long]:
rgb = rgb.long().numpy()
elif rgb.is_floating_point() and rgb.max() <= 1:
rgb = (rgb * 255).long().numpy()
else:
raise ValueError(
f'Not sure how to deal with RGB of dtype={rgb.dtype} and '
f'max={rgb.max()}')
if alpha is None:
return np.array([x for x in rgb])
if isinstance(alpha, (int, float)):
alpha = np.array([alpha] * rgb.shape[0])
elif isinstance(alpha, torch.Tensor):
alpha = alpha.numpy()
assert isinstance(alpha, np.ndarray)
assert alpha.ndim == 1
assert alpha.shape[0] == rgb.shape[0]
return np.array([
[x[0], x[1], x[1], a] for x, a in zip(rgb, alpha)])
def int_to_plotly_rgb(x):
"""Convert 1D torch.Tensor of int into plotly-friendly RGB format.
This operation is deterministic on the int values.
"""
assert isinstance(x, torch.Tensor)
assert x.dim() == 1
assert not x.is_floating_point()
x = x.cpu().long().numpy()
palette = np.array([ColorHash(i).rgb for i in range(x.max() + 1)])
return palette[x]
def hex_to_tensor(h):
h = h.lstrip('#')
rgb = tuple(int(h[i:i + 2], 16) for i in (0, 2, 4))
return to_float_rgb(torch.tensor(rgb))
def feats_to_plotly_rgb(feats, normalize=False, colorscale='Agsunset'):
"""Convert features of the format M x N with N>=1 to an M x 3
tensor with values in [0, 1 for RGB visualization].
"""
is_normalized = False
is_plotly_rgb_string_format = False
if feats.dim() == 1:
feats = feats.unsqueeze(1)
elif feats.dim() > 2:
raise NotImplementedError
if feats.shape[1] == 3:
color = feats
elif feats.shape[1] == 1:
# If only 1 feature is found convert to a 3-channel
# repetition for grayscale visualization or to plotly RGB string
# format if a colorscale was provided
if colorscale is None:
color = feats.repeat_interleave(3, 1)
else:
colorscale = get_colorscale(colorscale)
feats = min_max_normalize(feats).squeeze().numpy()
color = np.array(sample_colorscale(colorscale, feats))
is_plotly_rgb_string_format = True
elif feats.shape[1] == 2:
# If 2 features are found, add an extra channel.
color = torch.cat([feats, torch.ones(feats.shape[0], 1)], 1)
elif feats.shape[1] > 3:
# If more than 3 features or more are found, project features to
# a 3-dimensional space using N-simplex PCA. Heuristics for
# clamping:
# - most features live in [0, 1]
# - most n-simplex PCA features live in [-0.5, 0.6]
color = identity_PCA(feats, dim=3, normalize=normalize)
color = (torch.clamp(color, -0.5, 0.6) + 0.5) / 1.1
is_normalized = True
if normalize and not is_normalized and not is_plotly_rgb_string_format:
color = min_max_normalize(color)
# Convert to RGB-255 plotly-friendly numpy format
if not is_plotly_rgb_string_format:
color = rgb_to_plotly_rgb(color)
return color
def min_max_normalize(x):
"""Normalize an array of floats in a unit-hypercube of shared scale.
Typically useful for visualizing float features with colors
"""
# Unit-normalize the features in a hypercube of shared scale
# for nicer visualizations
high = x.max(dim=0).values.float()
low = x.min(dim=0).values.float()
x_normalized = (x.float() - low) / (high - low)
x_normalized[x_normalized.isnan() | x_normalized.isinf()] = 0
return x_normalized
def identity_PCA(x, dim=3, normalize=False):
"""Reduce dimension of x based on PCA on the union of the n-simplex.
This is a way of reducing the dimension of x while treating all
input dimensions with the same importance, independently of the
input distribution in x.
"""
assert x.dim() == 2, f"Expected x.dim()=2, got x.dim()={x.dim()} instead"
# Create z the union of the N-simplex
input_dim = x.shape[1]
z = torch.eye(input_dim)
# PCA on z
z_offset = z.mean(axis=0)
z_centered = z - z_offset
cov_matrix = z_centered.T.mm(z_centered) / len(z_centered)
_, eigenvectors = torch.linalg.eigh(cov_matrix)
# Normalize x if need be
if normalize:
high = x.max(dim=0).values
low = x.min(dim=0).values
x = (x - low) / (high - low)
x[x.isnan() | x.isinf()] = 0
# Apply the PCA on x
x_reduced = (x - z_offset).mm(eigenvectors[:, -dim:])
return x_reduced
|