File size: 5,097 Bytes
59fa244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""Graph-level evaluation: node matching, then edge matching, then P/R/F1.

Matching strategy (skeleton — greedy by descending similarity):
  1. Score every (predicted_node, gt_node) pair with `node_similarity`.
  2. Greedily pick the highest-scoring pair until no pair clears the
     threshold or all nodes on one side are consumed.
  3. Map predicted edges through the resulting id alignment and intersect
     with GT edges. Undirected matching is optional.

TODO: swap greedy for Hungarian (scipy.optimize.linear_sum_assignment)
      once we have a feel for the data. Greedy is good enough for a first
      pass and keeps dependencies small.
"""

from __future__ import annotations

from typing import Optional


def _norm(s: Optional[str]) -> str:
    return (s or "").strip().lower().replace("-", "_").replace(" ", "_")


def node_similarity(pred: dict, gt: dict) -> float:
    """Score a predicted node against a GT node in [0, 1].

    Both `type` match and `label` match contribute. If either side has no
    label, we fall back to type-only matching.
    """
    tp, tg = _norm(pred.get("type")), _norm(gt.get("type"))
    lp, lg = _norm(pred.get("label")), _norm(gt.get("label"))

    type_match = 1.0 if tp and tp == tg else 0.0

    if lp and lg:
        label_match = 1.0 if lp == lg else 0.0
        return 0.5 * type_match + 0.5 * label_match

    # Only one side (or neither) has a label — rely on type alone.
    return type_match


def match_nodes(
    pred_nodes: list[dict],
    gt_nodes: list[dict],
    threshold: float = 0.5,
) -> dict[str, str]:
    """Greedy 1:1 matching. Returns `{pred_id: gt_id}` for matched pairs."""
    scored: list[tuple[float, str, str]] = []
    for p in pred_nodes:
        for g in gt_nodes:
            s = node_similarity(p, g)
            if s >= threshold:
                scored.append((s, p["id"], g["id"]))
    scored.sort(reverse=True)

    matches: dict[str, str] = {}
    used_gt: set[str] = set()
    for _score, pid, gid in scored:
        if pid in matches or gid in used_gt:
            continue
        matches[pid] = gid
        used_gt.add(gid)
    return matches


def _prf(tp: int, fp: int, fn: int) -> dict:
    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall = tp / (tp + fn) if (tp + fn) else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "tp": tp,
        "fp": fp,
        "fn": fn,
    }


def evaluate(pred: dict, gt: dict, directed: bool = True, match_threshold: float = 0.5) -> dict:
    """Compare a single predicted graph against a single GT graph.

    Args:
        pred: `{nodes, edges}` as emitted by `GraphOut.to_dict()`.
        gt:   `{nodes, edges}` as emitted by `gt_loader.load_graphml()`.
        directed: if False, (u, v) and (v, u) are treated as the same edge.
        match_threshold: minimum similarity to accept a node pairing.

    Returns:
        Dict with `nodes`, `edges`, and `n_matched_nodes` keys.
    """
    # --- nodes ------------------------------------------------------------
    node_map = match_nodes(pred["nodes"], gt["nodes"], threshold=match_threshold)
    tp_n = len(node_map)
    fp_n = len(pred["nodes"]) - tp_n
    fn_n = len(gt["nodes"]) - tp_n
    node_metrics = _prf(tp_n, fp_n, fn_n)

    # --- edges ------------------------------------------------------------
    # Translate predicted edges through the node alignment. Unmatched
    # endpoints make the edge uncountable (it becomes a guaranteed FP).
    def canon(u: str, v: str) -> tuple[str, str]:
        return (u, v) if directed else tuple(sorted((u, v)))  # type: ignore[return-value]

    pred_edges_canon: set[tuple[str, str]] = set()
    unmappable_pred_edges = 0
    for e in pred["edges"]:
        if e["source"] in node_map and e["target"] in node_map:
            pred_edges_canon.add(canon(node_map[e["source"]], node_map[e["target"]]))
        else:
            unmappable_pred_edges += 1

    gt_edges_canon: set[tuple[str, str]] = set(
        canon(e["source"], e["target"]) for e in gt["edges"]
    )

    tp_e = len(pred_edges_canon & gt_edges_canon)
    fp_e = (len(pred_edges_canon) - tp_e) + unmappable_pred_edges
    fn_e = len(gt_edges_canon) - tp_e
    edge_metrics = _prf(tp_e, fp_e, fn_e)

    return {
        "nodes": node_metrics,
        "edges": edge_metrics,
        "n_pred_nodes": len(pred["nodes"]),
        "n_gt_nodes": len(gt["nodes"]),
        "n_pred_edges": len(pred["edges"]),
        "n_gt_edges": len(gt["edges"]),
    }


def aggregate(per_sample: list[dict]) -> dict:
    """Micro-average across samples by summing TP/FP/FN."""
    def sum_keys(metric: str) -> dict:
        tp = sum(s[metric]["tp"] for s in per_sample)
        fp = sum(s[metric]["fp"] for s in per_sample)
        fn = sum(s[metric]["fn"] for s in per_sample)
        return _prf(tp, fp, fn)

    return {
        "n_samples": len(per_sample),
        "nodes_micro": sum_keys("nodes"),
        "edges_micro": sum_keys("edges"),
    }