workbench / tests /unit /test_evaluation.py
GitHub Actions
Initial ZeroGPU deployment with spaces shim
7f9dfed
Raw
History Blame Contribute Delete
2.45 kB
from __future__ import annotations
import json
import tempfile
import unittest
from pathlib import Path
from training.evaluation import (
attach_perplexity,
compare_base_vs_tuned,
default_prompt_cases,
evaluate_responses,
load_prompt_cases,
log_eval_report,
perplexity_from_losses,
)
class EvaluationTest(unittest.TestCase):
def test_evaluates_exact_match_rate(self) -> None:
cases = default_prompt_cases()
report = evaluate_responses(cases, ["field note", "wrong", "no"])
self.assertEqual(report.exact_match_rate, 2 / 3)
self.assertEqual(report.rows[1].notes, "review")
def test_compares_base_vs_tuned_reports(self) -> None:
cases = default_prompt_cases()
base = evaluate_responses(cases, ["wrong", "wrong", "wrong"])
tuned = evaluate_responses(cases, ["field note", "jsonl", "no"])
comparison = compare_base_vs_tuned(base, tuned)
self.assertEqual(comparison.delta, 1.0)
def test_loads_prompt_cases_from_jsonl(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "cases.jsonl"
path.write_text(
json.dumps({"prompt": "Prompt", "expected": "Answer"}) + "\n",
encoding="utf-8",
)
cases = load_prompt_cases(path)
self.assertEqual(cases[0].prompt, "Prompt")
self.assertEqual(cases[0].expected, "Answer")
def test_logs_eval_report(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "eval.jsonl"
report = evaluate_responses(default_prompt_cases(), ["field note"])
saved = log_eval_report(report, path)
self.assertEqual(saved, path)
self.assertIn("exact_match_rate", path.read_text(encoding="utf-8"))
def test_calculates_perplexity_from_average_loss(self) -> None:
perplexity = perplexity_from_losses([0.0, 0.0])
self.assertEqual(perplexity, 1.0)
self.assertIsNone(perplexity_from_losses([]))
def test_attaches_perplexity_to_eval_report(self) -> None:
report = evaluate_responses(default_prompt_cases(), ["field note"])
updated = attach_perplexity(report, [0.0])
self.assertEqual(updated.exact_match_rate, report.exact_match_rate)
self.assertEqual(updated.perplexity, 1.0)
if __name__ == "__main__":
unittest.main()