Spaces:
Sleeping
Sleeping
| """Regionprops features and its augmentations. | |
| WindowedRegionFeatures (WRFeatures) is a class that holds regionprops features for a windowed track region. | |
| """ | |
| import itertools | |
| import logging | |
| from collections import OrderedDict | |
| from collections.abc import Iterable #, Sequence | |
| from functools import reduce | |
| from typing import Literal | |
| import joblib | |
| import numpy as np | |
| import pandas as pd | |
| from edt import edt | |
| from skimage.measure import regionprops, regionprops_table | |
| from tqdm import tqdm | |
| from typing import Tuple, Optional, Sequence, Union, List | |
| import typing | |
| try: | |
| from .utils import load_tiff_timeseries | |
| except: | |
| from utils import load_tiff_timeseries | |
| import torch | |
| logger = logging.getLogger(__name__) | |
| _PROPERTIES = { | |
| "regionprops": ( | |
| "area", | |
| "intensity_mean", | |
| "intensity_max", | |
| "intensity_min", | |
| "inertia_tensor", | |
| ), | |
| "regionprops2": ( | |
| "equivalent_diameter_area", | |
| "intensity_mean", | |
| "inertia_tensor", | |
| "border_dist", | |
| ), | |
| } | |
| def _filter_points( | |
| points: np.ndarray, shape: Tuple[int], origin: Optional[Tuple[int]] = None | |
| ) -> np.ndarray: | |
| """Returns indices of points that are inside the shape extent and given origin.""" | |
| ndim = points.shape[-1] | |
| if origin is None: | |
| origin = (0,) * ndim | |
| idx = tuple( | |
| np.logical_and(points[:, i] >= origin[i], points[:, i] < origin[i] + shape[i]) | |
| for i in range(ndim) | |
| ) | |
| idx = np.where(np.all(idx, axis=0))[0] | |
| return idx | |
| def _border_dist(mask: np.ndarray, cutoff: float = 5): | |
| """Returns distance to border normalized to 0 (at least cutoff away) and 1 (at border).""" | |
| border = np.zeros_like(mask) | |
| # only apply to last two dimensions | |
| ss = tuple( | |
| slice(None) if i < mask.ndim - 2 else slice(1, -1) | |
| for i, s in enumerate(mask.shape) | |
| ) | |
| border[ss] = 1 | |
| dist = 1 - np.minimum(edt(border) / cutoff, 1) | |
| return tuple(r.intensity_max for r in regionprops(mask, intensity_image=dist)) | |
| def _border_dist_fast(mask: np.ndarray, cutoff: float = 5): | |
| cutoff = int(cutoff) | |
| border = np.ones(mask.shape, dtype=np.float32) | |
| ndim = len(mask.shape) | |
| for axis, size in enumerate(mask.shape): | |
| # Create fade values for the band [0, cutoff) | |
| band_vals = np.arange(cutoff, dtype=np.float32) / cutoff | |
| # Build slices for the low border | |
| low_slices = [slice(None)] * ndim | |
| low_slices[axis] = slice(0, cutoff) | |
| border_low = border[tuple(low_slices)] | |
| border_low_vals = np.minimum( | |
| border_low, band_vals[(...,) + (None,) * (ndim - axis - 1)] | |
| ) | |
| border[tuple(low_slices)] = border_low_vals | |
| # Build slices for the high border | |
| high_slices = [slice(None)] * ndim | |
| high_slices[axis] = slice(size - cutoff, size) | |
| band_vals_rev = band_vals[::-1] | |
| border_high = border[tuple(high_slices)] | |
| border_high_vals = np.minimum( | |
| border_high, band_vals_rev[(...,) + (None,) * (ndim - axis - 1)] | |
| ) | |
| border[tuple(high_slices)] = border_high_vals | |
| dist = 1 - border | |
| return tuple(r.intensity_max for r in regionprops(mask, intensity_image=dist)) | |
| class WRFeatures: | |
| """regionprops features for a windowed track region.""" | |
| def __init__( | |
| self, | |
| coords: np.ndarray, | |
| labels: np.ndarray, | |
| timepoints: np.ndarray, | |
| features: typing.OrderedDict[str, np.ndarray], | |
| ): | |
| self.ndim = coords.shape[-1] | |
| if self.ndim not in (2, 3): | |
| raise ValueError("Only 2D or 3D data is supported") | |
| self.coords = coords | |
| self.labels = labels | |
| self.features = features.copy() | |
| self.timepoints = timepoints | |
| def __repr__(self): | |
| s = ( | |
| f"WindowRegionFeatures(ndim={self.ndim}, nregions={len(self.labels)}," | |
| f" ntimepoints={len(np.unique(self.timepoints))})\n\n" | |
| ) | |
| for k, v in self.features.items(): | |
| s += f"{k:>20} -> {v.shape}\n" | |
| return s | |
| def features_stacked(self): | |
| return np.concatenate([v for k, v in self.features.items()], axis=-1) | |
| def __len__(self): | |
| return len(self.labels) | |
| def __getitem__(self, key): | |
| if key in self.features: | |
| return self.features[key] | |
| else: | |
| raise KeyError(f"Key {key} not found in features") | |
| def concat(cls, feats: Sequence["WRFeatures"]) -> "WRFeatures": | |
| """Concatenate multiple WRFeatures into a single one.""" | |
| if len(feats) == 0: | |
| raise ValueError("Cannot concatenate empty list of features") | |
| return reduce(lambda x, y: x + y, feats) | |
| def __add__(self, other: "WRFeatures") -> "WRFeatures": | |
| """Concatenate two WRFeatures.""" | |
| if self.ndim != other.ndim: | |
| raise ValueError("Cannot concatenate features of different dimensions") | |
| if self.features.keys() != other.features.keys(): | |
| raise ValueError("Cannot concatenate features with different properties") | |
| coords = np.concatenate([self.coords, other.coords], axis=0) | |
| labels = np.concatenate([self.labels, other.labels], axis=0) | |
| timepoints = np.concatenate([self.timepoints, other.timepoints], axis=0) | |
| features = OrderedDict( | |
| (k, np.concatenate([v, other.features[k]], axis=0)) | |
| for k, v in self.features.items() | |
| ) | |
| return WRFeatures( | |
| coords=coords, labels=labels, timepoints=timepoints, features=features | |
| ) | |
| def from_mask_img( | |
| cls, | |
| mask: np.ndarray, | |
| img: np.ndarray, | |
| properties="regionprops2", | |
| t_start: int = 0, | |
| ): | |
| img = np.asarray(img) | |
| mask = np.asarray(mask) | |
| _ntime, ndim = mask.shape[0], mask.ndim - 1 | |
| if ndim not in (2, 3): | |
| raise ValueError("Only 2D or 3D data is supported") | |
| properties = tuple(_PROPERTIES[properties]) | |
| if "label" in properties or "centroid" in properties: | |
| raise ValueError( | |
| f"label and centroid should not be in properties {properties}" | |
| ) | |
| if "border_dist" in properties: | |
| use_border_dist = True | |
| # remove border_dist from properties | |
| properties = tuple(p for p in properties if p != "border_dist") | |
| else: | |
| use_border_dist = False | |
| df_properties = ("label", "centroid", *properties) | |
| dfs = [] | |
| for i, (y, x) in enumerate(zip(mask, img)): | |
| _df = pd.DataFrame( | |
| regionprops_table(y, intensity_image=x, properties=df_properties) | |
| ) | |
| _df["timepoint"] = i + t_start | |
| if use_border_dist: | |
| _df["border_dist"] = _border_dist_fast(y) | |
| dfs.append(_df) | |
| df = pd.concat(dfs) | |
| if use_border_dist: | |
| properties = (*properties, "border_dist") | |
| timepoints = df["timepoint"].values.astype(np.int32) | |
| labels = df["label"].values.astype(np.int32) | |
| coords = df[[f"centroid-{i}" for i in range(ndim)]].values.astype(np.float32) | |
| features = OrderedDict( | |
| ( | |
| p, | |
| np.stack( | |
| [ | |
| df[c].values.astype(np.float32) | |
| for c in df.columns | |
| if c.startswith(p) | |
| ], | |
| axis=-1, | |
| ), | |
| ) | |
| for p in properties | |
| ) | |
| return cls( | |
| coords=coords, labels=labels, timepoints=timepoints, features=features | |
| ) | |
| # augmentations | |
| class WRRandomCrop: | |
| """windowed region random crop augmentation.""" | |
| def __init__( | |
| self, | |
| crop_size: Optional[Union[int, Tuple[int]]] = None, | |
| ndim: int = 2, | |
| ) -> None: | |
| """crop_size: tuple of int | |
| can be tuple of length 1 (all dimensions) | |
| of length ndim (y,x,...) | |
| of length 2*ndim (y1,y2, x1,x2, ...). | |
| """ | |
| if isinstance(crop_size, int): | |
| crop_size = (crop_size,) * 2 * ndim | |
| elif isinstance(crop_size, Iterable): | |
| pass | |
| else: | |
| raise ValueError(f"{crop_size} has to be int or tuple of int") | |
| if len(crop_size) == 1: | |
| crop_size = (crop_size[0],) * 2 * ndim | |
| elif len(crop_size) == ndim: | |
| crop_size = tuple(itertools.chain(*tuple((c, c) for c in crop_size))) | |
| elif len(crop_size) == 2 * ndim: | |
| pass | |
| else: | |
| raise ValueError(f"crop_size has to be of length 1, {ndim}, or {2 * ndim}") | |
| crop_size = np.array(crop_size) | |
| self._ndim = ndim | |
| self._crop_bounds = crop_size[::2], crop_size[1::2] | |
| self._rng = np.random.RandomState() | |
| def __call__(self, features: WRFeatures): | |
| crop_size = self._rng.randint(self._crop_bounds[0], self._crop_bounds[1] + 1) | |
| points = features.coords | |
| if len(points) == 0: | |
| print("No points given, cannot ensure inside points") | |
| return features | |
| # sample point and corner relative to it | |
| _idx = np.random.randint(len(points)) | |
| corner = ( | |
| points[_idx] | |
| - crop_size | |
| + 1 | |
| + self._rng.randint(crop_size // 4, 3 * crop_size // 4) | |
| ) | |
| idx = _filter_points(points, shape=crop_size, origin=corner) | |
| return ( | |
| WRFeatures( | |
| coords=points[idx], | |
| labels=features.labels[idx], | |
| timepoints=features.timepoints[idx], | |
| features=OrderedDict((k, v[idx]) for k, v in features.features.items()), | |
| ), | |
| idx, | |
| ) | |
| class WRBaseAugmentation: | |
| def __init__(self, p: float = 0.5) -> None: | |
| self._p = p | |
| self._rng = np.random.RandomState() | |
| def __call__(self, features: WRFeatures): | |
| if self._rng.rand() > self._p or len(features) == 0: | |
| return features | |
| return self._augment(features) | |
| def _augment(self, features: WRFeatures): | |
| raise NotImplementedError() | |
| class WRRandomFlip(WRBaseAugmentation): | |
| def _augment(self, features: WRFeatures): | |
| ndim = features.ndim | |
| flip = self._rng.randint(0, 2, features.ndim) | |
| points = features.coords.copy() | |
| for i, f in enumerate(flip): | |
| if f == 1: | |
| points[:, ndim - i - 1] *= -1 | |
| return WRFeatures( | |
| coords=points, | |
| labels=features.labels, | |
| timepoints=features.timepoints, | |
| features=features.features, | |
| ) | |
| def _scale_matrix(sz: float, sy: float, sx: float): | |
| return np.diag([sz, sy, sx]) | |
| # def _scale_matrix(sy: float, sx: float): | |
| # return np.array([[1, 0, 0], [0, sy, 0], [0, 0, sx]]) | |
| def _shear_matrix(shy: float, shx: float): | |
| return np.array([[1, 0, 0], [0, 1 + shx * shy, shy], [0, shx, 1]]) | |
| def _rotation_matrix(theta: float): | |
| return np.array([ | |
| [1, 0, 0], | |
| [0, np.cos(theta), -np.sin(theta)], | |
| [0, np.sin(theta), np.cos(theta)], | |
| ]) | |
| def _transform_affine(k: str, v: np.ndarray, M: np.ndarray): | |
| ndim = len(M) | |
| if k == "area": | |
| v = np.linalg.det(M) * v | |
| elif k == "equivalent_diameter_area": | |
| v = np.linalg.det(M) ** (1 / len(M)) * v | |
| elif k == "inertia_tensor": | |
| # v' = M * v * M^T | |
| v = v.reshape(-1, ndim, ndim) | |
| # v * M^T | |
| v = np.einsum("ijk, mk -> ijm", v, M) | |
| # M * v | |
| v = np.einsum("ij, kjm -> kim", M, v) | |
| v = v.reshape(-1, ndim * ndim) | |
| elif k in ( | |
| "intensity_mean", | |
| "intensity_std", | |
| "intensity_max", | |
| "intensity_min", | |
| "border_dist", | |
| ): | |
| pass | |
| else: | |
| raise ValueError(f"Don't know how to affinely transform {k}") | |
| return v | |
| class WRRandomAffine(WRBaseAugmentation): | |
| def __init__( | |
| self, | |
| degrees: float = 10, | |
| scale: float = (0.9, 1.1), | |
| shear: float = (0.1, 0.1), | |
| p: float = 0.5, | |
| ): | |
| super().__init__(p) | |
| self.degrees = degrees if degrees is not None else 0 | |
| self.scale = scale if scale is not None else (1, 1) | |
| self.shear = shear if shear is not None else (0, 0) | |
| def _augment(self, features: WRFeatures): | |
| degrees = self._rng.uniform(-self.degrees, self.degrees) / 180 * np.pi | |
| scale = self._rng.uniform(*self.scale, 3) | |
| shy = self._rng.uniform(-self.shear[0], self.shear[0]) | |
| shx = self._rng.uniform(-self.shear[1], self.shear[1]) | |
| self._M = ( | |
| _rotation_matrix(degrees) @ _scale_matrix(*scale) @ _shear_matrix(shy, shx) | |
| ) | |
| # M is by default 3D , we need to remove the last dimension for 2D | |
| self._M = self._M[-features.ndim :, -features.ndim :] | |
| points = features.coords @ self._M.T | |
| feats = OrderedDict( | |
| (k, _transform_affine(k, v, self._M)) for k, v in features.features.items() | |
| ) | |
| return WRFeatures( | |
| coords=points, | |
| labels=features.labels, | |
| timepoints=features.timepoints, | |
| features=feats, | |
| ) | |
| class WRRandomBrightness(WRBaseAugmentation): | |
| def __init__( | |
| self, | |
| scale: Tuple[float] = (0.5, 2.0), | |
| shift: Tuple[float] = (-0.1, 0.1), | |
| p: float = 0.5, | |
| ): | |
| super().__init__(p) | |
| self.scale = scale | |
| self.shift = shift | |
| def _augment(self, features: WRFeatures): | |
| scale = self._rng.uniform(*self.scale) | |
| shift = self._rng.uniform(*self.shift) | |
| key_vals = [] | |
| for k, v in features.features.items(): | |
| if "intensity" in k: | |
| v = v * scale + shift | |
| key_vals.append((k, v)) | |
| feats = OrderedDict(key_vals) | |
| return WRFeatures( | |
| coords=features.coords, | |
| labels=features.labels, | |
| timepoints=features.timepoints, | |
| features=feats, | |
| ) | |
| class WRRandomOffset(WRBaseAugmentation): | |
| def __init__(self, offset: float = (-3, 3), p: float = 0.5): | |
| super().__init__(p) | |
| self.offset = offset | |
| def _augment(self, features: WRFeatures): | |
| offset = self._rng.uniform(*self.offset, features.coords.shape) | |
| coords = features.coords + offset | |
| return WRFeatures( | |
| coords=coords, | |
| labels=features.labels, | |
| timepoints=features.timepoints, | |
| features=features.features, | |
| ) | |
| class WRRandomMovement(WRBaseAugmentation): | |
| """random global linear shift.""" | |
| def __init__(self, offset: float = (-10, 10), p: float = 0.5): | |
| super().__init__(p) | |
| self.offset = offset | |
| def _augment(self, features: WRFeatures): | |
| base_offset = self._rng.uniform(*self.offset, features.coords.shape[-1]) | |
| tmin = features.timepoints.min() | |
| offset = (features.timepoints[:, None] - tmin) * base_offset[None] | |
| coords = features.coords + offset | |
| return WRFeatures( | |
| coords=coords, | |
| labels=features.labels, | |
| timepoints=features.timepoints, | |
| features=features.features, | |
| ) | |
| class WRAugmentationPipeline: | |
| def __init__(self, augmentations: Sequence[WRBaseAugmentation]): | |
| self.augmentations = augmentations | |
| def __call__(self, feats: WRFeatures): | |
| for aug in self.augmentations: | |
| feats = aug(feats) | |
| return feats | |
| def get_features( | |
| detections: np.ndarray, | |
| imgs: Optional[np.ndarray] = None, | |
| features: Literal["none", "wrfeat"] = "wrfeat", | |
| ndim: int = 2, | |
| n_workers=0, | |
| progbar_class=tqdm, | |
| ) -> List[WRFeatures]: | |
| detections = _check_dimensions(detections, ndim) | |
| imgs = _check_dimensions(imgs, ndim) | |
| logger.info(f"Extracting features from {len(detections)} detections") | |
| if n_workers > 0: | |
| logger.info(f"Using {n_workers} processes for feature extraction") | |
| features = joblib.Parallel(n_jobs=n_workers, backend="loky")( | |
| joblib.delayed(WRFeatures.from_mask_img)( | |
| # New axis for time component | |
| mask=mask[np.newaxis, ...].copy(), | |
| img=img[np.newaxis, ...].copy(), | |
| t_start=t, | |
| ) | |
| for t, (mask, img) in progbar_class( | |
| enumerate(zip(detections, imgs)), | |
| total=len(imgs), | |
| desc="Extracting features", | |
| ) | |
| ) | |
| else: | |
| logger.info("Using single process for feature extraction") | |
| features = tuple( | |
| WRFeatures.from_mask_img( | |
| mask=mask[np.newaxis, ...], | |
| img=img[np.newaxis, ...], | |
| t_start=t, | |
| ) | |
| for t, (mask, img) in progbar_class( | |
| enumerate(zip(detections, imgs)), | |
| total=len(imgs), | |
| desc="Extracting features", | |
| ) | |
| ) | |
| return features | |
| def _check_dimensions(x: np.ndarray, ndim: int): | |
| if ndim == 2 and not x.ndim == 3: | |
| raise ValueError(f"Expected 2D data, got {x.ndim - 1}D data") | |
| elif ndim == 3: | |
| # if ndim=3 and data is two dimensional, it will be cast to 3D | |
| if x.ndim == 3: | |
| x = np.expand_dims(x, axis=1) | |
| elif x.ndim == 4: | |
| pass | |
| else: | |
| raise ValueError(f"Expected 3D data, got {x.ndim - 1}D data") | |
| return x | |
| def build_windows( | |
| features: List[WRFeatures], window_size: int, progbar_class=tqdm | |
| ) -> List[dict]: | |
| windows = [] | |
| for t1, t2 in progbar_class( | |
| zip(range(0, len(features)), range(window_size, len(features) + 1)), | |
| total=len(features) - window_size + 1, | |
| desc="Building windows", | |
| ): | |
| feat = WRFeatures.concat(features[t1:t2]) | |
| labels = feat.labels | |
| timepoints = feat.timepoints | |
| coords = feat.coords | |
| if len(feat) == 0: | |
| coords = np.zeros((0, feat.ndim), dtype=int) | |
| w = dict( | |
| coords=coords, | |
| t1=t1, | |
| labels=labels, | |
| timepoints=timepoints, | |
| features=feat.features_stacked, | |
| ) | |
| windows.append(w) | |
| logger.debug(f"Built {len(windows)} track windows.\n") | |
| return windows | |
| def build_windows_sd( | |
| features: List[WRFeatures], imgs_enc, imgs_stable, boxes, imgs, masks, window_size: int, progbar_class=tqdm | |
| ) -> List[dict]: | |
| windows = [] | |
| for t1, t2 in progbar_class( | |
| zip(range(0, len(features)), range(window_size, len(features) + 1)), | |
| total=len(features) - window_size + 1, | |
| desc="Building windows", | |
| ): | |
| feat = WRFeatures.concat(features[t1:t2]) | |
| labels = feat.labels | |
| timepoints = feat.timepoints | |
| coords = feat.coords | |
| if len(feat) == 0: | |
| coords = np.zeros((0, feat.ndim), dtype=int) | |
| w = dict( | |
| coords=coords, | |
| t1=t1, | |
| labels=labels, | |
| timepoints=timepoints, | |
| features=feat.features_stacked, | |
| img_enc=imgs_enc[t1:t2], | |
| image_stable=imgs_stable[t1:t2], | |
| boxes=boxes, | |
| img=imgs[t1:t2], | |
| mask=masks[t1:t2], | |
| coords_t=torch.tensor(coords, dtype=torch.float32), | |
| labels_t=torch.tensor(labels, dtype=torch.int32), | |
| timepoints_t=torch.tensor(timepoints, dtype=torch.int64), | |
| features_t=torch.tensor(feat.features_stacked, dtype=torch.float32), | |
| img_t=torch.tensor(imgs[t1:t2], dtype=torch.float32), | |
| mask_t=torch.tensor(masks[t1:t2], dtype=torch.int32), | |
| ) | |
| windows.append(w) | |
| logger.debug(f"Built {len(windows)} track windows.\n") | |
| return windows | |
| if __name__ == "__main__": | |
| imgs = load_tiff_timeseries( | |
| # "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01", | |
| "/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01", | |
| ) | |
| masks = load_tiff_timeseries( | |
| # "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01_GT/TRA", | |
| "/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01_GT/TRA", | |
| dtype=int, | |
| ) | |
| features = get_features(detections=masks, imgs=imgs, ndim=3) | |
| windows = build_windows(features, window_size=4) | |
| # if __name__ == "__main__": | |
| # y = np.zeros((1, 100, 100), np.uint8) | |
| # y[:, 20:40, 20:60] = 1 | |
| # x = y + np.random.normal(0, 0.1, y.shape) | |
| # f = WRFeatures.from_mask_img(y, x, properties=("intensity_mean", "area")) | |