deepkick's picture
Initial commit: PID2Graph × Claude VLM evaluation + Gradio demo
59fa244
"""Graph-level evaluation: node matching, then edge matching, then P/R/F1.
Matching strategy (skeleton — greedy by descending similarity):
1. Score every (predicted_node, gt_node) pair with `node_similarity`.
2. Greedily pick the highest-scoring pair until no pair clears the
threshold or all nodes on one side are consumed.
3. Map predicted edges through the resulting id alignment and intersect
with GT edges. Undirected matching is optional.
TODO: swap greedy for Hungarian (scipy.optimize.linear_sum_assignment)
once we have a feel for the data. Greedy is good enough for a first
pass and keeps dependencies small.
"""
from __future__ import annotations
from typing import Optional
def _norm(s: Optional[str]) -> str:
return (s or "").strip().lower().replace("-", "_").replace(" ", "_")
def node_similarity(pred: dict, gt: dict) -> float:
"""Score a predicted node against a GT node in [0, 1].
Both `type` match and `label` match contribute. If either side has no
label, we fall back to type-only matching.
"""
tp, tg = _norm(pred.get("type")), _norm(gt.get("type"))
lp, lg = _norm(pred.get("label")), _norm(gt.get("label"))
type_match = 1.0 if tp and tp == tg else 0.0
if lp and lg:
label_match = 1.0 if lp == lg else 0.0
return 0.5 * type_match + 0.5 * label_match
# Only one side (or neither) has a label — rely on type alone.
return type_match
def match_nodes(
pred_nodes: list[dict],
gt_nodes: list[dict],
threshold: float = 0.5,
) -> dict[str, str]:
"""Greedy 1:1 matching. Returns `{pred_id: gt_id}` for matched pairs."""
scored: list[tuple[float, str, str]] = []
for p in pred_nodes:
for g in gt_nodes:
s = node_similarity(p, g)
if s >= threshold:
scored.append((s, p["id"], g["id"]))
scored.sort(reverse=True)
matches: dict[str, str] = {}
used_gt: set[str] = set()
for _score, pid, gid in scored:
if pid in matches or gid in used_gt:
continue
matches[pid] = gid
used_gt.add(gid)
return matches
def _prf(tp: int, fp: int, fn: int) -> dict:
precision = tp / (tp + fp) if (tp + fp) else 0.0
recall = tp / (tp + fn) if (tp + fn) else 0.0
f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
return {
"precision": precision,
"recall": recall,
"f1": f1,
"tp": tp,
"fp": fp,
"fn": fn,
}
def evaluate(pred: dict, gt: dict, directed: bool = True, match_threshold: float = 0.5) -> dict:
"""Compare a single predicted graph against a single GT graph.
Args:
pred: `{nodes, edges}` as emitted by `GraphOut.to_dict()`.
gt: `{nodes, edges}` as emitted by `gt_loader.load_graphml()`.
directed: if False, (u, v) and (v, u) are treated as the same edge.
match_threshold: minimum similarity to accept a node pairing.
Returns:
Dict with `nodes`, `edges`, and `n_matched_nodes` keys.
"""
# --- nodes ------------------------------------------------------------
node_map = match_nodes(pred["nodes"], gt["nodes"], threshold=match_threshold)
tp_n = len(node_map)
fp_n = len(pred["nodes"]) - tp_n
fn_n = len(gt["nodes"]) - tp_n
node_metrics = _prf(tp_n, fp_n, fn_n)
# --- edges ------------------------------------------------------------
# Translate predicted edges through the node alignment. Unmatched
# endpoints make the edge uncountable (it becomes a guaranteed FP).
def canon(u: str, v: str) -> tuple[str, str]:
return (u, v) if directed else tuple(sorted((u, v))) # type: ignore[return-value]
pred_edges_canon: set[tuple[str, str]] = set()
unmappable_pred_edges = 0
for e in pred["edges"]:
if e["source"] in node_map and e["target"] in node_map:
pred_edges_canon.add(canon(node_map[e["source"]], node_map[e["target"]]))
else:
unmappable_pred_edges += 1
gt_edges_canon: set[tuple[str, str]] = set(
canon(e["source"], e["target"]) for e in gt["edges"]
)
tp_e = len(pred_edges_canon & gt_edges_canon)
fp_e = (len(pred_edges_canon) - tp_e) + unmappable_pred_edges
fn_e = len(gt_edges_canon) - tp_e
edge_metrics = _prf(tp_e, fp_e, fn_e)
return {
"nodes": node_metrics,
"edges": edge_metrics,
"n_pred_nodes": len(pred["nodes"]),
"n_gt_nodes": len(gt["nodes"]),
"n_pred_edges": len(pred["edges"]),
"n_gt_edges": len(gt["edges"]),
}
def aggregate(per_sample: list[dict]) -> dict:
"""Micro-average across samples by summing TP/FP/FN."""
def sum_keys(metric: str) -> dict:
tp = sum(s[metric]["tp"] for s in per_sample)
fp = sum(s[metric]["fp"] for s in per_sample)
fn = sum(s[metric]["fn"] for s in per_sample)
return _prf(tp, fp, fn)
return {
"n_samples": len(per_sample),
"nodes_micro": sum_keys("nodes"),
"edges_micro": sum_keys("edges"),
}