import logging import os from pathlib import Path from typing import Literal, Union, Optional, Tuple import dask.array as da import numpy as np import tifffile import torch import yaml from tqdm import tqdm from ..data import build_windows, get_features, load_tiff_timeseries from ..tracking import TrackGraph, build_graph, track_greedy from ..utils import normalize from .model import TrackingTransformer from .predict import predict_windows from .pretrained import download_pretrained logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class Trackastra: """A transformer-based tracking model for time-lapse data. Trackastra links segmented objects across time frames by predicting associations with a transformer model trained on diverse time-lapse videos. The model takes as input: - A sequence of images of shape (T,(Z),Y,X) - Corresponding instance segmentation masks of shape (T,(Z),Y,X) It supports multiple tracking modes: - greedy_nodiv: Fast greedy linking without division - greedy: Fast greedy linking with division - ilp: Integer Linear Programming based linking (more accurate but slower) Examples: >>> # Load example data >>> from trackastra.data import example_data_bacteria >>> imgs, masks = example_data_bacteria() >>> >>> # Load pretrained model and track >>> model = Trackastra.from_pretrained("general_2d", device="cuda") >>> track_graph = model.track(imgs, masks, mode="greedy") """ def __init__( self, transformer: TrackingTransformer, train_args: dict, device: Literal["cuda", "mps", "cpu", "automatic", None] = None, ): """Initialize Trackastra model. Args: transformer: The underlying transformer model. train_args: Training configuration arguments. device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None). """ if device == "cuda": if torch.cuda.is_available(): self.device = "cuda" else: logger.info("Cuda not available, falling back to cpu.") self.device = "cpu" elif device == "mps": if ( torch.backends.mps.is_available() and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0" ): self.device = "mps" else: logger.info("Mps not available, falling back to cpu.") self.device = "cpu" elif device == "cpu": self.device = "cpu" elif device == "automatic" or device is None: should_use_mps = ( torch.backends.mps.is_available() and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0" ) self.device = ( "cuda" if torch.cuda.is_available() else ( "mps" if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") else "cpu" ) ) else: raise ValueError(f"Device {device} not recognized.") logger.info(f"Using device {self.device}") self.transformer = transformer.to(self.device) self.train_args = train_args @classmethod def from_folder(cls, dir: Union[Path, str], device: Optional[str] = None): """Load a Trackastra model from a local folder. Args: dir: Path to model folder containing: - model weights - train_config.yaml with training arguments device: Device to run model on. Returns: Trackastra model instance. """ # Always load to cpu first transformer = TrackingTransformer.from_folder( Path(dir).expanduser(), map_location="cpu" ) train_args = yaml.load(open(dir / "train_config.yaml"), Loader=yaml.FullLoader) return cls(transformer=transformer, train_args=train_args, device=device) @classmethod def from_pretrained( cls, name: str, device: Optional[str] = None, download_dir: Optional[Path] = None ): """Load a pretrained Trackastra model. Available pretrained models are described in detail in pretrained.json. Args: name: Name of pretrained model (e.g. "general_2d"). device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None). download_dir: Directory to download model to (defaults to ~/.cache/trackastra). Returns: Trackastra model instance. """ folder = download_pretrained(name, download_dir) # download zip from github to location/name, then unzip return cls.from_folder(folder, device=device) def _predict( self, imgs: Union[np.ndarray, da.Array], masks: Union[np.ndarray, da.Array], edge_threshold: float = 0.05, n_workers: int = 0, normalize_imgs: bool = True, progbar_class=tqdm, ): logger.info("Predicting weights for candidate graph") if normalize_imgs: if isinstance(imgs, da.Array): imgs = imgs.map_blocks(normalize) else: imgs = normalize(imgs) self.transformer.eval() features = get_features( detections=masks, imgs=imgs, ndim=self.transformer.config["coord_dim"], n_workers=n_workers, progbar_class=progbar_class, ) logger.info("Building windows") windows = build_windows( features, window_size=self.transformer.config["window"], progbar_class=progbar_class, ) logger.info("Predicting windows") predictions = predict_windows( windows=windows, features=features, model=self.transformer, edge_threshold=edge_threshold, spatial_dim=masks.ndim - 1, progbar_class=progbar_class, ) return predictions def _track_from_predictions( self, predictions, mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", use_distance: bool = False, max_distance: int = 256, max_neighbors: int = 10, delta_t: int = 1, **kwargs, ): logger.info("Running greedy tracker") nodes = predictions["nodes"] weights = predictions["weights"] candidate_graph = build_graph( nodes=nodes, weights=weights, use_distance=use_distance, max_distance=max_distance, max_neighbors=max_neighbors, delta_t=delta_t, ) if mode == "greedy": return track_greedy(candidate_graph) elif mode == "greedy_nodiv": return track_greedy(candidate_graph, allow_divisions=False) elif mode == "ilp": from trackastra.tracking.ilp import track_ilp return track_ilp(candidate_graph, ilp_config="gt", **kwargs) else: raise ValueError(f"Tracking mode {mode} does not exist.") def track( self, imgs: Union[np.ndarray, da.Array], masks: Union[np.ndarray, da.Array], mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", normalize_imgs: bool = True, progbar_class=tqdm, n_workers: int = 0, **kwargs, ) -> TrackGraph: """Track objects across time frames. This method links segmented objects across time frames using the specified tracking mode. No hyperparameters need to be chosen beyond the tracking mode. Args: imgs: Input images of shape (T,(Z),Y,X) (numpy or dask array) masks: Instance segmentation masks of shape (T,(Z),Y,X). mode: Tracking mode: - "greedy_nodiv": Fast greedy linking without division - "greedy": Fast greedy linking with division - "ilp": Integer Linear Programming based linking (more accurate but slower) progbar_class: Progress bar class to use. n_workers: Number of worker processes for feature extraction. normalize_imgs: Whether to normalize the images. **kwargs: Additional arguments passed to tracking algorithm. Returns: TrackGraph containing the tracking results. """ if not imgs.shape == masks.shape: raise RuntimeError( f"Img shape {imgs.shape} and mask shape {masks.shape} do not match." ) if not imgs.ndim == self.transformer.config["coord_dim"] + 1: raise RuntimeError( f"images should be a sequence of {self.transformer.config['coord_dim']}D images" ) predictions = self._predict( imgs, masks, normalize_imgs=normalize_imgs, progbar_class=progbar_class, n_workers=n_workers, ) track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs) return track_graph def track_from_disk( self, imgs_path: Path, masks_path: Path, mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", normalize_imgs: bool = True, **kwargs, ) -> Tuple[TrackGraph, np.ndarray]: """Track objects directly from image and mask files on disk. This method supports both single tiff files and directories Args: imgs_path: Path to input images. Can be: - Directory containing numbered tiff files of shape (C),(Z),Y,X - Single tiff file with time series of shape T,(C),(Z),Y,X masks_path: Path to mask files. Can be: - Directory containing numbered tiff files of shape (Z),Y,X - Single tiff file with time series of shape T,(Z),Y,X mode: Tracking mode: - "greedy_nodiv": Fast greedy linking without division - "greedy": Fast greedy linking with division - "ilp": Integer Linear Programming based linking (more accurate but slower) normalize_imgs: Whether to normalize the images. **kwargs: Additional arguments passed to tracking algorithm. Returns: Tuple of (TrackGraph, tracked masks). """ if not imgs_path.exists(): raise FileNotFoundError(f"{imgs_path=} does not exist.") if not masks_path.exists(): raise FileNotFoundError(f"{masks_path=} does not exist.") if imgs_path.is_dir(): imgs = load_tiff_timeseries(imgs_path) else: imgs = tifffile.imread(imgs_path) if masks_path.is_dir(): masks = load_tiff_timeseries(masks_path) else: masks = tifffile.imread(masks_path) if len(imgs) != len(masks): raise RuntimeError( f"#imgs and #masks do not match. Found {len(imgs)} images," f" {len(masks)} masks." ) if imgs.ndim - 1 == masks.ndim: if imgs[1] == 1: logger.info( "Found a channel dimension with a single channel. Removing dim." ) masks = np.squeeze(masks, 1) else: raise RuntimeError( "Trackastra currently only supports single channel images." ) if imgs.shape != masks.shape: raise RuntimeError( f"Img shape {imgs.shape} and mask shape {masks.shape} do not match." ) return self.track( imgs, masks, mode, normalize_imgs=normalize_imgs, **kwargs ), masks