phoebehxf
init
aff3c6f
import logging
import warnings
import numpy as np
import torch
from scipy.sparse import SparseEfficiencyWarning, csr_array
from tqdm import tqdm
from typing import List
# TODO fix circular import
# from .model import TrackingTransformer
# from trackastra.data import WRFeatures
warnings.simplefilter("ignore", SparseEfficiencyWarning)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def predict(batch, model):
"""Predict association scores between objects in a batch.
Args:
batch: Dictionary containing:
- features: Object features array
- coords: Object coordinates array
- timepoints: Time points array
model: TrackingTransformer model to use for prediction.
Returns:
Array of association scores between objects.
"""
feats = torch.from_numpy(batch["features"])
coords = torch.from_numpy(batch["coords"])
timepoints = torch.from_numpy(batch["timepoints"]).long()
# Hack that assumes that all parameters of a model are on the same device
device = next(model.parameters()).device
feats = feats.unsqueeze(0).to(device)
timepoints = timepoints.unsqueeze(0).to(device)
coords = coords.unsqueeze(0).to(device)
# Concat timepoints to coordinates
coords = torch.cat((timepoints.unsqueeze(2).float(), coords), dim=2)
with torch.no_grad():
A = model(coords, features=feats)
A = model.normalize_output(A, timepoints, coords)
# # Spatially far entries should not influence the causal normalization
# dist = torch.cdist(coords[0, :, 1:], coords[0, :, 1:])
# invalid = dist > model.config["spatial_pos_cutoff"]
# A[invalid] = -torch.inf
A = A.squeeze(0).detach().cpu().numpy()
return A
def predict_windows(
windows: List[dict],
# features: list[WRFeatures],
# model: TrackingTransformer,
features: list,
model,
intra_window_weight: float = 0,
delta_t: int = 1,
edge_threshold: float = 0.05,
spatial_dim: int = 3,
progbar_class=tqdm,
) -> dict:
"""Predict associations between objects across sliding windows.
This function processes a sequence of sliding windows to predict associations
between objects across time frames. It handles:
- Object tracking across time
- Weight normalization across windows
- Edge thresholding
- Time-based filtering
Args:
windows: List of window dictionaries containing:
- timepoints: Array of time points
- labels: Array of object labels
- features: Object features
- coords: Object coordinates
features: List of feature objects containing:
- labels: Object labels
- timepoints: Time points
- coords: Object coordinates
model: TrackingTransformer model to use for prediction.
intra_window_weight: Weight factor for objects in middle of window. Defaults to 0.
delta_t: Maximum time difference between objects to consider. Defaults to 1.
edge_threshold: Minimum association score to consider. Defaults to 0.05.
spatial_dim: Dimensionality of input masks. May be less than model.coord_dim.
progbar_class: Progress bar class to use. Defaults to tqdm.
Returns:
Dictionary containing:
- nodes: List of node properties (id, coords, time, label)
- weights: Tuple of ((node_i, node_j), weight) pairs
"""
# first get all objects/coords
time_labels_to_id = dict()
node_properties = list()
max_id = np.sum([len(f.labels) for f in features])
all_timepoints = np.concatenate([f.timepoints for f in features])
all_labels = np.concatenate([f.labels for f in features])
all_coords = np.concatenate([f.coords for f in features])
all_coords = all_coords[:, -spatial_dim:]
for i, (t, la, c) in enumerate(zip(all_timepoints, all_labels, all_coords)):
time_labels_to_id[(t, la)] = i
node_properties.append(
dict(
id=i,
coords=tuple(c),
time=t,
# index=ix,
label=la,
)
)
# create assoc matrix between ids
sp_weights, sp_accum = (
csr_array((max_id, max_id), dtype=np.float32),
csr_array((max_id, max_id), dtype=np.float32),
)
for t in progbar_class(
range(len(windows)),
desc="Computing associations",
):
# This assumes that the samples in the dataset are ordered by time and start at 0.
batch = windows[t]
timepoints = batch["timepoints"]
labels = batch["labels"]
A = predict(batch, model)
dt = timepoints[None, :] - timepoints[:, None]
time_mask = np.logical_and(dt <= delta_t, dt > 0)
A[~time_mask] = 0
ii, jj = np.where(A >= edge_threshold)
if len(ii) == 0:
continue
labels_ii = labels[ii]
labels_jj = labels[jj]
ts_ii = timepoints[ii]
ts_jj = timepoints[jj]
nodes_ii = np.array(
tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_ii, labels_ii))
)
nodes_jj = np.array(
tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_jj, labels_jj))
)
# weight middle parts higher
t_middle = t + (model.config["window"] - 1) / 2
ddt = timepoints[:, None] - t_middle * np.ones_like(dt)
window_weight = np.exp(-intra_window_weight * ddt**2) # default is 1
# window_weight = np.exp(4*A) # smooth max
sp_weights[nodes_ii, nodes_jj] += window_weight[ii, jj] * A[ii, jj]
sp_accum[nodes_ii, nodes_jj] += window_weight[ii, jj]
sp_weights_coo = sp_weights.tocoo()
sp_accum_coo = sp_accum.tocoo()
assert np.allclose(sp_weights_coo.col, sp_accum_coo.col) and np.allclose(
sp_weights_coo.row, sp_accum_coo.row
)
# Normalize weights by the number of times they were written from different sliding window positions
weights = tuple(
((i, j), v / a)
for i, j, v, a in zip(
sp_weights_coo.row,
sp_weights_coo.col,
sp_weights_coo.data,
sp_accum_coo.data,
)
)
results = dict()
results["nodes"] = node_properties
results["weights"] = weights
return results