Shengxiao0709's picture
Upload 78 files
8f72b1f verified
import logging
# from collections.abc import Sequence
from pathlib import Path
from timeit import default_timer
from typing import Literal
import joblib
import lz4.frame
import networkx as nx
import numpy as np
import pandas as pd
import tifffile
import torch
from numba import njit
from scipy import ndimage as ndi
from scipy.spatial.distance import cdist
from skimage.measure import regionprops
from skimage.segmentation import relabel_sequential
from torch.utils.data import Dataset
from tqdm import tqdm
from . import wrfeat
from ._check_ctc import _check_ctc, _get_node_attributes
from .augmentations import (
AugmentationPipeline,
RandomCrop,
default_augmenter,
)
from .features import (
_PROPERTIES,
extract_features_patch,
extract_features_regionprops,
)
from .matching import matching
from typing import List, Optional, Union, Tuple, Sequence
# from ..utils import blockwise_sum, normalize
from ..utils import blockwise_sum, normalize
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def _filter_track_df(df, start_frame, end_frame, downscale):
"""Only keep tracklets that are present in the given time interval."""
# only retain cells in interval
df = df[(df.t2 >= start_frame) & (df.t1 < end_frame)]
# shift start and end of each cell
df.t1 = df.t1 - start_frame
df.t2 = df.t2 - start_frame
# set start/end to min/max
df.t1 = df.t1.clip(0, end_frame - start_frame - 1)
df.t2 = df.t2.clip(0, end_frame - start_frame - 1)
# set all parents to 0 that are not in the interval
df.loc[~df.parent.isin(df.label), "parent"] = 0
if downscale > 1:
if start_frame % downscale != 0:
raise ValueError("start_frame must be a multiple of downscale")
logger.info(f"Temporal downscaling of tracklet links by {downscale}")
# remove tracklets that have been fully deleted by temporal downsampling
mask = (
# (df["t2"] - df["t1"] < downscale - 1)
(df["t1"] % downscale != 0)
& (df["t2"] % downscale != 0)
& (df["t1"] // downscale == df["t2"] // downscale)
)
logger.info(
f"Remove {mask.sum()} tracklets that are fully deleted by downsampling"
)
logger.debug(f"Remove {df[mask]}")
df = df[~mask]
# set parent to 0 if it has been deleted
df.loc[~df.parent.isin(df.label), "parent"] = 0
df["t2"] = (df["t2"] / float(downscale)).apply(np.floor).astype(int)
df["t1"] = (df["t1"] / float(downscale)).apply(np.ceil).astype(int)
# Correct for edge case of single frame tracklet
assert np.all(df["t1"] == np.minimum(df["t1"], df["t2"]))
return df
class _CompressedArray:
"""a simple class to compress and decompress a numpy arrays using lz4."""
# dont compress float types
def __init__(self, data):
self._data = lz4.frame.compress(data)
self._dtype = data.dtype.type
self._shape = data.shape
def decompress(self):
s = lz4.frame.decompress(self._data)
data = np.frombuffer(s, dtype=self._dtype).reshape(self._shape)
return data
def debug_function(f):
def wrapper(*args, **kwargs):
try:
batch = f(*args, **kwargs)
except Exception as e:
logger.error(f"Error in {f.__name__}: {e}")
return None
logger.info(f"XXXX {len(batch['coords'])}")
return batch
return wrapper
class CTCData(Dataset):
def __init__(
self,
root: str = "",
ndim: int = 2,
use_gt: bool = True,
detection_folders: List[str] = ["TRA"],
window_size: int = 10,
max_tokens: Optional[int] = None,
slice_pct: tuple = (0.0, 1.0),
downscale_spatial: int = 1,
downscale_temporal: int = 1,
augment: int = 0,
features: Literal[
"none",
"regionprops",
"regionprops2",
"patch",
"patch_regionprops",
"wrfeat",
] = "wrfeat",
sanity_dist: bool = False,
crop_size: Optional[tuple] = None,
return_dense: bool = False,
compress: bool = False,
**kwargs,
) -> None:
"""_summary_.
Args:
root (str):
Folder containing the CTC TRA folder.
ndim (int):
Number of dimensions of the data. Defaults to 2d
(if ndim=3 and data is two dimensional, it will be cast to 3D)
detection_folders:
List of relative paths to folder with detections.
Defaults to ["TRA"], which uses the ground truth detections.
window_size (int):
Window size for transformer.
slice_pct (tuple):
Slice the dataset by percentages (from, to).
augment (int):
if 0, no data augmentation. if > 0, defines level of data augmentation.
features (str):
Types of features to use.
sanity_dist (bool):
Use euclidian distance instead of the association matrix as a target.
crop_size (tuple):
Size of the crops to use for augmentation. If None, no cropping is used.
return_dense (bool):
Return dense masks and images in the data samples.
compress (bool):
Compress elements/remove img if not needed to save memory for large datasets
"""
super().__init__()
self.root = Path(root)
self.name = self.root.name
self.use_gt = use_gt
self.slice_pct = slice_pct
if not 0 <= slice_pct[0] < slice_pct[1] <= 1:
raise ValueError(f"Invalid slice_pct {slice_pct}")
self.downscale_spatial = downscale_spatial
self.downscale_temporal = downscale_temporal
self.detection_folders = detection_folders
self.ndim = ndim
self.features = features
if features not in ("none", "wrfeat") and features not in _PROPERTIES[ndim]:
raise ValueError(
f"'{features}' not one of the supported {ndim}D features"
f" {tuple(_PROPERTIES[ndim].keys())}"
)
logger.info(f"ROOT (config): \t{self.root}")
self.root, self.gt_tra_folder = self._guess_root_and_gt_tra_folder(self.root)
logger.info(f"ROOT (guessed): \t{self.root}")
logger.info(f"GT TRA (guessed):\t{self.gt_tra_folder}")
if self.use_gt:
self.gt_mask_folder = self._guess_mask_folder(self.root, self.gt_tra_folder)
else:
logger.info("Using dummy masks as GT")
self.gt_mask_folder = self._guess_det_folder(
self.root, self.detection_folders[0]
)
logger.info(f"GT MASK (guessed):\t{self.gt_mask_folder}")
# dont load image data if not needed
if features in ("none",):
self.img_folder = None
else:
self.img_folder = self._guess_img_folder(self.root)
logger.info(f"IMG (guessed):\t{self.img_folder}")
self.feat_dim, self.augmenter, self.cropper = self._setup_features_augs(
ndim, features, augment, crop_size
)
if window_size <= 1:
raise ValueError("window must be >1")
self.window_size = window_size
self.max_tokens = max_tokens
self.slice_pct = slice_pct
self.sanity_dist = sanity_dist
self.return_dense = return_dense
self.compress = compress
self.start_frame = 0
self.end_frame = None
start = default_timer()
if self.features == "wrfeat":
self.windows = self._load_wrfeat()
else:
self.windows = self._load()
self.n_divs = self._get_ndivs(self.windows)
if len(self.windows) > 0:
self.ndim = self.windows[0]["coords"].shape[1]
self.n_objects = tuple(len(t["coords"]) for t in self.windows)
logger.info(
f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track"
f" windows from {self.root} ({default_timer() - start:.1f}s)\n"
)
else:
self.n_objects = 0
logger.warning(f"Could not load any tracks from {self.root}")
if self.compress:
self._compress_data()
# def from_ctc
@classmethod
def from_arrays(cls, imgs: np.ndarray, masks: np.ndarray, train_args: dict):
self = cls(**train_args)
# for key, value in train_args.items():
# setattr(self, key, value)
# self.use_gt = use_gt
# self.slice_pct = slice_pct
# if not 0 <= slice_pct[0] < slice_pct[1] <= 1:
# raise ValueError(f"Invalid slice_pct {slice_pct}")
# self.downscale_spatial = downscale_spatial
# self.downscale_temporal = downscale_temporal
# self.detection_folders = detection_folders
# self.ndim = ndim
# self.features = features
# if features not in ("none", "wrfeat") and features not in _PROPERTIES[ndim]:
# raise ValueError(
# f"'{features}' not one of the supported {ndim}D features {tuple(_PROPERTIES[ndim].keys())}"
# )
# logger.info(f"ROOT (config): {self.root}")
# self.root, self.gt_tra_folder = self._guess_root_and_gt_tra_folder(self.root)
# logger.info(f"ROOT: \t{self.root}")
# logger.info(f"GT TRA:\t{self.gt_tra_folder}")
# if self.use_gt:
# self.gt_mask_folder = self._guess_mask_folder(self.root, self.gt_tra_folder)
# else:
# logger.info("Using dummy masks as GT")
# self.gt_mask_folder = self._guess_det_folder(
# self.root, self.detection_folders[0]
# )
# logger.info(f"GT MASK:\t{self.gt_mask_folder}")
# dont load image data if not needed
# if features in ("none",):
# self.img_folder = None
# else:
# self.img_folder = self._guess_img_folder(self.root)
# logger.info(f"IMG:\t\t{self.img_folder}")
self.feat_dim, self.augmenter, self.cropper = self._setup_features_augs(
self.ndim, self.features, self.augment, self.crop_size
)
start = default_timer()
if self.features == "wrfeat":
self.windows = self._load_wrfeat()
else:
self.windows = self._load()
self.n_divs = self._get_ndivs(self.windows)
if len(self.windows) > 0:
self.ndim = self.windows[0]["coords"].shape[1]
self.n_objects = tuple(len(t["coords"]) for t in self.windows)
logger.info(
f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track"
f" windows from {self.root} ({default_timer() - start:.1f}s)\n"
)
else:
self.n_objects = 0
logger.warning(f"Could not load any tracks from {self.root}")
if self.compress:
self._compress_data()
def _get_ndivs(self, windows):
n_divs = []
for w in tqdm(windows, desc="Counting divisions", leave=False):
_n = (
(
blockwise_sum(
torch.from_numpy(w["assoc_matrix"]).float(),
torch.from_numpy(w["timepoints"]).long(),
).max(dim=0)[0]
== 2
)
.sum()
.item()
)
n_divs.append(_n)
return n_divs
def _setup_features_augs(
self, ndim: int, features: str, augment: int, crop_size: Tuple[int]
):
if self.features == "wrfeat":
return self._setup_features_augs_wrfeat(ndim, features, augment, crop_size)
cropper = (
RandomCrop(
crop_size=crop_size,
ndim=ndim,
use_padding=False,
ensure_inside_points=True,
)
if crop_size is not None
else None
)
# Hack
if self.features == "none":
return 0, default_augmenter, cropper
if ndim == 2:
augmenter = AugmentationPipeline(p=0.8, level=augment) if augment else None
feat_dim = {
"none": 0,
"regionprops": 7,
"regionprops2": 6,
"patch": 256,
"patch_regionprops": 256 + 5,
}[features]
elif ndim == 3:
augmenter = AugmentationPipeline(p=0.8, level=augment) if augment else None
feat_dim = {
"none": 0,
"regionprops2": 11,
"patch_regionprops": 256 + 8,
}[features]
return feat_dim, augmenter, cropper
def _compress_data(self):
# compress masks and assoc_matrix
logger.info("Compressing masks and assoc_matrix to save memory")
for w in self.windows:
w["mask"] = _CompressedArray(w["mask"])
# dont compress full imgs (as needed for patch features)
w["img"] = _CompressedArray(w["img"])
w["assoc_matrix"] = _CompressedArray(w["assoc_matrix"])
self.gt_masks = _CompressedArray(self.gt_masks)
self.det_masks = {k: _CompressedArray(v) for k, v in self.det_masks.items()}
# dont compress full imgs (as needed for patch features)
self.imgs = _CompressedArray(self.imgs)
def _guess_root_and_gt_tra_folder(self, inp: Path):
"""Guesses the root and the ground truth folder from a given input path.
Args:
inp (Path): _description_
Returns:
Path: root folder,
"""
if inp.name == "TRA":
# 01_GT/TRA --> 01, 01_GT/TRA
root = inp.parent.parent / inp.parent.name.split("_")[0]
return root, inp
elif "ERR_SEG" in inp.name:
# 01_ERR_SEG --> 01, 01_GT/TRA. We know that the data is in CTC folder format
num = inp.name.split("_")[0]
return inp.parent / num, inp.parent / f"{num}_GT" / "TRA"
else:
ctc_tra = Path(f"{inp}_GT") / "TRA"
tra = ctc_tra if ctc_tra.exists() else inp / "TRA"
# 01 --> 01, 01_GT/TRA or 01/TRA
return inp, tra
def _guess_img_folder(self, root: Path):
"""Guesses the image folder corresponding to a root."""
if (root / "img").exists():
return root / "img"
else:
return root
def _guess_mask_folder(self, root: Path, gt_tra: Path):
"""Guesses the mask folder corresponding to a root.
In CTC format, we use silver truth segmentation masks.
"""
f = None
# first try CTC format
if gt_tra.parent.name.endswith("_GT"):
# We use the silver truth segmentation masks
f = root / str(gt_tra.parent.name).replace("_GT", "_ST") / "SEG"
# try our simpler 'img' format
if f is None or not f.exists():
f = gt_tra
if not f.exists():
raise ValueError(f"Could not find mask folder for {root}")
return f
@classmethod
def _guess_det_folder(cls, root: Path, suffix: str):
"""Checks for the annoying CTC format with dataset numbering as part of folder names."""
guesses = (
(root / suffix),
Path(f"{root}_{suffix}"),
Path(f"{root}_GT") / suffix,
)
for path in guesses:
if path.exists():
return path
logger.warning(f"Skipping non-existing detection folder {root / suffix}")
return None
def __len__(self):
return len(self.windows)
def _load_gt(self):
logger.info("Loading ground truth")
self.start_frame = int(
len(list(self.gt_mask_folder.glob("*.tif"))) * self.slice_pct[0]
)
self.end_frame = int(
len(list(self.gt_mask_folder.glob("*.tif"))) * self.slice_pct[1]
)
masks = self._load_tiffs(self.gt_mask_folder, dtype=np.int32)
masks = self._correct_gt_with_st(self.gt_mask_folder, masks, dtype=np.int32)
if self.use_gt:
track_df = self._load_tracklet_links(self.gt_tra_folder)
track_df = _filter_track_df(
track_df, self.start_frame, self.end_frame, self.downscale_temporal
)
else:
# create dummy track dataframe
logger.info("Using dummy track dataframe")
track_df = self._build_tracklets_without_gt(masks)
_check_ctc(track_df, _get_node_attributes(masks), masks)
# Build ground truth lineage graph
self.gt_labels, self.gt_timepoints, self.gt_graph = _ctc_lineages(
track_df, masks
)
return masks, track_df
def _correct_gt_with_st(
self, folder: Path, x: np.ndarray, dtype: Optional[str] = None
):
if str(folder).endswith("_GT/TRA"):
st_path = (
tuple(folder.parents)[1]
/ folder.parent.stem.replace("_GT", "_ST")
/ "SEG"
)
if not st_path.exists():
logger.debug("No _ST folder found, skipping correction")
else:
logger.info(f"ST MASK:\t\t{st_path} for correcting with ST masks")
st_masks = self._load_tiffs(st_path, dtype)
x = np.maximum(x, st_masks)
return x
def _load_tiffs(self, folder: Path, dtype=None):
assert isinstance(self.downscale_temporal, int)
logger.debug(f"Loading tiffs from {folder} as {dtype}")
logger.debug(
f"Temporal downscaling of {folder.name} by {self.downscale_temporal}"
)
x = np.stack([
tifffile.imread(f).astype(dtype)
for f in tqdm(
sorted(folder.glob("*.tif"))[
self.start_frame : self.end_frame : self.downscale_temporal
],
leave=False,
desc=f"Loading [{self.start_frame}:{self.end_frame}]",
)
])
# T, (Z), Y, X
assert isinstance(self.downscale_spatial, int)
if self.downscale_spatial > 1 or self.downscale_temporal > 1:
# TODO make safe for label arrays
logger.debug(
f"Spatial downscaling of {folder.name} by {self.downscale_spatial}"
)
slices = (
slice(None),
*tuple(
slice(None, None, self.downscale_spatial) for _ in range(x.ndim - 1)
),
)
x = x[slices]
logger.debug(f"Loaded array of shape {x.shape} from {folder}")
return x
def _masks2properties(self, masks):
"""Turn label masks into lists of properties, sorted (ascending) by time and label id.
Args:
masks (np.ndarray): T, (Z), H, W
Returns:
labels: List of labels
ts: List of timepoints
coords: List of coordinates
"""
# Get coordinates, timepoints, and labels of detections
labels = []
ts = []
coords = []
properties_by_time = dict()
assert len(self.imgs) == len(masks)
for _t, frame in tqdm(
enumerate(masks),
# total=len(detections),
leave=False,
desc="Loading masks and properties",
):
regions = regionprops(frame)
t_labels = []
t_ts = []
t_coords = []
for _r in regions:
t_labels.append(_r.label)
t_ts.append(_t)
centroid = np.array(_r.centroid).astype(int)
t_coords.append(centroid)
properties_by_time[_t] = dict(coords=t_coords, labels=t_labels)
labels.extend(t_labels)
ts.extend(t_ts)
coords.extend(t_coords)
labels = np.array(labels, dtype=int)
ts = np.array(ts, dtype=int)
coords = np.array(coords, dtype=int)
return labels, ts, coords, properties_by_time
def _load_tracklet_links(self, folder: Path) -> pd.DataFrame:
df = pd.read_csv(
folder / "man_track.txt",
delimiter=" ",
names=["label", "t1", "t2", "parent"],
dtype=int,
)
n_dets = (df.t2 - df.t1 + 1).sum()
logger.debug(f"{folder} has {n_dets} detections")
n_divs = (df[df.parent != 0]["parent"].value_counts() == 2).sum()
logger.debug(f"{folder} has {n_divs} divisions")
return df
def _build_tracklets_without_gt(self, masks):
"""Create a dataframe with tracklets from masks."""
rows = []
for t, m in enumerate(masks):
for c in np.unique(m[m > 0]):
rows.append([c, t, t, 0])
df = pd.DataFrame(rows, columns=["label", "t1", "t2", "parent"])
return df
def _check_dimensions(self, x: np.ndarray):
if self.ndim == 2 and not x.ndim == 3:
raise ValueError(f"Expected 2D data, got {x.ndim - 1}D data")
elif self.ndim == 3:
# if ndim=3 and data is two dimensional, it will be cast to 3D
if x.ndim == 3:
x = np.expand_dims(x, axis=1)
elif x.ndim == 4:
pass
else:
raise ValueError(f"Expected 3D data, got {x.ndim - 1}D data")
return x
def _load(self):
# Load ground truth
logger.info("Loading ground truth")
self.gt_masks, self.gt_track_df = self._load_gt()
self.gt_masks = self._check_dimensions(self.gt_masks)
# Load images
if self.img_folder is None:
self.imgs = np.zeros_like(self.gt_masks)
else:
logger.info("Loading images")
imgs = self._load_tiffs(self.img_folder, dtype=np.float32)
self.imgs = np.stack([
normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False)
])
self.imgs = self._check_dimensions(self.imgs)
if self.compress:
# prepare images to be compressed later (e.g. removing non masked parts for regionprops features)
self.imgs = np.stack([
_compress_img_mask_preproc(im, mask, self.features)
for im, mask in zip(self.imgs, self.gt_masks)
])
assert len(self.gt_masks) == len(self.imgs)
# Load each of the detection folders and create data samples with a sliding window
windows = []
self.properties_by_time = dict()
self.det_masks = dict()
for _f in self.detection_folders:
det_folder = self.root / _f
if det_folder == self.gt_mask_folder:
det_masks = self.gt_masks
logger.info("DET MASK:\tUsing GT masks")
(
det_labels,
det_ts,
det_coords,
det_properties_by_time,
) = self._masks2properties(det_masks)
det_gt_matching = {
t: {_l: _l for _l in det_properties_by_time[t]["labels"]}
for t in range(len(det_masks))
}
else:
det_folder = self._guess_det_folder(root=self.root, suffix=_f)
if det_folder is None:
continue
logger.info(f"DET MASK:\t{det_folder}")
det_masks = self._load_tiffs(det_folder, dtype=np.int32)
det_masks = self._correct_gt_with_st(
det_folder, det_masks, dtype=np.int32
)
det_masks = self._check_dimensions(det_masks)
(
det_labels,
det_ts,
det_coords,
det_properties_by_time,
) = self._masks2properties(det_masks)
# FIXME matching can be slow for big images
# raise NotImplementedError("Matching not implemented for 3d version")
det_gt_matching = {
t: {
_d: _gt
for _gt, _d in matching(
self.gt_masks[t],
det_masks[t],
threshold=0.3,
max_distance=16,
)
}
for t in tqdm(range(len(det_masks)), leave=False, desc="Matching")
}
self.properties_by_time[_f] = det_properties_by_time
self.det_masks[_f] = det_masks
_w = self._build_windows(
det_folder,
det_masks,
det_labels,
det_ts,
det_coords,
det_gt_matching,
)
windows.extend(_w)
return windows
def _build_windows(
self,
det_folder,
det_masks,
labels,
ts,
coords,
matching,
):
"""_summary_.
Args:
det_folder (_type_): _description_
det_masks (_type_): _description_
labels (_type_): _description_
ts (_type_): _description_
coords (_type_): _description_
matching (_type_): _description_
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
window_size = self.window_size
windows = []
# Creates the data samples with a sliding window
masks = self.gt_masks
for t1, t2 in tqdm(
zip(range(0, len(masks)), range(window_size, len(masks) + 1)),
total=len(masks) - window_size + 1,
leave=False,
desc="Building windows",
):
idx = (ts >= t1) & (ts < t2)
_ts = ts[idx]
_labels = labels[idx]
_coords = coords[idx]
# Use GT
# _labels = self.gt_labels[idx]
# _ts = self.gt_timepoints[idx]
if len(_labels) == 0:
# raise ValueError(f"No detections in sample {det_folder}:{t1}")
A = np.zeros((0, 0), dtype=bool)
_coords = np.zeros((0, masks.ndim - 1), dtype=int)
else:
if len(np.unique(_ts)) == 1:
logger.debug(
"Only detections from a single timepoint in sample"
f" {det_folder}:{t1}"
)
# build matrix from incomplete labels, but full lineage graph. If a label is missing, I should skip over it.
A = _ctc_assoc_matrix(
_labels,
_ts,
self.gt_graph,
matching,
)
if self.sanity_dist:
# # Sanity check: Can the model learn the euclidian distances?
# c = coords - coords.mean(axis=0, keepdims=True)
# c /= c.std(axis=0, keepdims=True)
# A = np.einsum('id,jd',c,c)
# A = 1 / (1 + np.exp(-A))
A = np.exp(-0.01 * cdist(_coords, _coords))
w = dict(
coords=_coords,
# TODO imgs and masks are unaltered here
t1=t1,
img=self.imgs[t1:t2],
mask=det_masks[t1:t2],
assoc_matrix=A,
labels=_labels,
timepoints=_ts,
)
windows.append(w)
logger.debug(f"Built {len(windows)} track windows from {det_folder}.\n")
return windows
def __getitem__(self, n: int, return_dense=None):
# if not set, use default
if self.features == "wrfeat":
return self._getitem_wrfeat(n, return_dense)
if return_dense is None:
return_dense = self.return_dense
track = self.windows[n]
coords = track["coords"]
assoc_matrix = track["assoc_matrix"]
labels = track["labels"]
img = track["img"]
mask = track["mask"]
timepoints = track["timepoints"]
min_time = track["t1"]
if isinstance(mask, _CompressedArray):
mask = mask.decompress()
if isinstance(img, _CompressedArray):
img = img.decompress()
if isinstance(assoc_matrix, _CompressedArray):
assoc_matrix = assoc_matrix.decompress()
# cropping
if self.cropper is not None:
(img2, mask2, coords2), idx = self.cropper(img, mask, coords)
cropped_timepoints = timepoints[idx]
# at least one detection in each timepoint to accept the crop
if len(np.unique(cropped_timepoints)) == self.window_size:
# at least two total detections to accept the crop
# if len(idx) >= 2:
img, mask, coords = img2, mask2, coords2
labels = labels[idx]
timepoints = timepoints[idx]
assoc_matrix = assoc_matrix[idx][:, idx]
else:
logger.debug("disable cropping as no trajectories would be left")
if self.features == "none":
if self.augmenter is not None:
coords = self.augmenter(coords)
# Empty features
features = np.zeros((len(coords), 0))
elif self.features in ("regionprops", "regionprops2"):
if self.augmenter is not None:
(img2, mask2, coords2), idx = self.augmenter(
img, mask, coords, timepoints - min_time
)
if len(idx) > 0:
img, mask, coords = img2, mask2, coords2
labels = labels[idx]
timepoints = timepoints[idx]
assoc_matrix = assoc_matrix[idx][:, idx]
mask = mask.astype(int)
else:
logger.debug(
"disable augmentation as no trajectories would be left"
)
features = tuple(
extract_features_regionprops(
m, im, labels[timepoints == i + min_time], properties=self.features
)
for i, (m, im) in enumerate(zip(mask, img))
)
features = np.concatenate(features, axis=0)
# features = np.zeros((len(coords), self.feat_dim))
elif self.features == "patch":
if self.augmenter is not None:
(img2, mask2, coords2), idx = self.augmenter(
img, mask, coords, timepoints - min_time
)
if len(idx) > 0:
img, mask, coords = img2, mask2, coords2
labels = labels[idx]
timepoints = timepoints[idx]
assoc_matrix = assoc_matrix[idx][:, idx]
mask = mask.astype(int)
else:
print("disable augmentation as no trajectories would be left")
features = tuple(
extract_features_patch(
m,
im,
coords[timepoints == min_time + i],
labels[timepoints == min_time + i],
)
for i, (m, im) in enumerate(zip(mask, img))
)
features = np.concatenate(features, axis=0)
elif self.features == "patch_regionprops":
if self.augmenter is not None:
(img2, mask2, coords2), idx = self.augmenter(
img, mask, coords, timepoints - min_time
)
if len(idx) > 0:
img, mask, coords = img2, mask2, coords2
labels = labels[idx]
timepoints = timepoints[idx]
assoc_matrix = assoc_matrix[idx][:, idx]
mask = mask.astype(int)
else:
print("disable augmentation as no trajectories would be left")
features1 = tuple(
extract_features_patch(
m,
im,
coords[timepoints == min_time + i],
labels[timepoints == min_time + i],
)
for i, (m, im) in enumerate(zip(mask, img))
)
features2 = tuple(
extract_features_regionprops(
m,
im,
labels[timepoints == i + min_time],
properties=self.features,
)
for i, (m, im) in enumerate(zip(mask, img))
)
features = tuple(
np.concatenate((f1, f2), axis=-1)
for f1, f2 in zip(features1, features2)
)
features = np.concatenate(features, axis=0)
# remove temporal offset and add timepoints to coords
relative_timepoints = timepoints - track["t1"]
coords = np.concatenate((relative_timepoints[:, None], coords), axis=-1)
if self.max_tokens and len(timepoints) > self.max_tokens:
time_incs = np.where(timepoints - np.roll(timepoints, 1))[0]
n_elems = time_incs[np.searchsorted(time_incs, self.max_tokens) - 1]
timepoints = timepoints[:n_elems]
labels = labels[:n_elems]
coords = coords[:n_elems]
features = features[:n_elems]
assoc_matrix = assoc_matrix[:n_elems, :n_elems]
logger.info(
f"Clipped window of size {timepoints[n_elems - 1] - timepoints.min()}"
)
coords0 = torch.from_numpy(coords).float()
features = torch.from_numpy(features).float()
assoc_matrix = torch.from_numpy(assoc_matrix.copy()).float()
labels = torch.from_numpy(labels).long()
timepoints = torch.from_numpy(timepoints).long()
if self.augmenter is not None:
coords = coords0.clone()
coords[:, 1:] += torch.randint(0, 256, (1, self.ndim))
else:
coords = coords0.clone()
res = dict(
features=features,
coords0=coords0,
coords=coords,
assoc_matrix=assoc_matrix,
timepoints=timepoints,
labels=labels,
)
if return_dense:
if all([x is not None for x in img]):
img = torch.from_numpy(img).float()
res["img"] = img
mask = torch.from_numpy(mask.astype(int)).long()
res["mask"] = mask
return res
# wrfeat functions...
# TODO: refactor this as a subclass or make everything a class factory. *very* hacky this way
def _setup_features_augs_wrfeat(
self, ndim: int, features: str, augment: int, crop_size: Tuple[int]
):
# FIXME: hardcoded
feat_dim = 7 if ndim == 2 else 12
if augment == 1:
augmenter = wrfeat.WRAugmentationPipeline([
wrfeat.WRRandomFlip(p=0.5),
wrfeat.WRRandomAffine(
p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1)
),
# wrfeat.WRRandomBrightness(p=0.8, factor=(0.5, 2.0)),
# wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)),
])
elif augment == 2:
augmenter = wrfeat.WRAugmentationPipeline([
wrfeat.WRRandomFlip(p=0.5),
wrfeat.WRRandomAffine(
p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1)
),
wrfeat.WRRandomBrightness(p=0.8),
wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)),
])
elif augment == 3:
augmenter = wrfeat.WRAugmentationPipeline([
wrfeat.WRRandomFlip(p=0.5),
wrfeat.WRRandomAffine(
p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1)
),
wrfeat.WRRandomBrightness(p=0.8),
wrfeat.WRRandomMovement(offset=(-10, 10), p=0.3),
wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)),
])
else:
augmenter = None
cropper = (
wrfeat.WRRandomCrop(
crop_size=crop_size,
ndim=ndim,
)
if crop_size is not None
else None
)
return feat_dim, augmenter, cropper
def _load_wrfeat(self):
# Load ground truth
self.gt_masks, self.gt_track_df = self._load_gt()
self.gt_masks = self._check_dimensions(self.gt_masks)
# Load images
if self.img_folder is None:
if self.gt_masks is not None:
self.imgs = np.zeros_like(self.gt_masks)
else:
raise NotImplementedError("No images and no GT masks")
else:
logger.info("Loading images")
imgs = self._load_tiffs(self.img_folder, dtype=np.float32)
self.imgs = np.stack([
normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False)
])
self.imgs = self._check_dimensions(self.imgs)
if self.compress:
# prepare images to be compressed later (e.g. removing non masked parts for regionprops features)
self.imgs = np.stack([
_compress_img_mask_preproc(im, mask, self.features)
for im, mask in zip(self.imgs, self.gt_masks)
])
assert len(self.gt_masks) == len(self.imgs)
# Load each of the detection folders and create data samples with a sliding window
windows = []
self.properties_by_time = dict()
self.det_masks = dict()
logger.info("Loading detections")
for _f in self.detection_folders:
det_folder = self.root / _f
if det_folder == self.gt_mask_folder:
det_masks = self.gt_masks
logger.info("DET MASK:\tUsing GT masks")
# identity matching
det_gt_matching = {
t: {_l: _l for _l in set(np.unique(d)) - {0}}
for t, d in enumerate(det_masks)
}
else:
det_folder = self._guess_det_folder(root=self.root, suffix=_f)
if det_folder is None:
continue
logger.info(f"DET MASK (guessed):\t{det_folder}")
det_masks = self._load_tiffs(det_folder, dtype=np.int32)
det_masks = self._correct_gt_with_st(
det_folder, det_masks, dtype=np.int32
)
det_masks = self._check_dimensions(det_masks)
# FIXME matching can be slow for big images
# raise NotImplementedError("Matching not implemented for 3d version")
det_gt_matching = {
t: {
_d: _gt
for _gt, _d in matching(
self.gt_masks[t],
det_masks[t],
threshold=0.3,
max_distance=16,
)
}
for t in tqdm(range(len(det_masks)), leave=False, desc="Matching")
}
self.det_masks[_f] = det_masks
# build features
features = joblib.Parallel(n_jobs=8)(
joblib.delayed(wrfeat.WRFeatures.from_mask_img)(
mask=mask[None], img=img[None], t_start=t
)
for t, (mask, img) in enumerate(zip(det_masks, self.imgs))
)
properties_by_time = dict()
for _t, _feats in enumerate(features):
properties_by_time[_t] = dict(
coords=_feats.coords, labels=_feats.labels
)
self.properties_by_time[_f] = properties_by_time
_w = self._build_windows_wrfeat(
features,
det_masks,
det_gt_matching,
)
windows.extend(_w)
return windows
def _build_windows_wrfeat(
self,
features: Sequence[wrfeat.WRFeatures],
det_masks: np.ndarray,
matching: Tuple[dict],
):
assert len(self.imgs) == len(det_masks)
window_size = self.window_size
windows = []
# Creates the data samples with a sliding window
for t1, t2 in tqdm(
zip(range(0, len(det_masks)), range(window_size, len(det_masks) + 1)),
total=len(det_masks) - window_size + 1,
leave=False,
desc="Building windows",
):
img = self.imgs[t1:t2]
mask = det_masks[t1:t2]
feat = wrfeat.WRFeatures.concat(features[t1:t2])
labels = feat.labels
timepoints = feat.timepoints
coords = feat.coords
if len(feat) == 0:
A = np.zeros((0, 0), dtype=bool)
coords = np.zeros((0, feat.ndim), dtype=int)
else:
# build matrix from incomplete labels, but full lineage graph. If a label is missing, I should skip over it.
A = _ctc_assoc_matrix(
labels,
timepoints,
self.gt_graph,
matching,
)
w = dict(
coords=coords,
# TODO imgs and masks are unaltered here
t1=t1,
img=img,
mask=mask,
assoc_matrix=A,
labels=labels,
timepoints=timepoints,
wrfeat=feat,
)
windows.append(w)
logger.debug(f"Built {len(windows)} track windows.\n")
return windows
def _getitem_wrfeat(self, n: int, return_dense=None):
# if not set, use default
if return_dense is None:
return_dense = self.return_dense
track = self.windows[n]
# coords = track["coords"]
assoc_matrix = track["assoc_matrix"]
labels = track["labels"]
img = track["img"]
mask = track["mask"]
timepoints = track["timepoints"]
# track["t1"]
feat = track["wrfeat"]
if return_dense and isinstance(mask, _CompressedArray):
mask = mask.decompress()
if return_dense and isinstance(img, _CompressedArray):
img = img.decompress()
if isinstance(assoc_matrix, _CompressedArray):
assoc_matrix = assoc_matrix.decompress()
# cropping
if self.cropper is not None:
# Use only if there is at least one timepoint per detection
cropped_feat, cropped_idx = self.cropper(feat)
cropped_timepoints = timepoints[cropped_idx]
if len(np.unique(cropped_timepoints)) == self.window_size:
idx = cropped_idx
feat = cropped_feat
labels = labels[idx]
timepoints = timepoints[idx]
assoc_matrix = assoc_matrix[idx][:, idx]
else:
logger.debug("Skipping cropping")
if self.augmenter is not None:
feat = self.augmenter(feat)
coords0 = np.concatenate((feat.timepoints[:, None], feat.coords), axis=-1)
coords0 = torch.from_numpy(coords0).float()
assoc_matrix = torch.from_numpy(assoc_matrix.astype(np.float32))
features = torch.from_numpy(feat.features_stacked).float()
labels = torch.from_numpy(feat.labels).long()
timepoints = torch.from_numpy(feat.timepoints).long()
if self.max_tokens and len(timepoints) > self.max_tokens:
time_incs = np.where(timepoints - np.roll(timepoints, 1))[0]
n_elems = time_incs[np.searchsorted(time_incs, self.max_tokens) - 1]
timepoints = timepoints[:n_elems]
labels = labels[:n_elems]
coords0 = coords0[:n_elems]
features = features[:n_elems]
assoc_matrix = assoc_matrix[:n_elems, :n_elems]
logger.debug(
f"Clipped window of size {timepoints[n_elems - 1] - timepoints.min()}"
)
if self.augmenter is not None:
coords = coords0.clone()
coords[:, 1:] += torch.randint(0, 512, (1, self.ndim))
else:
coords = coords0.clone()
res = dict(
features=features,
coords0=coords0,
coords=coords,
assoc_matrix=assoc_matrix,
timepoints=timepoints,
labels=labels,
)
if return_dense:
if all([x is not None for x in img]):
img = torch.from_numpy(img).float()
res["img"] = img
mask = torch.from_numpy(mask.astype(int)).long()
res["mask"] = mask
return res
def _ctc_lineages(df, masks, t1=0, t2=None):
"""From a ctc dataframe, create a digraph that contains all sublineages
between t1 and t2 (exclusive t2).
Args:
df: pd.DataFrame with columns `label`, `t1`, `t2`, `parent` (man_track.txt)
masks: List of masks. If t1 is not 0, then the masks are assumed to be already cropped accordingly.
t1: Start timepoint
t2: End timepoint (exclusive). If None, then t2 is set to len(masks)
Returns:
labels: List of label ids extracted from the masks, ordered by timepoint.
ts: List of corresponding timepoints
graph: The digraph of the lineages between t1 and t2.
"""
if t1 > 0:
assert t2 is not None
assert t2 - t1 == len(masks)
if t2 is None:
t2 = len(masks)
graph = nx.DiGraph()
labels = []
ts = []
# get all objects that are present in the time interval
df_sub = df[(df.t1 < t2) & (df.t2 >= t1)]
# Correct offset
df_sub.loc[:, "t1"] -= t1
df_sub.loc[:, "t2"] -= t1
# all_labels = df_sub.label.unique()
# TODO speed up by precalculating unique values once
# in_masks = set(np.where(np.bincount(np.stack(masks[t1:t2]).ravel()))[0]) - {0}
# all_labels = [l for l in all_labels if l in in_masks]
all_labels = set()
for t in tqdm(
range(0, t2 - t1), desc="Building and checking lineage graph", leave=False
):
# get all entities at timepoint
obs = df_sub[(df_sub.t1 <= t) & (df_sub.t2 >= t)]
in_t = set(np.where(np.bincount(masks[t].ravel()))[0]) - {0}
all_labels.update(in_t)
for row in obs.itertuples():
label, t1, t2, parent = row.label, row.t1, row.t2, row.parent
if label not in in_t:
continue
labels.append(label)
ts.append(t)
# add label as node if not already in graph
if not graph.has_node(label):
graph.add_node(label)
# Parents have been added in previous timepoints
if parent in all_labels:
if not graph.has_node(parent):
graph.add_node(parent)
graph.add_edge(parent, label)
labels = np.array(labels)
ts = np.array(ts)
return labels, ts, graph
@njit
def _assoc(A: np.ndarray, labels: np.ndarray, family: np.ndarray):
"""For each detection, associate with all detections that are."""
for i in range(len(labels)):
for j in range(len(labels)):
A[i, j] = family[i, labels[j]]
def _ctc_assoc_matrix(detections, ts, graph, matching):
"""Create the association matrix for a list of labels and a tracklet parent -> childrend graph.
Each detection is associated with all its ancestors and descendants, but not its siblings and their offspring.
Args:
detections: list of integer labels, ordered by timepoint
ts: list of timepoints corresponding to the detections
graph: networkx DiGraph with each ground truth tracklet id (spanning n timepoints) as a single node
and parent -> children relationships as edges.
matching: for each timepoint, a dictionary that maps from detection id to gt tracklet id
"""
assert 0 not in graph
matched_gt = []
for i, (label, t) in enumerate(zip(detections, ts)):
gt_tracklet_id = matching[t].get(label, 0)
matched_gt.append(gt_tracklet_id)
matched_gt = np.array(matched_gt, dtype=int)
# Now we have the subset of gt nodes that is matched to any detection in the current window
# relabel to reduce the size of lookup matrices
# offset 0 not allowed in skimage, which makes this very annoying
relabeled_gt, fwd_map, _inv_map = relabel_sequential(matched_gt, offset=1)
# dict is faster than arraymap
fwd_map = dict(zip(fwd_map.in_values, fwd_map.out_values))
# inv_map = dict(zip(inv_map.in_values, inv_map.out_values))
# the family relationships for each ground truth detection,
# Maps from local detection number (0-indexed) to global gt tracklet id (1-indexed)
family = np.zeros((len(detections), len(relabeled_gt) + 1), bool)
# Connects each tracklet id with its children and parent tracklets (according to man_track.txt)
for i, (label, t) in enumerate(zip(detections, ts)):
# Get the original label corresponding to the graph
gt_tracklet_id = matching[t].get(label, None)
if gt_tracklet_id is not None:
ancestors = []
descendants = []
# This iterates recursively through the graph
for n in nx.descendants(graph, gt_tracklet_id):
if n in fwd_map:
descendants.append(fwd_map[n])
for n in nx.ancestors(graph, gt_tracklet_id):
if n in fwd_map:
ancestors.append(fwd_map[n])
family[i, np.array([fwd_map[gt_tracklet_id], *ancestors, *descendants])] = (
True
)
else:
pass
# Now we match to nothing, so even the matrix diagonal will not be filled.
# This assures that matching to 0 is always false
assert family[:, 0].sum() == 0
# Create the detection-to-detection association matrix
A = np.zeros((len(detections), len(detections)), dtype=bool)
_assoc(A, relabeled_gt, family)
return A
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def _compress_img_mask_preproc(img, mask, features):
"""Remove certain img pixels if not needed to save memory for large datasets."""
# dont change anything if we need patch values
if features in ("patch", "patch_regionprops"):
# clear img pixels outside of patch_mask of size 16x16
patch_width = 16 # TOD: hardcoded: change this if needed
coords = tuple(np.array(r.centroid).astype(int) for r in regionprops(mask))
img2 = np.zeros_like(img)
if len(coords) > 0:
coords = np.stack(coords)
coords = np.clip(coords, 0, np.array(mask.shape)[None] - 1)
patch_mask = np.zeros_like(img, dtype=bool)
patch_mask[tuple(coords.T)] = True
# retain 3*patch_width+1 around center to be safe...
patch_mask = ndi.maximum_filter(patch_mask, 3 * patch_width + 1)
img2[patch_mask] = img[patch_mask]
else:
# otherwise set img value inside masks to mean
# FIXME: change when using other intensity based regionprops
img2 = np.zeros_like(img)
for reg in regionprops(mask, intensity_image=img):
m = mask[reg.slice] == reg.label
img2[reg.slice][m] = reg.mean_intensity
return img2
def pad_tensor(x, n_max: int, dim=0, value=0):
n = x.shape[dim]
if n_max < n:
raise ValueError(f"pad_tensor: n_max={n_max} must be larger than n={n} !")
pad_shape = list(x.shape)
pad_shape[dim] = n_max - n
# pad = torch.full(pad_shape, fill_value=value, dtype=x.dtype).to(x.device)
pad = torch.full(pad_shape, fill_value=value, dtype=x.dtype)
return torch.cat((x, pad), dim=dim)
def collate_sequence_padding(batch):
"""Collate function that pads all sequences to the same length."""
lens = tuple(len(x["coords"]) for x in batch)
n_max_len = max(lens)
# print(tuple(len(x["coords"]) for x in batch))
# print(tuple(len(x["features"]) for x in batch))
# print(batch[0].keys())
tuple(batch[0].keys())
normal_keys = {
"coords": 0,
"features": 0,
"labels": 0, # Not needed, remove for speed.
"timepoints": -1, # There are real timepoints with t=0. -1 for distinction from that.
}
n_pads = tuple(n_max_len - s for s in lens)
batch_new = dict(
(
k,
torch.stack(
[pad_tensor(x[k], n_max=n_max_len, value=v) for x in batch], dim=0
),
)
for k, v in normal_keys.items()
)
batch_new["assoc_matrix"] = torch.stack(
[
pad_tensor(
pad_tensor(x["assoc_matrix"], n_max_len, dim=0), n_max_len, dim=1
)
for x in batch
],
dim=0,
)
# add boolean mask that signifies whether tokens are padded or not (such that they can be ignored later)
pad_mask = torch.zeros((len(batch), n_max_len), dtype=torch.bool)
for i, n_pad in enumerate(n_pads):
pad_mask[i, n_max_len - n_pad :] = True
batch_new["padding_mask"] = pad_mask.bool()
return batch_new
if __name__ == "__main__":
dummy_data = CTCData(
root="../../scripts/data/synthetic_cells/01",
ndim=2,
detection_folders=["TRA"],
window_size=4,
max_tokens=None,
augment=3,
features="none",
downscale_temporal=1,
downscale_spatial=1,
sanity_dist=False,
crop_size=(256, 256),
)
x = dummy_data[0]