Spaces:
Running
Running
File size: 5,097 Bytes
59fa244 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | """Graph-level evaluation: node matching, then edge matching, then P/R/F1.
Matching strategy (skeleton — greedy by descending similarity):
1. Score every (predicted_node, gt_node) pair with `node_similarity`.
2. Greedily pick the highest-scoring pair until no pair clears the
threshold or all nodes on one side are consumed.
3. Map predicted edges through the resulting id alignment and intersect
with GT edges. Undirected matching is optional.
TODO: swap greedy for Hungarian (scipy.optimize.linear_sum_assignment)
once we have a feel for the data. Greedy is good enough for a first
pass and keeps dependencies small.
"""
from __future__ import annotations
from typing import Optional
def _norm(s: Optional[str]) -> str:
return (s or "").strip().lower().replace("-", "_").replace(" ", "_")
def node_similarity(pred: dict, gt: dict) -> float:
"""Score a predicted node against a GT node in [0, 1].
Both `type` match and `label` match contribute. If either side has no
label, we fall back to type-only matching.
"""
tp, tg = _norm(pred.get("type")), _norm(gt.get("type"))
lp, lg = _norm(pred.get("label")), _norm(gt.get("label"))
type_match = 1.0 if tp and tp == tg else 0.0
if lp and lg:
label_match = 1.0 if lp == lg else 0.0
return 0.5 * type_match + 0.5 * label_match
# Only one side (or neither) has a label — rely on type alone.
return type_match
def match_nodes(
pred_nodes: list[dict],
gt_nodes: list[dict],
threshold: float = 0.5,
) -> dict[str, str]:
"""Greedy 1:1 matching. Returns `{pred_id: gt_id}` for matched pairs."""
scored: list[tuple[float, str, str]] = []
for p in pred_nodes:
for g in gt_nodes:
s = node_similarity(p, g)
if s >= threshold:
scored.append((s, p["id"], g["id"]))
scored.sort(reverse=True)
matches: dict[str, str] = {}
used_gt: set[str] = set()
for _score, pid, gid in scored:
if pid in matches or gid in used_gt:
continue
matches[pid] = gid
used_gt.add(gid)
return matches
def _prf(tp: int, fp: int, fn: int) -> dict:
precision = tp / (tp + fp) if (tp + fp) else 0.0
recall = tp / (tp + fn) if (tp + fn) else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
return {
"precision": precision,
"recall": recall,
"f1": f1,
"tp": tp,
"fp": fp,
"fn": fn,
}
def evaluate(pred: dict, gt: dict, directed: bool = True, match_threshold: float = 0.5) -> dict:
"""Compare a single predicted graph against a single GT graph.
Args:
pred: `{nodes, edges}` as emitted by `GraphOut.to_dict()`.
gt: `{nodes, edges}` as emitted by `gt_loader.load_graphml()`.
directed: if False, (u, v) and (v, u) are treated as the same edge.
match_threshold: minimum similarity to accept a node pairing.
Returns:
Dict with `nodes`, `edges`, and `n_matched_nodes` keys.
"""
# --- nodes ------------------------------------------------------------
node_map = match_nodes(pred["nodes"], gt["nodes"], threshold=match_threshold)
tp_n = len(node_map)
fp_n = len(pred["nodes"]) - tp_n
fn_n = len(gt["nodes"]) - tp_n
node_metrics = _prf(tp_n, fp_n, fn_n)
# --- edges ------------------------------------------------------------
# Translate predicted edges through the node alignment. Unmatched
# endpoints make the edge uncountable (it becomes a guaranteed FP).
def canon(u: str, v: str) -> tuple[str, str]:
return (u, v) if directed else tuple(sorted((u, v))) # type: ignore[return-value]
pred_edges_canon: set[tuple[str, str]] = set()
unmappable_pred_edges = 0
for e in pred["edges"]:
if e["source"] in node_map and e["target"] in node_map:
pred_edges_canon.add(canon(node_map[e["source"]], node_map[e["target"]]))
else:
unmappable_pred_edges += 1
gt_edges_canon: set[tuple[str, str]] = set(
canon(e["source"], e["target"]) for e in gt["edges"]
)
tp_e = len(pred_edges_canon & gt_edges_canon)
fp_e = (len(pred_edges_canon) - tp_e) + unmappable_pred_edges
fn_e = len(gt_edges_canon) - tp_e
edge_metrics = _prf(tp_e, fp_e, fn_e)
return {
"nodes": node_metrics,
"edges": edge_metrics,
"n_pred_nodes": len(pred["nodes"]),
"n_gt_nodes": len(gt["nodes"]),
"n_pred_edges": len(pred["edges"]),
"n_gt_edges": len(gt["edges"]),
}
def aggregate(per_sample: list[dict]) -> dict:
"""Micro-average across samples by summing TP/FP/FN."""
def sum_keys(metric: str) -> dict:
tp = sum(s[metric]["tp"] for s in per_sample)
fp = sum(s[metric]["fp"] for s in per_sample)
fn = sum(s[metric]["fn"] for s in per_sample)
return _prf(tp, fp, fn)
return {
"n_samples": len(per_sample),
"nodes_micro": sum_keys("nodes"),
"edges_micro": sum_keys("edges"),
}
|