phoebehxf
init
aff3c6f
"""Regionprops features and its augmentations.
WindowedRegionFeatures (WRFeatures) is a class that holds regionprops features for a windowed track region.
"""
import itertools
import logging
from collections import OrderedDict
from collections.abc import Iterable #, Sequence
from functools import reduce
from typing import Literal
import joblib
import numpy as np
import pandas as pd
from edt import edt
from skimage.measure import regionprops, regionprops_table
from tqdm import tqdm
from typing import Tuple, Optional, Sequence, Union, List
import typing
try:
from .utils import load_tiff_timeseries
except:
from utils import load_tiff_timeseries
import torch
logger = logging.getLogger(__name__)
_PROPERTIES = {
"regionprops": (
"area",
"intensity_mean",
"intensity_max",
"intensity_min",
"inertia_tensor",
),
"regionprops2": (
"equivalent_diameter_area",
"intensity_mean",
"inertia_tensor",
"border_dist",
),
}
def _filter_points(
points: np.ndarray, shape: Tuple[int], origin: Optional[Tuple[int]] = None
) -> np.ndarray:
"""Returns indices of points that are inside the shape extent and given origin."""
ndim = points.shape[-1]
if origin is None:
origin = (0,) * ndim
idx = tuple(
np.logical_and(points[:, i] >= origin[i], points[:, i] < origin[i] + shape[i])
for i in range(ndim)
)
idx = np.where(np.all(idx, axis=0))[0]
return idx
def _border_dist(mask: np.ndarray, cutoff: float = 5):
"""Returns distance to border normalized to 0 (at least cutoff away) and 1 (at border)."""
border = np.zeros_like(mask)
# only apply to last two dimensions
ss = tuple(
slice(None) if i < mask.ndim - 2 else slice(1, -1)
for i, s in enumerate(mask.shape)
)
border[ss] = 1
dist = 1 - np.minimum(edt(border) / cutoff, 1)
return tuple(r.intensity_max for r in regionprops(mask, intensity_image=dist))
def _border_dist_fast(mask: np.ndarray, cutoff: float = 5):
cutoff = int(cutoff)
border = np.ones(mask.shape, dtype=np.float32)
ndim = len(mask.shape)
for axis, size in enumerate(mask.shape):
# Create fade values for the band [0, cutoff)
band_vals = np.arange(cutoff, dtype=np.float32) / cutoff
# Build slices for the low border
low_slices = [slice(None)] * ndim
low_slices[axis] = slice(0, cutoff)
border_low = border[tuple(low_slices)]
border_low_vals = np.minimum(
border_low, band_vals[(...,) + (None,) * (ndim - axis - 1)]
)
border[tuple(low_slices)] = border_low_vals
# Build slices for the high border
high_slices = [slice(None)] * ndim
high_slices[axis] = slice(size - cutoff, size)
band_vals_rev = band_vals[::-1]
border_high = border[tuple(high_slices)]
border_high_vals = np.minimum(
border_high, band_vals_rev[(...,) + (None,) * (ndim - axis - 1)]
)
border[tuple(high_slices)] = border_high_vals
dist = 1 - border
return tuple(r.intensity_max for r in regionprops(mask, intensity_image=dist))
class WRFeatures:
"""regionprops features for a windowed track region."""
def __init__(
self,
coords: np.ndarray,
labels: np.ndarray,
timepoints: np.ndarray,
features: typing.OrderedDict[str, np.ndarray],
):
self.ndim = coords.shape[-1]
if self.ndim not in (2, 3):
raise ValueError("Only 2D or 3D data is supported")
self.coords = coords
self.labels = labels
self.features = features.copy()
self.timepoints = timepoints
def __repr__(self):
s = (
f"WindowRegionFeatures(ndim={self.ndim}, nregions={len(self.labels)},"
f" ntimepoints={len(np.unique(self.timepoints))})\n\n"
)
for k, v in self.features.items():
s += f"{k:>20} -> {v.shape}\n"
return s
@property
def features_stacked(self):
return np.concatenate([v for k, v in self.features.items()], axis=-1)
def __len__(self):
return len(self.labels)
def __getitem__(self, key):
if key in self.features:
return self.features[key]
else:
raise KeyError(f"Key {key} not found in features")
@classmethod
def concat(cls, feats: Sequence["WRFeatures"]) -> "WRFeatures":
"""Concatenate multiple WRFeatures into a single one."""
if len(feats) == 0:
raise ValueError("Cannot concatenate empty list of features")
return reduce(lambda x, y: x + y, feats)
def __add__(self, other: "WRFeatures") -> "WRFeatures":
"""Concatenate two WRFeatures."""
if self.ndim != other.ndim:
raise ValueError("Cannot concatenate features of different dimensions")
if self.features.keys() != other.features.keys():
raise ValueError("Cannot concatenate features with different properties")
coords = np.concatenate([self.coords, other.coords], axis=0)
labels = np.concatenate([self.labels, other.labels], axis=0)
timepoints = np.concatenate([self.timepoints, other.timepoints], axis=0)
features = OrderedDict(
(k, np.concatenate([v, other.features[k]], axis=0))
for k, v in self.features.items()
)
return WRFeatures(
coords=coords, labels=labels, timepoints=timepoints, features=features
)
@classmethod
def from_mask_img(
cls,
mask: np.ndarray,
img: np.ndarray,
properties="regionprops2",
t_start: int = 0,
):
img = np.asarray(img)
mask = np.asarray(mask)
_ntime, ndim = mask.shape[0], mask.ndim - 1
if ndim not in (2, 3):
raise ValueError("Only 2D or 3D data is supported")
properties = tuple(_PROPERTIES[properties])
if "label" in properties or "centroid" in properties:
raise ValueError(
f"label and centroid should not be in properties {properties}"
)
if "border_dist" in properties:
use_border_dist = True
# remove border_dist from properties
properties = tuple(p for p in properties if p != "border_dist")
else:
use_border_dist = False
df_properties = ("label", "centroid", *properties)
dfs = []
for i, (y, x) in enumerate(zip(mask, img)):
_df = pd.DataFrame(
regionprops_table(y, intensity_image=x, properties=df_properties)
)
_df["timepoint"] = i + t_start
if use_border_dist:
_df["border_dist"] = _border_dist_fast(y)
dfs.append(_df)
df = pd.concat(dfs)
if use_border_dist:
properties = (*properties, "border_dist")
timepoints = df["timepoint"].values.astype(np.int32)
labels = df["label"].values.astype(np.int32)
coords = df[[f"centroid-{i}" for i in range(ndim)]].values.astype(np.float32)
features = OrderedDict(
(
p,
np.stack(
[
df[c].values.astype(np.float32)
for c in df.columns
if c.startswith(p)
],
axis=-1,
),
)
for p in properties
)
return cls(
coords=coords, labels=labels, timepoints=timepoints, features=features
)
# augmentations
class WRRandomCrop:
"""windowed region random crop augmentation."""
def __init__(
self,
crop_size: Optional[Union[int, Tuple[int]]] = None,
ndim: int = 2,
) -> None:
"""crop_size: tuple of int
can be tuple of length 1 (all dimensions)
of length ndim (y,x,...)
of length 2*ndim (y1,y2, x1,x2, ...).
"""
if isinstance(crop_size, int):
crop_size = (crop_size,) * 2 * ndim
elif isinstance(crop_size, Iterable):
pass
else:
raise ValueError(f"{crop_size} has to be int or tuple of int")
if len(crop_size) == 1:
crop_size = (crop_size[0],) * 2 * ndim
elif len(crop_size) == ndim:
crop_size = tuple(itertools.chain(*tuple((c, c) for c in crop_size)))
elif len(crop_size) == 2 * ndim:
pass
else:
raise ValueError(f"crop_size has to be of length 1, {ndim}, or {2 * ndim}")
crop_size = np.array(crop_size)
self._ndim = ndim
self._crop_bounds = crop_size[::2], crop_size[1::2]
self._rng = np.random.RandomState()
def __call__(self, features: WRFeatures):
crop_size = self._rng.randint(self._crop_bounds[0], self._crop_bounds[1] + 1)
points = features.coords
if len(points) == 0:
print("No points given, cannot ensure inside points")
return features
# sample point and corner relative to it
_idx = np.random.randint(len(points))
corner = (
points[_idx]
- crop_size
+ 1
+ self._rng.randint(crop_size // 4, 3 * crop_size // 4)
)
idx = _filter_points(points, shape=crop_size, origin=corner)
return (
WRFeatures(
coords=points[idx],
labels=features.labels[idx],
timepoints=features.timepoints[idx],
features=OrderedDict((k, v[idx]) for k, v in features.features.items()),
),
idx,
)
class WRBaseAugmentation:
def __init__(self, p: float = 0.5) -> None:
self._p = p
self._rng = np.random.RandomState()
def __call__(self, features: WRFeatures):
if self._rng.rand() > self._p or len(features) == 0:
return features
return self._augment(features)
def _augment(self, features: WRFeatures):
raise NotImplementedError()
class WRRandomFlip(WRBaseAugmentation):
def _augment(self, features: WRFeatures):
ndim = features.ndim
flip = self._rng.randint(0, 2, features.ndim)
points = features.coords.copy()
for i, f in enumerate(flip):
if f == 1:
points[:, ndim - i - 1] *= -1
return WRFeatures(
coords=points,
labels=features.labels,
timepoints=features.timepoints,
features=features.features,
)
def _scale_matrix(sz: float, sy: float, sx: float):
return np.diag([sz, sy, sx])
# def _scale_matrix(sy: float, sx: float):
# return np.array([[1, 0, 0], [0, sy, 0], [0, 0, sx]])
def _shear_matrix(shy: float, shx: float):
return np.array([[1, 0, 0], [0, 1 + shx * shy, shy], [0, shx, 1]])
def _rotation_matrix(theta: float):
return np.array([
[1, 0, 0],
[0, np.cos(theta), -np.sin(theta)],
[0, np.sin(theta), np.cos(theta)],
])
def _transform_affine(k: str, v: np.ndarray, M: np.ndarray):
ndim = len(M)
if k == "area":
v = np.linalg.det(M) * v
elif k == "equivalent_diameter_area":
v = np.linalg.det(M) ** (1 / len(M)) * v
elif k == "inertia_tensor":
# v' = M * v * M^T
v = v.reshape(-1, ndim, ndim)
# v * M^T
v = np.einsum("ijk, mk -> ijm", v, M)
# M * v
v = np.einsum("ij, kjm -> kim", M, v)
v = v.reshape(-1, ndim * ndim)
elif k in (
"intensity_mean",
"intensity_std",
"intensity_max",
"intensity_min",
"border_dist",
):
pass
else:
raise ValueError(f"Don't know how to affinely transform {k}")
return v
class WRRandomAffine(WRBaseAugmentation):
def __init__(
self,
degrees: float = 10,
scale: float = (0.9, 1.1),
shear: float = (0.1, 0.1),
p: float = 0.5,
):
super().__init__(p)
self.degrees = degrees if degrees is not None else 0
self.scale = scale if scale is not None else (1, 1)
self.shear = shear if shear is not None else (0, 0)
def _augment(self, features: WRFeatures):
degrees = self._rng.uniform(-self.degrees, self.degrees) / 180 * np.pi
scale = self._rng.uniform(*self.scale, 3)
shy = self._rng.uniform(-self.shear[0], self.shear[0])
shx = self._rng.uniform(-self.shear[1], self.shear[1])
self._M = (
_rotation_matrix(degrees) @ _scale_matrix(*scale) @ _shear_matrix(shy, shx)
)
# M is by default 3D , we need to remove the last dimension for 2D
self._M = self._M[-features.ndim :, -features.ndim :]
points = features.coords @ self._M.T
feats = OrderedDict(
(k, _transform_affine(k, v, self._M)) for k, v in features.features.items()
)
return WRFeatures(
coords=points,
labels=features.labels,
timepoints=features.timepoints,
features=feats,
)
class WRRandomBrightness(WRBaseAugmentation):
def __init__(
self,
scale: Tuple[float] = (0.5, 2.0),
shift: Tuple[float] = (-0.1, 0.1),
p: float = 0.5,
):
super().__init__(p)
self.scale = scale
self.shift = shift
def _augment(self, features: WRFeatures):
scale = self._rng.uniform(*self.scale)
shift = self._rng.uniform(*self.shift)
key_vals = []
for k, v in features.features.items():
if "intensity" in k:
v = v * scale + shift
key_vals.append((k, v))
feats = OrderedDict(key_vals)
return WRFeatures(
coords=features.coords,
labels=features.labels,
timepoints=features.timepoints,
features=feats,
)
class WRRandomOffset(WRBaseAugmentation):
def __init__(self, offset: float = (-3, 3), p: float = 0.5):
super().__init__(p)
self.offset = offset
def _augment(self, features: WRFeatures):
offset = self._rng.uniform(*self.offset, features.coords.shape)
coords = features.coords + offset
return WRFeatures(
coords=coords,
labels=features.labels,
timepoints=features.timepoints,
features=features.features,
)
class WRRandomMovement(WRBaseAugmentation):
"""random global linear shift."""
def __init__(self, offset: float = (-10, 10), p: float = 0.5):
super().__init__(p)
self.offset = offset
def _augment(self, features: WRFeatures):
base_offset = self._rng.uniform(*self.offset, features.coords.shape[-1])
tmin = features.timepoints.min()
offset = (features.timepoints[:, None] - tmin) * base_offset[None]
coords = features.coords + offset
return WRFeatures(
coords=coords,
labels=features.labels,
timepoints=features.timepoints,
features=features.features,
)
class WRAugmentationPipeline:
def __init__(self, augmentations: Sequence[WRBaseAugmentation]):
self.augmentations = augmentations
def __call__(self, feats: WRFeatures):
for aug in self.augmentations:
feats = aug(feats)
return feats
def get_features(
detections: np.ndarray,
imgs: Optional[np.ndarray] = None,
features: Literal["none", "wrfeat"] = "wrfeat",
ndim: int = 2,
n_workers=0,
progbar_class=tqdm,
) -> List[WRFeatures]:
detections = _check_dimensions(detections, ndim)
imgs = _check_dimensions(imgs, ndim)
logger.info(f"Extracting features from {len(detections)} detections")
if n_workers > 0:
logger.info(f"Using {n_workers} processes for feature extraction")
features = joblib.Parallel(n_jobs=n_workers, backend="loky")(
joblib.delayed(WRFeatures.from_mask_img)(
# New axis for time component
mask=mask[np.newaxis, ...].copy(),
img=img[np.newaxis, ...].copy(),
t_start=t,
)
for t, (mask, img) in progbar_class(
enumerate(zip(detections, imgs)),
total=len(imgs),
desc="Extracting features",
)
)
else:
logger.info("Using single process for feature extraction")
features = tuple(
WRFeatures.from_mask_img(
mask=mask[np.newaxis, ...],
img=img[np.newaxis, ...],
t_start=t,
)
for t, (mask, img) in progbar_class(
enumerate(zip(detections, imgs)),
total=len(imgs),
desc="Extracting features",
)
)
return features
def _check_dimensions(x: np.ndarray, ndim: int):
if ndim == 2 and not x.ndim == 3:
raise ValueError(f"Expected 2D data, got {x.ndim - 1}D data")
elif ndim == 3:
# if ndim=3 and data is two dimensional, it will be cast to 3D
if x.ndim == 3:
x = np.expand_dims(x, axis=1)
elif x.ndim == 4:
pass
else:
raise ValueError(f"Expected 3D data, got {x.ndim - 1}D data")
return x
def build_windows(
features: List[WRFeatures], window_size: int, progbar_class=tqdm
) -> List[dict]:
windows = []
for t1, t2 in progbar_class(
zip(range(0, len(features)), range(window_size, len(features) + 1)),
total=len(features) - window_size + 1,
desc="Building windows",
):
feat = WRFeatures.concat(features[t1:t2])
labels = feat.labels
timepoints = feat.timepoints
coords = feat.coords
if len(feat) == 0:
coords = np.zeros((0, feat.ndim), dtype=int)
w = dict(
coords=coords,
t1=t1,
labels=labels,
timepoints=timepoints,
features=feat.features_stacked,
)
windows.append(w)
logger.debug(f"Built {len(windows)} track windows.\n")
return windows
def build_windows_sd(
features: List[WRFeatures], imgs_enc, imgs_stable, boxes, imgs, masks, window_size: int, progbar_class=tqdm
) -> List[dict]:
windows = []
for t1, t2 in progbar_class(
zip(range(0, len(features)), range(window_size, len(features) + 1)),
total=len(features) - window_size + 1,
desc="Building windows",
):
feat = WRFeatures.concat(features[t1:t2])
labels = feat.labels
timepoints = feat.timepoints
coords = feat.coords
if len(feat) == 0:
coords = np.zeros((0, feat.ndim), dtype=int)
w = dict(
coords=coords,
t1=t1,
labels=labels,
timepoints=timepoints,
features=feat.features_stacked,
img_enc=imgs_enc[t1:t2],
image_stable=imgs_stable[t1:t2],
boxes=boxes,
img=imgs[t1:t2],
mask=masks[t1:t2],
coords_t=torch.tensor(coords, dtype=torch.float32),
labels_t=torch.tensor(labels, dtype=torch.int32),
timepoints_t=torch.tensor(timepoints, dtype=torch.int64),
features_t=torch.tensor(feat.features_stacked, dtype=torch.float32),
img_t=torch.tensor(imgs[t1:t2], dtype=torch.float32),
mask_t=torch.tensor(masks[t1:t2], dtype=torch.int32),
)
windows.append(w)
logger.debug(f"Built {len(windows)} track windows.\n")
return windows
if __name__ == "__main__":
imgs = load_tiff_timeseries(
# "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01",
"/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01",
)
masks = load_tiff_timeseries(
# "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01_GT/TRA",
"/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01_GT/TRA",
dtype=int,
)
features = get_features(detections=masks, imgs=imgs, ndim=3)
windows = build_windows(features, window_size=4)
# if __name__ == "__main__":
# y = np.zeros((1, 100, 100), np.uint8)
# y[:, 20:40, 20:60] = 1
# x = y + np.random.normal(0, 0.1, y.shape)
# f = WRFeatures.from_mask_img(y, x, properties=("intensity_mean", "area"))