File size: 4,355 Bytes
384d994
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""Lightweight rubric-based LLM judge for the HR environment."""

from __future__ import annotations

import json
import logging
import os
from dataclasses import dataclass
from typing import Any

logger = logging.getLogger(__name__)

PASS_THRESHOLD = 0.6

SYSTEM_PROMPT = """\
You are an impartial evaluator assessing whether an AI agent successfully \
completed an HR task. Score accurately based on evidence from the action trace.

Scoring:
- 0.8-1.0: All requirements fully met with clear evidence.
- 0.6-0.8: Core requirements met with minor gaps. (0.6 = PASS)
- 0.4-0.6: Partial completion, significant gaps remain.
- 0.2-0.4: Minimal progress, most requirements failed.
- 0.0-0.2: No meaningful progress.

Respond with valid JSON (no markdown fences):
{"score": 0.0, "verdict": "PASS or FAIL", "evidence": ["..."], "failed_criteria": ["..."]}"""


@dataclass
class EvalResult:
    """Result from the rubric judge."""

    score: float
    verdict: str
    evidence: list[str]
    failed_criteria: list[str]
    error: str | None = None


def evaluate_episode(
    *,
    task_instruction: str,
    rubric: list[str],
    action_history: list[dict[str, Any]],
) -> EvalResult:
    """Run the rubric judge on a completed episode. Returns EvalResult with 0.0-1.0 score."""
    model = os.environ.get("VERIFIER_MODEL", "").strip()
    api_key = os.environ.get("VERIFIER_API_KEY", "").strip()

    if not model or not api_key:
        return EvalResult(
            score=0.0,
            verdict="SKIPPED",
            evidence=[],
            failed_criteria=[],
            error="Set VERIFIER_MODEL and VERIFIER_API_KEY to enable evaluation",
        )

    provider = os.environ.get("VERIFIER_PROVIDER", "").strip() or None
    base_url = os.environ.get("VERIFIER_BASE_URL", "").strip() or None

    rubric_text = "\n".join(f"- {r}" for r in rubric) if rubric else "No specific rubric provided."

    trace = json.dumps(action_history[-50:], indent=2, ensure_ascii=False)
    if len(trace) > 40000:
        trace = trace[:40000] + "\n... [truncated]"

    user_prompt = f"""# Task
{task_instruction}

# Rubric Criteria
{rubric_text}

# Agent Action Trace
{trace}"""

    try:
        import litellm

        litellm_model = model
        if provider and not model.startswith(f"{provider}/"):
            litellm_model = f"{provider}/{model}"

        response = litellm.completion(
            model=litellm_model,
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_prompt},
            ],
            api_key=api_key,
            base_url=base_url,
            temperature=0.2,
        )
        raw = response.choices[0].message.content or ""
    except Exception as e:
        logger.warning("Rubric judge LLM call failed: %s", e)
        return EvalResult(score=0.0, verdict="ERROR", evidence=[], failed_criteria=[], error=str(e))

    return _parse_response(raw)


def _parse_response(raw: str) -> EvalResult:
    """Parse the judge's JSON response."""
    text = raw.strip()
    if text.startswith("```"):
        text = text.strip("`\n")
        if text.lower().startswith("json"):
            text = text[4:].strip()

    try:
        data = json.loads(text)
    except json.JSONDecodeError:
        import re

        match = re.search(r"\{.*\}", text, re.DOTALL)
        if match:
            try:
                data = json.loads(match.group(0))
            except json.JSONDecodeError:
                return EvalResult(
                    score=0.0, verdict="ERROR", evidence=[], failed_criteria=[],
                    error=f"Could not parse judge response: {raw[:300]}",
                )
        else:
            return EvalResult(
                score=0.0, verdict="ERROR", evidence=[], failed_criteria=[],
                error=f"Could not parse judge response: {raw[:300]}",
            )

    score = max(0.0, min(float(data.get("score", 0.0)), 1.0))
    verdict = data.get("verdict", "PASS" if score >= PASS_THRESHOLD else "FAIL")
    evidence = data.get("evidence", [])
    if isinstance(evidence, str):
        evidence = [evidence]
    failed = data.get("failed_criteria", [])
    if isinstance(failed, str):
        failed = [failed]

    return EvalResult(score=score, verdict=str(verdict), evidence=evidence, failed_criteria=failed)