Spaces:
Running
Running
| """Graph-level evaluation: node matching, then edge matching, then P/R/F1. | |
| Matching strategy (skeleton — greedy by descending similarity): | |
| 1. Score every (predicted_node, gt_node) pair with `node_similarity`. | |
| 2. Greedily pick the highest-scoring pair until no pair clears the | |
| threshold or all nodes on one side are consumed. | |
| 3. Map predicted edges through the resulting id alignment and intersect | |
| with GT edges. Undirected matching is optional. | |
| TODO: swap greedy for Hungarian (scipy.optimize.linear_sum_assignment) | |
| once we have a feel for the data. Greedy is good enough for a first | |
| pass and keeps dependencies small. | |
| """ | |
| from __future__ import annotations | |
| from typing import Optional | |
| def _norm(s: Optional[str]) -> str: | |
| return (s or "").strip().lower().replace("-", "_").replace(" ", "_") | |
| def node_similarity(pred: dict, gt: dict) -> float: | |
| """Score a predicted node against a GT node in [0, 1]. | |
| Both `type` match and `label` match contribute. If either side has no | |
| label, we fall back to type-only matching. | |
| """ | |
| tp, tg = _norm(pred.get("type")), _norm(gt.get("type")) | |
| lp, lg = _norm(pred.get("label")), _norm(gt.get("label")) | |
| type_match = 1.0 if tp and tp == tg else 0.0 | |
| if lp and lg: | |
| label_match = 1.0 if lp == lg else 0.0 | |
| return 0.5 * type_match + 0.5 * label_match | |
| # Only one side (or neither) has a label — rely on type alone. | |
| return type_match | |
| def match_nodes( | |
| pred_nodes: list[dict], | |
| gt_nodes: list[dict], | |
| threshold: float = 0.5, | |
| ) -> dict[str, str]: | |
| """Greedy 1:1 matching. Returns `{pred_id: gt_id}` for matched pairs.""" | |
| scored: list[tuple[float, str, str]] = [] | |
| for p in pred_nodes: | |
| for g in gt_nodes: | |
| s = node_similarity(p, g) | |
| if s >= threshold: | |
| scored.append((s, p["id"], g["id"])) | |
| scored.sort(reverse=True) | |
| matches: dict[str, str] = {} | |
| used_gt: set[str] = set() | |
| for _score, pid, gid in scored: | |
| if pid in matches or gid in used_gt: | |
| continue | |
| matches[pid] = gid | |
| used_gt.add(gid) | |
| return matches | |
| def _prf(tp: int, fp: int, fn: int) -> dict: | |
| precision = tp / (tp + fp) if (tp + fp) else 0.0 | |
| recall = tp / (tp + fn) if (tp + fn) else 0.0 | |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0 | |
| return { | |
| "precision": precision, | |
| "recall": recall, | |
| "f1": f1, | |
| "tp": tp, | |
| "fp": fp, | |
| "fn": fn, | |
| } | |
| def evaluate(pred: dict, gt: dict, directed: bool = True, match_threshold: float = 0.5) -> dict: | |
| """Compare a single predicted graph against a single GT graph. | |
| Args: | |
| pred: `{nodes, edges}` as emitted by `GraphOut.to_dict()`. | |
| gt: `{nodes, edges}` as emitted by `gt_loader.load_graphml()`. | |
| directed: if False, (u, v) and (v, u) are treated as the same edge. | |
| match_threshold: minimum similarity to accept a node pairing. | |
| Returns: | |
| Dict with `nodes`, `edges`, and `n_matched_nodes` keys. | |
| """ | |
| # --- nodes ------------------------------------------------------------ | |
| node_map = match_nodes(pred["nodes"], gt["nodes"], threshold=match_threshold) | |
| tp_n = len(node_map) | |
| fp_n = len(pred["nodes"]) - tp_n | |
| fn_n = len(gt["nodes"]) - tp_n | |
| node_metrics = _prf(tp_n, fp_n, fn_n) | |
| # --- edges ------------------------------------------------------------ | |
| # Translate predicted edges through the node alignment. Unmatched | |
| # endpoints make the edge uncountable (it becomes a guaranteed FP). | |
| def canon(u: str, v: str) -> tuple[str, str]: | |
| return (u, v) if directed else tuple(sorted((u, v))) # type: ignore[return-value] | |
| pred_edges_canon: set[tuple[str, str]] = set() | |
| unmappable_pred_edges = 0 | |
| for e in pred["edges"]: | |
| if e["source"] in node_map and e["target"] in node_map: | |
| pred_edges_canon.add(canon(node_map[e["source"]], node_map[e["target"]])) | |
| else: | |
| unmappable_pred_edges += 1 | |
| gt_edges_canon: set[tuple[str, str]] = set( | |
| canon(e["source"], e["target"]) for e in gt["edges"] | |
| ) | |
| tp_e = len(pred_edges_canon & gt_edges_canon) | |
| fp_e = (len(pred_edges_canon) - tp_e) + unmappable_pred_edges | |
| fn_e = len(gt_edges_canon) - tp_e | |
| edge_metrics = _prf(tp_e, fp_e, fn_e) | |
| return { | |
| "nodes": node_metrics, | |
| "edges": edge_metrics, | |
| "n_pred_nodes": len(pred["nodes"]), | |
| "n_gt_nodes": len(gt["nodes"]), | |
| "n_pred_edges": len(pred["edges"]), | |
| "n_gt_edges": len(gt["edges"]), | |
| } | |
| def aggregate(per_sample: list[dict]) -> dict: | |
| """Micro-average across samples by summing TP/FP/FN.""" | |
| def sum_keys(metric: str) -> dict: | |
| tp = sum(s[metric]["tp"] for s in per_sample) | |
| fp = sum(s[metric]["fp"] for s in per_sample) | |
| fn = sum(s[metric]["fn"] for s in per_sample) | |
| return _prf(tp, fp, fn) | |
| return { | |
| "n_samples": len(per_sample), | |
| "nodes_micro": sum_keys("nodes"), | |
| "edges_micro": sum_keys("edges"), | |
| } | |