File size: 6,416 Bytes
aff3c6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import logging
import warnings

import numpy as np
import torch
from scipy.sparse import SparseEfficiencyWarning, csr_array
from tqdm import tqdm
from typing import List

# TODO fix circular import
# from .model import TrackingTransformer
# from trackastra.data import WRFeatures

warnings.simplefilter("ignore", SparseEfficiencyWarning)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


def predict(batch, model):
    """Predict association scores between objects in a batch.
    
    Args:
        batch: Dictionary containing:
            - features: Object features array
            - coords: Object coordinates array
            - timepoints: Time points array
        model: TrackingTransformer model to use for prediction.
    
    Returns:
        Array of association scores between objects.
    """
    feats = torch.from_numpy(batch["features"])
    coords = torch.from_numpy(batch["coords"])
    timepoints = torch.from_numpy(batch["timepoints"]).long()
    # Hack that assumes that all parameters of a model are on the same device
    device = next(model.parameters()).device
    feats = feats.unsqueeze(0).to(device)
    timepoints = timepoints.unsqueeze(0).to(device)
    coords = coords.unsqueeze(0).to(device)

    # Concat timepoints to coordinates
    coords = torch.cat((timepoints.unsqueeze(2).float(), coords), dim=2)
    with torch.no_grad():
        A = model(coords, features=feats)
        A = model.normalize_output(A, timepoints, coords)

        # # Spatially far entries should not influence the causal normalization
        # dist = torch.cdist(coords[0, :, 1:], coords[0, :, 1:])
        # invalid = dist > model.config["spatial_pos_cutoff"]
        # A[invalid] = -torch.inf

        A = A.squeeze(0).detach().cpu().numpy()

    return A


def predict_windows(
    windows: List[dict],
    # features: list[WRFeatures],
    # model: TrackingTransformer,
    features: list,
    model,
    intra_window_weight: float = 0,
    delta_t: int = 1,
    edge_threshold: float = 0.05,
    spatial_dim: int = 3,
    progbar_class=tqdm,
) -> dict:
    """Predict associations between objects across sliding windows.
    
    This function processes a sequence of sliding windows to predict associations
    between objects across time frames. It handles:
    - Object tracking across time
    - Weight normalization across windows
    - Edge thresholding
    - Time-based filtering
    
    Args:
        windows: List of window dictionaries containing:
            - timepoints: Array of time points
            - labels: Array of object labels
            - features: Object features
            - coords: Object coordinates
        features: List of feature objects containing:
            - labels: Object labels
            - timepoints: Time points
            - coords: Object coordinates
        model: TrackingTransformer model to use for prediction.
        intra_window_weight: Weight factor for objects in middle of window. Defaults to 0.
        delta_t: Maximum time difference between objects to consider. Defaults to 1.
        edge_threshold: Minimum association score to consider. Defaults to 0.05.
        spatial_dim: Dimensionality of input masks. May be less than model.coord_dim.
        progbar_class: Progress bar class to use. Defaults to tqdm.
    
    Returns:
        Dictionary containing:
            - nodes: List of node properties (id, coords, time, label)
            - weights: Tuple of ((node_i, node_j), weight) pairs
    """
    # first get all objects/coords
    time_labels_to_id = dict()
    node_properties = list()
    max_id = np.sum([len(f.labels) for f in features])

    all_timepoints = np.concatenate([f.timepoints for f in features])
    all_labels = np.concatenate([f.labels for f in features])
    all_coords = np.concatenate([f.coords for f in features])
    all_coords = all_coords[:, -spatial_dim:]

    for i, (t, la, c) in enumerate(zip(all_timepoints, all_labels, all_coords)):
        time_labels_to_id[(t, la)] = i
        node_properties.append(
            dict(
                id=i,
                coords=tuple(c),
                time=t,
                # index=ix,
                label=la,
            )
        )

    # create assoc matrix between ids
    sp_weights, sp_accum = (
        csr_array((max_id, max_id), dtype=np.float32),
        csr_array((max_id, max_id), dtype=np.float32),
    )

    for t in progbar_class(
        range(len(windows)),
        desc="Computing associations",
    ):
        # This assumes that the samples in the dataset are ordered by time and start at 0.
        batch = windows[t]
        timepoints = batch["timepoints"]
        labels = batch["labels"]

        A = predict(batch, model)

        dt = timepoints[None, :] - timepoints[:, None]
        time_mask = np.logical_and(dt <= delta_t, dt > 0)
        A[~time_mask] = 0
        ii, jj = np.where(A >= edge_threshold)

        if len(ii) == 0:
            continue

        labels_ii = labels[ii]
        labels_jj = labels[jj]
        ts_ii = timepoints[ii]
        ts_jj = timepoints[jj]
        nodes_ii = np.array(
            tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_ii, labels_ii))
        )
        nodes_jj = np.array(
            tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_jj, labels_jj))
        )

        # weight middle parts higher
        t_middle = t + (model.config["window"] - 1) / 2
        ddt = timepoints[:, None] - t_middle * np.ones_like(dt)
        window_weight = np.exp(-intra_window_weight * ddt**2)  # default is 1
        # window_weight = np.exp(4*A) # smooth max
        sp_weights[nodes_ii, nodes_jj] += window_weight[ii, jj] * A[ii, jj]
        sp_accum[nodes_ii, nodes_jj] += window_weight[ii, jj]

    sp_weights_coo = sp_weights.tocoo()
    sp_accum_coo = sp_accum.tocoo()
    assert np.allclose(sp_weights_coo.col, sp_accum_coo.col) and np.allclose(
        sp_weights_coo.row, sp_accum_coo.row
    )

    # Normalize weights by the number of times they were written from different sliding window positions
    weights = tuple(
        ((i, j), v / a)
        for i, j, v, a in zip(
            sp_weights_coo.row,
            sp_weights_coo.col,
            sp_weights_coo.data,
            sp_accum_coo.data,
        )
    )

    results = dict()
    results["nodes"] = node_properties
    results["weights"] = weights

    return results