phoebehxf
init
aff3c6f
import logging
import numpy as np
import pandas as pd
from skimage.measure import label, regionprops_table
logger = logging.getLogger(__name__)
# from https://github.com/Janelia-Trackathon-2023/traccuracy/blob/main/src/traccuracy/loaders/_ctc.py
def _check_ctc(tracks: pd.DataFrame, detections: pd.DataFrame, masks: np.ndarray):
"""Sanity checks for valid CTC format.
Hard checks (throws exception):
- Tracklet IDs in tracks file must be unique and positive
- Parent tracklet IDs must exist in the tracks file
- Intertracklet edges must be directed forward in time.
- In each time point, the set of segmentation IDs present in the detections must equal the set
of tracklet IDs in the tracks file that overlap this time point.
Soft checks (prints warning):
- No duplicate tracklet IDs (non-connected pixels with same ID) in a single timepoint.
Args:
tracks (pd.DataFrame): Tracks in CTC format with columns Cell_ID, Start, End, Parent_ID.
detections (pd.DataFrame): Detections extracted from masks, containing columns
segmentation_id, t.
masks (np.ndarray): Set of masks with time in the first axis.
Raises:
ValueError: If any of the hard checks fail.
"""
logger.debug("Running CTC format checks")
tracks = tracks.copy()
tracks.columns = ["Cell_ID", "Start", "End", "Parent_ID"]
if tracks["Cell_ID"].min() < 1:
raise ValueError("Cell_IDs in tracks file must be positive integers.")
if len(tracks["Cell_ID"]) < len(tracks["Cell_ID"].unique()):
raise ValueError("Cell_IDs in tracks file must be unique integers.")
for _, row in tracks.iterrows():
if row["Parent_ID"] != 0:
if row["Parent_ID"] not in tracks["Cell_ID"].values:
raise ValueError(
f"Parent_ID {row['Parent_ID']} is not present in tracks."
)
parent_end = tracks[tracks["Cell_ID"] == row["Parent_ID"]]["End"].iloc[0]
if parent_end >= row["Start"]:
raise ValueError(
f"Invalid tracklet connection: Daughter tracklet with ID {row['Cell_ID']} "
f"starts at t={row['Start']}, "
f"but parent tracklet with ID {row['Parent_ID']} only ends at t={parent_end}."
)
for t in range(tracks["Start"].min(), tracks["End"].max()):
track_ids = set(
tracks[(tracks["Start"] <= t) & (tracks["End"] >= t)]["Cell_ID"]
)
det_ids = set(detections[(detections["t"] == t)]["segmentation_id"])
if not track_ids.issubset(det_ids):
raise ValueError(f"Missing IDs in masks at t={t}: {track_ids - det_ids}")
if not det_ids.issubset(track_ids):
raise ValueError(
f"IDs {det_ids - track_ids} at t={t} not represented in tracks file."
)
for t, frame in enumerate(masks):
_, n_components = label(frame, return_num=True)
n_labels = len(detections[detections["t"] == t])
if n_labels < n_components:
logger.warning(f"{n_components - n_labels} non-connected masks at t={t}.")
def _get_node_attributes(masks):
"""Calculates x,y,z,t,label for each detection in a movie.
Args:
masks (np.ndarray): Set of masks with time in the first axis
Returns:
pd.DataFrame: Dataframe with one detection per row. Columns
segmentation_id, x, y, z, t
"""
data_df = pd.concat([
_detections_from_image(masks, idx) for idx in range(masks.shape[0])
]).reset_index(drop=True)
data_df = data_df.rename(
columns={
"label": "segmentation_id",
"centroid-2": "z",
"centroid-1": "y",
"centroid-0": "x",
}
)
data_df["segmentation_id"] = data_df["segmentation_id"].astype(int)
data_df["t"] = data_df["t"].astype(int)
return data_df
def _detections_from_image(stack, idx):
"""Return the unique track label, centroid and time for each track vertex.
Args:
stack (np.ndarray): Stack of masks
idx (int): Index of the image to calculate the centroids and track labels
Returns:
pd.DataFrame: The dataframe of track data for one time step (specified by idx)
"""
props = regionprops_table(
np.asarray(stack[idx, ...]), properties=("label", "centroid")
)
props["t"] = np.full(props["label"].shape, idx)
return pd.DataFrame(props)