Spaces:
Sleeping
Sleeping
| import logging | |
| # from collections.abc import Sequence | |
| from pathlib import Path | |
| from timeit import default_timer | |
| from typing import Literal | |
| import joblib | |
| import lz4.frame | |
| import networkx as nx | |
| import numpy as np | |
| import pandas as pd | |
| import tifffile | |
| import torch | |
| from numba import njit | |
| from scipy import ndimage as ndi | |
| from scipy.spatial.distance import cdist | |
| from skimage.measure import regionprops | |
| from skimage.segmentation import relabel_sequential | |
| from torch.utils.data import Dataset | |
| from tqdm import tqdm | |
| from . import wrfeat | |
| from ._check_ctc import _check_ctc, _get_node_attributes | |
| from .augmentations import ( | |
| AugmentationPipeline, | |
| RandomCrop, | |
| default_augmenter, | |
| ) | |
| from .features import ( | |
| _PROPERTIES, | |
| extract_features_patch, | |
| extract_features_regionprops, | |
| ) | |
| from .matching import matching | |
| from typing import List, Optional, Union, Tuple, Sequence | |
| # from ..utils import blockwise_sum, normalize | |
| from ..utils import blockwise_sum, normalize | |
| logger = logging.getLogger(__name__) | |
| logger.setLevel(logging.INFO) | |
| def _filter_track_df(df, start_frame, end_frame, downscale): | |
| """Only keep tracklets that are present in the given time interval.""" | |
| # only retain cells in interval | |
| df = df[(df.t2 >= start_frame) & (df.t1 < end_frame)] | |
| # shift start and end of each cell | |
| df.t1 = df.t1 - start_frame | |
| df.t2 = df.t2 - start_frame | |
| # set start/end to min/max | |
| df.t1 = df.t1.clip(0, end_frame - start_frame - 1) | |
| df.t2 = df.t2.clip(0, end_frame - start_frame - 1) | |
| # set all parents to 0 that are not in the interval | |
| df.loc[~df.parent.isin(df.label), "parent"] = 0 | |
| if downscale > 1: | |
| if start_frame % downscale != 0: | |
| raise ValueError("start_frame must be a multiple of downscale") | |
| logger.info(f"Temporal downscaling of tracklet links by {downscale}") | |
| # remove tracklets that have been fully deleted by temporal downsampling | |
| mask = ( | |
| # (df["t2"] - df["t1"] < downscale - 1) | |
| (df["t1"] % downscale != 0) | |
| & (df["t2"] % downscale != 0) | |
| & (df["t1"] // downscale == df["t2"] // downscale) | |
| ) | |
| logger.info( | |
| f"Remove {mask.sum()} tracklets that are fully deleted by downsampling" | |
| ) | |
| logger.debug(f"Remove {df[mask]}") | |
| df = df[~mask] | |
| # set parent to 0 if it has been deleted | |
| df.loc[~df.parent.isin(df.label), "parent"] = 0 | |
| df["t2"] = (df["t2"] / float(downscale)).apply(np.floor).astype(int) | |
| df["t1"] = (df["t1"] / float(downscale)).apply(np.ceil).astype(int) | |
| # Correct for edge case of single frame tracklet | |
| assert np.all(df["t1"] == np.minimum(df["t1"], df["t2"])) | |
| return df | |
| class _CompressedArray: | |
| """a simple class to compress and decompress a numpy arrays using lz4.""" | |
| # dont compress float types | |
| def __init__(self, data): | |
| self._data = lz4.frame.compress(data) | |
| self._dtype = data.dtype.type | |
| self._shape = data.shape | |
| def decompress(self): | |
| s = lz4.frame.decompress(self._data) | |
| data = np.frombuffer(s, dtype=self._dtype).reshape(self._shape) | |
| return data | |
| def debug_function(f): | |
| def wrapper(*args, **kwargs): | |
| try: | |
| batch = f(*args, **kwargs) | |
| except Exception as e: | |
| logger.error(f"Error in {f.__name__}: {e}") | |
| return None | |
| logger.info(f"XXXX {len(batch['coords'])}") | |
| return batch | |
| return wrapper | |
| class CTCData(Dataset): | |
| def __init__( | |
| self, | |
| root: str = "", | |
| ndim: int = 2, | |
| use_gt: bool = True, | |
| detection_folders: List[str] = ["TRA"], | |
| window_size: int = 10, | |
| max_tokens: Optional[int] = None, | |
| slice_pct: tuple = (0.0, 1.0), | |
| downscale_spatial: int = 1, | |
| downscale_temporal: int = 1, | |
| augment: int = 0, | |
| features: Literal[ | |
| "none", | |
| "regionprops", | |
| "regionprops2", | |
| "patch", | |
| "patch_regionprops", | |
| "wrfeat", | |
| ] = "wrfeat", | |
| sanity_dist: bool = False, | |
| crop_size: Optional[tuple] = None, | |
| return_dense: bool = False, | |
| compress: bool = False, | |
| **kwargs, | |
| ) -> None: | |
| """_summary_. | |
| Args: | |
| root (str): | |
| Folder containing the CTC TRA folder. | |
| ndim (int): | |
| Number of dimensions of the data. Defaults to 2d | |
| (if ndim=3 and data is two dimensional, it will be cast to 3D) | |
| detection_folders: | |
| List of relative paths to folder with detections. | |
| Defaults to ["TRA"], which uses the ground truth detections. | |
| window_size (int): | |
| Window size for transformer. | |
| slice_pct (tuple): | |
| Slice the dataset by percentages (from, to). | |
| augment (int): | |
| if 0, no data augmentation. if > 0, defines level of data augmentation. | |
| features (str): | |
| Types of features to use. | |
| sanity_dist (bool): | |
| Use euclidian distance instead of the association matrix as a target. | |
| crop_size (tuple): | |
| Size of the crops to use for augmentation. If None, no cropping is used. | |
| return_dense (bool): | |
| Return dense masks and images in the data samples. | |
| compress (bool): | |
| Compress elements/remove img if not needed to save memory for large datasets | |
| """ | |
| super().__init__() | |
| self.root = Path(root) | |
| self.name = self.root.name | |
| self.use_gt = use_gt | |
| self.slice_pct = slice_pct | |
| if not 0 <= slice_pct[0] < slice_pct[1] <= 1: | |
| raise ValueError(f"Invalid slice_pct {slice_pct}") | |
| self.downscale_spatial = downscale_spatial | |
| self.downscale_temporal = downscale_temporal | |
| self.detection_folders = detection_folders | |
| self.ndim = ndim | |
| self.features = features | |
| if features not in ("none", "wrfeat") and features not in _PROPERTIES[ndim]: | |
| raise ValueError( | |
| f"'{features}' not one of the supported {ndim}D features" | |
| f" {tuple(_PROPERTIES[ndim].keys())}" | |
| ) | |
| logger.info(f"ROOT (config): \t{self.root}") | |
| self.root, self.gt_tra_folder = self._guess_root_and_gt_tra_folder(self.root) | |
| logger.info(f"ROOT (guessed): \t{self.root}") | |
| logger.info(f"GT TRA (guessed):\t{self.gt_tra_folder}") | |
| if self.use_gt: | |
| self.gt_mask_folder = self._guess_mask_folder(self.root, self.gt_tra_folder) | |
| else: | |
| logger.info("Using dummy masks as GT") | |
| self.gt_mask_folder = self._guess_det_folder( | |
| self.root, self.detection_folders[0] | |
| ) | |
| logger.info(f"GT MASK (guessed):\t{self.gt_mask_folder}") | |
| # dont load image data if not needed | |
| if features in ("none",): | |
| self.img_folder = None | |
| else: | |
| self.img_folder = self._guess_img_folder(self.root) | |
| logger.info(f"IMG (guessed):\t{self.img_folder}") | |
| self.feat_dim, self.augmenter, self.cropper = self._setup_features_augs( | |
| ndim, features, augment, crop_size | |
| ) | |
| if window_size <= 1: | |
| raise ValueError("window must be >1") | |
| self.window_size = window_size | |
| self.max_tokens = max_tokens | |
| self.slice_pct = slice_pct | |
| self.sanity_dist = sanity_dist | |
| self.return_dense = return_dense | |
| self.compress = compress | |
| self.start_frame = 0 | |
| self.end_frame = None | |
| start = default_timer() | |
| if self.features == "wrfeat": | |
| self.windows = self._load_wrfeat() | |
| else: | |
| self.windows = self._load() | |
| self.n_divs = self._get_ndivs(self.windows) | |
| if len(self.windows) > 0: | |
| self.ndim = self.windows[0]["coords"].shape[1] | |
| self.n_objects = tuple(len(t["coords"]) for t in self.windows) | |
| logger.info( | |
| f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track" | |
| f" windows from {self.root} ({default_timer() - start:.1f}s)\n" | |
| ) | |
| else: | |
| self.n_objects = 0 | |
| logger.warning(f"Could not load any tracks from {self.root}") | |
| if self.compress: | |
| self._compress_data() | |
| # def from_ctc | |
| def from_arrays(cls, imgs: np.ndarray, masks: np.ndarray, train_args: dict): | |
| self = cls(**train_args) | |
| # for key, value in train_args.items(): | |
| # setattr(self, key, value) | |
| # self.use_gt = use_gt | |
| # self.slice_pct = slice_pct | |
| # if not 0 <= slice_pct[0] < slice_pct[1] <= 1: | |
| # raise ValueError(f"Invalid slice_pct {slice_pct}") | |
| # self.downscale_spatial = downscale_spatial | |
| # self.downscale_temporal = downscale_temporal | |
| # self.detection_folders = detection_folders | |
| # self.ndim = ndim | |
| # self.features = features | |
| # if features not in ("none", "wrfeat") and features not in _PROPERTIES[ndim]: | |
| # raise ValueError( | |
| # f"'{features}' not one of the supported {ndim}D features {tuple(_PROPERTIES[ndim].keys())}" | |
| # ) | |
| # logger.info(f"ROOT (config): {self.root}") | |
| # self.root, self.gt_tra_folder = self._guess_root_and_gt_tra_folder(self.root) | |
| # logger.info(f"ROOT: \t{self.root}") | |
| # logger.info(f"GT TRA:\t{self.gt_tra_folder}") | |
| # if self.use_gt: | |
| # self.gt_mask_folder = self._guess_mask_folder(self.root, self.gt_tra_folder) | |
| # else: | |
| # logger.info("Using dummy masks as GT") | |
| # self.gt_mask_folder = self._guess_det_folder( | |
| # self.root, self.detection_folders[0] | |
| # ) | |
| # logger.info(f"GT MASK:\t{self.gt_mask_folder}") | |
| # dont load image data if not needed | |
| # if features in ("none",): | |
| # self.img_folder = None | |
| # else: | |
| # self.img_folder = self._guess_img_folder(self.root) | |
| # logger.info(f"IMG:\t\t{self.img_folder}") | |
| self.feat_dim, self.augmenter, self.cropper = self._setup_features_augs( | |
| self.ndim, self.features, self.augment, self.crop_size | |
| ) | |
| start = default_timer() | |
| if self.features == "wrfeat": | |
| self.windows = self._load_wrfeat() | |
| else: | |
| self.windows = self._load() | |
| self.n_divs = self._get_ndivs(self.windows) | |
| if len(self.windows) > 0: | |
| self.ndim = self.windows[0]["coords"].shape[1] | |
| self.n_objects = tuple(len(t["coords"]) for t in self.windows) | |
| logger.info( | |
| f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track" | |
| f" windows from {self.root} ({default_timer() - start:.1f}s)\n" | |
| ) | |
| else: | |
| self.n_objects = 0 | |
| logger.warning(f"Could not load any tracks from {self.root}") | |
| if self.compress: | |
| self._compress_data() | |
| def _get_ndivs(self, windows): | |
| n_divs = [] | |
| for w in tqdm(windows, desc="Counting divisions", leave=False): | |
| _n = ( | |
| ( | |
| blockwise_sum( | |
| torch.from_numpy(w["assoc_matrix"]).float(), | |
| torch.from_numpy(w["timepoints"]).long(), | |
| ).max(dim=0)[0] | |
| == 2 | |
| ) | |
| .sum() | |
| .item() | |
| ) | |
| n_divs.append(_n) | |
| return n_divs | |
| def _setup_features_augs( | |
| self, ndim: int, features: str, augment: int, crop_size: Tuple[int] | |
| ): | |
| if self.features == "wrfeat": | |
| return self._setup_features_augs_wrfeat(ndim, features, augment, crop_size) | |
| cropper = ( | |
| RandomCrop( | |
| crop_size=crop_size, | |
| ndim=ndim, | |
| use_padding=False, | |
| ensure_inside_points=True, | |
| ) | |
| if crop_size is not None | |
| else None | |
| ) | |
| # Hack | |
| if self.features == "none": | |
| return 0, default_augmenter, cropper | |
| if ndim == 2: | |
| augmenter = AugmentationPipeline(p=0.8, level=augment) if augment else None | |
| feat_dim = { | |
| "none": 0, | |
| "regionprops": 7, | |
| "regionprops2": 6, | |
| "patch": 256, | |
| "patch_regionprops": 256 + 5, | |
| }[features] | |
| elif ndim == 3: | |
| augmenter = AugmentationPipeline(p=0.8, level=augment) if augment else None | |
| feat_dim = { | |
| "none": 0, | |
| "regionprops2": 11, | |
| "patch_regionprops": 256 + 8, | |
| }[features] | |
| return feat_dim, augmenter, cropper | |
| def _compress_data(self): | |
| # compress masks and assoc_matrix | |
| logger.info("Compressing masks and assoc_matrix to save memory") | |
| for w in self.windows: | |
| w["mask"] = _CompressedArray(w["mask"]) | |
| # dont compress full imgs (as needed for patch features) | |
| w["img"] = _CompressedArray(w["img"]) | |
| w["assoc_matrix"] = _CompressedArray(w["assoc_matrix"]) | |
| self.gt_masks = _CompressedArray(self.gt_masks) | |
| self.det_masks = {k: _CompressedArray(v) for k, v in self.det_masks.items()} | |
| # dont compress full imgs (as needed for patch features) | |
| self.imgs = _CompressedArray(self.imgs) | |
| def _guess_root_and_gt_tra_folder(self, inp: Path): | |
| """Guesses the root and the ground truth folder from a given input path. | |
| Args: | |
| inp (Path): _description_ | |
| Returns: | |
| Path: root folder, | |
| """ | |
| if inp.name == "TRA": | |
| # 01_GT/TRA --> 01, 01_GT/TRA | |
| root = inp.parent.parent / inp.parent.name.split("_")[0] | |
| return root, inp | |
| elif "ERR_SEG" in inp.name: | |
| # 01_ERR_SEG --> 01, 01_GT/TRA. We know that the data is in CTC folder format | |
| num = inp.name.split("_")[0] | |
| return inp.parent / num, inp.parent / f"{num}_GT" / "TRA" | |
| else: | |
| ctc_tra = Path(f"{inp}_GT") / "TRA" | |
| tra = ctc_tra if ctc_tra.exists() else inp / "TRA" | |
| # 01 --> 01, 01_GT/TRA or 01/TRA | |
| return inp, tra | |
| def _guess_img_folder(self, root: Path): | |
| """Guesses the image folder corresponding to a root.""" | |
| if (root / "img").exists(): | |
| return root / "img" | |
| else: | |
| return root | |
| def _guess_mask_folder(self, root: Path, gt_tra: Path): | |
| """Guesses the mask folder corresponding to a root. | |
| In CTC format, we use silver truth segmentation masks. | |
| """ | |
| f = None | |
| # first try CTC format | |
| if gt_tra.parent.name.endswith("_GT"): | |
| # We use the silver truth segmentation masks | |
| f = root / str(gt_tra.parent.name).replace("_GT", "_ST") / "SEG" | |
| # try our simpler 'img' format | |
| if f is None or not f.exists(): | |
| f = gt_tra | |
| if not f.exists(): | |
| raise ValueError(f"Could not find mask folder for {root}") | |
| return f | |
| def _guess_det_folder(cls, root: Path, suffix: str): | |
| """Checks for the annoying CTC format with dataset numbering as part of folder names.""" | |
| guesses = ( | |
| (root / suffix), | |
| Path(f"{root}_{suffix}"), | |
| Path(f"{root}_GT") / suffix, | |
| ) | |
| for path in guesses: | |
| if path.exists(): | |
| return path | |
| logger.warning(f"Skipping non-existing detection folder {root / suffix}") | |
| return None | |
| def __len__(self): | |
| return len(self.windows) | |
| def _load_gt(self): | |
| logger.info("Loading ground truth") | |
| self.start_frame = int( | |
| len(list(self.gt_mask_folder.glob("*.tif"))) * self.slice_pct[0] | |
| ) | |
| self.end_frame = int( | |
| len(list(self.gt_mask_folder.glob("*.tif"))) * self.slice_pct[1] | |
| ) | |
| masks = self._load_tiffs(self.gt_mask_folder, dtype=np.int32) | |
| masks = self._correct_gt_with_st(self.gt_mask_folder, masks, dtype=np.int32) | |
| if self.use_gt: | |
| track_df = self._load_tracklet_links(self.gt_tra_folder) | |
| track_df = _filter_track_df( | |
| track_df, self.start_frame, self.end_frame, self.downscale_temporal | |
| ) | |
| else: | |
| # create dummy track dataframe | |
| logger.info("Using dummy track dataframe") | |
| track_df = self._build_tracklets_without_gt(masks) | |
| _check_ctc(track_df, _get_node_attributes(masks), masks) | |
| # Build ground truth lineage graph | |
| self.gt_labels, self.gt_timepoints, self.gt_graph = _ctc_lineages( | |
| track_df, masks | |
| ) | |
| return masks, track_df | |
| def _correct_gt_with_st( | |
| self, folder: Path, x: np.ndarray, dtype: Optional[str] = None | |
| ): | |
| if str(folder).endswith("_GT/TRA"): | |
| st_path = ( | |
| tuple(folder.parents)[1] | |
| / folder.parent.stem.replace("_GT", "_ST") | |
| / "SEG" | |
| ) | |
| if not st_path.exists(): | |
| logger.debug("No _ST folder found, skipping correction") | |
| else: | |
| logger.info(f"ST MASK:\t\t{st_path} for correcting with ST masks") | |
| st_masks = self._load_tiffs(st_path, dtype) | |
| x = np.maximum(x, st_masks) | |
| return x | |
| def _load_tiffs(self, folder: Path, dtype=None): | |
| assert isinstance(self.downscale_temporal, int) | |
| logger.debug(f"Loading tiffs from {folder} as {dtype}") | |
| logger.debug( | |
| f"Temporal downscaling of {folder.name} by {self.downscale_temporal}" | |
| ) | |
| x = np.stack([ | |
| tifffile.imread(f).astype(dtype) | |
| for f in tqdm( | |
| sorted(folder.glob("*.tif"))[ | |
| self.start_frame : self.end_frame : self.downscale_temporal | |
| ], | |
| leave=False, | |
| desc=f"Loading [{self.start_frame}:{self.end_frame}]", | |
| ) | |
| ]) | |
| # T, (Z), Y, X | |
| assert isinstance(self.downscale_spatial, int) | |
| if self.downscale_spatial > 1 or self.downscale_temporal > 1: | |
| # TODO make safe for label arrays | |
| logger.debug( | |
| f"Spatial downscaling of {folder.name} by {self.downscale_spatial}" | |
| ) | |
| slices = ( | |
| slice(None), | |
| *tuple( | |
| slice(None, None, self.downscale_spatial) for _ in range(x.ndim - 1) | |
| ), | |
| ) | |
| x = x[slices] | |
| logger.debug(f"Loaded array of shape {x.shape} from {folder}") | |
| return x | |
| def _masks2properties(self, masks): | |
| """Turn label masks into lists of properties, sorted (ascending) by time and label id. | |
| Args: | |
| masks (np.ndarray): T, (Z), H, W | |
| Returns: | |
| labels: List of labels | |
| ts: List of timepoints | |
| coords: List of coordinates | |
| """ | |
| # Get coordinates, timepoints, and labels of detections | |
| labels = [] | |
| ts = [] | |
| coords = [] | |
| properties_by_time = dict() | |
| assert len(self.imgs) == len(masks) | |
| for _t, frame in tqdm( | |
| enumerate(masks), | |
| # total=len(detections), | |
| leave=False, | |
| desc="Loading masks and properties", | |
| ): | |
| regions = regionprops(frame) | |
| t_labels = [] | |
| t_ts = [] | |
| t_coords = [] | |
| for _r in regions: | |
| t_labels.append(_r.label) | |
| t_ts.append(_t) | |
| centroid = np.array(_r.centroid).astype(int) | |
| t_coords.append(centroid) | |
| properties_by_time[_t] = dict(coords=t_coords, labels=t_labels) | |
| labels.extend(t_labels) | |
| ts.extend(t_ts) | |
| coords.extend(t_coords) | |
| labels = np.array(labels, dtype=int) | |
| ts = np.array(ts, dtype=int) | |
| coords = np.array(coords, dtype=int) | |
| return labels, ts, coords, properties_by_time | |
| def _load_tracklet_links(self, folder: Path) -> pd.DataFrame: | |
| df = pd.read_csv( | |
| folder / "man_track.txt", | |
| delimiter=" ", | |
| names=["label", "t1", "t2", "parent"], | |
| dtype=int, | |
| ) | |
| n_dets = (df.t2 - df.t1 + 1).sum() | |
| logger.debug(f"{folder} has {n_dets} detections") | |
| n_divs = (df[df.parent != 0]["parent"].value_counts() == 2).sum() | |
| logger.debug(f"{folder} has {n_divs} divisions") | |
| return df | |
| def _build_tracklets_without_gt(self, masks): | |
| """Create a dataframe with tracklets from masks.""" | |
| rows = [] | |
| for t, m in enumerate(masks): | |
| for c in np.unique(m[m > 0]): | |
| rows.append([c, t, t, 0]) | |
| df = pd.DataFrame(rows, columns=["label", "t1", "t2", "parent"]) | |
| return df | |
| def _check_dimensions(self, x: np.ndarray): | |
| if self.ndim == 2 and not x.ndim == 3: | |
| raise ValueError(f"Expected 2D data, got {x.ndim - 1}D data") | |
| elif self.ndim == 3: | |
| # if ndim=3 and data is two dimensional, it will be cast to 3D | |
| if x.ndim == 3: | |
| x = np.expand_dims(x, axis=1) | |
| elif x.ndim == 4: | |
| pass | |
| else: | |
| raise ValueError(f"Expected 3D data, got {x.ndim - 1}D data") | |
| return x | |
| def _load(self): | |
| # Load ground truth | |
| logger.info("Loading ground truth") | |
| self.gt_masks, self.gt_track_df = self._load_gt() | |
| self.gt_masks = self._check_dimensions(self.gt_masks) | |
| # Load images | |
| if self.img_folder is None: | |
| self.imgs = np.zeros_like(self.gt_masks) | |
| else: | |
| logger.info("Loading images") | |
| imgs = self._load_tiffs(self.img_folder, dtype=np.float32) | |
| self.imgs = np.stack([ | |
| normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False) | |
| ]) | |
| self.imgs = self._check_dimensions(self.imgs) | |
| if self.compress: | |
| # prepare images to be compressed later (e.g. removing non masked parts for regionprops features) | |
| self.imgs = np.stack([ | |
| _compress_img_mask_preproc(im, mask, self.features) | |
| for im, mask in zip(self.imgs, self.gt_masks) | |
| ]) | |
| assert len(self.gt_masks) == len(self.imgs) | |
| # Load each of the detection folders and create data samples with a sliding window | |
| windows = [] | |
| self.properties_by_time = dict() | |
| self.det_masks = dict() | |
| for _f in self.detection_folders: | |
| det_folder = self.root / _f | |
| if det_folder == self.gt_mask_folder: | |
| det_masks = self.gt_masks | |
| logger.info("DET MASK:\tUsing GT masks") | |
| ( | |
| det_labels, | |
| det_ts, | |
| det_coords, | |
| det_properties_by_time, | |
| ) = self._masks2properties(det_masks) | |
| det_gt_matching = { | |
| t: {_l: _l for _l in det_properties_by_time[t]["labels"]} | |
| for t in range(len(det_masks)) | |
| } | |
| else: | |
| det_folder = self._guess_det_folder(root=self.root, suffix=_f) | |
| if det_folder is None: | |
| continue | |
| logger.info(f"DET MASK:\t{det_folder}") | |
| det_masks = self._load_tiffs(det_folder, dtype=np.int32) | |
| det_masks = self._correct_gt_with_st( | |
| det_folder, det_masks, dtype=np.int32 | |
| ) | |
| det_masks = self._check_dimensions(det_masks) | |
| ( | |
| det_labels, | |
| det_ts, | |
| det_coords, | |
| det_properties_by_time, | |
| ) = self._masks2properties(det_masks) | |
| # FIXME matching can be slow for big images | |
| # raise NotImplementedError("Matching not implemented for 3d version") | |
| det_gt_matching = { | |
| t: { | |
| _d: _gt | |
| for _gt, _d in matching( | |
| self.gt_masks[t], | |
| det_masks[t], | |
| threshold=0.3, | |
| max_distance=16, | |
| ) | |
| } | |
| for t in tqdm(range(len(det_masks)), leave=False, desc="Matching") | |
| } | |
| self.properties_by_time[_f] = det_properties_by_time | |
| self.det_masks[_f] = det_masks | |
| _w = self._build_windows( | |
| det_folder, | |
| det_masks, | |
| det_labels, | |
| det_ts, | |
| det_coords, | |
| det_gt_matching, | |
| ) | |
| windows.extend(_w) | |
| return windows | |
| def _build_windows( | |
| self, | |
| det_folder, | |
| det_masks, | |
| labels, | |
| ts, | |
| coords, | |
| matching, | |
| ): | |
| """_summary_. | |
| Args: | |
| det_folder (_type_): _description_ | |
| det_masks (_type_): _description_ | |
| labels (_type_): _description_ | |
| ts (_type_): _description_ | |
| coords (_type_): _description_ | |
| matching (_type_): _description_ | |
| Raises: | |
| ValueError: _description_ | |
| ValueError: _description_ | |
| Returns: | |
| _type_: _description_ | |
| """ | |
| window_size = self.window_size | |
| windows = [] | |
| # Creates the data samples with a sliding window | |
| masks = self.gt_masks | |
| for t1, t2 in tqdm( | |
| zip(range(0, len(masks)), range(window_size, len(masks) + 1)), | |
| total=len(masks) - window_size + 1, | |
| leave=False, | |
| desc="Building windows", | |
| ): | |
| idx = (ts >= t1) & (ts < t2) | |
| _ts = ts[idx] | |
| _labels = labels[idx] | |
| _coords = coords[idx] | |
| # Use GT | |
| # _labels = self.gt_labels[idx] | |
| # _ts = self.gt_timepoints[idx] | |
| if len(_labels) == 0: | |
| # raise ValueError(f"No detections in sample {det_folder}:{t1}") | |
| A = np.zeros((0, 0), dtype=bool) | |
| _coords = np.zeros((0, masks.ndim - 1), dtype=int) | |
| else: | |
| if len(np.unique(_ts)) == 1: | |
| logger.debug( | |
| "Only detections from a single timepoint in sample" | |
| f" {det_folder}:{t1}" | |
| ) | |
| # build matrix from incomplete labels, but full lineage graph. If a label is missing, I should skip over it. | |
| A = _ctc_assoc_matrix( | |
| _labels, | |
| _ts, | |
| self.gt_graph, | |
| matching, | |
| ) | |
| if self.sanity_dist: | |
| # # Sanity check: Can the model learn the euclidian distances? | |
| # c = coords - coords.mean(axis=0, keepdims=True) | |
| # c /= c.std(axis=0, keepdims=True) | |
| # A = np.einsum('id,jd',c,c) | |
| # A = 1 / (1 + np.exp(-A)) | |
| A = np.exp(-0.01 * cdist(_coords, _coords)) | |
| w = dict( | |
| coords=_coords, | |
| # TODO imgs and masks are unaltered here | |
| t1=t1, | |
| img=self.imgs[t1:t2], | |
| mask=det_masks[t1:t2], | |
| assoc_matrix=A, | |
| labels=_labels, | |
| timepoints=_ts, | |
| ) | |
| windows.append(w) | |
| logger.debug(f"Built {len(windows)} track windows from {det_folder}.\n") | |
| return windows | |
| def __getitem__(self, n: int, return_dense=None): | |
| # if not set, use default | |
| if self.features == "wrfeat": | |
| return self._getitem_wrfeat(n, return_dense) | |
| if return_dense is None: | |
| return_dense = self.return_dense | |
| track = self.windows[n] | |
| coords = track["coords"] | |
| assoc_matrix = track["assoc_matrix"] | |
| labels = track["labels"] | |
| img = track["img"] | |
| mask = track["mask"] | |
| timepoints = track["timepoints"] | |
| min_time = track["t1"] | |
| if isinstance(mask, _CompressedArray): | |
| mask = mask.decompress() | |
| if isinstance(img, _CompressedArray): | |
| img = img.decompress() | |
| if isinstance(assoc_matrix, _CompressedArray): | |
| assoc_matrix = assoc_matrix.decompress() | |
| # cropping | |
| if self.cropper is not None: | |
| (img2, mask2, coords2), idx = self.cropper(img, mask, coords) | |
| cropped_timepoints = timepoints[idx] | |
| # at least one detection in each timepoint to accept the crop | |
| if len(np.unique(cropped_timepoints)) == self.window_size: | |
| # at least two total detections to accept the crop | |
| # if len(idx) >= 2: | |
| img, mask, coords = img2, mask2, coords2 | |
| labels = labels[idx] | |
| timepoints = timepoints[idx] | |
| assoc_matrix = assoc_matrix[idx][:, idx] | |
| else: | |
| logger.debug("disable cropping as no trajectories would be left") | |
| if self.features == "none": | |
| if self.augmenter is not None: | |
| coords = self.augmenter(coords) | |
| # Empty features | |
| features = np.zeros((len(coords), 0)) | |
| elif self.features in ("regionprops", "regionprops2"): | |
| if self.augmenter is not None: | |
| (img2, mask2, coords2), idx = self.augmenter( | |
| img, mask, coords, timepoints - min_time | |
| ) | |
| if len(idx) > 0: | |
| img, mask, coords = img2, mask2, coords2 | |
| labels = labels[idx] | |
| timepoints = timepoints[idx] | |
| assoc_matrix = assoc_matrix[idx][:, idx] | |
| mask = mask.astype(int) | |
| else: | |
| logger.debug( | |
| "disable augmentation as no trajectories would be left" | |
| ) | |
| features = tuple( | |
| extract_features_regionprops( | |
| m, im, labels[timepoints == i + min_time], properties=self.features | |
| ) | |
| for i, (m, im) in enumerate(zip(mask, img)) | |
| ) | |
| features = np.concatenate(features, axis=0) | |
| # features = np.zeros((len(coords), self.feat_dim)) | |
| elif self.features == "patch": | |
| if self.augmenter is not None: | |
| (img2, mask2, coords2), idx = self.augmenter( | |
| img, mask, coords, timepoints - min_time | |
| ) | |
| if len(idx) > 0: | |
| img, mask, coords = img2, mask2, coords2 | |
| labels = labels[idx] | |
| timepoints = timepoints[idx] | |
| assoc_matrix = assoc_matrix[idx][:, idx] | |
| mask = mask.astype(int) | |
| else: | |
| print("disable augmentation as no trajectories would be left") | |
| features = tuple( | |
| extract_features_patch( | |
| m, | |
| im, | |
| coords[timepoints == min_time + i], | |
| labels[timepoints == min_time + i], | |
| ) | |
| for i, (m, im) in enumerate(zip(mask, img)) | |
| ) | |
| features = np.concatenate(features, axis=0) | |
| elif self.features == "patch_regionprops": | |
| if self.augmenter is not None: | |
| (img2, mask2, coords2), idx = self.augmenter( | |
| img, mask, coords, timepoints - min_time | |
| ) | |
| if len(idx) > 0: | |
| img, mask, coords = img2, mask2, coords2 | |
| labels = labels[idx] | |
| timepoints = timepoints[idx] | |
| assoc_matrix = assoc_matrix[idx][:, idx] | |
| mask = mask.astype(int) | |
| else: | |
| print("disable augmentation as no trajectories would be left") | |
| features1 = tuple( | |
| extract_features_patch( | |
| m, | |
| im, | |
| coords[timepoints == min_time + i], | |
| labels[timepoints == min_time + i], | |
| ) | |
| for i, (m, im) in enumerate(zip(mask, img)) | |
| ) | |
| features2 = tuple( | |
| extract_features_regionprops( | |
| m, | |
| im, | |
| labels[timepoints == i + min_time], | |
| properties=self.features, | |
| ) | |
| for i, (m, im) in enumerate(zip(mask, img)) | |
| ) | |
| features = tuple( | |
| np.concatenate((f1, f2), axis=-1) | |
| for f1, f2 in zip(features1, features2) | |
| ) | |
| features = np.concatenate(features, axis=0) | |
| # remove temporal offset and add timepoints to coords | |
| relative_timepoints = timepoints - track["t1"] | |
| coords = np.concatenate((relative_timepoints[:, None], coords), axis=-1) | |
| if self.max_tokens and len(timepoints) > self.max_tokens: | |
| time_incs = np.where(timepoints - np.roll(timepoints, 1))[0] | |
| n_elems = time_incs[np.searchsorted(time_incs, self.max_tokens) - 1] | |
| timepoints = timepoints[:n_elems] | |
| labels = labels[:n_elems] | |
| coords = coords[:n_elems] | |
| features = features[:n_elems] | |
| assoc_matrix = assoc_matrix[:n_elems, :n_elems] | |
| logger.info( | |
| f"Clipped window of size {timepoints[n_elems - 1] - timepoints.min()}" | |
| ) | |
| coords0 = torch.from_numpy(coords).float() | |
| features = torch.from_numpy(features).float() | |
| assoc_matrix = torch.from_numpy(assoc_matrix.copy()).float() | |
| labels = torch.from_numpy(labels).long() | |
| timepoints = torch.from_numpy(timepoints).long() | |
| if self.augmenter is not None: | |
| coords = coords0.clone() | |
| coords[:, 1:] += torch.randint(0, 256, (1, self.ndim)) | |
| else: | |
| coords = coords0.clone() | |
| res = dict( | |
| features=features, | |
| coords0=coords0, | |
| coords=coords, | |
| assoc_matrix=assoc_matrix, | |
| timepoints=timepoints, | |
| labels=labels, | |
| ) | |
| if return_dense: | |
| if all([x is not None for x in img]): | |
| img = torch.from_numpy(img).float() | |
| res["img"] = img | |
| mask = torch.from_numpy(mask.astype(int)).long() | |
| res["mask"] = mask | |
| return res | |
| # wrfeat functions... | |
| # TODO: refactor this as a subclass or make everything a class factory. *very* hacky this way | |
| def _setup_features_augs_wrfeat( | |
| self, ndim: int, features: str, augment: int, crop_size: Tuple[int] | |
| ): | |
| # FIXME: hardcoded | |
| feat_dim = 7 if ndim == 2 else 12 | |
| if augment == 1: | |
| augmenter = wrfeat.WRAugmentationPipeline([ | |
| wrfeat.WRRandomFlip(p=0.5), | |
| wrfeat.WRRandomAffine( | |
| p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1) | |
| ), | |
| # wrfeat.WRRandomBrightness(p=0.8, factor=(0.5, 2.0)), | |
| # wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)), | |
| ]) | |
| elif augment == 2: | |
| augmenter = wrfeat.WRAugmentationPipeline([ | |
| wrfeat.WRRandomFlip(p=0.5), | |
| wrfeat.WRRandomAffine( | |
| p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1) | |
| ), | |
| wrfeat.WRRandomBrightness(p=0.8), | |
| wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)), | |
| ]) | |
| elif augment == 3: | |
| augmenter = wrfeat.WRAugmentationPipeline([ | |
| wrfeat.WRRandomFlip(p=0.5), | |
| wrfeat.WRRandomAffine( | |
| p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1) | |
| ), | |
| wrfeat.WRRandomBrightness(p=0.8), | |
| wrfeat.WRRandomMovement(offset=(-10, 10), p=0.3), | |
| wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)), | |
| ]) | |
| else: | |
| augmenter = None | |
| cropper = ( | |
| wrfeat.WRRandomCrop( | |
| crop_size=crop_size, | |
| ndim=ndim, | |
| ) | |
| if crop_size is not None | |
| else None | |
| ) | |
| return feat_dim, augmenter, cropper | |
| def _load_wrfeat(self): | |
| # Load ground truth | |
| self.gt_masks, self.gt_track_df = self._load_gt() | |
| self.gt_masks = self._check_dimensions(self.gt_masks) | |
| # Load images | |
| if self.img_folder is None: | |
| if self.gt_masks is not None: | |
| self.imgs = np.zeros_like(self.gt_masks) | |
| else: | |
| raise NotImplementedError("No images and no GT masks") | |
| else: | |
| logger.info("Loading images") | |
| imgs = self._load_tiffs(self.img_folder, dtype=np.float32) | |
| self.imgs = np.stack([ | |
| normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False) | |
| ]) | |
| self.imgs = self._check_dimensions(self.imgs) | |
| if self.compress: | |
| # prepare images to be compressed later (e.g. removing non masked parts for regionprops features) | |
| self.imgs = np.stack([ | |
| _compress_img_mask_preproc(im, mask, self.features) | |
| for im, mask in zip(self.imgs, self.gt_masks) | |
| ]) | |
| assert len(self.gt_masks) == len(self.imgs) | |
| # Load each of the detection folders and create data samples with a sliding window | |
| windows = [] | |
| self.properties_by_time = dict() | |
| self.det_masks = dict() | |
| logger.info("Loading detections") | |
| for _f in self.detection_folders: | |
| det_folder = self.root / _f | |
| if det_folder == self.gt_mask_folder: | |
| det_masks = self.gt_masks | |
| logger.info("DET MASK:\tUsing GT masks") | |
| # identity matching | |
| det_gt_matching = { | |
| t: {_l: _l for _l in set(np.unique(d)) - {0}} | |
| for t, d in enumerate(det_masks) | |
| } | |
| else: | |
| det_folder = self._guess_det_folder(root=self.root, suffix=_f) | |
| if det_folder is None: | |
| continue | |
| logger.info(f"DET MASK (guessed):\t{det_folder}") | |
| det_masks = self._load_tiffs(det_folder, dtype=np.int32) | |
| det_masks = self._correct_gt_with_st( | |
| det_folder, det_masks, dtype=np.int32 | |
| ) | |
| det_masks = self._check_dimensions(det_masks) | |
| # FIXME matching can be slow for big images | |
| # raise NotImplementedError("Matching not implemented for 3d version") | |
| det_gt_matching = { | |
| t: { | |
| _d: _gt | |
| for _gt, _d in matching( | |
| self.gt_masks[t], | |
| det_masks[t], | |
| threshold=0.3, | |
| max_distance=16, | |
| ) | |
| } | |
| for t in tqdm(range(len(det_masks)), leave=False, desc="Matching") | |
| } | |
| self.det_masks[_f] = det_masks | |
| # build features | |
| features = joblib.Parallel(n_jobs=8)( | |
| joblib.delayed(wrfeat.WRFeatures.from_mask_img)( | |
| mask=mask[None], img=img[None], t_start=t | |
| ) | |
| for t, (mask, img) in enumerate(zip(det_masks, self.imgs)) | |
| ) | |
| properties_by_time = dict() | |
| for _t, _feats in enumerate(features): | |
| properties_by_time[_t] = dict( | |
| coords=_feats.coords, labels=_feats.labels | |
| ) | |
| self.properties_by_time[_f] = properties_by_time | |
| _w = self._build_windows_wrfeat( | |
| features, | |
| det_masks, | |
| det_gt_matching, | |
| ) | |
| windows.extend(_w) | |
| return windows | |
| def _build_windows_wrfeat( | |
| self, | |
| features: Sequence[wrfeat.WRFeatures], | |
| det_masks: np.ndarray, | |
| matching: Tuple[dict], | |
| ): | |
| assert len(self.imgs) == len(det_masks) | |
| window_size = self.window_size | |
| windows = [] | |
| # Creates the data samples with a sliding window | |
| for t1, t2 in tqdm( | |
| zip(range(0, len(det_masks)), range(window_size, len(det_masks) + 1)), | |
| total=len(det_masks) - window_size + 1, | |
| leave=False, | |
| desc="Building windows", | |
| ): | |
| img = self.imgs[t1:t2] | |
| mask = det_masks[t1:t2] | |
| feat = wrfeat.WRFeatures.concat(features[t1:t2]) | |
| labels = feat.labels | |
| timepoints = feat.timepoints | |
| coords = feat.coords | |
| if len(feat) == 0: | |
| A = np.zeros((0, 0), dtype=bool) | |
| coords = np.zeros((0, feat.ndim), dtype=int) | |
| else: | |
| # build matrix from incomplete labels, but full lineage graph. If a label is missing, I should skip over it. | |
| A = _ctc_assoc_matrix( | |
| labels, | |
| timepoints, | |
| self.gt_graph, | |
| matching, | |
| ) | |
| w = dict( | |
| coords=coords, | |
| # TODO imgs and masks are unaltered here | |
| t1=t1, | |
| img=img, | |
| mask=mask, | |
| assoc_matrix=A, | |
| labels=labels, | |
| timepoints=timepoints, | |
| wrfeat=feat, | |
| ) | |
| windows.append(w) | |
| logger.debug(f"Built {len(windows)} track windows.\n") | |
| return windows | |
| def _getitem_wrfeat(self, n: int, return_dense=None): | |
| # if not set, use default | |
| if return_dense is None: | |
| return_dense = self.return_dense | |
| track = self.windows[n] | |
| # coords = track["coords"] | |
| assoc_matrix = track["assoc_matrix"] | |
| labels = track["labels"] | |
| img = track["img"] | |
| mask = track["mask"] | |
| timepoints = track["timepoints"] | |
| # track["t1"] | |
| feat = track["wrfeat"] | |
| if return_dense and isinstance(mask, _CompressedArray): | |
| mask = mask.decompress() | |
| if return_dense and isinstance(img, _CompressedArray): | |
| img = img.decompress() | |
| if isinstance(assoc_matrix, _CompressedArray): | |
| assoc_matrix = assoc_matrix.decompress() | |
| # cropping | |
| if self.cropper is not None: | |
| # Use only if there is at least one timepoint per detection | |
| cropped_feat, cropped_idx = self.cropper(feat) | |
| cropped_timepoints = timepoints[cropped_idx] | |
| if len(np.unique(cropped_timepoints)) == self.window_size: | |
| idx = cropped_idx | |
| feat = cropped_feat | |
| labels = labels[idx] | |
| timepoints = timepoints[idx] | |
| assoc_matrix = assoc_matrix[idx][:, idx] | |
| else: | |
| logger.debug("Skipping cropping") | |
| if self.augmenter is not None: | |
| feat = self.augmenter(feat) | |
| coords0 = np.concatenate((feat.timepoints[:, None], feat.coords), axis=-1) | |
| coords0 = torch.from_numpy(coords0).float() | |
| assoc_matrix = torch.from_numpy(assoc_matrix.astype(np.float32)) | |
| features = torch.from_numpy(feat.features_stacked).float() | |
| labels = torch.from_numpy(feat.labels).long() | |
| timepoints = torch.from_numpy(feat.timepoints).long() | |
| if self.max_tokens and len(timepoints) > self.max_tokens: | |
| time_incs = np.where(timepoints - np.roll(timepoints, 1))[0] | |
| n_elems = time_incs[np.searchsorted(time_incs, self.max_tokens) - 1] | |
| timepoints = timepoints[:n_elems] | |
| labels = labels[:n_elems] | |
| coords0 = coords0[:n_elems] | |
| features = features[:n_elems] | |
| assoc_matrix = assoc_matrix[:n_elems, :n_elems] | |
| logger.debug( | |
| f"Clipped window of size {timepoints[n_elems - 1] - timepoints.min()}" | |
| ) | |
| if self.augmenter is not None: | |
| coords = coords0.clone() | |
| coords[:, 1:] += torch.randint(0, 512, (1, self.ndim)) | |
| else: | |
| coords = coords0.clone() | |
| res = dict( | |
| features=features, | |
| coords0=coords0, | |
| coords=coords, | |
| assoc_matrix=assoc_matrix, | |
| timepoints=timepoints, | |
| labels=labels, | |
| ) | |
| if return_dense: | |
| if all([x is not None for x in img]): | |
| img = torch.from_numpy(img).float() | |
| res["img"] = img | |
| mask = torch.from_numpy(mask.astype(int)).long() | |
| res["mask"] = mask | |
| return res | |
| def _ctc_lineages(df, masks, t1=0, t2=None): | |
| """From a ctc dataframe, create a digraph that contains all sublineages | |
| between t1 and t2 (exclusive t2). | |
| Args: | |
| df: pd.DataFrame with columns `label`, `t1`, `t2`, `parent` (man_track.txt) | |
| masks: List of masks. If t1 is not 0, then the masks are assumed to be already cropped accordingly. | |
| t1: Start timepoint | |
| t2: End timepoint (exclusive). If None, then t2 is set to len(masks) | |
| Returns: | |
| labels: List of label ids extracted from the masks, ordered by timepoint. | |
| ts: List of corresponding timepoints | |
| graph: The digraph of the lineages between t1 and t2. | |
| """ | |
| if t1 > 0: | |
| assert t2 is not None | |
| assert t2 - t1 == len(masks) | |
| if t2 is None: | |
| t2 = len(masks) | |
| graph = nx.DiGraph() | |
| labels = [] | |
| ts = [] | |
| # get all objects that are present in the time interval | |
| df_sub = df[(df.t1 < t2) & (df.t2 >= t1)] | |
| # Correct offset | |
| df_sub.loc[:, "t1"] -= t1 | |
| df_sub.loc[:, "t2"] -= t1 | |
| # all_labels = df_sub.label.unique() | |
| # TODO speed up by precalculating unique values once | |
| # in_masks = set(np.where(np.bincount(np.stack(masks[t1:t2]).ravel()))[0]) - {0} | |
| # all_labels = [l for l in all_labels if l in in_masks] | |
| all_labels = set() | |
| for t in tqdm( | |
| range(0, t2 - t1), desc="Building and checking lineage graph", leave=False | |
| ): | |
| # get all entities at timepoint | |
| obs = df_sub[(df_sub.t1 <= t) & (df_sub.t2 >= t)] | |
| in_t = set(np.where(np.bincount(masks[t].ravel()))[0]) - {0} | |
| all_labels.update(in_t) | |
| for row in obs.itertuples(): | |
| label, t1, t2, parent = row.label, row.t1, row.t2, row.parent | |
| if label not in in_t: | |
| continue | |
| labels.append(label) | |
| ts.append(t) | |
| # add label as node if not already in graph | |
| if not graph.has_node(label): | |
| graph.add_node(label) | |
| # Parents have been added in previous timepoints | |
| if parent in all_labels: | |
| if not graph.has_node(parent): | |
| graph.add_node(parent) | |
| graph.add_edge(parent, label) | |
| labels = np.array(labels) | |
| ts = np.array(ts) | |
| return labels, ts, graph | |
| def _assoc(A: np.ndarray, labels: np.ndarray, family: np.ndarray): | |
| """For each detection, associate with all detections that are.""" | |
| for i in range(len(labels)): | |
| for j in range(len(labels)): | |
| A[i, j] = family[i, labels[j]] | |
| def _ctc_assoc_matrix(detections, ts, graph, matching): | |
| """Create the association matrix for a list of labels and a tracklet parent -> childrend graph. | |
| Each detection is associated with all its ancestors and descendants, but not its siblings and their offspring. | |
| Args: | |
| detections: list of integer labels, ordered by timepoint | |
| ts: list of timepoints corresponding to the detections | |
| graph: networkx DiGraph with each ground truth tracklet id (spanning n timepoints) as a single node | |
| and parent -> children relationships as edges. | |
| matching: for each timepoint, a dictionary that maps from detection id to gt tracklet id | |
| """ | |
| assert 0 not in graph | |
| matched_gt = [] | |
| for i, (label, t) in enumerate(zip(detections, ts)): | |
| gt_tracklet_id = matching[t].get(label, 0) | |
| matched_gt.append(gt_tracklet_id) | |
| matched_gt = np.array(matched_gt, dtype=int) | |
| # Now we have the subset of gt nodes that is matched to any detection in the current window | |
| # relabel to reduce the size of lookup matrices | |
| # offset 0 not allowed in skimage, which makes this very annoying | |
| relabeled_gt, fwd_map, _inv_map = relabel_sequential(matched_gt, offset=1) | |
| # dict is faster than arraymap | |
| fwd_map = dict(zip(fwd_map.in_values, fwd_map.out_values)) | |
| # inv_map = dict(zip(inv_map.in_values, inv_map.out_values)) | |
| # the family relationships for each ground truth detection, | |
| # Maps from local detection number (0-indexed) to global gt tracklet id (1-indexed) | |
| family = np.zeros((len(detections), len(relabeled_gt) + 1), bool) | |
| # Connects each tracklet id with its children and parent tracklets (according to man_track.txt) | |
| for i, (label, t) in enumerate(zip(detections, ts)): | |
| # Get the original label corresponding to the graph | |
| gt_tracklet_id = matching[t].get(label, None) | |
| if gt_tracklet_id is not None: | |
| ancestors = [] | |
| descendants = [] | |
| # This iterates recursively through the graph | |
| for n in nx.descendants(graph, gt_tracklet_id): | |
| if n in fwd_map: | |
| descendants.append(fwd_map[n]) | |
| for n in nx.ancestors(graph, gt_tracklet_id): | |
| if n in fwd_map: | |
| ancestors.append(fwd_map[n]) | |
| family[i, np.array([fwd_map[gt_tracklet_id], *ancestors, *descendants])] = ( | |
| True | |
| ) | |
| else: | |
| pass | |
| # Now we match to nothing, so even the matrix diagonal will not be filled. | |
| # This assures that matching to 0 is always false | |
| assert family[:, 0].sum() == 0 | |
| # Create the detection-to-detection association matrix | |
| A = np.zeros((len(detections), len(detections)), dtype=bool) | |
| _assoc(A, relabeled_gt, family) | |
| return A | |
| def sigmoid(x): | |
| return 1 / (1 + np.exp(-x)) | |
| def _compress_img_mask_preproc(img, mask, features): | |
| """Remove certain img pixels if not needed to save memory for large datasets.""" | |
| # dont change anything if we need patch values | |
| if features in ("patch", "patch_regionprops"): | |
| # clear img pixels outside of patch_mask of size 16x16 | |
| patch_width = 16 # TOD: hardcoded: change this if needed | |
| coords = tuple(np.array(r.centroid).astype(int) for r in regionprops(mask)) | |
| img2 = np.zeros_like(img) | |
| if len(coords) > 0: | |
| coords = np.stack(coords) | |
| coords = np.clip(coords, 0, np.array(mask.shape)[None] - 1) | |
| patch_mask = np.zeros_like(img, dtype=bool) | |
| patch_mask[tuple(coords.T)] = True | |
| # retain 3*patch_width+1 around center to be safe... | |
| patch_mask = ndi.maximum_filter(patch_mask, 3 * patch_width + 1) | |
| img2[patch_mask] = img[patch_mask] | |
| else: | |
| # otherwise set img value inside masks to mean | |
| # FIXME: change when using other intensity based regionprops | |
| img2 = np.zeros_like(img) | |
| for reg in regionprops(mask, intensity_image=img): | |
| m = mask[reg.slice] == reg.label | |
| img2[reg.slice][m] = reg.mean_intensity | |
| return img2 | |
| def pad_tensor(x, n_max: int, dim=0, value=0): | |
| n = x.shape[dim] | |
| if n_max < n: | |
| raise ValueError(f"pad_tensor: n_max={n_max} must be larger than n={n} !") | |
| pad_shape = list(x.shape) | |
| pad_shape[dim] = n_max - n | |
| # pad = torch.full(pad_shape, fill_value=value, dtype=x.dtype).to(x.device) | |
| pad = torch.full(pad_shape, fill_value=value, dtype=x.dtype) | |
| return torch.cat((x, pad), dim=dim) | |
| def collate_sequence_padding(batch): | |
| """Collate function that pads all sequences to the same length.""" | |
| lens = tuple(len(x["coords"]) for x in batch) | |
| n_max_len = max(lens) | |
| # print(tuple(len(x["coords"]) for x in batch)) | |
| # print(tuple(len(x["features"]) for x in batch)) | |
| # print(batch[0].keys()) | |
| tuple(batch[0].keys()) | |
| normal_keys = { | |
| "coords": 0, | |
| "features": 0, | |
| "labels": 0, # Not needed, remove for speed. | |
| "timepoints": -1, # There are real timepoints with t=0. -1 for distinction from that. | |
| } | |
| n_pads = tuple(n_max_len - s for s in lens) | |
| batch_new = dict( | |
| ( | |
| k, | |
| torch.stack( | |
| [pad_tensor(x[k], n_max=n_max_len, value=v) for x in batch], dim=0 | |
| ), | |
| ) | |
| for k, v in normal_keys.items() | |
| ) | |
| batch_new["assoc_matrix"] = torch.stack( | |
| [ | |
| pad_tensor( | |
| pad_tensor(x["assoc_matrix"], n_max_len, dim=0), n_max_len, dim=1 | |
| ) | |
| for x in batch | |
| ], | |
| dim=0, | |
| ) | |
| # add boolean mask that signifies whether tokens are padded or not (such that they can be ignored later) | |
| pad_mask = torch.zeros((len(batch), n_max_len), dtype=torch.bool) | |
| for i, n_pad in enumerate(n_pads): | |
| pad_mask[i, n_max_len - n_pad :] = True | |
| batch_new["padding_mask"] = pad_mask.bool() | |
| return batch_new | |
| if __name__ == "__main__": | |
| dummy_data = CTCData( | |
| root="../../scripts/data/synthetic_cells/01", | |
| ndim=2, | |
| detection_folders=["TRA"], | |
| window_size=4, | |
| max_tokens=None, | |
| augment=3, | |
| features="none", | |
| downscale_temporal=1, | |
| downscale_spatial=1, | |
| sanity_dist=False, | |
| crop_size=(256, 256), | |
| ) | |
| x = dummy_data[0] | |