"""Interactive eval: run test samples through the local GGUF model. Requires llama-server running on port 8080: llama-server -m finetune/models/.gguf -ngl 99 --port 8080 --ctx-size 4096 --log-disable Uses the /v1/chat/completions endpoint with a messages list. The Qwen3 GGUF embeds its chat template in metadata, so llama-server applies it automatically. Usage ----- uv run finetune/eval_cli.py # prompts for index uv run finetune/eval_cli.py 5 # run sample at index 5 uv run finetune/eval_cli.py 5 12 20 # run multiple samples Use --task places for place extraction: uv run finetune/eval_cli.py --task places 0 5 Override run directory: uv run finetune/eval_cli.py --run-dir dataset/output/runs/v1 0 """ from __future__ import annotations import argparse import json import sys import urllib.error import urllib.request from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from pathlib import Path SERVER_URL = "http://localhost:9000" MAX_TOKENS = 2048 TEMPERATURE = 0.6 DEFAULT_RUN_DIR = Path("dataset/output/runs/v3") def postprocess_sql(text: str) -> str: cleaned = text.strip() if "```sql" in cleaned: cleaned = cleaned.split("```sql", 1)[1] if cleaned.startswith("```"): cleaned = cleaned[3:] if "```" in cleaned: cleaned = cleaned.split("```", 1)[0] return cleaned.strip() def check_server() -> bool: try: urllib.request.urlopen(f"{SERVER_URL}/health", timeout=2) return True except Exception: return False def chat_complete(messages: list[dict]) -> str: """Call llama-server /v1/chat/completions with a messages list.""" payload = json.dumps({ "messages": messages, "n_predict": MAX_TOKENS, "temperature": TEMPERATURE, "chat_template_kwargs": {"enable_thinking": False}, }).encode() req = urllib.request.Request( f"{SERVER_URL}/v1/chat/completions", data=payload, headers={"Content-Type": "application/json"}, ) with urllib.request.urlopen(req, timeout=60) as resp: return json.loads(resp.read())["choices"][0]["message"]["content"] def load_samples(run_dir: Path, task: str, split: str = "val") -> list[dict]: path = run_dir / task / f"{split}.jsonl" if not path.exists(): print(f"Error: {path} not found") sys.exit(1) print(f"Loading {task} samples from: {path}") with path.open() as f: return [json.loads(line) for line in f if line.strip()] def build_raw_prompt(sample: dict) -> str: """Reconstruct the plain prompt string from messages format (all turns except assistant).""" return "\n\n".join(m["content"] for m in sample["messages"][:-1]) def eval_sample(sample: dict, task: str) -> dict: """Run a single sample through the server and return a result dict.""" expected = sample["messages"][-1]["content"] messages = sample["messages"][:-1] user_content = sample["messages"][-2]["content"] if "" in user_content: question = user_content.split("")[-1].split("")[0].strip() else: question = user_content[:120] raw = chat_complete(messages) predicted = postprocess_sql(raw) if task == "sql" else raw.strip() return { "question": question, "expected": expected, "predicted": predicted, "exact_match": predicted.strip() == expected.strip(), } def run_sample(sample: dict, task: str, total: int, index: int, verbose: bool = False) -> None: user_content = sample["messages"][-2]["content"] if "" in user_content: question = user_content.split("")[-1].split("")[0].strip() else: question = user_content[:120] header = f" Sample {index}/{total-1} | {task} " print(f"\n{'━' * 60}") print(f"{'━' * ((60 - len(header)) // 2)}{header}{'━' * ((60 - len(header)) // 2)}") print(f"{'━' * 60}") print(f"\nQuestion: {question}\n") if verbose: prompt = build_raw_prompt(sample) print(f"{'─' * 60}") print(f"Full prompt ({len(prompt)} chars, ~{len(prompt.split())} words):") print(f"{'─' * 60}") print(prompt) result = eval_sample(sample, task) print(f"{'─' * 60}") print("Expected:") print(f"{'─' * 60}") print(result["expected"]) print(f"\n{'─' * 60}") print("Generated:") print(f"{'─' * 60}") print(result["predicted"]) print(f"\n{'─' * 60}") print(f"Match: {'YES' if result['exact_match'] else 'NO'}") def run_batch( samples: list[dict], task: str, label: str, output_path: Path, workers: int = 8, ) -> None: """Run all samples concurrently and save results to a JSON file.""" total = len(samples) results = [None] * total completed = 0 with ThreadPoolExecutor(max_workers=workers) as executor: futures = {executor.submit(eval_sample, s, task): i for i, s in enumerate(samples)} for future in as_completed(futures): i = futures[future] result = future.result() results[i] = {"index": i, **result} completed += 1 if completed % 50 == 0 or completed == total: print(f"{completed}/{total} done", flush=True) matches = sum(1 for r in results if r["exact_match"]) exact_match_rate = matches / total if total else 0 output = { "summary": { "label": label, "task": task, "num_samples": total, "exact_matches": matches, "exact_match_rate": exact_match_rate, "timestamp": datetime.now().isoformat(), }, "results": results, } output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("w") as f: json.dump(output, f, indent=2) print(f"\n{'=' * 60}") print(f"[{label}] {matches}/{total} exact matches ({100 * exact_match_rate:.1f}%)") print(f"Results saved to {output_path}") print(f"{'=' * 60}") def main() -> None: parser = argparse.ArgumentParser(description="Interactive eval against llama-server") parser.add_argument("indices", nargs="*", type=int, help="Sample indices to evaluate") parser.add_argument("--task", default="sql", choices=["sql", "places"]) parser.add_argument( "--run-dir", type=Path, default=DEFAULT_RUN_DIR, help="Run directory containing {task}/{split}.jsonl files", ) parser.add_argument("--split", default="val", choices=["val", "test"], help="Dataset split") parser.add_argument("--verbose", "-v", action="store_true", help="Print full prompt sent to the model") parser.add_argument("--all", dest="run_all", action="store_true", help="Run all samples in batch mode") parser.add_argument("--max-samples", type=int, default=None, help="Limit number of samples (batch mode)") parser.add_argument("--label", default="local-gguf", help="Label for batch output file") parser.add_argument("--output", type=Path, default=None, help="Output JSON path (batch mode)") parser.add_argument("--workers", type=int, default=4, help="Concurrent requests; match llama-server --parallel (default 4)") args = parser.parse_args() if not check_server(): print("llama-server not running. Start it with:") print("llama-server -m finetune/models/.gguf -ngl 99 --port 9000 --ctx-size 2048 --log-disable") sys.exit(1) samples = load_samples(args.run_dir, args.task, args.split) total = len(samples) if args.run_all: if args.max_samples: samples = samples[: args.max_samples] output_path = args.output or Path(f"eval-{args.label}-{args.task}.json") print(f"Running batch eval: {len(samples)} samples, {args.workers} workers") run_batch(samples, args.task, args.label, output_path, workers=args.workers) return if not args.indices: print(f"Test set has {total} {args.task} samples (0-{total-1})") raw = input("Enter index (or press Enter for 0): ").strip() indices = [int(raw) if raw else 0] else: indices = args.indices for idx in indices: if not (0 <= idx < total): print(f"Index {idx} out of range (0-{total-1}), skipping") continue run_sample(samples[idx], args.task, total, idx, verbose=args.verbose) print(f"\n{'━' * 60}\n") if __name__ == "__main__": main()