Spaces:
Sleeping
Sleeping
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import Literal, Union, Optional, Tuple | |
| import dask.array as da | |
| import numpy as np | |
| import tifffile | |
| import torch | |
| import yaml | |
| from tqdm import tqdm | |
| from ..data import build_windows, get_features, load_tiff_timeseries | |
| from ..tracking import TrackGraph, build_graph, track_greedy | |
| from ..utils import normalize | |
| from .model import TrackingTransformer | |
| from .predict import predict_windows | |
| from .pretrained import download_pretrained | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class Trackastra: | |
| """A transformer-based tracking model for time-lapse data. | |
| Trackastra links segmented objects across time frames by predicting | |
| associations with a transformer model trained on diverse time-lapse videos. | |
| The model takes as input: | |
| - A sequence of images of shape (T,(Z),Y,X) | |
| - Corresponding instance segmentation masks of shape (T,(Z),Y,X) | |
| It supports multiple tracking modes: | |
| - greedy_nodiv: Fast greedy linking without division | |
| - greedy: Fast greedy linking with division | |
| - ilp: Integer Linear Programming based linking (more accurate but slower) | |
| Examples: | |
| >>> # Load example data | |
| >>> from trackastra.data import example_data_bacteria | |
| >>> imgs, masks = example_data_bacteria() | |
| >>> | |
| >>> # Load pretrained model and track | |
| >>> model = Trackastra.from_pretrained("general_2d", device="cuda") | |
| >>> track_graph = model.track(imgs, masks, mode="greedy") | |
| """ | |
| def __init__( | |
| self, | |
| transformer: TrackingTransformer, | |
| train_args: dict, | |
| device: Literal["cuda", "mps", "cpu", "automatic", None] = None, | |
| ): | |
| """Initialize Trackastra model. | |
| Args: | |
| transformer: The underlying transformer model. | |
| train_args: Training configuration arguments. | |
| device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None). | |
| """ | |
| if device == "cuda": | |
| if torch.cuda.is_available(): | |
| self.device = "cuda" | |
| else: | |
| logger.info("Cuda not available, falling back to cpu.") | |
| self.device = "cpu" | |
| elif device == "mps": | |
| if ( | |
| torch.backends.mps.is_available() | |
| and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None | |
| and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0" | |
| ): | |
| self.device = "mps" | |
| else: | |
| logger.info("Mps not available, falling back to cpu.") | |
| self.device = "cpu" | |
| elif device == "cpu": | |
| self.device = "cpu" | |
| elif device == "automatic" or device is None: | |
| should_use_mps = ( | |
| torch.backends.mps.is_available() | |
| and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None | |
| and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0" | |
| ) | |
| self.device = ( | |
| "cuda" | |
| if torch.cuda.is_available() | |
| else ( | |
| "mps" | |
| if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") | |
| else "cpu" | |
| ) | |
| ) | |
| else: | |
| raise ValueError(f"Device {device} not recognized.") | |
| logger.info(f"Using device {self.device}") | |
| self.transformer = transformer.to(self.device) | |
| self.train_args = train_args | |
| def from_folder(cls, dir: Union[Path, str], device: Optional[str] = None): | |
| """Load a Trackastra model from a local folder. | |
| Args: | |
| dir: Path to model folder containing: | |
| - model weights | |
| - train_config.yaml with training arguments | |
| device: Device to run model on. | |
| Returns: | |
| Trackastra model instance. | |
| """ | |
| # Always load to cpu first | |
| transformer = TrackingTransformer.from_folder( | |
| Path(dir).expanduser(), map_location="cpu" | |
| ) | |
| train_args = yaml.load(open(dir / "train_config.yaml"), Loader=yaml.FullLoader) | |
| return cls(transformer=transformer, train_args=train_args, device=device) | |
| def from_pretrained( | |
| cls, name: str, device: Optional[str] = None, download_dir: Optional[Path] = None | |
| ): | |
| """Load a pretrained Trackastra model. | |
| Available pretrained models are described in detail in pretrained.json. | |
| Args: | |
| name: Name of pretrained model (e.g. "general_2d"). | |
| device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None). | |
| download_dir: Directory to download model to (defaults to ~/.cache/trackastra). | |
| Returns: | |
| Trackastra model instance. | |
| """ | |
| folder = download_pretrained(name, download_dir) | |
| # download zip from github to location/name, then unzip | |
| return cls.from_folder(folder, device=device) | |
| def _predict( | |
| self, | |
| imgs: Union[np.ndarray, da.Array], | |
| masks: Union[np.ndarray, da.Array], | |
| edge_threshold: float = 0.05, | |
| n_workers: int = 0, | |
| normalize_imgs: bool = True, | |
| progbar_class=tqdm, | |
| ): | |
| logger.info("Predicting weights for candidate graph") | |
| if normalize_imgs: | |
| if isinstance(imgs, da.Array): | |
| imgs = imgs.map_blocks(normalize) | |
| else: | |
| imgs = normalize(imgs) | |
| self.transformer.eval() | |
| features = get_features( | |
| detections=masks, | |
| imgs=imgs, | |
| ndim=self.transformer.config["coord_dim"], | |
| n_workers=n_workers, | |
| progbar_class=progbar_class, | |
| ) | |
| logger.info("Building windows") | |
| windows = build_windows( | |
| features, | |
| window_size=self.transformer.config["window"], | |
| progbar_class=progbar_class, | |
| ) | |
| logger.info("Predicting windows") | |
| predictions = predict_windows( | |
| windows=windows, | |
| features=features, | |
| model=self.transformer, | |
| edge_threshold=edge_threshold, | |
| spatial_dim=masks.ndim - 1, | |
| progbar_class=progbar_class, | |
| ) | |
| return predictions | |
| def _track_from_predictions( | |
| self, | |
| predictions, | |
| mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", | |
| use_distance: bool = False, | |
| max_distance: int = 256, | |
| max_neighbors: int = 10, | |
| delta_t: int = 1, | |
| **kwargs, | |
| ): | |
| logger.info("Running greedy tracker") | |
| nodes = predictions["nodes"] | |
| weights = predictions["weights"] | |
| candidate_graph = build_graph( | |
| nodes=nodes, | |
| weights=weights, | |
| use_distance=use_distance, | |
| max_distance=max_distance, | |
| max_neighbors=max_neighbors, | |
| delta_t=delta_t, | |
| ) | |
| if mode == "greedy": | |
| return track_greedy(candidate_graph) | |
| elif mode == "greedy_nodiv": | |
| return track_greedy(candidate_graph, allow_divisions=False) | |
| elif mode == "ilp": | |
| from trackastra.tracking.ilp import track_ilp | |
| return track_ilp(candidate_graph, ilp_config="gt", **kwargs) | |
| else: | |
| raise ValueError(f"Tracking mode {mode} does not exist.") | |
| def track( | |
| self, | |
| imgs: Union[np.ndarray, da.Array], | |
| masks: Union[np.ndarray, da.Array], | |
| mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", | |
| normalize_imgs: bool = True, | |
| progbar_class=tqdm, | |
| n_workers: int = 0, | |
| **kwargs, | |
| ) -> TrackGraph: | |
| """Track objects across time frames. | |
| This method links segmented objects across time frames using the specified | |
| tracking mode. No hyperparameters need to be chosen beyond the tracking mode. | |
| Args: | |
| imgs: Input images of shape (T,(Z),Y,X) (numpy or dask array) | |
| masks: Instance segmentation masks of shape (T,(Z),Y,X). | |
| mode: Tracking mode: | |
| - "greedy_nodiv": Fast greedy linking without division | |
| - "greedy": Fast greedy linking with division | |
| - "ilp": Integer Linear Programming based linking (more accurate but slower) | |
| progbar_class: Progress bar class to use. | |
| n_workers: Number of worker processes for feature extraction. | |
| normalize_imgs: Whether to normalize the images. | |
| **kwargs: Additional arguments passed to tracking algorithm. | |
| Returns: | |
| TrackGraph containing the tracking results. | |
| """ | |
| if not imgs.shape == masks.shape: | |
| raise RuntimeError( | |
| f"Img shape {imgs.shape} and mask shape {masks.shape} do not match." | |
| ) | |
| if not imgs.ndim == self.transformer.config["coord_dim"] + 1: | |
| raise RuntimeError( | |
| f"images should be a sequence of {self.transformer.config['coord_dim']}D images" | |
| ) | |
| predictions = self._predict( | |
| imgs, | |
| masks, | |
| normalize_imgs=normalize_imgs, | |
| progbar_class=progbar_class, | |
| n_workers=n_workers, | |
| ) | |
| track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs) | |
| return track_graph | |
| def track_from_disk( | |
| self, | |
| imgs_path: Path, | |
| masks_path: Path, | |
| mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", | |
| normalize_imgs: bool = True, | |
| **kwargs, | |
| ) -> Tuple[TrackGraph, np.ndarray]: | |
| """Track objects directly from image and mask files on disk. | |
| This method supports both single tiff files and directories | |
| Args: | |
| imgs_path: Path to input images. Can be: | |
| - Directory containing numbered tiff files of shape (C),(Z),Y,X | |
| - Single tiff file with time series of shape T,(C),(Z),Y,X | |
| masks_path: Path to mask files. Can be: | |
| - Directory containing numbered tiff files of shape (Z),Y,X | |
| - Single tiff file with time series of shape T,(Z),Y,X | |
| mode: Tracking mode: | |
| - "greedy_nodiv": Fast greedy linking without division | |
| - "greedy": Fast greedy linking with division | |
| - "ilp": Integer Linear Programming based linking (more accurate but slower) | |
| normalize_imgs: Whether to normalize the images. | |
| **kwargs: Additional arguments passed to tracking algorithm. | |
| Returns: | |
| Tuple of (TrackGraph, tracked masks). | |
| """ | |
| if not imgs_path.exists(): | |
| raise FileNotFoundError(f"{imgs_path=} does not exist.") | |
| if not masks_path.exists(): | |
| raise FileNotFoundError(f"{masks_path=} does not exist.") | |
| if imgs_path.is_dir(): | |
| imgs = load_tiff_timeseries(imgs_path) | |
| else: | |
| imgs = tifffile.imread(imgs_path) | |
| if masks_path.is_dir(): | |
| masks = load_tiff_timeseries(masks_path) | |
| else: | |
| masks = tifffile.imread(masks_path) | |
| if len(imgs) != len(masks): | |
| raise RuntimeError( | |
| f"#imgs and #masks do not match. Found {len(imgs)} images," | |
| f" {len(masks)} masks." | |
| ) | |
| if imgs.ndim - 1 == masks.ndim: | |
| if imgs[1] == 1: | |
| logger.info( | |
| "Found a channel dimension with a single channel. Removing dim." | |
| ) | |
| masks = np.squeeze(masks, 1) | |
| else: | |
| raise RuntimeError( | |
| "Trackastra currently only supports single channel images." | |
| ) | |
| if imgs.shape != masks.shape: | |
| raise RuntimeError( | |
| f"Img shape {imgs.shape} and mask shape {masks.shape} do not match." | |
| ) | |
| return self.track( | |
| imgs, masks, mode, normalize_imgs=normalize_imgs, **kwargs | |
| ), masks | |