File size: 5,953 Bytes
94f1300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""골든 질문 자동 평가 (`kpaa eval`).

`tests/eval_questions.yaml`의 10개 질문을 라이브 파이프라인에 던지고,
각 질문의 `expected_phrases`(모두 매칭)와 `forbidden_phrases`(하나라도 매칭 시 실패),
면책 문구 부착 여부를 검사한다.

라이브 LLM 추론이 들어가므로 10건 × ~30~60초 = 5~10분 소요.
빠른 검증은 `--limit N`로 일부만.

CI에서는 LLM 호출이 너무 무거우므로 이 스크립트는 사용자 로컬 검증용.
"""
from __future__ import annotations

import re
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import yaml

from kpaa.pipeline import generate

_DISCLAIMER_RE = re.compile(r"※[\s\S]{0,400}법률\s*자문")


@dataclass
class CaseResult:
    id: int
    question: str
    answer: str
    elapsed_s: float
    missed_expected: list[str]
    hit_forbidden: list[str]
    has_disclaimer: bool

    @property
    def passed(self) -> bool:
        return (
            not self.missed_expected
            and not self.hit_forbidden
            and self.has_disclaimer
        )


def _load_eval(path: Path | None = None) -> list[dict[str, Any]]:
    p = path or (Path(__file__).resolve().parents[2] / "tests" / "eval_questions.yaml")
    if not p.exists():
        raise FileNotFoundError(f"eval yaml not found: {p}")
    raw = yaml.safe_load(p.read_text(encoding="utf-8"))
    return list(raw or [])


def _check(answer: str, item: dict[str, Any]) -> tuple[list[str], list[str], bool]:
    """답변에서 expected/forbidden 정규식 매칭 + 면책 문구 검사."""
    missed = []
    for pat in item.get("expected_phrases", []) or []:
        if not re.search(pat, answer):
            missed.append(pat)
    hit = []
    for pat in item.get("forbidden_phrases", []) or []:
        if re.search(pat, answer):
            hit.append(pat)
    has_disclaimer = bool(_DISCLAIMER_RE.search(answer))
    return missed, hit, has_disclaimer


async def _generate_answer(query: str) -> tuple[str, float]:
    """파이프라인 종단 호출 → 최종 답변과 경과 시간."""
    t0 = time.monotonic()
    final = ""
    chunks: list[str] = []
    async for evt in generate(query):
        if evt["event"] == "token":
            chunks.append(evt["delta"])
        elif evt["event"] == "done":
            final = evt["answer"]
    if not final:
        final = "".join(chunks)
    return final, time.monotonic() - t0


async def run(
    *,
    limit: int | None = None,
    eval_path: Path | None = None,
    show_answers: bool = False,
    out_path: Path | None = None,
) -> int:
    items = _load_eval(eval_path)
    if limit is not None:
        items = items[:limit]
    print(f"▶ kpaa eval — 골든 질문 {len(items)}건 평가 시작")
    print(f"  (각 질문 LLM 추론 ~30–60초, 총 {len(items) * 45 / 60:.1f}분 소요 예상)\n")

    results: list[CaseResult] = []
    for i, item in enumerate(items, 1):
        q = item["question"]
        qid = item.get("id", i)
        print(f"[{i}/{len(items)}] #{qid} {q}")
        try:
            answer, secs = await _generate_answer(q)
        except Exception as e:
            print(f"  ✗ ERROR: {type(e).__name__}: {e}\n")
            results.append(
                CaseResult(
                    id=qid, question=q, answer="", elapsed_s=0.0,
                    missed_expected=item.get("expected_phrases", []) or [],
                    hit_forbidden=[],
                    has_disclaimer=False,
                )
            )
            continue

        missed, hit, has_dc = _check(answer, item)
        r = CaseResult(
            id=qid, question=q, answer=answer, elapsed_s=secs,
            missed_expected=missed, hit_forbidden=hit, has_disclaimer=has_dc,
        )
        results.append(r)

        flag = "✅" if r.passed else "❌"
        print(f"  {flag} ({secs:.1f}s) "
              f"missed={len(missed)} forbidden={len(hit)} disclaimer={has_dc}")
        if missed:
            print(f"     missed: {missed}")
        if hit:
            print(f"     forbidden hit: {hit}")
        if not has_dc:
            print("     disclaimer absent")
        if show_answers or not r.passed:
            preview = answer.replace("\n", "\n     ")
            print(f"     ─── 답변 ─────────────\n     {preview}")
        print()

    # 종합
    passed = sum(1 for r in results if r.passed)
    total = len(results)
    print("─" * 60)
    print(f"결과: {passed}/{total} 통과 ({100*passed/total if total else 0:.0f}%)")
    print("─" * 60)
    if passed < total:
        print("\n실패한 질문:")
        for r in results:
            if not r.passed:
                print(f"  #{r.id}: {r.question}")
                if r.missed_expected:
                    print(f"     missing: {r.missed_expected}")
                if r.hit_forbidden:
                    print(f"     forbidden: {r.hit_forbidden}")
                if not r.has_disclaimer:
                    print("     disclaimer absent")

    if out_path is not None:
        import json

        out_path.write_text(
            json.dumps(
                [
                    {
                        "id": r.id,
                        "question": r.question,
                        "answer": r.answer,
                        "elapsed_s": round(r.elapsed_s, 2),
                        "passed": r.passed,
                        "missed_expected": r.missed_expected,
                        "hit_forbidden": r.hit_forbidden,
                        "has_disclaimer": r.has_disclaimer,
                    }
                    for r in results
                ],
                ensure_ascii=False,
                indent=2,
            ),
            encoding="utf-8",
        )
        print(f"\n결과 저장: {out_path}")

    return 0 if passed == total else 1


__all__ = ["run", "CaseResult"]