promptstat / eval /step_d.py
xxixx1028's picture
Deploy PromptStat — UI shell + MiniCPM4.1-8B + 4-LoRA hybrid (Modal)
dc9f530 verified
Raw
History Blame Contribute Delete
7.16 kB
"""STEP D — model swap to MiniCPM-V-4.6-Thinking (:8004) for the reasoning-heavy axes (Focus, Interaction;
and re-usable for Phase-3 Critical). Compares to the base-8B A+B best (focus 0.433, interaction 0.320).
The Thinking model emits a reasoning block then JSON, sometimes with a dangling `</think>` (no opening
tag), so we parse robustly: take text after the last `</think>`, then json/regex-extract the field.
Separate cache namespace (think_8004.jsonl) so identical prompt text doesn't collide with base-8B cache.
"""
from __future__ import annotations
import concurrent.futures as cf
import hashlib
import json
import os
import re
import sys
import threading
import requests
from eval import kappa as K
from eval.prompts_v4 import build_focus_boundary_prompt
from eval.prompts_v3 import build_interaction_prompt
CACHE = os.path.join(os.path.dirname(__file__), "_cache", "think_8004.jsonl")
MODEL = "MiniCPM-V-4.6-Thinking"
def parse_think(text, field):
"""Robust to dangling </think> + reasoning noise. Returns {field: value} or None."""
t = text or ""
if "</think>" in t:
t = t.rsplit("</think>", 1)[-1]
t = re.sub(r"^```[a-zA-Z0-9]*\n?|\n?```$", "", t.strip()).strip()
try:
d = json.loads(t)
if isinstance(d, dict) and field in d:
return {field: d[field]}
except Exception:
pass
# regex fallback: "field": value
m = re.search(rf'"{field}"\s*:\s*("?[a-zA-Z_]+"?|true|false)', t)
if m:
v = m.group(1).strip('"')
if v in ("true", "false"):
return {field: v == "true"}
return {field: v}
return None
class ThinkClient:
def __init__(self, base, token, workers=8):
self.base = base.rstrip("/"); self.token = token; self.workers = workers
self._lock = threading.Lock(); self.cache = {}
if os.path.exists(CACHE):
for line in open(CACHE):
try:
d = json.loads(line); self.cache[d["k"]] = d["v"]
except Exception:
pass
@staticmethod
def _key(p):
return hashlib.sha1(("THINK::" + p).encode()).hexdigest()
def _call(self, prompt, retries=4):
body = {"model": MODEL, "messages": [{"role": "user", "content": prompt}], "temperature": 0,
"max_tokens": 1536, "chat_template_kwargs": {"enable_thinking": True}}
last = None
for attempt in range(retries):
try:
r = requests.post(f"{self.base}/v1/chat/completions",
headers={"Authorization": f"Bearer {self.token}", "Content-Type": "application/json"},
json=body, timeout=240)
r.raise_for_status()
return r.json()["choices"][0]["message"]["content"]
except (requests.exceptions.RequestException,) as e:
last = e # transient endpoint hiccup — backoff and retry
import time
time.sleep(2 * (attempt + 1))
raise last
def run_all(self, prompts):
uniq = list(dict.fromkeys(prompts))
todo = [p for p in uniq if self._key(p) not in self.cache]
def work(p):
v = self._call(p); k = self._key(p)
with self._lock:
self.cache[k] = v
with open(CACHE, "a") as f:
f.write(json.dumps({"k": k, "v": v}, ensure_ascii=False) + "\n")
return p
if todo:
with cf.ThreadPoolExecutor(max_workers=self.workers) as ex:
for i, _ in enumerate(ex.map(work, todo), 1):
if i % 50 == 0:
print(f" ... {i}/{len(todo)} thinking calls", flush=True)
return {p: self.cache[self._key(p)] for p in uniq}
def main():
base = os.environ.get("OPENBMB_BASE_URL"); token = os.environ.get("OPENBMB_TOKEN")
if not base or not token:
print("ERROR: creds", file=sys.stderr); sys.exit(2)
base = base.replace("8001", "8004")
gt = K.load_gt(); convs = K.load_convs(); embedder = K.FastEmbedder()
client = ThinkClient(base, token, workers=8)
# interaction
int_prompts, int_rows = [], []
for r in gt:
ut = K.user_turns(convs[r["id"]])
for row in r["interaction"]:
i = int(row["turn"][1:]) - 1
int_prompts.append(build_interaction_prompt(ut[i - 1], ut[i]))
int_rows.append(int(bool(row["refinement"])))
# focus (embedder-gated, T<=0.70 to cover the sweep)
foc_tasks = []
geom = {}
for r in gt:
ut = K.user_turns(convs[r["id"]])
if len(ut) < 2:
continue
import numpy as np
vecs = embedder.embed(ut); g = []
for i in range(len(ut) - 1):
cos = float(np.dot(vecs[i], vecs[i + 1])); g.append((i, cos))
if cos < 0.70:
foc_tasks.append((r["id"], f"U{i+1}->U{i+2}", build_focus_boundary_prompt(ut[i], ut[i + 1])))
geom[r["id"]] = g
print(f"[step_d] thinking calls: interaction {len(int_prompts)} + focus {len(foc_tasks)}", flush=True)
all_prompts = int_prompts + [t[2] for t in foc_tasks]
resp = client.run_all(all_prompts)
# interaction kappa
iyp = [int(bool((parse_think(resp[p], "refinement_attempt") or {}).get("refinement_attempt"))) for p in int_prompts]
ifail = sum(1 for p in int_prompts if parse_think(resp[p], "refinement_attempt") is None)
ik = K.cohen_kappa(int_rows, iyp)
# focus: rebuild per-boundary relation preds, sweep T
rel = {(cid, bk): (parse_think(resp[p], "relation") or {}).get("relation") for (cid, bk, p) in foc_tasks}
ffail = sum(1 for (cid, bk, p) in foc_tasks if parse_think(resp[p], "relation") is None)
best = None
for T in [0.40, 0.45, 0.50, 0.55, 0.60, 0.65, 0.70]:
yt, yp = [], []; hit = tot = 0
for r in gt:
gtrel = {f"{c['a']}->{c['b']}": c["relation"] for c in r["focus"]}
for (i, cos) in geom.get(r["id"], []):
bk = f"U{i+1}->U{i+2}"; isgt = (gtrel.get(bk) == "topic_shift")
yt.append(int(isgt))
pred = int(cos < T and rel.get((r["id"], bk)) == "topic_shift")
yp.append(pred)
if isgt:
tot += 1; hit += int(cos < T)
k = K.cohen_kappa(yt, yp); rec = hit / tot if tot else None
cand = (k if k is not None else -9, rec or 0)
if best is None or cand > best[0]:
best = (cand, T, k, rec)
print("\n=== STEP D (MiniCPM-V-4.6-Thinking, :8004) vs base-8B best ===")
print(f" interaction κ={ik:+.3f} parse_fail={ifail}/{len(int_prompts)} (base-8B best v3: +0.320)")
print(f" focus κ={best[2]:+.3f} T={best[1]} recall={best[3]:.2f} parse_fail={ffail}/{len(foc_tasks)} (base-8B best v4: +0.433)")
json.dump({"interaction_think": ik, "focus_think": best[2], "focus_T": best[1],
"int_parsefail": ifail, "foc_parsefail": ffail},
open(os.path.join(os.path.dirname(__file__), "_cache", "step_d.json"), "w"), indent=1)
if __name__ == "__main__":
main()