Spaces:
Sleeping
Sleeping
| import logging | |
| from itertools import chain | |
| import networkx as nx | |
| import numpy as np | |
| import scipy | |
| from tqdm import tqdm | |
| from .track_graph import TrackGraph | |
| from typing import Optional, Tuple | |
| # from trackastra.tracking import graph_to_napari_tracks, graph_to_ctc | |
| logger = logging.getLogger(__name__) | |
| def copy_edge(edge: tuple, source: nx.DiGraph, target: nx.DiGraph): | |
| if edge[0] not in target.nodes: | |
| target.add_node(edge[0], **source.nodes[edge[0]]) | |
| if edge[1] not in target.nodes: | |
| target.add_node(edge[1], **source.nodes[edge[1]]) | |
| target.add_edge(edge[0], edge[1], **source.edges[(edge[0], edge[1])]) | |
| def track_greedy( | |
| candidate_graph, | |
| allow_divisions=True, | |
| threshold=0.5, | |
| edge_attr="weight", | |
| ): | |
| """Greedy matching, global. | |
| Iterates over global edges sorted by weight, and keeps edge if feasible and weight above threshold. | |
| Args: | |
| allow_divisions (bool, optional): | |
| Whether to model divisions. Defaults to True. | |
| Returns: | |
| solution_graph: NetworkX graph of tracks | |
| """ | |
| logger.info("Running greedy tracker") | |
| solution_graph = nx.DiGraph() | |
| # TODO bring back | |
| # if args.gt_as_dets: | |
| # solution_graph.add_nodes_from(candidate_graph.nodes(data=True)) | |
| edges = candidate_graph.edges(data=True) | |
| edges = sorted( | |
| edges, | |
| key=lambda edge: edge[2][edge_attr], | |
| reverse=True, | |
| ) | |
| for edge in tqdm(edges, desc="Greedily matched edges"): | |
| node_in, node_out, features = edge | |
| assert ( | |
| features[edge_attr] <= 1.0 | |
| ), "Edge weights are assumed to be normalized to [0,1]" | |
| # assumes sorted edges | |
| if features[edge_attr] < threshold: | |
| break | |
| # Check whether this edge is a feasible edge to add | |
| # i.e. no fusing | |
| if node_out in solution_graph.nodes and solution_graph.in_degree(node_out) > 0: | |
| # target node already has an incoming edge | |
| continue | |
| if node_in in solution_graph and solution_graph.out_degree(node_in) >= ( | |
| 2 if allow_divisions else 1 | |
| ): | |
| # parent node already has max number of outgoing edges | |
| continue | |
| # otherwise add to solution | |
| copy_edge(edge, candidate_graph, solution_graph) | |
| # df, masks = graph_to_ctc(solution_graph, masks_original) | |
| # tracks, tracks_graph, _ = graph_to_napari_tracks(solution_graph) | |
| return solution_graph | |
| # TODO this should all be in a tracker class | |
| # return df, masks, solution_graph, tracks_graph, tracks, candidate_graph | |
| def build_graph( | |
| nodes: dict, | |
| weights: Optional[tuple] = None, | |
| use_distance: bool = False, | |
| max_distance: Optional[int] = None, | |
| max_neighbors: Optional[int] = None, | |
| delta_t=1, | |
| ): | |
| logger.info(f"Build candidate graph with {delta_t=}") | |
| G = nx.DiGraph() | |
| for node in nodes: | |
| G.add_node( | |
| node["id"], | |
| time=node["time"], | |
| label=node["label"], | |
| coords=node["coords"], | |
| # index=node["index"], | |
| weight=1, | |
| ) | |
| if use_distance: | |
| weights = None | |
| if weights: | |
| weights = {w[0]: w[1] for w in weights} | |
| graph = TrackGraph(G, frame_attribute="time") | |
| frame_pairs = zip( | |
| chain(*[ | |
| list(range(graph.t_begin, graph.t_end - d)) for d in range(1, delta_t + 1) | |
| ]), | |
| chain(*[ | |
| list(range(graph.t_begin + d, graph.t_end)) for d in range(1, delta_t + 1) | |
| ]), | |
| ) | |
| iterator = tqdm( | |
| frame_pairs, | |
| total=(graph.t_end - graph.t_begin) * delta_t, | |
| leave=False, | |
| ) | |
| for t_begin, t_end in iterator: | |
| n_edges_t = len(G.edges) | |
| ni, nj = graph.nodes_by_frame(t_begin), graph.nodes_by_frame(t_end) | |
| pi = [] | |
| for _ni in ni: | |
| pi.append(np.array(G.nodes[_ni]["coords"])) | |
| pi = np.stack(pi) | |
| pj = [] | |
| for _nj in nj: | |
| pj.append(np.array(G.nodes[_nj]["coords"])) | |
| pj = np.stack(pj) | |
| dists = scipy.spatial.distance.cdist(pi, pj) | |
| for _i, _ni in enumerate(ni): | |
| inds = np.argsort(dists[_i]) | |
| neighbors = 0 | |
| for _j, _nj in zip(inds, np.array(nj)[inds]): | |
| if max_neighbors and neighbors >= max_neighbors: | |
| break | |
| dist = dists[_i, _j] | |
| if max_distance is None or dist <= max_distance: | |
| if weights is None: | |
| G.add_edge(_ni, _nj, weight=1 - dist / max_distance) | |
| neighbors += 1 | |
| else: | |
| if (_ni, _nj) in weights: | |
| G.add_edge(_ni, _nj, weight=weights[(_ni, _nj)]) | |
| neighbors += 1 | |
| e_added = len(G.edges) - n_edges_t | |
| if e_added == 0: | |
| logger.warning(f"No candidate edges in frame {t_begin}") | |
| iterator.set_description( | |
| f"{e_added} edges in frame {t_begin} Total edges: {len(G.edges)}" | |
| ) | |
| logger.info(f"Added {len(G.nodes)} vertices, {len(G.edges)} edges") | |
| return G | |