phoebehxf
init
aff3c6f
import logging
from itertools import chain
import networkx as nx
import numpy as np
import scipy
from tqdm import tqdm
from .track_graph import TrackGraph
from typing import Optional, Tuple
# from trackastra.tracking import graph_to_napari_tracks, graph_to_ctc
logger = logging.getLogger(__name__)
def copy_edge(edge: tuple, source: nx.DiGraph, target: nx.DiGraph):
if edge[0] not in target.nodes:
target.add_node(edge[0], **source.nodes[edge[0]])
if edge[1] not in target.nodes:
target.add_node(edge[1], **source.nodes[edge[1]])
target.add_edge(edge[0], edge[1], **source.edges[(edge[0], edge[1])])
def track_greedy(
candidate_graph,
allow_divisions=True,
threshold=0.5,
edge_attr="weight",
):
"""Greedy matching, global.
Iterates over global edges sorted by weight, and keeps edge if feasible and weight above threshold.
Args:
allow_divisions (bool, optional):
Whether to model divisions. Defaults to True.
Returns:
solution_graph: NetworkX graph of tracks
"""
logger.info("Running greedy tracker")
solution_graph = nx.DiGraph()
# TODO bring back
# if args.gt_as_dets:
# solution_graph.add_nodes_from(candidate_graph.nodes(data=True))
edges = candidate_graph.edges(data=True)
edges = sorted(
edges,
key=lambda edge: edge[2][edge_attr],
reverse=True,
)
for edge in tqdm(edges, desc="Greedily matched edges"):
node_in, node_out, features = edge
assert (
features[edge_attr] <= 1.0
), "Edge weights are assumed to be normalized to [0,1]"
# assumes sorted edges
if features[edge_attr] < threshold:
break
# Check whether this edge is a feasible edge to add
# i.e. no fusing
if node_out in solution_graph.nodes and solution_graph.in_degree(node_out) > 0:
# target node already has an incoming edge
continue
if node_in in solution_graph and solution_graph.out_degree(node_in) >= (
2 if allow_divisions else 1
):
# parent node already has max number of outgoing edges
continue
# otherwise add to solution
copy_edge(edge, candidate_graph, solution_graph)
# df, masks = graph_to_ctc(solution_graph, masks_original)
# tracks, tracks_graph, _ = graph_to_napari_tracks(solution_graph)
return solution_graph
# TODO this should all be in a tracker class
# return df, masks, solution_graph, tracks_graph, tracks, candidate_graph
def build_graph(
nodes: dict,
weights: Optional[tuple] = None,
use_distance: bool = False,
max_distance: Optional[int] = None,
max_neighbors: Optional[int] = None,
delta_t=1,
):
logger.info(f"Build candidate graph with {delta_t=}")
G = nx.DiGraph()
for node in nodes:
G.add_node(
node["id"],
time=node["time"],
label=node["label"],
coords=node["coords"],
# index=node["index"],
weight=1,
)
if use_distance:
weights = None
if weights:
weights = {w[0]: w[1] for w in weights}
graph = TrackGraph(G, frame_attribute="time")
frame_pairs = zip(
chain(*[
list(range(graph.t_begin, graph.t_end - d)) for d in range(1, delta_t + 1)
]),
chain(*[
list(range(graph.t_begin + d, graph.t_end)) for d in range(1, delta_t + 1)
]),
)
iterator = tqdm(
frame_pairs,
total=(graph.t_end - graph.t_begin) * delta_t,
leave=False,
)
for t_begin, t_end in iterator:
n_edges_t = len(G.edges)
ni, nj = graph.nodes_by_frame(t_begin), graph.nodes_by_frame(t_end)
pi = []
for _ni in ni:
pi.append(np.array(G.nodes[_ni]["coords"]))
pi = np.stack(pi)
pj = []
for _nj in nj:
pj.append(np.array(G.nodes[_nj]["coords"]))
pj = np.stack(pj)
dists = scipy.spatial.distance.cdist(pi, pj)
for _i, _ni in enumerate(ni):
inds = np.argsort(dists[_i])
neighbors = 0
for _j, _nj in zip(inds, np.array(nj)[inds]):
if max_neighbors and neighbors >= max_neighbors:
break
dist = dists[_i, _j]
if max_distance is None or dist <= max_distance:
if weights is None:
G.add_edge(_ni, _nj, weight=1 - dist / max_distance)
neighbors += 1
else:
if (_ni, _nj) in weights:
G.add_edge(_ni, _nj, weight=weights[(_ni, _nj)])
neighbors += 1
e_added = len(G.edges) - n_edges_t
if e_added == 0:
logger.warning(f"No candidate edges in frame {t_begin}")
iterator.set_description(
f"{e_added} edges in frame {t_begin} Total edges: {len(G.edges)}"
)
logger.info(f"Added {len(G.nodes)} vertices, {len(G.edges)} edges")
return G