import colorsys import itertools import logging import random import sys from pathlib import Path from timeit import default_timer import dask.array as da import matplotlib import numpy as np import torch from typing import Optional, Union logger = logging.getLogger(__name__) def _single_color_integer_cmap(color=(0.3, 0.4, 0.5)): from matplotlib.colors import Colormap assert len(color) in (3, 4) class BinaryMap(Colormap): def __init__(self, color): self.color = np.array(color) if len(self.color) == 3: self.color = np.concatenate([self.color, [1]]) def __call__(self, X, alpha=None, bytes=False): res = np.zeros((*X.shape, 4), np.float32) res[..., -1] = self.color[-1] res[X > 0] = np.expand_dims(self.color, 0) if bytes: return np.clip(256 * res, 0, 255).astype(np.uint8) else: return res return BinaryMap(color) def render_label( lbl, img=None, cmap=None, cmap_img="gray", alpha=0.5, alpha_boundary=None, normalize_img=True, ): """Renders a label image and optionally overlays it with another image. Used for generating simple output images to asses the label quality. Parameters ---------- lbl: np.ndarray of dtype np.uint16 The 2D label image img: np.ndarray The array to overlay the label image with (optional) cmap: string, tuple, or callable The label colormap. If given as rgb(a) only a single color is used, if None uses a random colormap cmap_img: string or callable The colormap of img (optional) alpha: float The alpha value of the overlay. Set alpha=1 to get fully opaque labels alpha_boundary: float The alpha value of the boundary (if None, use the same as for labels, i.e. no boundaries are visible) normalize_img: bool If True, normalizes the img (if given) Returns: ------- img: np.ndarray the (m,n,4) RGBA image of the rendered label Example: ------- from scipy.ndimage import label, zoom img = zoom(np.random.uniform(0,1,(16,16)),(8,8),order=3) lbl,_ = label(img>.8) u1 = render_label(lbl, img = img, alpha = .7) u2 = render_label(lbl, img = img, alpha = 0, alpha_boundary =.8) plt.subplot(1,2,1);plt.imshow(u1) plt.subplot(1,2,2);plt.imshow(u2) """ from matplotlib import cm from skimage.segmentation import find_boundaries alpha = np.clip(alpha, 0, 1) if alpha_boundary is None: alpha_boundary = alpha if cmap is None: cmap = random_label_cmap() elif isinstance(cmap, tuple): cmap = _single_color_integer_cmap(cmap) else: pass cmap = cm.get_cmap(cmap) if isinstance(cmap, str) else cmap cmap_img = cm.get_cmap(cmap_img) if isinstance(cmap_img, str) else cmap_img # render image if given if img is None: im_img = np.zeros((*lbl.shape, 4), np.float32) im_img[..., -1] = 1 else: assert lbl.shape[:2] == img.shape[:2] img = normalize(img) if normalize_img else img if img.ndim == 2: im_img = cmap_img(img) elif img.ndim == 3: im_img = img[..., :4] if img.shape[-1] < 4: im_img = np.concatenate( [img, np.ones(img.shape[:2] + (4 - img.shape[-1],))], axis=-1 ) else: raise ValueError("img should be 2 or 3 dimensional") # render label im_lbl = cmap(lbl) mask_lbl = lbl > 0 mask_bound = np.bitwise_and(mask_lbl, find_boundaries(lbl, mode="thick")) # blend im = im_img.copy() im[mask_lbl] = alpha * im_lbl[mask_lbl] + (1 - alpha) * im_img[mask_lbl] im[mask_bound] = ( alpha_boundary * im_lbl[mask_bound] + (1 - alpha_boundary) * im_img[mask_bound] ) return im def random_label_cmap(n=2**16, h=(0, 1), lightness=(0.4, 1), s=(0.2, 0.8)): h, lightness, s = ( np.random.uniform(*h, n), np.random.uniform(*lightness, n), np.random.uniform(*s, n), ) cols = np.stack( [colorsys.hls_to_rgb(_h, _l, _s) for _h, _l, _s in zip(h, lightness, s)], axis=0 ) cols[0] = 0 return matplotlib.colors.ListedColormap(cols) # @torch.jit.script def _blockwise_sum_with_bounds(A: torch.Tensor, bounds: torch.Tensor, dim: int = 0): A = A.transpose(dim, 0) cum = torch.cumsum(A, dim=0) cum = torch.cat((torch.zeros_like(cum[:1]), cum), dim=0) B = torch.zeros_like(A, device=A.device) for i, j in itertools.pairwise(bounds[:-1], bounds[1:]): B[i:j] = cum[j] - cum[i] B = B.transpose(0, dim) return B def _bounds_from_timepoints(timepoints: torch.Tensor): assert timepoints.ndim == 1 bounds = torch.cat(( torch.tensor([0], device=timepoints.device), # torch.nonzero faster than torch.where torch.nonzero(timepoints[1:] - timepoints[:-1], as_tuple=False)[:, 0] + 1, torch.tensor([len(timepoints)], device=timepoints.device), )) return bounds # def blockwise_sum(A: torch.Tensor, timepoints: torch.Tensor, dim: int = 0): # # get block boundaries # assert A.shape[dim] == len(timepoints) # bounds = _bounds_from_timepoints(timepoints) # # normalize within blocks # u = _blockwise_sum_with_bounds(A, bounds, dim=dim) # return u def blockwise_sum( A: torch.Tensor, timepoints: torch.Tensor, dim: int = 0, reduce: str = "sum" ): if not A.shape[dim] == len(timepoints): raise ValueError( f"Dimension {dim} of A ({A.shape[dim]}) must match length of timepoints" f" ({len(timepoints)})" ) A = A.transpose(dim, 0) if len(timepoints) == 0: logger.warning("Empty timepoints in block_sum. Returning zero tensor.") return A # -1 is the filling value for padded/invalid timepoints min_t = timepoints[timepoints >= 0] if len(min_t) == 0: logger.warning("All timepoints are -1 in block_sum. Returning zero tensor.") return A min_t = min_t.min() # after that, valid timepoints start with 1 (padding timepoints will be mapped to 0) ts = torch.clamp(timepoints - min_t + 1, min=0) index = ts.unsqueeze(1).expand(-1, len(ts)) blocks = ts.max().long() + 1 out = torch.zeros((blocks, A.shape[1]), device=A.device, dtype=A.dtype) out = torch.scatter_reduce(out, 0, index, A, reduce=reduce) B = out[ts] B = B.transpose(0, dim) return B # TODO allow for batch dimension. Should be faster than looping def blockwise_causal_norm( A: torch.Tensor, timepoints: torch.Tensor, mode: str = "quiet_softmax", mask_invalid: torch.BoolTensor = None, eps: float = 1e-6, ): """Normalization over the causal dimension of A. For each block of constant timepoints, normalize the corresponding block of A such that the sum over the causal dimension is 1. Args: A (torch.Tensor): input tensor timepoints (torch.Tensor): timepoints for each element in the causal dimension mode: normalization mode. `linear`: Simple linear normalization. `softmax`: Apply exp to A before normalization. `quiet_softmax`: Apply exp to A before normalization, and add 1 to the denominator of each row/column. mask_invalid: Values that should not influence the normalization. eps (float, optional): epsilon for numerical stability. """ assert A.ndim == 2 and A.shape[0] == A.shape[1] A = A.clone() if mode in ("softmax", "quiet_softmax"): # Subtract max for numerical stability # https://stats.stackexchange.com/questions/338285/how-does-the-subtraction-of-the-logit-maximum-improve-learning # TODO test without this subtraction if mask_invalid is not None: assert mask_invalid.shape == A.shape A[mask_invalid] = -torch.inf # TODO set to min, then to 0 after exp # Blockwise max with torch.no_grad(): ma0 = blockwise_sum(A, timepoints, dim=0, reduce="amax") ma1 = blockwise_sum(A, timepoints, dim=1, reduce="amax") u0 = torch.exp(A - ma0) u1 = torch.exp(A - ma1) elif mode == "linear": A = torch.sigmoid(A) if mask_invalid is not None: assert mask_invalid.shape == A.shape A[mask_invalid] = 0 u0, u1 = A, A ma0 = ma1 = 0 else: raise NotImplementedError(f"Mode {mode} not implemented") # get block boundaries and normalize within blocks # bounds = _bounds_from_timepoints(timepoints) # u0_sum = _blockwise_sum_with_bounds(u0, bounds, dim=0) + eps # u1_sum = _blockwise_sum_with_bounds(u1, bounds, dim=1) + eps u0_sum = blockwise_sum(u0, timepoints, dim=0) + eps u1_sum = blockwise_sum(u1, timepoints, dim=1) + eps if mode == "quiet_softmax": # Add 1 to the denominator of the softmax. With this, the softmax outputs can be all 0, if the logits are all negative. # If the logits are positive, the softmax outputs will sum to 1. # Trick: With maximum subtraction, this is equivalent to adding 1 to the denominator u0_sum += torch.exp(-ma0) u1_sum += torch.exp(-ma1) mask0 = timepoints.unsqueeze(0) > timepoints.unsqueeze(1) # mask1 = timepoints.unsqueeze(0) < timepoints.unsqueeze(1) # Entries with t1 == t2 are always masked out in final loss mask1 = ~mask0 # blockwise diagonal will be normalized along dim=0 res = mask0 * u0 / u0_sum + mask1 * u1 / u1_sum res = torch.clamp(res, 0, 1) return res def normalize_tensor(x: torch.Tensor, dim: Optional[int] = None, eps: float = 1e-8): if dim is None: dim = tuple(range(x.ndim)) mi, ma = torch.amin(x, dim=dim, keepdim=True), torch.amax(x, dim=dim, keepdim=True) return (x - mi) / (ma - mi + eps) def normalize(x: Union[np.ndarray, da.Array], subsample: Optional[int] = 4): """Percentile normalize the image. If subsample is not None, calculate the percentile values over a subsampled image (last two axis) which is way faster for large images. """ x = x.astype(np.float32) if subsample is not None and all(s > 64 * subsample for s in x.shape[-2:]): y = x[..., ::subsample, ::subsample] else: y = x mi, ma = np.percentile(y, (1, 99.8)).astype(np.float32) x -= mi x /= ma - mi + 1e-8 return x def normalize_01(x: Union[np.ndarray, da.Array], subsample: Optional[int] = 4): """Percentile normalize the image. If subsample is not None, calculate the percentile values over a subsampled image (last two axis) which is way faster for large images. """ x = x.astype(np.float32) if subsample is not None and all(s > 64 * subsample for s in x.shape[-2:]): y = x[..., ::subsample, ::subsample] else: y = x # mi, ma = np.percentile(y, (1, 99.8)).astype(np.float32) mi = x.min() ma = x.max() x -= mi x /= ma - mi + 1e-8 return x def batched(x, batch_size, device): return x.unsqueeze(0).expand(batch_size, *((-1,) * x.ndim)).to(device) def preallocate_memory(dataset, model_lightning, batch_size, max_tokens, device): """https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#preallocate-memory-in-case-of-variable-input-length.""" start = default_timer() if max_tokens is None: logger.warning( "Preallocating memory without specifying max_tokens not implemented." ) return # max_len = 0 # max_idx = -1 # # TODO speed up # # find largest training sample # if isinstance(dataset, torch.utils.data.dataset.ConcatDataset): # lens = tuple( # len(t["timepoints"]) for data in dataset.datasets for t in data.windows # ) # elif isinstance(dataset, torch.utils.data.Dataset): # lens = tuple(len(t["timepoints"]) for t in dataset.windows) # else: # lens = tuple( # len(s["timepoints"]) # for i, s in tqdm( # enumerate(dataset), # desc="Iterate over training set to find largest training sample", # total=len(dataset), # leave=False, # ) # ) # max_len = max(lens) # max_idx = lens.index(max_len) # # build random batch # x = dataset[max_idx] # batch = dict( # features=batched(x["features"], batch_size, device), # coords=batched(x["coords"], batch_size, device), # assoc_matrix=batched(x["assoc_matrix"], batch_size, device), # timepoints=batched(x["timepoints"], batch_size, device), # padding_mask=batched(torch.zeros_like(x["timepoints"]), batch_size, device), # ) else: max_len = max_tokens x = dataset[0] batch = dict( features=batched( torch.zeros( (max_len,) + x["features"].shape[1:], dtype=x["features"].dtype ), batch_size, device, ), coords=batched( torch.zeros( (max_len,) + x["coords"].shape[1:], dtype=x["coords"].dtype ), batch_size, device, ), assoc_matrix=batched( torch.zeros((max_len, max_len), dtype=x["assoc_matrix"].dtype), batch_size, device, ), timepoints=batched( torch.zeros(max_len, dtype=x["timepoints"].dtype), batch_size, device ), padding_mask=batched(torch.zeros(max_len, dtype=bool), batch_size, device), ) loss = model_lightning._common_step(batch)["loss"] loss.backward() model_lightning.zero_grad() logger.info( f"Preallocated memory for largest training batch (length {max_len}) in" f" {default_timer() - start:.02f} s" ) if device.type == "cuda": logger.info( "Memory allocated for model:" f" {torch.cuda.max_memory_allocated() / 1024**3:.02f} GB" ) def seed(s=None): """Seed random number generators. Defaults to unix timestamp of function call. Args: s (``int``): Manual seed. """ if s is None: s = int(default_timer()) random.seed(s) logger.debug(f"Seed `random` rng with {s}.") np.random.seed(s) logger.debug(f"Seed `numpy` rng with {s}.") if "torch" in sys.modules: torch.manual_seed(s) logger.debug(f"Seed `torch` rng with {s}.") return s def str2bool(x: str) -> bool: """Cast string to boolean. Useful for parsing command line arguments. """ if not isinstance(x, str): raise TypeError("String expected.") elif x.lower() in ("true", "t", "1"): return True elif x.lower() in ("false", "f", "0"): return False else: raise ValueError(f"'{x}' does not seem to be boolean.") def str2path(x: str) -> Path: """Cast string to resolved absolute path. Useful for parsing command line arguments. """ if not isinstance(x, str): raise TypeError("String expected.") else: return Path(x).expanduser().resolve() if __name__ == "__main__": A = torch.rand(50, 50) idx = torch.tensor([0, 10, 20, A.shape[0]]) A = torch.eye(50) B = _blockwise_sum_with_bounds(A, idx) tps = torch.repeat_interleave(torch.arange(5), 10) C = blockwise_causal_norm(A, tps)