PRobe / agent /client.py
Thakur, Mahipal
refactor: remove legacy architecture, promote clean structure to repo root
85fab7b
"""PRobe Environment Client."""
from __future__ import annotations
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
from .models import ProbeAction, ProbeObservation
class ProbeEnv(EnvClient[ProbeAction, ProbeObservation, State]):
"""
Client for the PRobe environment.
Maintains a persistent WebSocket connection to the server.
Example::
with ProbeEnv(base_url="http://localhost:8000") as env:
result = env.reset()
print(result.observation.task_description)
action = ProbeAction(
action_type="add_comment",
line_number=4,
comment="Off-by-one: range(len+1) causes IndexError",
severity="error",
category="bug",
)
result = env.step(action)
print(result.reward)
"""
def _step_payload(self, action: ProbeAction) -> dict:
payload: dict = {"action_type": action.action_type.value}
if action.line_number is not None:
payload["line_number"] = action.line_number
if action.comment is not None:
payload["comment"] = action.comment
if action.severity is not None:
payload["severity"] = action.severity.value
if action.category is not None:
payload["category"] = action.category.value
if action.classification is not None:
payload["classification"] = action.classification.value
return payload
def _parse_result(
self, payload: dict
) -> StepResult[ProbeObservation]:
obs_data: dict = payload.get("observation", {})
# Use model_validate so new fields added to ProbeObservation
# are picked up automatically without changing this method.
observation = ProbeObservation.model_validate(obs_data)
return StepResult(
observation=observation,
reward=float(payload.get("reward") or 0.0),
done=bool(payload.get("done", False)),
)
def _parse_state(self, payload: dict) -> State:
return State(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
)