feather-runtime / overlay /scripts /chat_eval.py
Jackoatmon's picture
Update Feather H200 runtime: Nemotron streaming and HTM force-CPU canary fixes
c2bf4b6 verified
"""Non-interactive chat eval for HYDRA.
Runs a fixed set of prompts through the same chat template that `chat.py`
uses, prints a markdown table with the response and coherence heuristics.
Usage:
python scripts/chat_eval.py # auto-select checkpoint
python scripts/chat_eval.py --ckpt PATH
python scripts/chat_eval.py --random
python scripts/chat_eval.py --json out.json # also dump raw results
python scripts/chat_eval.py --max 80 # cap new tokens per prompt
"""
from __future__ import annotations
import argparse
import json
import os
import re
import sys
import time
from pathlib import Path
_REPO_ROOT = Path(__file__).resolve().parent.parent
if str(_REPO_ROOT) not in sys.path:
sys.path.insert(0, str(_REPO_ROOT))
import torch # noqa: E402
from scripts.chat import ( # noqa: E402
ASSISTANT_TAG, END_TAG, USER_TAG, build_prompt,
generate_stream, load_model_and_tokenizer, resolve_checkpoint,
)
PROMPTS: list[str] = [
# Factual
"What is the capital of France?",
"Who wrote Romeo and Juliet?",
"What is 2 plus 2?",
"What color is the sky on a clear day?",
# Completion
"Once upon a time",
"The cat sat on the",
"In a hole in the ground there lived",
# Instruction
"Write one short sentence about rain.",
"List three animals.",
"Define the word 'library'.",
# Conversational
"Hello, how are you?",
"Tell me a joke.",
# Creative
"Describe a sunset in one line.",
"Give me a name for a pet robot.",
"What is the meaning of friendship?",
]
# Heuristic thresholds (printed, not enforced as pass/fail).
THRESH_DISTINCT_2 = 0.30
THRESH_SENT_MIN = 5
THRESH_SENT_MAX = 30
THRESH_EN_RATIO = 0.95
# ---------------------------------------------------------------------------
# Coherence heuristics
# ---------------------------------------------------------------------------
def _tokens(text: str) -> list[str]:
return re.findall(r"[A-Za-z0-9']+", text)
def distinct_2(text: str) -> float:
toks = _tokens(text)
if len(toks) < 2:
return 0.0
bigrams = [(toks[i], toks[i + 1]) for i in range(len(toks) - 1)]
return len(set(bigrams)) / max(1, len(bigrams))
def avg_sentence_len(text: str) -> float:
sents = re.split(r"[.!?]+", text)
lens = [len(_tokens(s)) for s in sents if _tokens(s)]
if not lens:
return 0.0
return sum(lens) / len(lens)
def english_char_ratio(text: str) -> float:
if not text:
return 0.0
allowed = 0
for c in text:
if c.isalnum() or c.isspace() or c in ".,!?;:'\"-()[]{}/\\*#@&%+=_<>|$":
allowed += 1
return allowed / len(text)
# ---------------------------------------------------------------------------
# Runner
# ---------------------------------------------------------------------------
def _run_one(model, tokenizer, prompt: str, *, max_new_tokens: int, device: torch.device,
max_seq_len: int, temperature: float, top_k: int, top_p: float,
repetition_penalty: float) -> str:
prompt_text = build_prompt(system="", history=[], user_msg=prompt)
prompt_ids = tokenizer.encode(prompt_text)
stream = generate_stream(
model, tokenizer, prompt_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repetition_penalty,
stop_strings=(END_TAG,),
max_seq_len=max_seq_len,
device=device,
)
collected: list[str] = []
try:
while True:
collected.append(next(stream))
except StopIteration as stop:
if stop.value is not None:
text = stop.value
else:
text = "".join(collected)
if END_TAG in text:
text = text.split(END_TAG, 1)[0]
return text.strip()
def _render_markdown(rows: list[dict]) -> str:
lines = [
"| # | Prompt | Response | dist-2 | sent_len | en_ratio | flags |",
"|---|--------|----------|--------|----------|----------|-------|",
]
def _cell(s: str, n: int = 60) -> str:
s = s.replace("|", "\\|").replace("\n", " ")
if len(s) > n:
s = s[: n - 1] + "…"
return s
for i, r in enumerate(rows, 1):
flags = []
if r["distinct_2"] < THRESH_DISTINCT_2:
flags.append("repetitive")
if not (THRESH_SENT_MIN <= r["avg_sentence_len"] <= THRESH_SENT_MAX):
flags.append("sent_len")
if r["en_ratio"] < THRESH_EN_RATIO:
flags.append("non_en")
flag_str = ",".join(flags) or "ok"
lines.append(
f"| {i} | {_cell(r['prompt'], 40)} | {_cell(r['response'], 60)} | "
f"{r['distinct_2']:.2f} | {r['avg_sentence_len']:.1f} | "
f"{r['en_ratio']:.2f} | {flag_str} |"
)
return "\n".join(lines)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
p = argparse.ArgumentParser(description="HYDRA chat eval")
p.add_argument("--ckpt", type=str, default=None, help="Checkpoint path.")
p.add_argument("--sft", action="store_true", help="Prefer SFT checkpoint.")
p.add_argument("--random", action="store_true", help="Use random weights.")
p.add_argument("--max", dest="max_new_tokens", type=int, default=80)
p.add_argument("--temp", dest="temperature", type=float, default=0.8)
p.add_argument("--topk", dest="top_k", type=int, default=40)
p.add_argument("--topp", dest="top_p", type=float, default=0.9)
p.add_argument("--rep", dest="repetition_penalty", type=float, default=1.1)
p.add_argument("--json", dest="json_out", type=str, default=None,
help="Optional: dump raw results to this JSON path.")
p.add_argument("--device", type=str, default=None)
return p.parse_args(argv)
def main(argv: list[str] | None = None) -> int:
args = _parse_args(argv)
if args.device:
device = torch.device(args.device)
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
ckpt_path = None if args.random else resolve_checkpoint(args.ckpt, args.sft)
t0 = time.time()
model, tokenizer, meta = load_model_and_tokenizer(ckpt_path, device)
dt_load = time.time() - t0
print(f"[chat_eval] Loaded in {dt_load:.1f}s ckpt={meta['ckpt']}")
from prepare import MAX_SEQ_LEN
rows: list[dict] = []
t_gen = time.time()
for i, prompt in enumerate(PROMPTS, 1):
t_start = time.time()
try:
resp = _run_one(
model, tokenizer, prompt,
max_new_tokens=args.max_new_tokens,
device=device,
max_seq_len=MAX_SEQ_LEN,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
)
err = None
except Exception as e: # noqa: BLE001 — eval must not abort mid-prompt.
resp = ""
err = repr(e)
print(f"[chat_eval] prompt {i} failed: {err}", file=sys.stderr)
rows.append({
"prompt": prompt,
"response": resp,
"distinct_2": distinct_2(resp),
"avg_sentence_len": avg_sentence_len(resp),
"en_ratio": english_char_ratio(resp),
"latency_s": round(time.time() - t_start, 2),
"error": err,
})
print(f"[chat_eval] {i:2d}/{len(PROMPTS)} {rows[-1]['latency_s']:.1f}s {resp!r}")
dt_gen = time.time() - t_gen
print()
print("## HYDRA chat_eval results")
print(f"- checkpoint: `{meta['ckpt']}`")
if meta.get("step") is not None:
print(f"- step: {meta['step']}")
if meta.get("val_bpb") is not None:
print(f"- val_bpb: {meta['val_bpb']}")
print(f"- prompts: {len(PROMPTS)}")
print(f"- load: {dt_load:.1f}s generation: {dt_gen:.1f}s")
print()
print(_render_markdown(rows))
print()
# Summary heuristics
any_empty = sum(1 for r in rows if not r["response"])
any_error = sum(1 for r in rows if r["error"])
mean_d2 = sum(r["distinct_2"] for r in rows) / max(1, len(rows))
mean_en = sum(r["en_ratio"] for r in rows) / max(1, len(rows))
print("### Aggregates")
print(f"- empty responses: {any_empty}/{len(rows)}")
print(f"- generation errors: {any_error}/{len(rows)}")
print(f"- mean distinct-2: {mean_d2:.3f} (target > {THRESH_DISTINCT_2})")
print(f"- mean en_ratio: {mean_en:.3f} (target > {THRESH_EN_RATIO})")
print()
print("_Quality at this model scale (~7.5M params) is NOT expected to meet thresholds; "
"this eval verifies the chat interface, not dialogue coherence._")
if args.json_out:
out = {
"meta": meta,
"settings": {
"max_new_tokens": args.max_new_tokens,
"temperature": args.temperature,
"top_k": args.top_k,
"top_p": args.top_p,
"repetition_penalty": args.repetition_penalty,
},
"rows": rows,
"aggregates": {
"empty": any_empty,
"errors": any_error,
"mean_distinct_2": mean_d2,
"mean_en_ratio": mean_en,
"load_s": dt_load,
"gen_s": dt_gen,
},
}
Path(args.json_out).write_text(json.dumps(out, indent=2))
print(f"[chat_eval] JSON written to {args.json_out}")
# Exit 0 if we loaded and generated *something* for each prompt (even if
# quality was poor). Exit 1 only on load failure (caught by main's exception
# propagation) or if ALL prompts returned empty strings — that signals a
# broken generation loop, not poor quality.
if any_empty == len(rows):
print("[chat_eval] ALL prompts returned empty — generation loop is broken.", file=sys.stderr)
return 1
return 0
if __name__ == "__main__":
sys.exit(main())