Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import json | |
| import math | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| from tracking.trackio_client import TrackingClient | |
| class PromptCase: | |
| """One prompt and expected answer for lightweight evaluation.""" | |
| prompt: str | |
| expected: str | |
| class EvalRow: | |
| """One evaluated response row.""" | |
| prompt: str | |
| expected: str | |
| actual: str | |
| exact_match: bool | |
| notes: str | |
| class EvalReport: | |
| """Aggregate evaluation report.""" | |
| rows: list[EvalRow] | |
| exact_match_rate: float | |
| perplexity: float | None = None | |
| def as_table(self) -> list[list[str]]: | |
| return [ | |
| [ | |
| row.prompt, | |
| row.expected, | |
| row.actual, | |
| str(row.exact_match), | |
| row.notes, | |
| ] | |
| for row in self.rows | |
| ] | |
| def as_dict(self) -> dict: | |
| return { | |
| "exact_match_rate": self.exact_match_rate, | |
| "perplexity": self.perplexity, | |
| "rows": [asdict(row) for row in self.rows], | |
| } | |
| class ComparisonReport: | |
| """Base-vs-tuned comparison using exact-match rates.""" | |
| base: EvalReport | |
| tuned: EvalReport | |
| delta: float | |
| def as_dict(self) -> dict: | |
| return { | |
| "base_exact_match_rate": self.base.exact_match_rate, | |
| "tuned_exact_match_rate": self.tuned.exact_match_rate, | |
| "delta": self.delta, | |
| } | |
| def default_prompt_cases() -> list[PromptCase]: | |
| return [ | |
| PromptCase("Identify the correction target.", "field note"), | |
| PromptCase("What format should corrected training data use?", "jsonl"), | |
| PromptCase("Should model weights download on startup?", "no"), | |
| ] | |
| def load_prompt_cases(path: str | Path) -> list[PromptCase]: | |
| rows = _read_jsonl(path) | |
| return [PromptCase(prompt=str(row["prompt"]), expected=str(row["expected"])) for row in rows] | |
| def evaluate_responses(cases: list[PromptCase], responses: list[str]) -> EvalReport: | |
| rows = [] | |
| for case, actual in zip(cases, responses, strict=False): | |
| exact = _normalize(case.expected) == _normalize(actual) | |
| rows.append( | |
| EvalRow( | |
| prompt=case.prompt, | |
| expected=case.expected, | |
| actual=actual, | |
| exact_match=exact, | |
| notes="exact" if exact else "review", | |
| ) | |
| ) | |
| exact_match_rate = 0.0 | |
| if rows: | |
| exact_match_rate = sum(1 for row in rows if row.exact_match) / len(rows) | |
| return EvalReport(rows=rows, exact_match_rate=exact_match_rate) | |
| def perplexity_from_losses(losses: list[float]) -> float | None: | |
| if not losses: | |
| return None | |
| average_loss = sum(losses) / len(losses) | |
| return math.exp(average_loss) | |
| def attach_perplexity(report: EvalReport, losses: list[float]) -> EvalReport: | |
| return EvalReport( | |
| rows=report.rows, | |
| exact_match_rate=report.exact_match_rate, | |
| perplexity=perplexity_from_losses(losses), | |
| ) | |
| def compare_base_vs_tuned(base: EvalReport, tuned: EvalReport) -> ComparisonReport: | |
| return ComparisonReport( | |
| base=base, | |
| tuned=tuned, | |
| delta=tuned.exact_match_rate - base.exact_match_rate, | |
| ) | |
| def log_eval_report(report: EvalReport, path: str | Path = "data/eval_results.jsonl") -> Path: | |
| output = Path(path) | |
| output.parent.mkdir(parents=True, exist_ok=True) | |
| with output.open("a", encoding="utf-8") as f: | |
| f.write(json.dumps(report.as_dict(), ensure_ascii=False) + "\n") | |
| return output | |
| def log_eval_metrics( | |
| report: EvalReport, | |
| client: TrackingClient | None = None, | |
| ) -> Path: | |
| tracker = client or TrackingClient() | |
| tracker.init(run_name="evaluation") | |
| return tracker.log( | |
| "training_metrics", | |
| { | |
| "exact_match_rate": report.exact_match_rate, | |
| "perplexity": report.perplexity, | |
| "rows": len(report.rows), | |
| }, | |
| ) | |
| def _read_jsonl(path: str | Path) -> list[dict]: | |
| rows = [] | |
| with Path(path).open(encoding="utf-8") as f: | |
| for line in f: | |
| if line.strip(): | |
| rows.append(json.loads(line)) | |
| return rows | |
| def _normalize(value: str) -> str: | |
| return " ".join(value.casefold().strip().split()) | |