Spaces:
Running on Zero
Running on Zero
File size: 2,445 Bytes
7f9dfed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 | from __future__ import annotations
import json
import tempfile
import unittest
from pathlib import Path
from training.evaluation import (
attach_perplexity,
compare_base_vs_tuned,
default_prompt_cases,
evaluate_responses,
load_prompt_cases,
log_eval_report,
perplexity_from_losses,
)
class EvaluationTest(unittest.TestCase):
def test_evaluates_exact_match_rate(self) -> None:
cases = default_prompt_cases()
report = evaluate_responses(cases, ["field note", "wrong", "no"])
self.assertEqual(report.exact_match_rate, 2 / 3)
self.assertEqual(report.rows[1].notes, "review")
def test_compares_base_vs_tuned_reports(self) -> None:
cases = default_prompt_cases()
base = evaluate_responses(cases, ["wrong", "wrong", "wrong"])
tuned = evaluate_responses(cases, ["field note", "jsonl", "no"])
comparison = compare_base_vs_tuned(base, tuned)
self.assertEqual(comparison.delta, 1.0)
def test_loads_prompt_cases_from_jsonl(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "cases.jsonl"
path.write_text(
json.dumps({"prompt": "Prompt", "expected": "Answer"}) + "\n",
encoding="utf-8",
)
cases = load_prompt_cases(path)
self.assertEqual(cases[0].prompt, "Prompt")
self.assertEqual(cases[0].expected, "Answer")
def test_logs_eval_report(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
path = Path(tmp) / "eval.jsonl"
report = evaluate_responses(default_prompt_cases(), ["field note"])
saved = log_eval_report(report, path)
self.assertEqual(saved, path)
self.assertIn("exact_match_rate", path.read_text(encoding="utf-8"))
def test_calculates_perplexity_from_average_loss(self) -> None:
perplexity = perplexity_from_losses([0.0, 0.0])
self.assertEqual(perplexity, 1.0)
self.assertIsNone(perplexity_from_losses([]))
def test_attaches_perplexity_to_eval_report(self) -> None:
report = evaluate_responses(default_prompt_cases(), ["field note"])
updated = attach_perplexity(report, [0.0])
self.assertEqual(updated.exact_match_rate, report.exact_match_rate)
self.assertEqual(updated.perplexity, 1.0)
if __name__ == "__main__":
unittest.main()
|