| """ |
| Evaluation module: compute SHD, F1, Precision, Recall between predicted and true CPDAGs. |
| """ |
| import numpy as np |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def compute_shd(pred_adj, true_adj): |
| """Compute Structural Hamming Distance between two CPDAGs/DAGs. |
| |
| Both inputs are adjacency matrices where: |
| adj[i,j]=1 and adj[j,i]=0 means i->j (directed) |
| adj[i,j]=1 and adj[j,i]=1 means i--j (undirected) |
| |
| SHD counts: missing edges + extra edges + wrongly oriented edges |
| """ |
| n = pred_adj.shape[0] |
| assert pred_adj.shape == true_adj.shape, "Adjacency matrices must have same shape" |
| |
| shd = 0 |
| for i in range(n): |
| for j in range(i + 1, n): |
| |
| t_ij, t_ji = true_adj[i, j], true_adj[j, i] |
| |
| p_ij, p_ji = pred_adj[i, j], pred_adj[j, i] |
| |
| true_has_edge = (t_ij == 1 or t_ji == 1) |
| pred_has_edge = (p_ij == 1 or p_ji == 1) |
| |
| if true_has_edge and not pred_has_edge: |
| |
| shd += 1 |
| elif not true_has_edge and pred_has_edge: |
| |
| shd += 1 |
| elif true_has_edge and pred_has_edge: |
| |
| true_type = (t_ij, t_ji) |
| pred_type = (p_ij, p_ji) |
| if true_type != pred_type: |
| |
| shd += 1 |
| |
| return shd |
|
|
|
|
| def compute_edge_metrics(pred_adj, true_adj): |
| """Compute precision, recall, F1 on edges (skeleton-level and directed). |
| |
| Returns dict with: |
| - skeleton_precision, skeleton_recall, skeleton_f1: ignoring direction |
| - directed_precision, directed_recall, directed_f1: including direction |
| - shd: structural hamming distance |
| - n_true_edges, n_pred_edges: edge counts |
| """ |
| n = pred_adj.shape[0] |
| |
| |
| true_skeleton = ((true_adj + true_adj.T) > 0).astype(int) |
| pred_skeleton = ((pred_adj + pred_adj.T) > 0).astype(int) |
| |
| |
| skel_tp = skel_fp = skel_fn = 0 |
| for i in range(n): |
| for j in range(i + 1, n): |
| t = true_skeleton[i, j] |
| p = pred_skeleton[i, j] |
| if t == 1 and p == 1: |
| skel_tp += 1 |
| elif t == 0 and p == 1: |
| skel_fp += 1 |
| elif t == 1 and p == 0: |
| skel_fn += 1 |
| |
| skel_precision = skel_tp / (skel_tp + skel_fp) if (skel_tp + skel_fp) > 0 else 0 |
| skel_recall = skel_tp / (skel_tp + skel_fn) if (skel_tp + skel_fn) > 0 else 0 |
| skel_f1 = (2 * skel_precision * skel_recall / (skel_precision + skel_recall) |
| if (skel_precision + skel_recall) > 0 else 0) |
| |
| |
| dir_tp = dir_fp = dir_fn = 0 |
| for i in range(n): |
| for j in range(n): |
| if i == j: |
| continue |
| t = true_adj[i, j] |
| p = pred_adj[i, j] |
| if t == 1 and p == 1: |
| dir_tp += 1 |
| elif t == 0 and p == 1: |
| dir_fp += 1 |
| elif t == 1 and p == 0: |
| dir_fn += 1 |
| |
| dir_precision = dir_tp / (dir_tp + dir_fp) if (dir_tp + dir_fp) > 0 else 0 |
| dir_recall = dir_tp / (dir_tp + dir_fn) if (dir_tp + dir_fn) > 0 else 0 |
| dir_f1 = (2 * dir_precision * dir_recall / (dir_precision + dir_recall) |
| if (dir_precision + dir_recall) > 0 else 0) |
| |
| shd = compute_shd(pred_adj, true_adj) |
| |
| |
| n_true_edges = 0 |
| n_pred_edges = 0 |
| for i in range(n): |
| for j in range(i + 1, n): |
| if true_adj[i, j] or true_adj[j, i]: |
| n_true_edges += 1 |
| if pred_adj[i, j] or pred_adj[j, i]: |
| n_pred_edges += 1 |
| |
| return { |
| 'shd': shd, |
| 'skeleton_precision': skel_precision, |
| 'skeleton_recall': skel_recall, |
| 'skeleton_f1': skel_f1, |
| 'directed_precision': dir_precision, |
| 'directed_recall': dir_recall, |
| 'directed_f1': dir_f1, |
| 'n_true_edges': n_true_edges, |
| 'n_pred_edges': n_pred_edges, |
| } |
|
|
|
|
| def dag_to_cpdag(dag_adjmat): |
| """Import from data.generator to avoid circular dependency.""" |
| from causal_selection.data.generator import dag_to_cpdag as _dag_to_cpdag |
| return _dag_to_cpdag(dag_adjmat) |
|
|
|
|
| def evaluate_algorithm_result(result, true_cpdag): |
| """Evaluate a single algorithm result against ground truth CPDAG. |
| |
| Args: |
| result: dict from run_algorithm (must have 'adjmat', 'output_type', 'status') |
| true_cpdag: ground truth CPDAG adjacency matrix |
| |
| Returns: |
| dict with all metrics, or penalty metrics if algorithm failed |
| """ |
| n = true_cpdag.shape[0] |
| max_possible_shd = n * (n - 1) // 2 |
| |
| if result['status'] != 'success' or result['adjmat'] is None: |
| return { |
| 'shd': max_possible_shd, |
| 'normalized_shd': 1.0, |
| 'skeleton_precision': 0.0, |
| 'skeleton_recall': 0.0, |
| 'skeleton_f1': 0.0, |
| 'directed_precision': 0.0, |
| 'directed_recall': 0.0, |
| 'directed_f1': 0.0, |
| 'n_true_edges': int(((true_cpdag + true_cpdag.T) > 0).sum() // 2), |
| 'n_pred_edges': 0, |
| 'runtime': result['runtime'], |
| 'status': result['status'], |
| } |
| |
| pred_adj = result['adjmat'] |
| |
| |
| if result['output_type'] == 'dag': |
| pred_cpdag = dag_to_cpdag(pred_adj) |
| else: |
| pred_cpdag = pred_adj |
| |
| |
| metrics = compute_edge_metrics(pred_cpdag, true_cpdag) |
| metrics['normalized_shd'] = metrics['shd'] / max_possible_shd if max_possible_shd > 0 else 0 |
| metrics['runtime'] = result['runtime'] |
| metrics['status'] = result['status'] |
| |
| return metrics |
|
|
|
|
| if __name__ == '__main__': |
| |
| from causal_selection.data.generator import load_bn_model, get_true_dag_adjmat, dag_to_cpdag as gen_dag_to_cpdag, sample_dataset |
| from causal_selection.discovery.algorithms import run_algorithm, ALGORITHM_POOL |
| import warnings |
| warnings.filterwarnings('ignore') |
| |
| model = load_bn_model('asia') |
| true_dag, nodes = get_true_dag_adjmat(model) |
| true_cpdag = gen_dag_to_cpdag(true_dag) |
| df = sample_dataset(model, 1000, seed=0) |
| |
| print(f"ASIA (N=1000) - True edges: {int(((true_cpdag + true_cpdag.T) > 0).sum() // 2)}") |
| print(f"{'Algorithm':15s} {'SHD':>5s} {'nSHD':>6s} {'Skel_F1':>8s} {'Dir_F1':>7s} {'Runtime':>8s} {'Status'}") |
| print("-" * 70) |
| |
| for algo_name in ALGORITHM_POOL: |
| result = run_algorithm(algo_name, df, timeout_sec=60) |
| metrics = evaluate_algorithm_result(result, true_cpdag) |
| print(f"{algo_name:15s} {metrics['shd']:5d} {metrics['normalized_shd']:6.3f} " |
| f"{metrics['skeleton_f1']:8.3f} {metrics['directed_f1']:7.3f} " |
| f"{metrics['runtime']:7.2f}s {metrics['status']}") |
|
|