Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| from typing import Any | |
| from urllib import parse, request | |
| from .env import SupportTicketEnvironment | |
| from .fixtures import list_task_ids | |
| from .models import SupportTicketAction, SupportTicketStepResult | |
| class SupportTicketEnv: | |
| def __init__(self, base_url: str | None = None, task_id: str | None = None) -> None: | |
| self.base_url = base_url.rstrip("/") if base_url else None | |
| self.task_id = task_id | |
| self.session_id: str | None = None | |
| self._local_env = SupportTicketEnvironment(task_id=task_id) if not self.base_url else None | |
| def from_docker_image( | |
| cls, | |
| image_name: str, | |
| base_url: str = "http://localhost:8000", | |
| task_id: str | None = None, | |
| ) -> "SupportTicketEnv": | |
| del image_name | |
| return cls(base_url=base_url, task_id=task_id) | |
| def from_env( | |
| cls, | |
| repo_id: str, | |
| base_url: str, | |
| task_id: str | None = None, | |
| ) -> "SupportTicketEnv": | |
| del repo_id | |
| return cls(base_url=base_url, task_id=task_id) | |
| def reset(self, task_id: str | None = None) -> SupportTicketStepResult: | |
| effective_task_id = task_id or self.task_id | |
| if self._local_env is not None: | |
| return self._local_env.reset(effective_task_id) | |
| payload = {} | |
| if effective_task_id: | |
| payload["task_id"] = effective_task_id | |
| if self.session_id: | |
| payload["session_id"] = self.session_id | |
| result = self._post_json("/reset", payload) | |
| self.session_id = result.info.get("session_id") | |
| return result | |
| def step(self, action: SupportTicketAction | dict[str, Any]) -> SupportTicketStepResult: | |
| if self._local_env is not None: | |
| return self._local_env.step(action) | |
| payload = { | |
| "session_id": self.session_id, | |
| "action": action.model_dump(mode="json") if hasattr(action, "model_dump") else action, | |
| } | |
| result = self._post_json("/step", payload) | |
| self.session_id = result.info.get("session_id", self.session_id) | |
| return result | |
| def state(self) -> dict[str, Any]: | |
| if self._local_env is not None: | |
| return self._local_env.state() | |
| if not self.session_id: | |
| raise RuntimeError("reset() must be called before state() when using HTTP mode.") | |
| query = parse.urlencode({"session_id": self.session_id}) | |
| with request.urlopen(f"{self.base_url}/state?{query}") as response: | |
| return json.loads(response.read().decode("utf-8")) | |
| def close(self) -> None: | |
| self.session_id = None | |
| def __enter__(self) -> "SupportTicketEnv": | |
| return self | |
| def __exit__(self, exc_type, exc, tb) -> None: | |
| self.close() | |
| return None | |
| def list_tasks() -> list[str]: | |
| return list_task_ids() | |
| def _post_json(self, path: str, payload: dict[str, Any]) -> SupportTicketStepResult: | |
| body = json.dumps(payload).encode("utf-8") | |
| req = request.Request( | |
| f"{self.base_url}{path}", | |
| data=body, | |
| headers={"Content-Type": "application/json"}, | |
| method="POST", | |
| ) | |
| with request.urlopen(req) as response: | |
| data = json.loads(response.read().decode("utf-8")) | |
| return SupportTicketStepResult.model_validate(data) | |