gazet / finetune /eval_cli.py
srmsoumya's picture
chore: Run eval on v3 dataset
e98676b
"""Interactive eval: run test samples through the local GGUF model.
Requires llama-server running on port 8080:
llama-server -m finetune/models/<model>.gguf -ngl 99 --port 8080 --ctx-size 4096 --log-disable
Uses the /v1/chat/completions endpoint with a messages list. The Qwen3 GGUF
embeds its chat template in metadata, so llama-server applies it automatically.
Usage
-----
uv run finetune/eval_cli.py # prompts for index
uv run finetune/eval_cli.py 5 # run sample at index 5
uv run finetune/eval_cli.py 5 12 20 # run multiple samples
Use --task places for place extraction:
uv run finetune/eval_cli.py --task places 0 5
Override run directory:
uv run finetune/eval_cli.py --run-dir dataset/output/runs/v1 0
"""
from __future__ import annotations
import argparse
import json
import sys
import urllib.error
import urllib.request
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from pathlib import Path
SERVER_URL = "http://localhost:9000"
MAX_TOKENS = 2048
TEMPERATURE = 0.6
DEFAULT_RUN_DIR = Path("dataset/output/runs/v3")
def postprocess_sql(text: str) -> str:
cleaned = text.strip()
if "```sql" in cleaned:
cleaned = cleaned.split("```sql", 1)[1]
if cleaned.startswith("```"):
cleaned = cleaned[3:]
if "```" in cleaned:
cleaned = cleaned.split("```", 1)[0]
return cleaned.strip()
def check_server() -> bool:
try:
urllib.request.urlopen(f"{SERVER_URL}/health", timeout=2)
return True
except Exception:
return False
def chat_complete(messages: list[dict]) -> str:
"""Call llama-server /v1/chat/completions with a messages list."""
payload = json.dumps({
"messages": messages,
"n_predict": MAX_TOKENS,
"temperature": TEMPERATURE,
"chat_template_kwargs": {"enable_thinking": False},
}).encode()
req = urllib.request.Request(
f"{SERVER_URL}/v1/chat/completions",
data=payload,
headers={"Content-Type": "application/json"},
)
with urllib.request.urlopen(req, timeout=60) as resp:
return json.loads(resp.read())["choices"][0]["message"]["content"]
def load_samples(run_dir: Path, task: str, split: str = "val") -> list[dict]:
path = run_dir / task / f"{split}.jsonl"
if not path.exists():
print(f"Error: {path} not found")
sys.exit(1)
print(f"Loading {task} samples from: {path}")
with path.open() as f:
return [json.loads(line) for line in f if line.strip()]
def build_raw_prompt(sample: dict) -> str:
"""Reconstruct the plain prompt string from messages format (all turns except assistant)."""
return "\n\n".join(m["content"] for m in sample["messages"][:-1])
def eval_sample(sample: dict, task: str) -> dict:
"""Run a single sample through the server and return a result dict."""
expected = sample["messages"][-1]["content"]
messages = sample["messages"][:-1]
user_content = sample["messages"][-2]["content"]
if "<USER_QUERY>" in user_content:
question = user_content.split("<USER_QUERY>")[-1].split("</USER_QUERY>")[0].strip()
else:
question = user_content[:120]
raw = chat_complete(messages)
predicted = postprocess_sql(raw) if task == "sql" else raw.strip()
return {
"question": question,
"expected": expected,
"predicted": predicted,
"exact_match": predicted.strip() == expected.strip(),
}
def run_sample(sample: dict, task: str, total: int, index: int, verbose: bool = False) -> None:
user_content = sample["messages"][-2]["content"]
if "<USER_QUERY>" in user_content:
question = user_content.split("<USER_QUERY>")[-1].split("</USER_QUERY>")[0].strip()
else:
question = user_content[:120]
header = f" Sample {index}/{total-1} | {task} "
print(f"\n{'━' * 60}")
print(f"{'━' * ((60 - len(header)) // 2)}{header}{'━' * ((60 - len(header)) // 2)}")
print(f"{'━' * 60}")
print(f"\nQuestion: {question}\n")
if verbose:
prompt = build_raw_prompt(sample)
print(f"{'─' * 60}")
print(f"Full prompt ({len(prompt)} chars, ~{len(prompt.split())} words):")
print(f"{'─' * 60}")
print(prompt)
result = eval_sample(sample, task)
print(f"{'─' * 60}")
print("Expected:")
print(f"{'─' * 60}")
print(result["expected"])
print(f"\n{'─' * 60}")
print("Generated:")
print(f"{'─' * 60}")
print(result["predicted"])
print(f"\n{'─' * 60}")
print(f"Match: {'YES' if result['exact_match'] else 'NO'}")
def run_batch(
samples: list[dict],
task: str,
label: str,
output_path: Path,
workers: int = 8,
) -> None:
"""Run all samples concurrently and save results to a JSON file."""
total = len(samples)
results = [None] * total
completed = 0
with ThreadPoolExecutor(max_workers=workers) as executor:
futures = {executor.submit(eval_sample, s, task): i for i, s in enumerate(samples)}
for future in as_completed(futures):
i = futures[future]
result = future.result()
results[i] = {"index": i, **result}
completed += 1
if completed % 50 == 0 or completed == total:
print(f"{completed}/{total} done", flush=True)
matches = sum(1 for r in results if r["exact_match"])
exact_match_rate = matches / total if total else 0
output = {
"summary": {
"label": label,
"task": task,
"num_samples": total,
"exact_matches": matches,
"exact_match_rate": exact_match_rate,
"timestamp": datetime.now().isoformat(),
},
"results": results,
}
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("w") as f:
json.dump(output, f, indent=2)
print(f"\n{'=' * 60}")
print(f"[{label}] {matches}/{total} exact matches ({100 * exact_match_rate:.1f}%)")
print(f"Results saved to {output_path}")
print(f"{'=' * 60}")
def main() -> None:
parser = argparse.ArgumentParser(description="Interactive eval against llama-server")
parser.add_argument("indices", nargs="*", type=int, help="Sample indices to evaluate")
parser.add_argument("--task", default="sql", choices=["sql", "places"])
parser.add_argument(
"--run-dir",
type=Path,
default=DEFAULT_RUN_DIR,
help="Run directory containing {task}/{split}.jsonl files",
)
parser.add_argument("--split", default="val", choices=["val", "test"], help="Dataset split")
parser.add_argument("--verbose", "-v", action="store_true", help="Print full prompt sent to the model")
parser.add_argument("--all", dest="run_all", action="store_true", help="Run all samples in batch mode")
parser.add_argument("--max-samples", type=int, default=None, help="Limit number of samples (batch mode)")
parser.add_argument("--label", default="local-gguf", help="Label for batch output file")
parser.add_argument("--output", type=Path, default=None, help="Output JSON path (batch mode)")
parser.add_argument("--workers", type=int, default=4, help="Concurrent requests; match llama-server --parallel (default 4)")
args = parser.parse_args()
if not check_server():
print("llama-server not running. Start it with:")
print("llama-server -m finetune/models/<model>.gguf -ngl 99 --port 9000 --ctx-size 2048 --log-disable")
sys.exit(1)
samples = load_samples(args.run_dir, args.task, args.split)
total = len(samples)
if args.run_all:
if args.max_samples:
samples = samples[: args.max_samples]
output_path = args.output or Path(f"eval-{args.label}-{args.task}.json")
print(f"Running batch eval: {len(samples)} samples, {args.workers} workers")
run_batch(samples, args.task, args.label, output_path, workers=args.workers)
return
if not args.indices:
print(f"Test set has {total} {args.task} samples (0-{total-1})")
raw = input("Enter index (or press Enter for 0): ").strip()
indices = [int(raw) if raw else 0]
else:
indices = args.indices
for idx in indices:
if not (0 <= idx < total):
print(f"Index {idx} out of range (0-{total-1}), skipping")
continue
run_sample(samples[idx], args.task, total, idx, verbose=args.verbose)
print(f"\n{'━' * 60}\n")
if __name__ == "__main__":
main()