Spaces:
Runtime error
Runtime error
| """HR environment — wraps 4 tool servers with OpenEnv's reset/step/state contract.""" | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| from typing import Any | |
| from uuid import uuid4 | |
| import requests | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| from simlab_hr.evaluator import evaluate_episode | |
| from simlab_hr.models import HRAction, HRObservation | |
| from simlab_hr.tasks import BUNDLED_TASKS, get_task | |
| logger = logging.getLogger(__name__) | |
| MAX_STEPS_PER_EPISODE = 30 | |
| TOOL_SERVER_ENV_MAP = { | |
| "hrms": "HRMS_TOOL_SERVER_URL", | |
| "email": "EMAIL_TOOL_SERVER_URL", | |
| "calendar": "CALENDAR_TOOL_SERVER_URL", | |
| "rocketchat": "ROCKETCHAT_TOOL_SERVER_URL", | |
| } | |
| TOOL_SERVER_DEFAULTS = { | |
| "hrms": "http://localhost:8030", | |
| "email": "http://localhost:8040", | |
| "calendar": "http://localhost:8050", | |
| "rocketchat": "http://localhost:8060", | |
| } | |
| class HREnvironment(Environment): | |
| """OpenEnv environment backed by SimLab's HR tool servers.""" | |
| def __init__(self) -> None: | |
| self._server_urls: dict[str, str] = {} | |
| for name, env_var in TOOL_SERVER_ENV_MAP.items(): | |
| self._server_urls[name] = os.environ.get(env_var, TOOL_SERVER_DEFAULTS[name]) | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._current_task = BUNDLED_TASKS[0] | |
| self._tools: dict[str, list[str]] = {} | |
| self._episode_count = 0 | |
| self._action_history: list[dict[str, Any]] = [] | |
| def reset(self) -> HRObservation: | |
| self._current_task = get_task(self._episode_count) | |
| self._episode_count += 1 | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._tools = self._discover_all_tools() | |
| self._action_history = [] | |
| return HRObservation( | |
| result=( | |
| "HR environment ready. You have access to 4 tool servers: " | |
| "hrms (employee records, leave, payroll), email (inbox), " | |
| "calendar (scheduling), and rocketchat (team messaging). " | |
| "When you've completed the task, call tool_name='submit_task' " | |
| "on any server to trigger evaluation and get your score." | |
| ), | |
| is_error=False, | |
| tools_available=self._tools, | |
| task_instruction=self._current_task.instruction, | |
| done=False, | |
| reward=0.0, | |
| ) | |
| def step(self, action: HRAction) -> HRObservation: | |
| self._state.step_count += 1 | |
| if action.tool_name == "submit_task": | |
| return self._evaluate_and_finish() | |
| server_url = self._server_urls.get(action.tool_server) | |
| if server_url is None: | |
| result = f"Unknown tool server: '{action.tool_server}'. Use one of: hrms, email, calendar, rocketchat." | |
| is_error = True | |
| else: | |
| result, is_error = self._call_tool(server_url, action) | |
| self._action_history.append({ | |
| "step": self._state.step_count, | |
| "server": action.tool_server, | |
| "tool": action.tool_name, | |
| "parameters": action.parameters, | |
| "result": result[:2000], | |
| "is_error": is_error, | |
| }) | |
| at_step_limit = self._state.step_count >= MAX_STEPS_PER_EPISODE | |
| if at_step_limit: | |
| return self._evaluate_and_finish() | |
| return HRObservation( | |
| result=result, | |
| is_error=is_error, | |
| tools_available=self._tools, | |
| task_instruction=self._current_task.instruction, | |
| done=False, | |
| reward=0.0, | |
| ) | |
| def state(self) -> State: | |
| return self._state | |
| def _call_tool(self, server_url: str, action: HRAction) -> tuple[str, bool]: | |
| """Proxy a tool call to the appropriate server. Returns (result, is_error).""" | |
| payload = {"action": {"tool_name": action.tool_name, "parameters": action.parameters}} | |
| try: | |
| resp = requests.post( | |
| f"{server_url}/step", | |
| json=payload, | |
| headers={"Content-Type": "application/json"}, | |
| timeout=30, | |
| ) | |
| result = resp.text | |
| is_error = resp.status_code != 200 | |
| try: | |
| parsed = resp.json() | |
| result = json.dumps(parsed, indent=2) if isinstance(parsed, (dict, list)) else str(parsed) | |
| except (json.JSONDecodeError, ValueError): | |
| pass | |
| return result, is_error | |
| except requests.RequestException as exc: | |
| return f"Tool invocation failed on {action.tool_server}: {exc}", True | |
| def _evaluate_and_finish(self) -> HRObservation: | |
| """Run the rubric judge and return the final observation with reward.""" | |
| eval_result = evaluate_episode( | |
| task_instruction=self._current_task.instruction, | |
| rubric=self._current_task.rubric, | |
| action_history=self._action_history, | |
| ) | |
| verdict_msg = ( | |
| f"Episode complete. Score: {eval_result.score:.2f} ({eval_result.verdict})" | |
| ) | |
| if eval_result.evidence: | |
| verdict_msg += "\nEvidence: " + "; ".join(eval_result.evidence) | |
| if eval_result.failed_criteria: | |
| verdict_msg += "\nFailed: " + "; ".join(eval_result.failed_criteria) | |
| if eval_result.error: | |
| verdict_msg += f"\nNote: {eval_result.error}" | |
| return HRObservation( | |
| result=verdict_msg, | |
| is_error=False, | |
| tools_available=self._tools, | |
| task_instruction=self._current_task.instruction, | |
| done=True, | |
| reward=eval_result.score, | |
| ) | |
| def _discover_all_tools(self) -> dict[str, list[str]]: | |
| """Fetch available tools from each tool server.""" | |
| all_tools: dict[str, list[str]] = {} | |
| for name, url in self._server_urls.items(): | |
| all_tools[name] = self._discover_tools(name, url) | |
| all_tools.setdefault("_meta", []).append("submit_task") | |
| return all_tools | |
| def _discover_tools(self, server_name: str, server_url: str) -> list[str]: | |
| """Fetch tool names from a single server's GET /tools endpoint.""" | |
| try: | |
| resp = requests.get(f"{server_url}/tools", timeout=15) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| tools = data.get("tools", []) if isinstance(data, dict) else [] | |
| return [t["name"] for t in tools if isinstance(t, dict) and "name" in t] | |
| except Exception as exc: | |
| logger.warning("Could not discover tools from %s: %s", server_name, exc) | |
| return [] | |