workbench / training /evaluation.py
GitHub Actions
Initial ZeroGPU deployment with spaces shim
7f9dfed
Raw
History Blame Contribute Delete
4.39 kB
from __future__ import annotations
import json
import math
from dataclasses import asdict, dataclass
from pathlib import Path
from tracking.trackio_client import TrackingClient
@dataclass(frozen=True)
class PromptCase:
"""One prompt and expected answer for lightweight evaluation."""
prompt: str
expected: str
@dataclass(frozen=True)
class EvalRow:
"""One evaluated response row."""
prompt: str
expected: str
actual: str
exact_match: bool
notes: str
@dataclass(frozen=True)
class EvalReport:
"""Aggregate evaluation report."""
rows: list[EvalRow]
exact_match_rate: float
perplexity: float | None = None
def as_table(self) -> list[list[str]]:
return [
[
row.prompt,
row.expected,
row.actual,
str(row.exact_match),
row.notes,
]
for row in self.rows
]
def as_dict(self) -> dict:
return {
"exact_match_rate": self.exact_match_rate,
"perplexity": self.perplexity,
"rows": [asdict(row) for row in self.rows],
}
@dataclass(frozen=True)
class ComparisonReport:
"""Base-vs-tuned comparison using exact-match rates."""
base: EvalReport
tuned: EvalReport
delta: float
def as_dict(self) -> dict:
return {
"base_exact_match_rate": self.base.exact_match_rate,
"tuned_exact_match_rate": self.tuned.exact_match_rate,
"delta": self.delta,
}
def default_prompt_cases() -> list[PromptCase]:
return [
PromptCase("Identify the correction target.", "field note"),
PromptCase("What format should corrected training data use?", "jsonl"),
PromptCase("Should model weights download on startup?", "no"),
]
def load_prompt_cases(path: str | Path) -> list[PromptCase]:
rows = _read_jsonl(path)
return [PromptCase(prompt=str(row["prompt"]), expected=str(row["expected"])) for row in rows]
def evaluate_responses(cases: list[PromptCase], responses: list[str]) -> EvalReport:
rows = []
for case, actual in zip(cases, responses, strict=False):
exact = _normalize(case.expected) == _normalize(actual)
rows.append(
EvalRow(
prompt=case.prompt,
expected=case.expected,
actual=actual,
exact_match=exact,
notes="exact" if exact else "review",
)
)
exact_match_rate = 0.0
if rows:
exact_match_rate = sum(1 for row in rows if row.exact_match) / len(rows)
return EvalReport(rows=rows, exact_match_rate=exact_match_rate)
def perplexity_from_losses(losses: list[float]) -> float | None:
if not losses:
return None
average_loss = sum(losses) / len(losses)
return math.exp(average_loss)
def attach_perplexity(report: EvalReport, losses: list[float]) -> EvalReport:
return EvalReport(
rows=report.rows,
exact_match_rate=report.exact_match_rate,
perplexity=perplexity_from_losses(losses),
)
def compare_base_vs_tuned(base: EvalReport, tuned: EvalReport) -> ComparisonReport:
return ComparisonReport(
base=base,
tuned=tuned,
delta=tuned.exact_match_rate - base.exact_match_rate,
)
def log_eval_report(report: EvalReport, path: str | Path = "data/eval_results.jsonl") -> Path:
output = Path(path)
output.parent.mkdir(parents=True, exist_ok=True)
with output.open("a", encoding="utf-8") as f:
f.write(json.dumps(report.as_dict(), ensure_ascii=False) + "\n")
return output
def log_eval_metrics(
report: EvalReport,
client: TrackingClient | None = None,
) -> Path:
tracker = client or TrackingClient()
tracker.init(run_name="evaluation")
return tracker.log(
"training_metrics",
{
"exact_match_rate": report.exact_match_rate,
"perplexity": report.perplexity,
"rows": len(report.rows),
},
)
def _read_jsonl(path: str | Path) -> list[dict]:
rows = []
with Path(path).open(encoding="utf-8") as f:
for line in f:
if line.strip():
rows.append(json.loads(line))
return rows
def _normalize(value: str) -> str:
return " ".join(value.casefold().strip().split())