promptstat / eval /cascade.py
xxixx1028's picture
Deploy PromptStat — UI shell + MiniCPM4.1-8B + 4-LoRA hybrid (Modal)
dc9f530 verified
Raw
History Blame Contribute Delete
5.45 kB
"""Cascade runner: measure base-8B per axis under different prompt/inference strategies, compare to
baseline, apply the decision gate. Reuses eval.kappa scoring; the 8B cache makes unchanged prompts free.
Usage:
python -m eval.cascade baseline v2 # measure these versions, print before/after κ table
Versions are registered in VERSIONS (a builders module per key). Results persist to
eval/_cache/cascade_results.json so later steps accumulate.
"""
from __future__ import annotations
import importlib
import json
import os
import sys
import time
from eval import kappa as K
from prompt_card.scoring import observable_axes as OA
RESULTS = os.path.join(os.path.dirname(__file__), "_cache", "cascade_results.json")
VERSIONS = {"baseline": OA, "v2": "eval.prompts_v2", "v3": "eval.prompts_v3", "v4": "eval.prompts_v4"}
def _builders(v):
b = VERSIONS[v]
return importlib.import_module(b) if isinstance(b, str) else b
def measure(builders, gt, convs, embedder, client):
"""Return per-axis headline κ + detail. Cache-served calls are free."""
prompts, plan, geom = K.build_prompts(gt, convs, embedder, builders=builders)
pi = {}
for it in plan:
pi.setdefault(it[0], []).append(it)
n_before = client.misses
responses = client.run_all(prompts)
new_calls = client.misses - n_before
out = {"_new_calls": new_calls}
# technique / input_quality (per-category + axis-level "any")
for axis, fields in (("technique", K.TECH), ("input_quality", K.IQ)):
per, fail = K.score_binary_axis(gt, responses, pi, axis, fields)
cats = {f: K.cohen_kappa(*per[f]) for f in fields}
n = len(per[fields[0]][0])
anyt = [int(any(per[f][0][j] for f in fields)) for j in range(n)]
anyp = [int(any(per[f][1][j] for f in fields)) for j in range(n)]
feat_ks = [cats[f] for f in fields if sum(per[f][0]) > 0] # only categories with positives
headline = (sum(feat_ks) / len(feat_ks)) if feat_ks else None
out[axis] = {"headline": headline, "axis_any": K.cohen_kappa(anyt, anyp),
"cats": {f: (cats[f], K.binary_counts(*per[f]), sum(per[f][0])) for f in fields},
"parse_fail": fail}
# interaction
yt, yp, fail = K.score_interaction(gt, responses, pi)
out["interaction"] = {"headline": K.cohen_kappa(yt, yp), "counts": K.binary_counts(yt, yp),
"pos": sum(yt), "n": len(yt), "parse_fail": fail}
# focus (sweep best T)
best = None
for T in [0.40, 0.45, 0.50, 0.55, 0.60, 0.65, 0.70]:
fyt, fyp, info = K.score_focus(gt, responses, pi, geom, T)
k = K.cohen_kappa(fyt, fyp)
cand = (k if k is not None else -9, info["recall"] or 0)
if best is None or cand > best[0]:
best = (cand, T, k, info["recall"])
out["focus"] = {"headline": best[2], "T": best[1], "recall": best[3]}
return out
def gate(k):
if k is None:
return "N/A"
if k >= 0.6:
return "SOLID"
if k >= 0.4:
return "OK"
if k >= 0.2:
return "try-next"
return "must-next"
def main(version_keys):
base_url = os.environ.get("OPENBMB_BASE_URL"); token = os.environ.get("OPENBMB_TOKEN")
if not base_url or not token:
print("ERROR: set OPENBMB_BASE_URL and OPENBMB_TOKEN", file=sys.stderr); sys.exit(2)
from prompt_card.llm.minicpm import MiniCPMClient
gt = K.load_gt(); convs = K.load_convs(); embedder = K.FastEmbedder()
client = K.CachedClient(MiniCPMClient(base_url, token), workers=8)
prior = {}
if os.path.exists(RESULTS):
prior = json.load(open(RESULTS))
results = dict(prior)
for v in version_keys:
t0 = time.time()
print(f"\n=== measuring '{v}' ===", flush=True)
r = measure(_builders(v), gt, convs, embedder, client)
r["_secs"] = round(time.time() - t0, 1)
results[v] = r
print(f" new 8B calls: {r['_new_calls']} · {r['_secs']}s", flush=True)
json.dump(results, open(RESULTS, "w"), indent=1, default=str)
axes = ["technique", "input_quality", "interaction", "focus"]
print("\n================ κ comparison (headline per axis) ================")
head = "axis".ljust(16) + "".join(v.ljust(12) for v in version_keys) + "gate(last)"
print(head)
for ax in axes:
row = ax.ljust(16)
last = None
for v in version_keys:
k = results[v][ax]["headline"]; last = k
row += (f"{k:+.3f}" if k is not None else "N/A").ljust(12)
print(row + gate(last))
# per-category technique/IQ detail for the last version
last = version_keys[-1]
print(f"\n--- per-category detail ({last}) ---")
for ax in ("technique", "input_quality"):
for f, (k, c, npos) in results[last][ax]["cats"].items():
ks = f"{k:+.3f}" if k is not None else "N/A"
print(f" {ax[:4]}.{f:22} κ={ks} pos={npos} [TN {c['tn']} FP {c['fp']} FN {c['fn']} TP {c['tp']}]")
fi = results[last]["interaction"]; c = fi["counts"]
print(f" interaction.refinement κ={fi['headline']:+.3f} pos={fi['pos']}/{fi['n']} "
f"[TN {c['tn']} FP {c['fp']} FN {c['fn']} TP {c['tp']}]")
ff = results[last]["focus"]
print(f" focus.topic_shift κ={ff['headline']:+.3f} T={ff['T']} recall={ff['recall']:.2f}")
if __name__ == "__main__":
keys = sys.argv[1:] or ["baseline", "v2"]
main(keys)