Spaces:
Sleeping
Sleeping
File size: 12,096 Bytes
8f72b1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 |
import logging
import os
from pathlib import Path
from typing import Literal, Union, Optional, Tuple
import dask.array as da
import numpy as np
import tifffile
import torch
import yaml
from tqdm import tqdm
from ..data import build_windows, get_features, load_tiff_timeseries
from ..tracking import TrackGraph, build_graph, track_greedy
from ..utils import normalize
from .model import TrackingTransformer
from .predict import predict_windows
from .pretrained import download_pretrained
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class Trackastra:
"""A transformer-based tracking model for time-lapse data.
Trackastra links segmented objects across time frames by predicting
associations with a transformer model trained on diverse time-lapse videos.
The model takes as input:
- A sequence of images of shape (T,(Z),Y,X)
- Corresponding instance segmentation masks of shape (T,(Z),Y,X)
It supports multiple tracking modes:
- greedy_nodiv: Fast greedy linking without division
- greedy: Fast greedy linking with division
- ilp: Integer Linear Programming based linking (more accurate but slower)
Examples:
>>> # Load example data
>>> from trackastra.data import example_data_bacteria
>>> imgs, masks = example_data_bacteria()
>>>
>>> # Load pretrained model and track
>>> model = Trackastra.from_pretrained("general_2d", device="cuda")
>>> track_graph = model.track(imgs, masks, mode="greedy")
"""
def __init__(
self,
transformer: TrackingTransformer,
train_args: dict,
device: Literal["cuda", "mps", "cpu", "automatic", None] = None,
):
"""Initialize Trackastra model.
Args:
transformer: The underlying transformer model.
train_args: Training configuration arguments.
device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None).
"""
if device == "cuda":
if torch.cuda.is_available():
self.device = "cuda"
else:
logger.info("Cuda not available, falling back to cpu.")
self.device = "cpu"
elif device == "mps":
if (
torch.backends.mps.is_available()
and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None
and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0"
):
self.device = "mps"
else:
logger.info("Mps not available, falling back to cpu.")
self.device = "cpu"
elif device == "cpu":
self.device = "cpu"
elif device == "automatic" or device is None:
should_use_mps = (
torch.backends.mps.is_available()
and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None
and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0"
)
self.device = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK")
else "cpu"
)
)
else:
raise ValueError(f"Device {device} not recognized.")
logger.info(f"Using device {self.device}")
self.transformer = transformer.to(self.device)
self.train_args = train_args
@classmethod
def from_folder(cls, dir: Union[Path, str], device: Optional[str] = None):
"""Load a Trackastra model from a local folder.
Args:
dir: Path to model folder containing:
- model weights
- train_config.yaml with training arguments
device: Device to run model on.
Returns:
Trackastra model instance.
"""
# Always load to cpu first
transformer = TrackingTransformer.from_folder(
Path(dir).expanduser(), map_location="cpu"
)
train_args = yaml.load(open(dir / "train_config.yaml"), Loader=yaml.FullLoader)
return cls(transformer=transformer, train_args=train_args, device=device)
@classmethod
def from_pretrained(
cls, name: str, device: Optional[str] = None, download_dir: Optional[Path] = None
):
"""Load a pretrained Trackastra model.
Available pretrained models are described in detail in pretrained.json.
Args:
name: Name of pretrained model (e.g. "general_2d").
device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None).
download_dir: Directory to download model to (defaults to ~/.cache/trackastra).
Returns:
Trackastra model instance.
"""
folder = download_pretrained(name, download_dir)
# download zip from github to location/name, then unzip
return cls.from_folder(folder, device=device)
def _predict(
self,
imgs: Union[np.ndarray, da.Array],
masks: Union[np.ndarray, da.Array],
edge_threshold: float = 0.05,
n_workers: int = 0,
normalize_imgs: bool = True,
progbar_class=tqdm,
):
logger.info("Predicting weights for candidate graph")
if normalize_imgs:
if isinstance(imgs, da.Array):
imgs = imgs.map_blocks(normalize)
else:
imgs = normalize(imgs)
self.transformer.eval()
features = get_features(
detections=masks,
imgs=imgs,
ndim=self.transformer.config["coord_dim"],
n_workers=n_workers,
progbar_class=progbar_class,
)
logger.info("Building windows")
windows = build_windows(
features,
window_size=self.transformer.config["window"],
progbar_class=progbar_class,
)
logger.info("Predicting windows")
predictions = predict_windows(
windows=windows,
features=features,
model=self.transformer,
edge_threshold=edge_threshold,
spatial_dim=masks.ndim - 1,
progbar_class=progbar_class,
)
return predictions
def _track_from_predictions(
self,
predictions,
mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
use_distance: bool = False,
max_distance: int = 256,
max_neighbors: int = 10,
delta_t: int = 1,
**kwargs,
):
logger.info("Running greedy tracker")
nodes = predictions["nodes"]
weights = predictions["weights"]
candidate_graph = build_graph(
nodes=nodes,
weights=weights,
use_distance=use_distance,
max_distance=max_distance,
max_neighbors=max_neighbors,
delta_t=delta_t,
)
if mode == "greedy":
return track_greedy(candidate_graph)
elif mode == "greedy_nodiv":
return track_greedy(candidate_graph, allow_divisions=False)
elif mode == "ilp":
from trackastra.tracking.ilp import track_ilp
return track_ilp(candidate_graph, ilp_config="gt", **kwargs)
else:
raise ValueError(f"Tracking mode {mode} does not exist.")
def track(
self,
imgs: Union[np.ndarray, da.Array],
masks: Union[np.ndarray, da.Array],
mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
normalize_imgs: bool = True,
progbar_class=tqdm,
n_workers: int = 0,
**kwargs,
) -> TrackGraph:
"""Track objects across time frames.
This method links segmented objects across time frames using the specified
tracking mode. No hyperparameters need to be chosen beyond the tracking mode.
Args:
imgs: Input images of shape (T,(Z),Y,X) (numpy or dask array)
masks: Instance segmentation masks of shape (T,(Z),Y,X).
mode: Tracking mode:
- "greedy_nodiv": Fast greedy linking without division
- "greedy": Fast greedy linking with division
- "ilp": Integer Linear Programming based linking (more accurate but slower)
progbar_class: Progress bar class to use.
n_workers: Number of worker processes for feature extraction.
normalize_imgs: Whether to normalize the images.
**kwargs: Additional arguments passed to tracking algorithm.
Returns:
TrackGraph containing the tracking results.
"""
if not imgs.shape == masks.shape:
raise RuntimeError(
f"Img shape {imgs.shape} and mask shape {masks.shape} do not match."
)
if not imgs.ndim == self.transformer.config["coord_dim"] + 1:
raise RuntimeError(
f"images should be a sequence of {self.transformer.config['coord_dim']}D images"
)
predictions = self._predict(
imgs,
masks,
normalize_imgs=normalize_imgs,
progbar_class=progbar_class,
n_workers=n_workers,
)
track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs)
return track_graph
def track_from_disk(
self,
imgs_path: Path,
masks_path: Path,
mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy",
normalize_imgs: bool = True,
**kwargs,
) -> Tuple[TrackGraph, np.ndarray]:
"""Track objects directly from image and mask files on disk.
This method supports both single tiff files and directories
Args:
imgs_path: Path to input images. Can be:
- Directory containing numbered tiff files of shape (C),(Z),Y,X
- Single tiff file with time series of shape T,(C),(Z),Y,X
masks_path: Path to mask files. Can be:
- Directory containing numbered tiff files of shape (Z),Y,X
- Single tiff file with time series of shape T,(Z),Y,X
mode: Tracking mode:
- "greedy_nodiv": Fast greedy linking without division
- "greedy": Fast greedy linking with division
- "ilp": Integer Linear Programming based linking (more accurate but slower)
normalize_imgs: Whether to normalize the images.
**kwargs: Additional arguments passed to tracking algorithm.
Returns:
Tuple of (TrackGraph, tracked masks).
"""
if not imgs_path.exists():
raise FileNotFoundError(f"{imgs_path=} does not exist.")
if not masks_path.exists():
raise FileNotFoundError(f"{masks_path=} does not exist.")
if imgs_path.is_dir():
imgs = load_tiff_timeseries(imgs_path)
else:
imgs = tifffile.imread(imgs_path)
if masks_path.is_dir():
masks = load_tiff_timeseries(masks_path)
else:
masks = tifffile.imread(masks_path)
if len(imgs) != len(masks):
raise RuntimeError(
f"#imgs and #masks do not match. Found {len(imgs)} images,"
f" {len(masks)} masks."
)
if imgs.ndim - 1 == masks.ndim:
if imgs[1] == 1:
logger.info(
"Found a channel dimension with a single channel. Removing dim."
)
masks = np.squeeze(masks, 1)
else:
raise RuntimeError(
"Trackastra currently only supports single channel images."
)
if imgs.shape != masks.shape:
raise RuntimeError(
f"Img shape {imgs.shape} and mask shape {masks.shape} do not match."
)
return self.track(
imgs, masks, mode, normalize_imgs=normalize_imgs, **kwargs
), masks
|