Spaces:
Sleeping
Sleeping
| """ | |
| client.py — Python client for the Traffic Flow environment. | |
| This file lets anyone connect to a running Traffic Flow server | |
| (local, Docker, or HF Space) using a clean Python API — no raw | |
| HTTP calls needed. | |
| The OpenEnv spec requires this file to exist in the environment | |
| package so training frameworks (TRL, SkyRL, etc.) can import it. | |
| Usage: | |
| # Connect to a local server | |
| import requests | |
| from traffic_env.client import TrafficEnvClient | |
| client = TrafficEnvClient("http://localhost:7860") | |
| obs = client.reset(seed=42) | |
| while not obs["done"]: | |
| action = {"action_type": "extend_green", "intersection_id": 0} | |
| obs = client.step(action) | |
| print("Episode reward:", obs["info"]["cumulative_reward"]) | |
| # Connect to a HF Space | |
| client = TrafficEnvClient("https://YOUR-USERNAME-traffic-env.hf.space") | |
| """ | |
| import requests | |
| from typing import Optional, Dict, Any | |
| class TrafficEnvClient: | |
| """ | |
| Thin HTTP client for the Traffic Flow OpenEnv server. | |
| Wraps the /reset, /step, /state, and /health endpoints | |
| so calling code never needs to deal with raw HTTP. | |
| """ | |
| def __init__(self, base_url: str = "http://localhost:7860"): | |
| """ | |
| Args: | |
| base_url: URL of the running server. | |
| Local: "http://localhost:7860" | |
| Docker: "http://localhost:7860" | |
| HF: "https://YOUR-USERNAME-traffic-env.hf.space" | |
| """ | |
| # Strip trailing slash so URLs always look clean | |
| self.base_url = base_url.rstrip("/") | |
| def health(self) -> Dict[str, Any]: | |
| """Check if the server is running. Returns {"status": "healthy"}.""" | |
| resp = requests.get(f"{self.base_url}/health", timeout=10) | |
| resp.raise_for_status() | |
| return resp.json() | |
| def reset(self, seed: Optional[int] = None) -> Dict[str, Any]: | |
| """ | |
| Start a new episode. | |
| Args: | |
| seed: Integer seed for reproducibility. | |
| Same seed → same episode every time. | |
| Returns: | |
| Initial observation as a dict. | |
| """ | |
| params = {"seed": seed} if seed is not None else {} | |
| resp = requests.post( | |
| f"{self.base_url}/reset", | |
| params=params, | |
| timeout=30, | |
| ) | |
| resp.raise_for_status() | |
| return resp.json() | |
| def step(self, action: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Apply one action to the environment. | |
| Args: | |
| action: Dict with keys: | |
| - action_type: "extend_green" or "next_phase" | |
| - intersection_id: int (0-indexed) | |
| Returns: | |
| New observation as a dict (includes reward and done flag). | |
| Example: | |
| obs = client.step({"action_type": "extend_green", "intersection_id": 0}) | |
| print(obs["reward"]) # e.g. +0.23 | |
| print(obs["done"]) # False until episode ends | |
| """ | |
| resp = requests.post( | |
| f"{self.base_url}/step", | |
| json=action, | |
| timeout=30, | |
| ) | |
| resp.raise_for_status() | |
| return resp.json() | |
| def state(self) -> Dict[str, Any]: | |
| """ | |
| Get episode-level metadata. | |
| Returns: | |
| Dict with episode_id, step_count, cumulative_reward, etc. | |
| """ | |
| resp = requests.get(f"{self.base_url}/state", timeout=10) | |
| resp.raise_for_status() | |
| return resp.json() | |
| def run_episode(self, policy_fn, seed: Optional[int] = None) -> Dict[str, Any]: | |
| """ | |
| Convenience method: run a full episode with a given policy function. | |
| Args: | |
| policy_fn: Callable that takes an observation dict and returns | |
| an action dict. Example: | |
| def my_policy(obs): | |
| return {"action_type": "next_phase", | |
| "intersection_id": 0} | |
| seed: Optional seed for reproducibility. | |
| Returns: | |
| Final episode state (cumulative_reward, throughput, etc.) | |
| """ | |
| obs = self.reset(seed=seed) | |
| while not obs.get("done", False): | |
| action = policy_fn(obs) | |
| obs = self.step(action) | |
| return self.state() | |
| # Quick smoke test ---------------------------- | |
| if __name__ == "__main__": | |
| import sys | |
| url = sys.argv[1] if len(sys.argv) > 1 else "http://localhost:7860" | |
| client = TrafficEnvClient(url) | |
| print(f"Connecting to {url}...") | |
| h = client.health() | |
| print(f"Health: {h}") | |
| print("Resetting episode (seed=42)...") | |
| obs = client.reset(seed=42) | |
| print(f"Initial obs: {obs['total_waiting_vehicles']} waiting, done={obs['done']}") | |
| print("Taking one action (extend_green)...") | |
| obs = client.step({"action_type": "extend_green", "intersection_id": 0}) | |
| print(f"After step: reward={obs['reward']:.4f}, throughput={obs['throughput_last_step']}") | |
| s = client.state() | |
| print(f"State: step={s['step_count']}, cumulative_reward={s['cumulative_reward']}") | |
| print("Client smoke test passed.") | |