Dar3devil's picture
Initial customer support OpenEnv upload
2b73c16 verified
Raw
History Blame Contribute Delete
3.51 kB
from __future__ import annotations
import json
from typing import Any
from urllib import parse, request
from .env import SupportTicketEnvironment
from .fixtures import list_task_ids
from .models import SupportTicketAction, SupportTicketStepResult
class SupportTicketEnv:
def __init__(self, base_url: str | None = None, task_id: str | None = None) -> None:
self.base_url = base_url.rstrip("/") if base_url else None
self.task_id = task_id
self.session_id: str | None = None
self._local_env = SupportTicketEnvironment(task_id=task_id) if not self.base_url else None
@classmethod
def from_docker_image(
cls,
image_name: str,
base_url: str = "http://localhost:8000",
task_id: str | None = None,
) -> "SupportTicketEnv":
del image_name
return cls(base_url=base_url, task_id=task_id)
@classmethod
def from_env(
cls,
repo_id: str,
base_url: str,
task_id: str | None = None,
) -> "SupportTicketEnv":
del repo_id
return cls(base_url=base_url, task_id=task_id)
def reset(self, task_id: str | None = None) -> SupportTicketStepResult:
effective_task_id = task_id or self.task_id
if self._local_env is not None:
return self._local_env.reset(effective_task_id)
payload = {}
if effective_task_id:
payload["task_id"] = effective_task_id
if self.session_id:
payload["session_id"] = self.session_id
result = self._post_json("/reset", payload)
self.session_id = result.info.get("session_id")
return result
def step(self, action: SupportTicketAction | dict[str, Any]) -> SupportTicketStepResult:
if self._local_env is not None:
return self._local_env.step(action)
payload = {
"session_id": self.session_id,
"action": action.model_dump(mode="json") if hasattr(action, "model_dump") else action,
}
result = self._post_json("/step", payload)
self.session_id = result.info.get("session_id", self.session_id)
return result
def state(self) -> dict[str, Any]:
if self._local_env is not None:
return self._local_env.state()
if not self.session_id:
raise RuntimeError("reset() must be called before state() when using HTTP mode.")
query = parse.urlencode({"session_id": self.session_id})
with request.urlopen(f"{self.base_url}/state?{query}") as response:
return json.loads(response.read().decode("utf-8"))
def close(self) -> None:
self.session_id = None
def __enter__(self) -> "SupportTicketEnv":
return self
def __exit__(self, exc_type, exc, tb) -> None:
self.close()
return None
@staticmethod
def list_tasks() -> list[str]:
return list_task_ids()
def _post_json(self, path: str, payload: dict[str, Any]) -> SupportTicketStepResult:
body = json.dumps(payload).encode("utf-8")
req = request.Request(
f"{self.base_url}{path}",
data=body,
headers={"Content-Type": "application/json"},
method="POST",
)
with request.urlopen(req) as response:
data = json.loads(response.read().decode("utf-8"))
return SupportTicketStepResult.model_validate(data)