File size: 3,760 Bytes
d29b763
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Benchmark suite for baseline vs upgraded DDI models."""
from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict

import numpy as np
from sklearn.metrics import accuracy_score, average_precision_score, confusion_matrix, f1_score, precision_recall_fscore_support, recall_score, roc_auc_score

from training.calibration import expected_calibration_error


def evaluate(y_true: np.ndarray, y_pred: np.ndarray, y_proba: np.ndarray) -> Dict:
    num_classes = y_proba.shape[1]
    major_idx = num_classes - 1

    precision, recall, f1, support = precision_recall_fscore_support(y_true, y_pred, labels=list(range(num_classes)), zero_division=0)

    # One-vs-rest AUROC/AUPRC macro
    y_true_ovr = np.eye(num_classes)[y_true]
    auroc = float(roc_auc_score(y_true_ovr, y_proba, average='macro', multi_class='ovr'))
    auprc = float(average_precision_score(y_true_ovr, y_proba, average='macro'))
    ece = float(expected_calibration_error(y_true, y_proba, n_bins=15))
    cm = confusion_matrix(y_true, y_pred, labels=list(range(num_classes)))

    return {
        'accuracy': float(accuracy_score(y_true, y_pred)),
        'macro_precision': float(np.mean(precision)),
        'macro_recall': float(np.mean(recall)),
        'macro_f1': float(f1_score(y_true, y_pred, average='macro', zero_division=0)),
        'severe_recall': float(recall_score(y_true, y_pred, labels=[major_idx], average='macro', zero_division=0)),
        'auroc_macro_ovr': auroc,
        'auprc_macro': auprc,
        'ece': ece,
        'confusion_matrix': cm.tolist(),
        'per_class': {
            str(i): {
                'precision': float(precision[i]),
                'recall': float(recall[i]),
                'f1': float(f1[i]),
                'support': int(support[i]),
            }
            for i in range(num_classes)
        },
    }


def main() -> None:
    parser = argparse.ArgumentParser(description='Benchmark baseline vs upgraded predictions')
    parser.add_argument('--baseline-json', type=str, required=True, help='JSON with y_true,y_pred,y_proba for baseline')
    parser.add_argument('--upgraded-json', type=str, required=True, help='JSON with y_true,y_pred,y_proba for upgraded model')
    parser.add_argument('--out-json', type=str, required=True)
    args = parser.parse_args()

    baseline = json.loads(Path(args.baseline_json).read_text(encoding='utf-8'))
    upgraded = json.loads(Path(args.upgraded_json).read_text(encoding='utf-8'))

    y_true = np.array(baseline['y_true'], dtype=np.int64)
    base_pred = np.array(baseline['y_pred'], dtype=np.int64)
    base_proba = np.array(baseline['y_proba'], dtype=np.float32)

    up_true = np.array(upgraded['y_true'], dtype=np.int64)
    up_pred = np.array(upgraded['y_pred'], dtype=np.int64)
    up_proba = np.array(upgraded['y_proba'], dtype=np.float32)

    if y_true.shape != up_true.shape or not np.array_equal(y_true, up_true):
        raise ValueError('baseline and upgraded y_true must match exactly for fair benchmark')

    b = evaluate(y_true, base_pred, base_proba)
    u = evaluate(y_true, up_pred, up_proba)

    out = {
        'baseline': b,
        'upgraded': u,
        'delta': {
            'accuracy': u['accuracy'] - b['accuracy'],
            'macro_f1': u['macro_f1'] - b['macro_f1'],
            'severe_recall': u['severe_recall'] - b['severe_recall'],
            'auroc_macro_ovr': u['auroc_macro_ovr'] - b['auroc_macro_ovr'],
            'auprc_macro': u['auprc_macro'] - b['auprc_macro'],
        },
    }

    out_path = Path(args.out_json)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text(json.dumps(out, indent=2), encoding='utf-8')


if __name__ == '__main__':
    main()