Spaces:
Sleeping
Sleeping
File size: 6,416 Bytes
aff3c6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import logging
import warnings
import numpy as np
import torch
from scipy.sparse import SparseEfficiencyWarning, csr_array
from tqdm import tqdm
from typing import List
# TODO fix circular import
# from .model import TrackingTransformer
# from trackastra.data import WRFeatures
warnings.simplefilter("ignore", SparseEfficiencyWarning)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def predict(batch, model):
"""Predict association scores between objects in a batch.
Args:
batch: Dictionary containing:
- features: Object features array
- coords: Object coordinates array
- timepoints: Time points array
model: TrackingTransformer model to use for prediction.
Returns:
Array of association scores between objects.
"""
feats = torch.from_numpy(batch["features"])
coords = torch.from_numpy(batch["coords"])
timepoints = torch.from_numpy(batch["timepoints"]).long()
# Hack that assumes that all parameters of a model are on the same device
device = next(model.parameters()).device
feats = feats.unsqueeze(0).to(device)
timepoints = timepoints.unsqueeze(0).to(device)
coords = coords.unsqueeze(0).to(device)
# Concat timepoints to coordinates
coords = torch.cat((timepoints.unsqueeze(2).float(), coords), dim=2)
with torch.no_grad():
A = model(coords, features=feats)
A = model.normalize_output(A, timepoints, coords)
# # Spatially far entries should not influence the causal normalization
# dist = torch.cdist(coords[0, :, 1:], coords[0, :, 1:])
# invalid = dist > model.config["spatial_pos_cutoff"]
# A[invalid] = -torch.inf
A = A.squeeze(0).detach().cpu().numpy()
return A
def predict_windows(
windows: List[dict],
# features: list[WRFeatures],
# model: TrackingTransformer,
features: list,
model,
intra_window_weight: float = 0,
delta_t: int = 1,
edge_threshold: float = 0.05,
spatial_dim: int = 3,
progbar_class=tqdm,
) -> dict:
"""Predict associations between objects across sliding windows.
This function processes a sequence of sliding windows to predict associations
between objects across time frames. It handles:
- Object tracking across time
- Weight normalization across windows
- Edge thresholding
- Time-based filtering
Args:
windows: List of window dictionaries containing:
- timepoints: Array of time points
- labels: Array of object labels
- features: Object features
- coords: Object coordinates
features: List of feature objects containing:
- labels: Object labels
- timepoints: Time points
- coords: Object coordinates
model: TrackingTransformer model to use for prediction.
intra_window_weight: Weight factor for objects in middle of window. Defaults to 0.
delta_t: Maximum time difference between objects to consider. Defaults to 1.
edge_threshold: Minimum association score to consider. Defaults to 0.05.
spatial_dim: Dimensionality of input masks. May be less than model.coord_dim.
progbar_class: Progress bar class to use. Defaults to tqdm.
Returns:
Dictionary containing:
- nodes: List of node properties (id, coords, time, label)
- weights: Tuple of ((node_i, node_j), weight) pairs
"""
# first get all objects/coords
time_labels_to_id = dict()
node_properties = list()
max_id = np.sum([len(f.labels) for f in features])
all_timepoints = np.concatenate([f.timepoints for f in features])
all_labels = np.concatenate([f.labels for f in features])
all_coords = np.concatenate([f.coords for f in features])
all_coords = all_coords[:, -spatial_dim:]
for i, (t, la, c) in enumerate(zip(all_timepoints, all_labels, all_coords)):
time_labels_to_id[(t, la)] = i
node_properties.append(
dict(
id=i,
coords=tuple(c),
time=t,
# index=ix,
label=la,
)
)
# create assoc matrix between ids
sp_weights, sp_accum = (
csr_array((max_id, max_id), dtype=np.float32),
csr_array((max_id, max_id), dtype=np.float32),
)
for t in progbar_class(
range(len(windows)),
desc="Computing associations",
):
# This assumes that the samples in the dataset are ordered by time and start at 0.
batch = windows[t]
timepoints = batch["timepoints"]
labels = batch["labels"]
A = predict(batch, model)
dt = timepoints[None, :] - timepoints[:, None]
time_mask = np.logical_and(dt <= delta_t, dt > 0)
A[~time_mask] = 0
ii, jj = np.where(A >= edge_threshold)
if len(ii) == 0:
continue
labels_ii = labels[ii]
labels_jj = labels[jj]
ts_ii = timepoints[ii]
ts_jj = timepoints[jj]
nodes_ii = np.array(
tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_ii, labels_ii))
)
nodes_jj = np.array(
tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_jj, labels_jj))
)
# weight middle parts higher
t_middle = t + (model.config["window"] - 1) / 2
ddt = timepoints[:, None] - t_middle * np.ones_like(dt)
window_weight = np.exp(-intra_window_weight * ddt**2) # default is 1
# window_weight = np.exp(4*A) # smooth max
sp_weights[nodes_ii, nodes_jj] += window_weight[ii, jj] * A[ii, jj]
sp_accum[nodes_ii, nodes_jj] += window_weight[ii, jj]
sp_weights_coo = sp_weights.tocoo()
sp_accum_coo = sp_accum.tocoo()
assert np.allclose(sp_weights_coo.col, sp_accum_coo.col) and np.allclose(
sp_weights_coo.row, sp_accum_coo.row
)
# Normalize weights by the number of times they were written from different sliding window positions
weights = tuple(
((i, j), v / a)
for i, j, v, a in zip(
sp_weights_coo.row,
sp_weights_coo.col,
sp_weights_coo.data,
sp_accum_coo.data,
)
)
results = dict()
results["nodes"] = node_properties
results["weights"] = weights
return results
|