"""Non-interactive chat eval for HYDRA. Runs a fixed set of prompts through the same chat template that `chat.py` uses, prints a markdown table with the response and coherence heuristics. Usage: python scripts/chat_eval.py # auto-select checkpoint python scripts/chat_eval.py --ckpt PATH python scripts/chat_eval.py --random python scripts/chat_eval.py --json out.json # also dump raw results python scripts/chat_eval.py --max 80 # cap new tokens per prompt """ from __future__ import annotations import argparse import json import os import re import sys import time from pathlib import Path _REPO_ROOT = Path(__file__).resolve().parent.parent if str(_REPO_ROOT) not in sys.path: sys.path.insert(0, str(_REPO_ROOT)) import torch # noqa: E402 from scripts.chat import ( # noqa: E402 ASSISTANT_TAG, END_TAG, USER_TAG, build_prompt, generate_stream, load_model_and_tokenizer, resolve_checkpoint, ) PROMPTS: list[str] = [ # Factual "What is the capital of France?", "Who wrote Romeo and Juliet?", "What is 2 plus 2?", "What color is the sky on a clear day?", # Completion "Once upon a time", "The cat sat on the", "In a hole in the ground there lived", # Instruction "Write one short sentence about rain.", "List three animals.", "Define the word 'library'.", # Conversational "Hello, how are you?", "Tell me a joke.", # Creative "Describe a sunset in one line.", "Give me a name for a pet robot.", "What is the meaning of friendship?", ] # Heuristic thresholds (printed, not enforced as pass/fail). THRESH_DISTINCT_2 = 0.30 THRESH_SENT_MIN = 5 THRESH_SENT_MAX = 30 THRESH_EN_RATIO = 0.95 # --------------------------------------------------------------------------- # Coherence heuristics # --------------------------------------------------------------------------- def _tokens(text: str) -> list[str]: return re.findall(r"[A-Za-z0-9']+", text) def distinct_2(text: str) -> float: toks = _tokens(text) if len(toks) < 2: return 0.0 bigrams = [(toks[i], toks[i + 1]) for i in range(len(toks) - 1)] return len(set(bigrams)) / max(1, len(bigrams)) def avg_sentence_len(text: str) -> float: sents = re.split(r"[.!?]+", text) lens = [len(_tokens(s)) for s in sents if _tokens(s)] if not lens: return 0.0 return sum(lens) / len(lens) def english_char_ratio(text: str) -> float: if not text: return 0.0 allowed = 0 for c in text: if c.isalnum() or c.isspace() or c in ".,!?;:'\"-()[]{}/\\*#@&%+=_<>|$": allowed += 1 return allowed / len(text) # --------------------------------------------------------------------------- # Runner # --------------------------------------------------------------------------- def _run_one(model, tokenizer, prompt: str, *, max_new_tokens: int, device: torch.device, max_seq_len: int, temperature: float, top_k: int, top_p: float, repetition_penalty: float) -> str: prompt_text = build_prompt(system="", history=[], user_msg=prompt) prompt_ids = tokenizer.encode(prompt_text) stream = generate_stream( model, tokenizer, prompt_ids, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, stop_strings=(END_TAG,), max_seq_len=max_seq_len, device=device, ) collected: list[str] = [] try: while True: collected.append(next(stream)) except StopIteration as stop: if stop.value is not None: text = stop.value else: text = "".join(collected) if END_TAG in text: text = text.split(END_TAG, 1)[0] return text.strip() def _render_markdown(rows: list[dict]) -> str: lines = [ "| # | Prompt | Response | dist-2 | sent_len | en_ratio | flags |", "|---|--------|----------|--------|----------|----------|-------|", ] def _cell(s: str, n: int = 60) -> str: s = s.replace("|", "\\|").replace("\n", " ") if len(s) > n: s = s[: n - 1] + "…" return s for i, r in enumerate(rows, 1): flags = [] if r["distinct_2"] < THRESH_DISTINCT_2: flags.append("repetitive") if not (THRESH_SENT_MIN <= r["avg_sentence_len"] <= THRESH_SENT_MAX): flags.append("sent_len") if r["en_ratio"] < THRESH_EN_RATIO: flags.append("non_en") flag_str = ",".join(flags) or "ok" lines.append( f"| {i} | {_cell(r['prompt'], 40)} | {_cell(r['response'], 60)} | " f"{r['distinct_2']:.2f} | {r['avg_sentence_len']:.1f} | " f"{r['en_ratio']:.2f} | {flag_str} |" ) return "\n".join(lines) # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: p = argparse.ArgumentParser(description="HYDRA chat eval") p.add_argument("--ckpt", type=str, default=None, help="Checkpoint path.") p.add_argument("--sft", action="store_true", help="Prefer SFT checkpoint.") p.add_argument("--random", action="store_true", help="Use random weights.") p.add_argument("--max", dest="max_new_tokens", type=int, default=80) p.add_argument("--temp", dest="temperature", type=float, default=0.8) p.add_argument("--topk", dest="top_k", type=int, default=40) p.add_argument("--topp", dest="top_p", type=float, default=0.9) p.add_argument("--rep", dest="repetition_penalty", type=float, default=1.1) p.add_argument("--json", dest="json_out", type=str, default=None, help="Optional: dump raw results to this JSON path.") p.add_argument("--device", type=str, default=None) return p.parse_args(argv) def main(argv: list[str] | None = None) -> int: args = _parse_args(argv) if args.device: device = torch.device(args.device) elif torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") ckpt_path = None if args.random else resolve_checkpoint(args.ckpt, args.sft) t0 = time.time() model, tokenizer, meta = load_model_and_tokenizer(ckpt_path, device) dt_load = time.time() - t0 print(f"[chat_eval] Loaded in {dt_load:.1f}s ckpt={meta['ckpt']}") from prepare import MAX_SEQ_LEN rows: list[dict] = [] t_gen = time.time() for i, prompt in enumerate(PROMPTS, 1): t_start = time.time() try: resp = _run_one( model, tokenizer, prompt, max_new_tokens=args.max_new_tokens, device=device, max_seq_len=MAX_SEQ_LEN, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, ) err = None except Exception as e: # noqa: BLE001 — eval must not abort mid-prompt. resp = "" err = repr(e) print(f"[chat_eval] prompt {i} failed: {err}", file=sys.stderr) rows.append({ "prompt": prompt, "response": resp, "distinct_2": distinct_2(resp), "avg_sentence_len": avg_sentence_len(resp), "en_ratio": english_char_ratio(resp), "latency_s": round(time.time() - t_start, 2), "error": err, }) print(f"[chat_eval] {i:2d}/{len(PROMPTS)} {rows[-1]['latency_s']:.1f}s {resp!r}") dt_gen = time.time() - t_gen print() print("## HYDRA chat_eval results") print(f"- checkpoint: `{meta['ckpt']}`") if meta.get("step") is not None: print(f"- step: {meta['step']}") if meta.get("val_bpb") is not None: print(f"- val_bpb: {meta['val_bpb']}") print(f"- prompts: {len(PROMPTS)}") print(f"- load: {dt_load:.1f}s generation: {dt_gen:.1f}s") print() print(_render_markdown(rows)) print() # Summary heuristics any_empty = sum(1 for r in rows if not r["response"]) any_error = sum(1 for r in rows if r["error"]) mean_d2 = sum(r["distinct_2"] for r in rows) / max(1, len(rows)) mean_en = sum(r["en_ratio"] for r in rows) / max(1, len(rows)) print("### Aggregates") print(f"- empty responses: {any_empty}/{len(rows)}") print(f"- generation errors: {any_error}/{len(rows)}") print(f"- mean distinct-2: {mean_d2:.3f} (target > {THRESH_DISTINCT_2})") print(f"- mean en_ratio: {mean_en:.3f} (target > {THRESH_EN_RATIO})") print() print("_Quality at this model scale (~7.5M params) is NOT expected to meet thresholds; " "this eval verifies the chat interface, not dialogue coherence._") if args.json_out: out = { "meta": meta, "settings": { "max_new_tokens": args.max_new_tokens, "temperature": args.temperature, "top_k": args.top_k, "top_p": args.top_p, "repetition_penalty": args.repetition_penalty, }, "rows": rows, "aggregates": { "empty": any_empty, "errors": any_error, "mean_distinct_2": mean_d2, "mean_en_ratio": mean_en, "load_s": dt_load, "gen_s": dt_gen, }, } Path(args.json_out).write_text(json.dumps(out, indent=2)) print(f"[chat_eval] JSON written to {args.json_out}") # Exit 0 if we loaded and generated *something* for each prompt (even if # quality was poor). Exit 1 only on load failure (caught by main's exception # propagation) or if ALL prompts returned empty strings — that signals a # broken generation loop, not poor quality. if any_empty == len(rows): print("[chat_eval] ALL prompts returned empty — generation loop is broken.", file=sys.stderr) return 1 return 0 if __name__ == "__main__": sys.exit(main())