File size: 5,477 Bytes
a9141f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Generate the three infographics for reports/evaluation_report.md.

Inputs:
  - eval/results/results-guarded-scored.summary.json   (required)
  - eval/results/results-raw-scored.summary.json       (optional, enables figure 3)

Outputs (PNG):
  - reports/figures/scores_by_axis.png
  - reports/figures/latency_cost.png
  - reports/figures/refusal_matrix.png
"""
from __future__ import annotations

import argparse
import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

ROOT = Path(__file__).resolve().parent.parent
FIG_DIR = ROOT / "reports" / "figures"
FIG_DIR.mkdir(parents=True, exist_ok=True)

# Public list price as of 2026-01 (cents per 1k tokens). Update if pricing changes.
# Llama-3.2-1B is self-hosted on CPU — we report compute-amortized as $0 marginal.
# If the OpenAI call fell back to Groq the actual marginal cost is lower than this
# (Groq's free tier is $0 within quota); these numbers assume the primary served.
COST_PER_1K = {
    "openai": {"in": 0.200, "out": 0.800},   # gpt-4.1 approx ($/1k tokens)
    "llama":  {"in": 0.000, "out": 0.000},
}
AXES = ["hallucination", "content_safety", "bias"]
COLOURS = {"openai": "#10a37f", "llama": "#f0b429"}


def _load(p: Path):
    return json.loads(p.read_text(encoding="utf-8")) if p.exists() else None


def fig_scores(summary: dict) -> Path:
    models = list(summary["axis_pct"].keys())
    x = np.arange(len(AXES))
    w = 0.8 / max(1, len(models))
    fig, ax = plt.subplots(figsize=(7, 4.2))
    for i, m in enumerate(models):
        vals = [summary["axis_pct"].get(m, {}).get(a, 0) for a in AXES]
        bars = ax.bar(x + i*w - 0.4 + w/2, vals, width=w, label=m, color=COLOURS.get(m, "#888"))
        for b, v in zip(bars, vals):
            ax.text(b.get_x()+b.get_width()/2, v+1, f"{v:.0f}", ha="center", va="bottom", fontsize=9)
    ax.set_xticks(x); ax.set_xticklabels([a.replace("_"," ").title() for a in AXES])
    ax.set_ylim(0, 105); ax.set_ylabel("Score (% of max)")
    ax.set_title("Quality by axis — higher is better")
    ax.legend(loc="lower right"); ax.grid(axis="y", alpha=0.3)
    p = FIG_DIR / "scores_by_axis.png"
    fig.tight_layout(); fig.savefig(p, dpi=150); plt.close(fig)
    return p


def fig_latency_cost(summary: dict) -> Path:
    models = list(summary["latency_ms"].keys())
    p50 = [summary["latency_ms"][m].get("p50") or 0 for m in models]
    p95 = [summary["latency_ms"][m].get("p95") or 0 for m in models]
    # Cost per turn ≈ mean_in/1k * in + mean_out/1k * out (cents).
    cost = []
    for m in models:
        t = summary["tokens"].get(m, {})
        c = COST_PER_1K.get(m, {"in":0,"out":0})
        ci = (t.get("mean_in")  or 0) / 1000 * c["in"]
        co = (t.get("mean_out") or 0) / 1000 * c["out"]
        cost.append(round(ci + co, 3))

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4))
    x = np.arange(len(models))
    ax1.bar(x-0.2, p50, 0.4, label="p50", color="#6aa6ff")
    ax1.bar(x+0.2, p95, 0.4, label="p95", color="#1f78b4")
    ax1.set_xticks(x); ax1.set_xticklabels(models)
    ax1.set_ylabel("Latency (ms)"); ax1.set_title("Latency per turn"); ax1.legend(); ax1.grid(axis="y", alpha=0.3)
    for i, (a, b) in enumerate(zip(p50, p95)):
        ax1.text(i-0.2, a+10, str(a), ha="center", fontsize=9)
        ax1.text(i+0.2, b+10, str(b), ha="center", fontsize=9)

    ax2.bar(x, cost, 0.55, color=[COLOURS.get(m, "#888") for m in models])
    ax2.set_xticks(x); ax2.set_xticklabels(models)
    ax2.set_ylabel("¢ / turn (mean)"); ax2.set_title("Cost per turn — Llama self-hosted = $0 marginal")
    for i, v in enumerate(cost):
        ax2.text(i, v + max(cost)*0.02 if cost else 0, f"{v:.3f}¢", ha="center", fontsize=9)
    ax2.grid(axis="y", alpha=0.3)
    p = FIG_DIR / "latency_cost.png"
    fig.tight_layout(); fig.savefig(p, dpi=150); plt.close(fig)
    return p


def fig_refusal_matrix(guarded: dict, raw: dict | None) -> Path:
    models = list(guarded["refusals"].keys())
    fig, ax = plt.subplots(figsize=(7, 3.8))
    x = np.arange(len(models))
    g_rate = [guarded["refusals"][m]["refusal_rate"] or 0 for m in models]
    g_block = [guarded["refusals"][m]["block_rate"] or 0 for m in models]
    if raw:
        r_rate = [raw["refusals"].get(m, {}).get("refusal_rate", 0) or 0 for m in models]
    else:
        r_rate = [0]*len(models)
    w = 0.27
    ax.bar(x-w, r_rate, w, label="refusal (guardrails OFF)", color="#8a93a6")
    ax.bar(x,   g_rate, w, label="refusal (guardrails ON)",  color="#6aa6ff")
    ax.bar(x+w, g_block, w, label="output blocked by filter", color="#f06464")
    ax.set_xticks(x); ax.set_xticklabels(models); ax.set_ylabel("% of prompts")
    ax.set_title("Refusal & guardrail block rates"); ax.legend(); ax.grid(axis="y", alpha=0.3)
    p = FIG_DIR / "refusal_matrix.png"
    fig.tight_layout(); fig.savefig(p, dpi=150); plt.close(fig)
    return p


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--guarded", default="eval/results/results-guarded-scored.summary.json")
    ap.add_argument("--raw",     default="eval/results/results-raw-scored.summary.json")
    args = ap.parse_args()
    g = _load(ROOT / args.guarded)
    r = _load(ROOT / args.raw)
    if g is None:
        raise SystemExit(f"missing {args.guarded} — run score.py first")
    paths = [fig_scores(g), fig_latency_cost(g), fig_refusal_matrix(g, r)]
    for p in paths:
        print(f"wrote {p.relative_to(ROOT)}")


if __name__ == "__main__":
    main()