from __future__ import annotations import json from typing import Any from urllib import parse, request from .env import SupportTicketEnvironment from .fixtures import list_task_ids from .models import SupportTicketAction, SupportTicketStepResult class SupportTicketEnv: def __init__(self, base_url: str | None = None, task_id: str | None = None) -> None: self.base_url = base_url.rstrip("/") if base_url else None self.task_id = task_id self.session_id: str | None = None self._local_env = SupportTicketEnvironment(task_id=task_id) if not self.base_url else None @classmethod def from_docker_image( cls, image_name: str, base_url: str = "http://localhost:8000", task_id: str | None = None, ) -> "SupportTicketEnv": del image_name return cls(base_url=base_url, task_id=task_id) @classmethod def from_env( cls, repo_id: str, base_url: str, task_id: str | None = None, ) -> "SupportTicketEnv": del repo_id return cls(base_url=base_url, task_id=task_id) def reset(self, task_id: str | None = None) -> SupportTicketStepResult: effective_task_id = task_id or self.task_id if self._local_env is not None: return self._local_env.reset(effective_task_id) payload = {} if effective_task_id: payload["task_id"] = effective_task_id if self.session_id: payload["session_id"] = self.session_id result = self._post_json("/reset", payload) self.session_id = result.info.get("session_id") return result def step(self, action: SupportTicketAction | dict[str, Any]) -> SupportTicketStepResult: if self._local_env is not None: return self._local_env.step(action) payload = { "session_id": self.session_id, "action": action.model_dump(mode="json") if hasattr(action, "model_dump") else action, } result = self._post_json("/step", payload) self.session_id = result.info.get("session_id", self.session_id) return result def state(self) -> dict[str, Any]: if self._local_env is not None: return self._local_env.state() if not self.session_id: raise RuntimeError("reset() must be called before state() when using HTTP mode.") query = parse.urlencode({"session_id": self.session_id}) with request.urlopen(f"{self.base_url}/state?{query}") as response: return json.loads(response.read().decode("utf-8")) def close(self) -> None: self.session_id = None def __enter__(self) -> "SupportTicketEnv": return self def __exit__(self, exc_type, exc, tb) -> None: self.close() return None @staticmethod def list_tasks() -> list[str]: return list_task_ids() def _post_json(self, path: str, payload: dict[str, Any]) -> SupportTicketStepResult: body = json.dumps(payload).encode("utf-8") req = request.Request( f"{self.base_url}{path}", data=body, headers={"Content-Type": "application/json"}, method="POST", ) with request.urlopen(req) as response: data = json.loads(response.read().decode("utf-8")) return SupportTicketStepResult.model_validate(data)