import logging import warnings import numpy as np import torch from scipy.sparse import SparseEfficiencyWarning, csr_array from tqdm import tqdm from typing import List # TODO fix circular import # from .model import TrackingTransformer # from trackastra.data import WRFeatures warnings.simplefilter("ignore", SparseEfficiencyWarning) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) def predict(batch, model): """Predict association scores between objects in a batch. Args: batch: Dictionary containing: - features: Object features array - coords: Object coordinates array - timepoints: Time points array model: TrackingTransformer model to use for prediction. Returns: Array of association scores between objects. """ feats = torch.from_numpy(batch["features"]) coords = torch.from_numpy(batch["coords"]) timepoints = torch.from_numpy(batch["timepoints"]).long() # Hack that assumes that all parameters of a model are on the same device device = next(model.parameters()).device feats = feats.unsqueeze(0).to(device) timepoints = timepoints.unsqueeze(0).to(device) coords = coords.unsqueeze(0).to(device) # Concat timepoints to coordinates coords = torch.cat((timepoints.unsqueeze(2).float(), coords), dim=2) with torch.no_grad(): A = model(coords, features=feats) A = model.normalize_output(A, timepoints, coords) # # Spatially far entries should not influence the causal normalization # dist = torch.cdist(coords[0, :, 1:], coords[0, :, 1:]) # invalid = dist > model.config["spatial_pos_cutoff"] # A[invalid] = -torch.inf A = A.squeeze(0).detach().cpu().numpy() return A def predict_windows( windows: List[dict], # features: list[WRFeatures], # model: TrackingTransformer, features: list, model, intra_window_weight: float = 0, delta_t: int = 1, edge_threshold: float = 0.05, spatial_dim: int = 3, progbar_class=tqdm, ) -> dict: """Predict associations between objects across sliding windows. This function processes a sequence of sliding windows to predict associations between objects across time frames. It handles: - Object tracking across time - Weight normalization across windows - Edge thresholding - Time-based filtering Args: windows: List of window dictionaries containing: - timepoints: Array of time points - labels: Array of object labels - features: Object features - coords: Object coordinates features: List of feature objects containing: - labels: Object labels - timepoints: Time points - coords: Object coordinates model: TrackingTransformer model to use for prediction. intra_window_weight: Weight factor for objects in middle of window. Defaults to 0. delta_t: Maximum time difference between objects to consider. Defaults to 1. edge_threshold: Minimum association score to consider. Defaults to 0.05. spatial_dim: Dimensionality of input masks. May be less than model.coord_dim. progbar_class: Progress bar class to use. Defaults to tqdm. Returns: Dictionary containing: - nodes: List of node properties (id, coords, time, label) - weights: Tuple of ((node_i, node_j), weight) pairs """ # first get all objects/coords time_labels_to_id = dict() node_properties = list() max_id = np.sum([len(f.labels) for f in features]) all_timepoints = np.concatenate([f.timepoints for f in features]) all_labels = np.concatenate([f.labels for f in features]) all_coords = np.concatenate([f.coords for f in features]) all_coords = all_coords[:, -spatial_dim:] for i, (t, la, c) in enumerate(zip(all_timepoints, all_labels, all_coords)): time_labels_to_id[(t, la)] = i node_properties.append( dict( id=i, coords=tuple(c), time=t, # index=ix, label=la, ) ) # create assoc matrix between ids sp_weights, sp_accum = ( csr_array((max_id, max_id), dtype=np.float32), csr_array((max_id, max_id), dtype=np.float32), ) for t in progbar_class( range(len(windows)), desc="Computing associations", ): # This assumes that the samples in the dataset are ordered by time and start at 0. batch = windows[t] timepoints = batch["timepoints"] labels = batch["labels"] A = predict(batch, model) dt = timepoints[None, :] - timepoints[:, None] time_mask = np.logical_and(dt <= delta_t, dt > 0) A[~time_mask] = 0 ii, jj = np.where(A >= edge_threshold) if len(ii) == 0: continue labels_ii = labels[ii] labels_jj = labels[jj] ts_ii = timepoints[ii] ts_jj = timepoints[jj] nodes_ii = np.array( tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_ii, labels_ii)) ) nodes_jj = np.array( tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_jj, labels_jj)) ) # weight middle parts higher t_middle = t + (model.config["window"] - 1) / 2 ddt = timepoints[:, None] - t_middle * np.ones_like(dt) window_weight = np.exp(-intra_window_weight * ddt**2) # default is 1 # window_weight = np.exp(4*A) # smooth max sp_weights[nodes_ii, nodes_jj] += window_weight[ii, jj] * A[ii, jj] sp_accum[nodes_ii, nodes_jj] += window_weight[ii, jj] sp_weights_coo = sp_weights.tocoo() sp_accum_coo = sp_accum.tocoo() assert np.allclose(sp_weights_coo.col, sp_accum_coo.col) and np.allclose( sp_weights_coo.row, sp_accum_coo.row ) # Normalize weights by the number of times they were written from different sliding window positions weights = tuple( ((i, j), v / a) for i, j, v, a in zip( sp_weights_coo.row, sp_weights_coo.col, sp_weights_coo.data, sp_accum_coo.data, ) ) results = dict() results["nodes"] = node_properties results["weights"] = weights return results