vinay-pepakayala's picture
Upload folder using huggingface_hub
751aa40 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""Githubissuetriage Environment Client."""
import argparse
from contextlib import AbstractContextManager
from typing import Dict
from openenv.core import EnvClient
from openenv.core.client_types import StepResult
from openenv.core.env_server.types import State
from openenv.core.sync_client import SyncEnvClient
try:
from GitHubIssueTriage.models import Action, Observation
except ImportError: # pragma: no cover
from models import Action, Observation
class GithubissuetriageEnv(
EnvClient[Action, Observation, State]
):
"""
Client for the Githubissuetriage Environment.
This client maintains a persistent WebSocket connection to the environment server,
enabling efficient multi-step interactions with lower latency.
Each client instance has its own dedicated environment session on the server.
Example:
>>> # Connect to a running server over websocket
>>> with GithubissuetriageEnv(base_url="http://localhost:8000").session() as session:
... result = session.initial_result
... print(result.observation.echoed_message)
...
... result = session.step(Action(message="Hello!"))
... print(result.observation.echoed_message)
Example with Docker:
>>> # Automatically start container and connect
>>> client = GithubissuetriageEnv.from_docker_image("GitHubIssueTriage-env:latest")
>>> try:
... result = client.reset()
... result = client.step(Action(message="Test"))
... finally:
... client.close()
"""
def __init__(
self,
base_url: str,
connect_timeout_s: float = 10.0,
message_timeout_s: float = 60.0,
max_message_size_mb: float = 100.0,
provider=None,
mode: str | None = None,
**kwargs,
) -> None:
self.episodes = kwargs.pop("episodes", None)
self.strict_mode = kwargs.pop("strict_mode", None)
self.live_github = kwargs.pop("live_github", None)
self.extra_env_kwargs = kwargs
super().__init__(
base_url=base_url,
connect_timeout_s=connect_timeout_s,
message_timeout_s=message_timeout_s,
max_message_size_mb=max_message_size_mb,
provider=provider,
mode=mode,
)
class Session(AbstractContextManager["GithubissuetriageEnv.Session"]):
"""Context manager that opens a websocket connection and resets the env."""
def __init__(
self,
client: "GithubissuetriageEnv",
*,
task_id: str | None = None,
difficulty: str | None = None,
seed: int | None = None,
) -> None:
self._client = client
self._task_id = task_id
self._difficulty = difficulty
self._seed = seed
self._session = client.sync()
self.initial_result = None
def __enter__(self) -> "GithubissuetriageEnv.Session":
self._session.__enter__()
self.initial_result = self._session.reset(
task_id=self._task_id,
difficulty=self._difficulty,
seed=self._seed,
)
return self
def __exit__(self, exc_type, exc, tb) -> bool | None:
return self._session.__exit__(exc_type, exc, tb)
def reset(self, **kwargs):
return self._session.reset(**kwargs)
def step(self, action, **kwargs):
return self._session.step(action, **kwargs)
def close(self) -> None:
self._session.close()
def websocket_session(self) -> SyncEnvClient[Action, Observation, State]:
"""
Return a synchronous websocket session wrapper for this environment client.
This is the easiest way to use the environment from regular Python code:
with GithubissuetriageEnv(base_url="http://localhost:8000").websocket_session() as client:
result = client.reset()
"""
return self.sync()
def session(
self,
*,
task_id: str | None = None,
difficulty: str | None = None,
seed: int | None = None,
) -> "GithubissuetriageEnv.Session":
"""Open a websocket session and reset the environment on entry."""
return GithubissuetriageEnv.Session(
self,
task_id=task_id,
difficulty=difficulty,
seed=seed,
)
def _step_payload(self, action: Action | Dict) -> Dict:
"""
Convert Action to JSON payload for step message.
Args:
action: Action instance or dict
Returns:
Dictionary representation suitable for JSON encoding
"""
if isinstance(action, dict):
action_dict = action.copy()
elif hasattr(action, "model_dump"):
action_dict = action.model_dump(exclude_none=True, mode="json")
else:
raise TypeError(f"Unsupported action type: {type(action).__name__}")
# Ensure 'type' field exists and is a string
if "type" not in action_dict:
raise ValueError(f"Action missing 'type' field: {action_dict}")
# Convert ActionType enum to string if needed
if hasattr(action_dict["type"], "value"):
action_dict["type"] = action_dict["type"].value
# Log the exact payload being sent
import json
print(f"[CLIENT_STEP_PAYLOAD] {json.dumps(action_dict)}", flush=True)
return action_dict
def _parse_result(self, payload: Dict) -> StepResult[Observation]:
"""
Parse server response into StepResult[Observation].
Args:
payload: JSON response data from server
Returns:
StepResult with Observation
"""
try:
print(f"[CLIENT_PARSE_RESULT] Received payload keys: {list(payload.keys())}", flush=True)
# Extract observation data - server sends nested 'observation' object
obs_data = payload.get("observation", payload)
if not isinstance(obs_data, dict):
obs_data = {}
# Some servers wrap step data as {"observation": {"observation": ..., "info": ...}}
if (
"observation" in obs_data
and isinstance(obs_data.get("observation"), dict)
and "episode_id" not in obs_data
):
obs_data = obs_data["observation"]
raw_reward = payload.get("reward", obs_data.get("reward"))
if isinstance(raw_reward, dict):
reward_value = raw_reward.get("total")
else:
reward_value = raw_reward
print(f"[CLIENT_PARSE_RESULT] Observation keys: {list(obs_data.keys())}", flush=True)
# Build observation with all required fields
observation = Observation.model_validate(
{
**obs_data,
"reward": reward_value,
"done": payload.get("done", obs_data.get("done", False)),
}
)
return StepResult(
observation=observation,
reward=reward_value,
done=payload.get("done", obs_data.get("done", False)),
)
except Exception as e:
# Log the error but provide helpful debugging info
print(f"[DEBUG] _parse_result failed: {e}", flush=True)
print(f"[DEBUG] Payload structure: {list(payload.keys())}", flush=True)
if "observation" in payload:
print(f"[DEBUG] Observation keys: {list(payload['observation'].keys())}", flush=True)
raise
def _parse_state(self, payload: Dict) -> State:
"""
Parse server response into State object.
Args:
payload: JSON response from state request
Returns:
State object with episode_id and step_count
"""
return State(
episode_id=payload.get("episode_id"),
step_count=payload.get("step_count", 0),
)
def main() -> None:
"""Open a websocket session and reset the environment from the command line."""
parser = argparse.ArgumentParser()
parser.add_argument(
"--base-url",
default="http://localhost:8000",
help="Environment server URL (http:// or ws://).",
)
parser.add_argument("--task-id", default=None, help="Optional task or episode id.")
parser.add_argument(
"--difficulty",
default=None,
help="Optional difficulty filter (easy, medium, hard).",
)
parser.add_argument("--seed", type=int, default=None, help="Optional reset seed.")
args = parser.parse_args()
session = GithubissuetriageEnv(base_url=args.base_url).session(
task_id=args.task_id,
difficulty=args.difficulty,
seed=args.seed,
)
with session:
print(session.initial_result.observation.model_dump_json(indent=2))
if __name__ == "__main__":
main()