File size: 6,672 Bytes
bd67f06
 
 
 
 
 
 
384d994
bd67f06
 
 
 
 
 
384d994
bd67f06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384d994
bd67f06
 
 
 
 
 
384d994
bd67f06
 
 
 
 
 
384d994
 
bd67f06
 
 
 
 
 
 
 
 
 
 
384d994
 
 
bd67f06
 
384d994
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd67f06
384d994
 
bd67f06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384d994
bd67f06
384d994
 
 
 
 
 
 
 
 
bd67f06
384d994
 
 
 
 
 
 
 
 
bd67f06
 
384d994
 
bd67f06
 
384d994
 
bd67f06
 
 
 
 
 
 
384d994
bd67f06
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""HR environment — wraps 4 tool servers with OpenEnv's reset/step/state contract."""

from __future__ import annotations

import json
import logging
import os
from typing import Any
from uuid import uuid4

import requests
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State

from simlab_hr.evaluator import evaluate_episode
from simlab_hr.models import HRAction, HRObservation
from simlab_hr.tasks import BUNDLED_TASKS, get_task

logger = logging.getLogger(__name__)

MAX_STEPS_PER_EPISODE = 30

TOOL_SERVER_ENV_MAP = {
    "hrms": "HRMS_TOOL_SERVER_URL",
    "email": "EMAIL_TOOL_SERVER_URL",
    "calendar": "CALENDAR_TOOL_SERVER_URL",
    "rocketchat": "ROCKETCHAT_TOOL_SERVER_URL",
}

TOOL_SERVER_DEFAULTS = {
    "hrms": "http://localhost:8030",
    "email": "http://localhost:8040",
    "calendar": "http://localhost:8050",
    "rocketchat": "http://localhost:8060",
}


class HREnvironment(Environment):
    """OpenEnv environment backed by SimLab's HR tool servers."""

    def __init__(self) -> None:
        self._server_urls: dict[str, str] = {}
        for name, env_var in TOOL_SERVER_ENV_MAP.items():
            self._server_urls[name] = os.environ.get(env_var, TOOL_SERVER_DEFAULTS[name])

        self._state = State(episode_id=str(uuid4()), step_count=0)
        self._current_task = BUNDLED_TASKS[0]
        self._tools: dict[str, list[str]] = {}
        self._episode_count = 0
        self._action_history: list[dict[str, Any]] = []

    def reset(self) -> HRObservation:
        self._current_task = get_task(self._episode_count)
        self._episode_count += 1
        self._state = State(episode_id=str(uuid4()), step_count=0)
        self._tools = self._discover_all_tools()
        self._action_history = []

        return HRObservation(
            result=(
                "HR environment ready. You have access to 4 tool servers: "
                "hrms (employee records, leave, payroll), email (inbox), "
                "calendar (scheduling), and rocketchat (team messaging). "
                "When you've completed the task, call tool_name='submit_task' "
                "on any server to trigger evaluation and get your score."
            ),
            is_error=False,
            tools_available=self._tools,
            task_instruction=self._current_task.instruction,
            done=False,
            reward=0.0,
        )

    def step(self, action: HRAction) -> HRObservation:
        self._state.step_count += 1

        if action.tool_name == "submit_task":
            return self._evaluate_and_finish()

        server_url = self._server_urls.get(action.tool_server)
        if server_url is None:
            result = f"Unknown tool server: '{action.tool_server}'. Use one of: hrms, email, calendar, rocketchat."
            is_error = True
        else:
            result, is_error = self._call_tool(server_url, action)

        self._action_history.append({
            "step": self._state.step_count,
            "server": action.tool_server,
            "tool": action.tool_name,
            "parameters": action.parameters,
            "result": result[:2000],
            "is_error": is_error,
        })

        at_step_limit = self._state.step_count >= MAX_STEPS_PER_EPISODE
        if at_step_limit:
            return self._evaluate_and_finish()

        return HRObservation(
            result=result,
            is_error=is_error,
            tools_available=self._tools,
            task_instruction=self._current_task.instruction,
            done=False,
            reward=0.0,
        )

    @property
    def state(self) -> State:
        return self._state

    def _call_tool(self, server_url: str, action: HRAction) -> tuple[str, bool]:
        """Proxy a tool call to the appropriate server. Returns (result, is_error)."""
        payload = {"action": {"tool_name": action.tool_name, "parameters": action.parameters}}
        try:
            resp = requests.post(
                f"{server_url}/step",
                json=payload,
                headers={"Content-Type": "application/json"},
                timeout=30,
            )
            result = resp.text
            is_error = resp.status_code != 200
            try:
                parsed = resp.json()
                result = json.dumps(parsed, indent=2) if isinstance(parsed, (dict, list)) else str(parsed)
            except (json.JSONDecodeError, ValueError):
                pass
            return result, is_error
        except requests.RequestException as exc:
            return f"Tool invocation failed on {action.tool_server}: {exc}", True

    def _evaluate_and_finish(self) -> HRObservation:
        """Run the rubric judge and return the final observation with reward."""
        eval_result = evaluate_episode(
            task_instruction=self._current_task.instruction,
            rubric=self._current_task.rubric,
            action_history=self._action_history,
        )

        verdict_msg = (
            f"Episode complete. Score: {eval_result.score:.2f} ({eval_result.verdict})"
        )
        if eval_result.evidence:
            verdict_msg += "\nEvidence: " + "; ".join(eval_result.evidence)
        if eval_result.failed_criteria:
            verdict_msg += "\nFailed: " + "; ".join(eval_result.failed_criteria)
        if eval_result.error:
            verdict_msg += f"\nNote: {eval_result.error}"

        return HRObservation(
            result=verdict_msg,
            is_error=False,
            tools_available=self._tools,
            task_instruction=self._current_task.instruction,
            done=True,
            reward=eval_result.score,
        )

    def _discover_all_tools(self) -> dict[str, list[str]]:
        """Fetch available tools from each tool server."""
        all_tools: dict[str, list[str]] = {}
        for name, url in self._server_urls.items():
            all_tools[name] = self._discover_tools(name, url)
        all_tools.setdefault("_meta", []).append("submit_task")
        return all_tools

    def _discover_tools(self, server_name: str, server_url: str) -> list[str]:
        """Fetch tool names from a single server's GET /tools endpoint."""
        try:
            resp = requests.get(f"{server_url}/tools", timeout=15)
            resp.raise_for_status()
            data = resp.json()
            tools = data.get("tools", []) if isinstance(data, dict) else []
            return [t["name"] for t in tools if isinstance(t, dict) and "name" in t]
        except Exception as exc:
            logger.warning("Could not discover tools from %s: %s", server_name, exc)
            return []