graphstrike / client.py
Pandago's picture
Upload folder using huggingface_hub
50f71a7 verified
"""Python client for the Fake Gang Detection OpenEnv environment."""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any, Dict, Optional
try:
import requests
except ImportError:
requests = None # type: ignore
from models import (
AccountProfile,
FakeGangAction,
FakeGangObservation,
FakeGangState,
ActionType,
)
# ---------------------------------------------------------------------------
# Result container
# ---------------------------------------------------------------------------
@dataclass
class StepResult:
observation: FakeGangObservation
done: bool
reward: Optional[float]
message: str
# ---------------------------------------------------------------------------
# Sync HTTP client
# ---------------------------------------------------------------------------
class FakeGangEnvClient:
"""Synchronous HTTP client for the Fake Gang Detection environment."""
def __init__(self, base_url: str = "http://localhost:8000") -> None:
if requests is None:
raise ImportError("Install 'requests' to use FakeGangEnvClient.")
self.base_url = base_url.rstrip("/")
self._session = requests.Session()
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def reset(
self,
task: str = "easy",
seed: Optional[int] = None,
episode_id: Optional[str] = None,
) -> StepResult:
payload = {"task": task}
if seed is not None:
payload["seed"] = seed
if episode_id is not None:
payload["episode_id"] = episode_id
resp = self._post("/reset", payload)
return self._parse_result(resp)
def step(self, action: FakeGangAction) -> StepResult:
resp = self._post("/step", action.model_dump())
return self._parse_result(resp)
def state(self) -> FakeGangState:
resp = self._session.get(f"{self.base_url}/state")
resp.raise_for_status()
return FakeGangState(**resp.json())
def health(self) -> Dict[str, str]:
resp = self._session.get(f"{self.base_url}/health")
resp.raise_for_status()
return resp.json()
# ------------------------------------------------------------------
# Convenience shortcuts
# ------------------------------------------------------------------
def inspect(self, account_id: str) -> StepResult:
return self.step(FakeGangAction(action_type=ActionType.INSPECT, account_id=account_id))
def investigate_network(self, account_id: str) -> StepResult:
return self.step(FakeGangAction(action_type=ActionType.INVESTIGATE_NETWORK, account_id=account_id))
def flag(self, account_id: str) -> StepResult:
return self.step(FakeGangAction(action_type=ActionType.FLAG, account_id=account_id))
def unflag(self, account_id: str) -> StepResult:
return self.step(FakeGangAction(action_type=ActionType.UNFLAG, account_id=account_id))
def submit(self) -> StepResult:
return self.step(FakeGangAction(action_type=ActionType.SUBMIT))
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _post(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
resp = self._session.post(f"{self.base_url}{path}", json=payload)
resp.raise_for_status()
return resp.json()
def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
obs_data = payload["observation"]
profiles = [AccountProfile(**p) for p in obs_data.get("visible_accounts", [])]
obs = FakeGangObservation(
done=obs_data.get("done", False),
reward=obs_data.get("reward"),
visible_accounts=profiles,
visible_account_ids=obs_data.get("visible_account_ids", []),
flagged_ids=obs_data.get("flagged_ids", []),
inspected_ids=obs_data.get("inspected_ids", []),
graph_edges=obs_data.get("graph_edges", {}),
steps_remaining=obs_data.get("steps_remaining", 0),
evasion_triggered=obs_data.get("evasion_triggered", False),
evasion_count=obs_data.get("evasion_count", 0),
task=obs_data.get("task", "easy"),
message=obs_data.get("message", ""),
suspect_ids=obs_data.get("suspect_ids", []),
)
return StepResult(
observation=obs,
done=payload.get("done", False),
reward=payload.get("reward"),
message=payload.get("message", ""),
)
# ------------------------------------------------------------------
# Context manager support
# ------------------------------------------------------------------
def __enter__(self) -> "FakeGangEnvClient":
return self
def __exit__(self, *args: Any) -> None:
self._session.close()