DNSArenaEnv / client.py
PraneshkumarR's picture
DNS-Env: DNS Zone File Debugging Environment for OpenEnv
0d58c6d
"""HTTP client for the DNS-Env OpenEnv environment."""
from __future__ import annotations
import requests
from typing import Any
class DNSEnvClient:
"""Simple HTTP client for the DNS environment server.
Wraps the REST endpoints exposed by the FastAPI server so callers
can interact with the environment from plain Python without manually
assembling HTTP requests.
Parameters
----------
base_url:
Root URL of the environment server (default ``http://localhost:7860``).
session_id:
Identifies the caller's session. The server maintains independent
environment state per session.
timeout:
HTTP request timeout in seconds.
"""
def __init__(
self,
base_url: str = "http://localhost:7860",
session_id: str = "default",
timeout: float = 30.0,
):
self.base_url = base_url.rstrip("/")
self.session_id = session_id
self.timeout = timeout
# ------------------------------------------------------------------
# Endpoints
# ------------------------------------------------------------------
def health(self) -> dict:
"""Liveness check. Returns ``{"status": "ok"}`` when the server is up."""
resp = requests.get(f"{self.base_url}/health", timeout=self.timeout)
resp.raise_for_status()
return resp.json()
def reset(
self,
task_id: str | None = None,
seed: int | None = None,
episode_id: str | None = None,
) -> dict:
"""Reset the environment and start a new episode.
Parameters
----------
task_id:
One of ``fix_single_record``, ``configure_mail``,
``debug_delegation``. When *None* the server cycles tasks.
seed:
Optional RNG seed for reproducibility.
episode_id:
Optional caller-supplied episode identifier.
Returns
-------
dict
Observation JSON with keys: ``output``, ``task_description``,
``zone_names``, ``available_commands``, ``done``, ``reward``,
``metadata``.
"""
body: dict[str, Any] = {"session_id": self.session_id}
options: dict[str, Any] = {}
if task_id:
options["task_id"] = task_id
if options:
body["options"] = options
if seed is not None:
body["seed"] = seed
if episode_id is not None:
body["episode_id"] = episode_id
resp = requests.post(
f"{self.base_url}/reset", json=body, timeout=self.timeout
)
resp.raise_for_status()
return resp.json()
def step(self, command: str, **args: Any) -> dict:
"""Execute one action in the environment.
Parameters
----------
command:
One of the available commands (``view_zone``, ``add_record``,
``edit_record``, ``delete_record``, ``check_zone``, ``dig``,
``submit``).
**args:
Keyword arguments forwarded as the action's ``args`` dict.
Returns
-------
dict
Observation JSON.
"""
body = {
"session_id": self.session_id,
"action": {"command": command, "args": args},
}
resp = requests.post(
f"{self.base_url}/step", json=body, timeout=self.timeout
)
resp.raise_for_status()
return resp.json()
def state(self) -> dict:
"""Return the current episode state (step count, task id, etc.)."""
resp = requests.get(
f"{self.base_url}/state",
params={"session_id": self.session_id},
timeout=self.timeout,
)
resp.raise_for_status()
return resp.json()
def tasks(self) -> list[str]:
"""List available task identifiers."""
resp = requests.get(f"{self.base_url}/tasks", timeout=self.timeout)
resp.raise_for_status()
return resp.json()["tasks"]