kpaa / src /kpaa /cli_eval.py
scvcoder's picture
Initial backend code: src/kpaa, runtime data, requirements
94f1300 verified
raw
history blame
5.95 kB
"""골든 질문 자동 평가 (`kpaa eval`).
`tests/eval_questions.yaml`의 10개 질문을 라이브 파이프라인에 던지고,
각 질문의 `expected_phrases`(모두 매칭)와 `forbidden_phrases`(하나라도 매칭 시 실패),
면책 문구 부착 여부를 검사한다.
라이브 LLM 추론이 들어가므로 10건 × ~30~60초 = 5~10분 소요.
빠른 검증은 `--limit N`로 일부만.
CI에서는 LLM 호출이 너무 무거우므로 이 스크립트는 사용자 로컬 검증용.
"""
from __future__ import annotations
import re
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import yaml
from kpaa.pipeline import generate
_DISCLAIMER_RE = re.compile(r"※[\s\S]{0,400}법률\s*자문")
@dataclass
class CaseResult:
id: int
question: str
answer: str
elapsed_s: float
missed_expected: list[str]
hit_forbidden: list[str]
has_disclaimer: bool
@property
def passed(self) -> bool:
return (
not self.missed_expected
and not self.hit_forbidden
and self.has_disclaimer
)
def _load_eval(path: Path | None = None) -> list[dict[str, Any]]:
p = path or (Path(__file__).resolve().parents[2] / "tests" / "eval_questions.yaml")
if not p.exists():
raise FileNotFoundError(f"eval yaml not found: {p}")
raw = yaml.safe_load(p.read_text(encoding="utf-8"))
return list(raw or [])
def _check(answer: str, item: dict[str, Any]) -> tuple[list[str], list[str], bool]:
"""답변에서 expected/forbidden 정규식 매칭 + 면책 문구 검사."""
missed = []
for pat in item.get("expected_phrases", []) or []:
if not re.search(pat, answer):
missed.append(pat)
hit = []
for pat in item.get("forbidden_phrases", []) or []:
if re.search(pat, answer):
hit.append(pat)
has_disclaimer = bool(_DISCLAIMER_RE.search(answer))
return missed, hit, has_disclaimer
async def _generate_answer(query: str) -> tuple[str, float]:
"""파이프라인 종단 호출 → 최종 답변과 경과 시간."""
t0 = time.monotonic()
final = ""
chunks: list[str] = []
async for evt in generate(query):
if evt["event"] == "token":
chunks.append(evt["delta"])
elif evt["event"] == "done":
final = evt["answer"]
if not final:
final = "".join(chunks)
return final, time.monotonic() - t0
async def run(
*,
limit: int | None = None,
eval_path: Path | None = None,
show_answers: bool = False,
out_path: Path | None = None,
) -> int:
items = _load_eval(eval_path)
if limit is not None:
items = items[:limit]
print(f"▶ kpaa eval — 골든 질문 {len(items)}건 평가 시작")
print(f" (각 질문 LLM 추론 ~30–60초, 총 {len(items) * 45 / 60:.1f}분 소요 예상)\n")
results: list[CaseResult] = []
for i, item in enumerate(items, 1):
q = item["question"]
qid = item.get("id", i)
print(f"[{i}/{len(items)}] #{qid} {q}")
try:
answer, secs = await _generate_answer(q)
except Exception as e:
print(f" ✗ ERROR: {type(e).__name__}: {e}\n")
results.append(
CaseResult(
id=qid, question=q, answer="", elapsed_s=0.0,
missed_expected=item.get("expected_phrases", []) or [],
hit_forbidden=[],
has_disclaimer=False,
)
)
continue
missed, hit, has_dc = _check(answer, item)
r = CaseResult(
id=qid, question=q, answer=answer, elapsed_s=secs,
missed_expected=missed, hit_forbidden=hit, has_disclaimer=has_dc,
)
results.append(r)
flag = "✅" if r.passed else "❌"
print(f" {flag} ({secs:.1f}s) "
f"missed={len(missed)} forbidden={len(hit)} disclaimer={has_dc}")
if missed:
print(f" missed: {missed}")
if hit:
print(f" forbidden hit: {hit}")
if not has_dc:
print(" disclaimer absent")
if show_answers or not r.passed:
preview = answer.replace("\n", "\n ")
print(f" ─── 답변 ─────────────\n {preview}")
print()
# 종합
passed = sum(1 for r in results if r.passed)
total = len(results)
print("─" * 60)
print(f"결과: {passed}/{total} 통과 ({100*passed/total if total else 0:.0f}%)")
print("─" * 60)
if passed < total:
print("\n실패한 질문:")
for r in results:
if not r.passed:
print(f" #{r.id}: {r.question}")
if r.missed_expected:
print(f" missing: {r.missed_expected}")
if r.hit_forbidden:
print(f" forbidden: {r.hit_forbidden}")
if not r.has_disclaimer:
print(" disclaimer absent")
if out_path is not None:
import json
out_path.write_text(
json.dumps(
[
{
"id": r.id,
"question": r.question,
"answer": r.answer,
"elapsed_s": round(r.elapsed_s, 2),
"passed": r.passed,
"missed_expected": r.missed_expected,
"hit_forbidden": r.hit_forbidden,
"has_disclaimer": r.has_disclaimer,
}
for r in results
],
ensure_ascii=False,
indent=2,
),
encoding="utf-8",
)
print(f"\n결과 저장: {out_path}")
return 0 if passed == total else 1
__all__ = ["run", "CaseResult"]