File size: 6,484 Bytes
e181764
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""Reward/evaluation rubrics for HR Onboarding/Offboarding tasks.

Each task has a set of rubric criteria. This module evaluates agent action logs
against those criteria to compute rewards.
"""

import re
from typing import Any
try:
    from .tasks import Task
except ImportError:
    from tasks import Task


class RubricEvaluator:
    """Evaluates agent performance against task rubric criteria."""

    def __init__(self):
        self._checkers = {
            "tool_used": self._check_tool_used,
            "tool_not_used": self._check_tool_not_used,
            "tool_used_any": self._check_tool_used_any,
            "param_value": self._check_param_value,
            "param_contains": self._check_param_contains,
            "tool_order": self._check_tool_order,
            "tool_count": self._check_tool_count,
            "result_contains": self._check_result_contains,
        }

    def evaluate(self, task: Task, action_log: list[dict]) -> dict:
        """Evaluate action log against task rubric criteria.

        Returns:
            {
                "task_id": str,
                "criteria_results": list of {name, passed, description},
                "score": float (0.0-1.0),
                "passed": bool (all criteria satisfied),
            }
        """
        criteria_results = []
        for criterion in task.rubric_criteria:
            check_str = criterion["check"]
            passed = self._evaluate_criterion(check_str, action_log)
            criteria_results.append({
                "name": criterion["name"],
                "description": criterion["description"],
                "passed": passed,
            })

        total = len(criteria_results)
        passed_count = sum(1 for c in criteria_results if c["passed"])
        score = passed_count / total if total > 0 else 0.0

        return {
            "task_id": task.task_id,
            "criteria_results": criteria_results,
            "score": score,
            "passed": all(c["passed"] for c in criteria_results),
            "passed_count": passed_count,
            "total_criteria": total,
        }

    def _evaluate_criterion(self, check_str: str, action_log: list[dict]) -> bool:
        """Parse and evaluate a single criterion check string."""
        # Parse check type and args
        parts = check_str.split(":", 1)
        if len(parts) != 2:
            return False

        check_type = parts[0]
        check_args = parts[1]

        checker = self._checkers.get(check_type)
        if not checker:
            return False

        return checker(check_args, action_log)

    def _check_tool_used(self, tool_name: str, action_log: list[dict]) -> bool:
        """Check if a specific tool was used at least once."""
        return any(a["tool"] == tool_name for a in action_log)

    def _check_tool_not_used(self, tool_name: str, action_log: list[dict]) -> bool:
        """Check that a specific tool was NOT used."""
        return not any(a["tool"] == tool_name for a in action_log)

    def _check_tool_used_any(self, tools_csv: str, action_log: list[dict]) -> bool:
        """Check if any of the comma-separated tools were used."""
        tool_names = [t.strip() for t in tools_csv.split(",")]
        return any(a["tool"] in tool_names for a in action_log)

    def _check_param_value(self, spec: str, action_log: list[dict]) -> bool:
        """Check if a tool was called with a specific parameter value.
        Format: tool_name.param_name=expected_value
        """
        match = re.match(r"(\w+)\.(\w+)=(.+)", spec)
        if not match:
            return False
        tool_name, param_name, expected_value = match.groups()

        for action in action_log:
            if action["tool"] == tool_name:
                actual = action["params"].get(param_name)
                if actual is not None and str(actual) == expected_value:
                    return True
                # Check nested in 'updates' dict
                updates = action["params"].get("updates", {})
                if param_name in updates and str(updates[param_name]) == expected_value:
                    return True
        return False

    def _check_param_contains(self, spec: str, action_log: list[dict]) -> bool:
        """Check if a tool parameter contains a substring.
        Format: tool_name.param_name=substring
        """
        match = re.match(r"(\w+)\.(\w+)=(.+)", spec)
        if not match:
            return False
        tool_name, param_name, substring = match.groups()

        for action in action_log:
            if action["tool"] == tool_name:
                actual = action["params"].get(param_name, "")
                if substring.lower() in str(actual).lower():
                    return True
        return False

    def _check_tool_order(self, spec: str, action_log: list[dict]) -> bool:
        """Check that tool A was called before tool B.
        Format: tool_a<tool_b
        """
        parts = spec.split("<")
        if len(parts) != 2:
            return False
        tool_a, tool_b = parts

        idx_a = None
        idx_b = None
        for i, action in enumerate(action_log):
            if action["tool"] == tool_a and idx_a is None:
                idx_a = i
            if action["tool"] == tool_b and idx_b is None:
                idx_b = i

        if idx_a is None or idx_b is None:
            return False
        return idx_a < idx_b

    def _check_tool_count(self, spec: str, action_log: list[dict]) -> bool:
        """Check that a tool was called at least N times.
        Format: tool_name>=N
        """
        match = re.match(r"(\w+)>=(\d+)", spec)
        if not match:
            return False
        tool_name, min_count = match.groups()
        min_count = int(min_count)

        count = sum(1 for a in action_log if a["tool"] == tool_name)
        return count >= min_count

    def _check_result_contains(self, substring: str, action_log: list[dict]) -> bool:
        """Check if any tool result contains a substring."""
        for action in action_log:
            result_str = str(action.get("result", ""))
            if substring.lower() in result_str.lower():
                return True
        return False


def compute_reward(task: Task, action_log: list[dict]) -> float:
    """Convenience function to compute reward for a task given action log."""
    evaluator = RubricEvaluator()
    result = evaluator.evaluate(task, action_log)
    return result["score"]