Spaces:
Runtime error
Runtime error
| """Non-interactive chat eval for HYDRA. | |
| Runs a fixed set of prompts through the same chat template that `chat.py` | |
| uses, prints a markdown table with the response and coherence heuristics. | |
| Usage: | |
| python scripts/chat_eval.py # auto-select checkpoint | |
| python scripts/chat_eval.py --ckpt PATH | |
| python scripts/chat_eval.py --random | |
| python scripts/chat_eval.py --json out.json # also dump raw results | |
| python scripts/chat_eval.py --max 80 # cap new tokens per prompt | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import re | |
| import sys | |
| import time | |
| from pathlib import Path | |
| _REPO_ROOT = Path(__file__).resolve().parent.parent | |
| if str(_REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(_REPO_ROOT)) | |
| import torch # noqa: E402 | |
| from scripts.chat import ( # noqa: E402 | |
| ASSISTANT_TAG, END_TAG, USER_TAG, build_prompt, | |
| generate_stream, load_model_and_tokenizer, resolve_checkpoint, | |
| ) | |
| PROMPTS: list[str] = [ | |
| # Factual | |
| "What is the capital of France?", | |
| "Who wrote Romeo and Juliet?", | |
| "What is 2 plus 2?", | |
| "What color is the sky on a clear day?", | |
| # Completion | |
| "Once upon a time", | |
| "The cat sat on the", | |
| "In a hole in the ground there lived", | |
| # Instruction | |
| "Write one short sentence about rain.", | |
| "List three animals.", | |
| "Define the word 'library'.", | |
| # Conversational | |
| "Hello, how are you?", | |
| "Tell me a joke.", | |
| # Creative | |
| "Describe a sunset in one line.", | |
| "Give me a name for a pet robot.", | |
| "What is the meaning of friendship?", | |
| ] | |
| # Heuristic thresholds (printed, not enforced as pass/fail). | |
| THRESH_DISTINCT_2 = 0.30 | |
| THRESH_SENT_MIN = 5 | |
| THRESH_SENT_MAX = 30 | |
| THRESH_EN_RATIO = 0.95 | |
| # --------------------------------------------------------------------------- | |
| # Coherence heuristics | |
| # --------------------------------------------------------------------------- | |
| def _tokens(text: str) -> list[str]: | |
| return re.findall(r"[A-Za-z0-9']+", text) | |
| def distinct_2(text: str) -> float: | |
| toks = _tokens(text) | |
| if len(toks) < 2: | |
| return 0.0 | |
| bigrams = [(toks[i], toks[i + 1]) for i in range(len(toks) - 1)] | |
| return len(set(bigrams)) / max(1, len(bigrams)) | |
| def avg_sentence_len(text: str) -> float: | |
| sents = re.split(r"[.!?]+", text) | |
| lens = [len(_tokens(s)) for s in sents if _tokens(s)] | |
| if not lens: | |
| return 0.0 | |
| return sum(lens) / len(lens) | |
| def english_char_ratio(text: str) -> float: | |
| if not text: | |
| return 0.0 | |
| allowed = 0 | |
| for c in text: | |
| if c.isalnum() or c.isspace() or c in ".,!?;:'\"-()[]{}/\\*#@&%+=_<>|$": | |
| allowed += 1 | |
| return allowed / len(text) | |
| # --------------------------------------------------------------------------- | |
| # Runner | |
| # --------------------------------------------------------------------------- | |
| def _run_one(model, tokenizer, prompt: str, *, max_new_tokens: int, device: torch.device, | |
| max_seq_len: int, temperature: float, top_k: int, top_p: float, | |
| repetition_penalty: float) -> str: | |
| prompt_text = build_prompt(system="", history=[], user_msg=prompt) | |
| prompt_ids = tokenizer.encode(prompt_text) | |
| stream = generate_stream( | |
| model, tokenizer, prompt_ids, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| stop_strings=(END_TAG,), | |
| max_seq_len=max_seq_len, | |
| device=device, | |
| ) | |
| collected: list[str] = [] | |
| try: | |
| while True: | |
| collected.append(next(stream)) | |
| except StopIteration as stop: | |
| if stop.value is not None: | |
| text = stop.value | |
| else: | |
| text = "".join(collected) | |
| if END_TAG in text: | |
| text = text.split(END_TAG, 1)[0] | |
| return text.strip() | |
| def _render_markdown(rows: list[dict]) -> str: | |
| lines = [ | |
| "| # | Prompt | Response | dist-2 | sent_len | en_ratio | flags |", | |
| "|---|--------|----------|--------|----------|----------|-------|", | |
| ] | |
| def _cell(s: str, n: int = 60) -> str: | |
| s = s.replace("|", "\\|").replace("\n", " ") | |
| if len(s) > n: | |
| s = s[: n - 1] + "…" | |
| return s | |
| for i, r in enumerate(rows, 1): | |
| flags = [] | |
| if r["distinct_2"] < THRESH_DISTINCT_2: | |
| flags.append("repetitive") | |
| if not (THRESH_SENT_MIN <= r["avg_sentence_len"] <= THRESH_SENT_MAX): | |
| flags.append("sent_len") | |
| if r["en_ratio"] < THRESH_EN_RATIO: | |
| flags.append("non_en") | |
| flag_str = ",".join(flags) or "ok" | |
| lines.append( | |
| f"| {i} | {_cell(r['prompt'], 40)} | {_cell(r['response'], 60)} | " | |
| f"{r['distinct_2']:.2f} | {r['avg_sentence_len']:.1f} | " | |
| f"{r['en_ratio']:.2f} | {flag_str} |" | |
| ) | |
| return "\n".join(lines) | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: | |
| p = argparse.ArgumentParser(description="HYDRA chat eval") | |
| p.add_argument("--ckpt", type=str, default=None, help="Checkpoint path.") | |
| p.add_argument("--sft", action="store_true", help="Prefer SFT checkpoint.") | |
| p.add_argument("--random", action="store_true", help="Use random weights.") | |
| p.add_argument("--max", dest="max_new_tokens", type=int, default=80) | |
| p.add_argument("--temp", dest="temperature", type=float, default=0.8) | |
| p.add_argument("--topk", dest="top_k", type=int, default=40) | |
| p.add_argument("--topp", dest="top_p", type=float, default=0.9) | |
| p.add_argument("--rep", dest="repetition_penalty", type=float, default=1.1) | |
| p.add_argument("--json", dest="json_out", type=str, default=None, | |
| help="Optional: dump raw results to this JSON path.") | |
| p.add_argument("--device", type=str, default=None) | |
| return p.parse_args(argv) | |
| def main(argv: list[str] | None = None) -> int: | |
| args = _parse_args(argv) | |
| if args.device: | |
| device = torch.device(args.device) | |
| elif torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| ckpt_path = None if args.random else resolve_checkpoint(args.ckpt, args.sft) | |
| t0 = time.time() | |
| model, tokenizer, meta = load_model_and_tokenizer(ckpt_path, device) | |
| dt_load = time.time() - t0 | |
| print(f"[chat_eval] Loaded in {dt_load:.1f}s ckpt={meta['ckpt']}") | |
| from prepare import MAX_SEQ_LEN | |
| rows: list[dict] = [] | |
| t_gen = time.time() | |
| for i, prompt in enumerate(PROMPTS, 1): | |
| t_start = time.time() | |
| try: | |
| resp = _run_one( | |
| model, tokenizer, prompt, | |
| max_new_tokens=args.max_new_tokens, | |
| device=device, | |
| max_seq_len=MAX_SEQ_LEN, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| top_p=args.top_p, | |
| repetition_penalty=args.repetition_penalty, | |
| ) | |
| err = None | |
| except Exception as e: # noqa: BLE001 — eval must not abort mid-prompt. | |
| resp = "" | |
| err = repr(e) | |
| print(f"[chat_eval] prompt {i} failed: {err}", file=sys.stderr) | |
| rows.append({ | |
| "prompt": prompt, | |
| "response": resp, | |
| "distinct_2": distinct_2(resp), | |
| "avg_sentence_len": avg_sentence_len(resp), | |
| "en_ratio": english_char_ratio(resp), | |
| "latency_s": round(time.time() - t_start, 2), | |
| "error": err, | |
| }) | |
| print(f"[chat_eval] {i:2d}/{len(PROMPTS)} {rows[-1]['latency_s']:.1f}s {resp!r}") | |
| dt_gen = time.time() - t_gen | |
| print() | |
| print("## HYDRA chat_eval results") | |
| print(f"- checkpoint: `{meta['ckpt']}`") | |
| if meta.get("step") is not None: | |
| print(f"- step: {meta['step']}") | |
| if meta.get("val_bpb") is not None: | |
| print(f"- val_bpb: {meta['val_bpb']}") | |
| print(f"- prompts: {len(PROMPTS)}") | |
| print(f"- load: {dt_load:.1f}s generation: {dt_gen:.1f}s") | |
| print() | |
| print(_render_markdown(rows)) | |
| print() | |
| # Summary heuristics | |
| any_empty = sum(1 for r in rows if not r["response"]) | |
| any_error = sum(1 for r in rows if r["error"]) | |
| mean_d2 = sum(r["distinct_2"] for r in rows) / max(1, len(rows)) | |
| mean_en = sum(r["en_ratio"] for r in rows) / max(1, len(rows)) | |
| print("### Aggregates") | |
| print(f"- empty responses: {any_empty}/{len(rows)}") | |
| print(f"- generation errors: {any_error}/{len(rows)}") | |
| print(f"- mean distinct-2: {mean_d2:.3f} (target > {THRESH_DISTINCT_2})") | |
| print(f"- mean en_ratio: {mean_en:.3f} (target > {THRESH_EN_RATIO})") | |
| print() | |
| print("_Quality at this model scale (~7.5M params) is NOT expected to meet thresholds; " | |
| "this eval verifies the chat interface, not dialogue coherence._") | |
| if args.json_out: | |
| out = { | |
| "meta": meta, | |
| "settings": { | |
| "max_new_tokens": args.max_new_tokens, | |
| "temperature": args.temperature, | |
| "top_k": args.top_k, | |
| "top_p": args.top_p, | |
| "repetition_penalty": args.repetition_penalty, | |
| }, | |
| "rows": rows, | |
| "aggregates": { | |
| "empty": any_empty, | |
| "errors": any_error, | |
| "mean_distinct_2": mean_d2, | |
| "mean_en_ratio": mean_en, | |
| "load_s": dt_load, | |
| "gen_s": dt_gen, | |
| }, | |
| } | |
| Path(args.json_out).write_text(json.dumps(out, indent=2)) | |
| print(f"[chat_eval] JSON written to {args.json_out}") | |
| # Exit 0 if we loaded and generated *something* for each prompt (even if | |
| # quality was poor). Exit 1 only on load failure (caught by main's exception | |
| # propagation) or if ALL prompts returned empty strings — that signals a | |
| # broken generation loop, not poor quality. | |
| if any_empty == len(rows): | |
| print("[chat_eval] ALL prompts returned empty — generation loop is broken.", file=sys.stderr) | |
| return 1 | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |