from __future__ import annotations import argparse import sys from pathlib import Path from typing import Any, Optional import requests ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.append(str(ROOT)) from server.softmax_surrogate_environment import DEFAULT_BUDGET, SoftmaxSurrogateEnvironment class SoftmaxSurrogateEnvClient: def __init__( self, base_url: Optional[str] = None, measurement_path: str = "data/autotune_measurements.csv", budget: int = DEFAULT_BUDGET, seed: int = 0, ) -> None: self.base_url = base_url self._local_env = None if base_url is None: self._local_env = SoftmaxSurrogateEnvironment( measurement_path=measurement_path, budget=budget, seed=seed, ) def reset(self, task: Optional[str] = None, seed: Optional[int] = None) -> dict: if self._local_env is not None: return self._local_env.reset(task=task, seed=seed) payload = {} if task is not None: payload["task"] = task if seed is not None: payload["seed"] = seed resp = requests.post(f"{self.base_url}/reset", json=payload, timeout=60) resp.raise_for_status() return resp.json() def step(self, action: Any) -> dict: if self._local_env is not None: return self._local_env.step(action) payload = action if isinstance(action, dict) else {"x": action} resp = requests.post(f"{self.base_url}/step", json=payload, timeout=120) resp.raise_for_status() return resp.json() def state(self) -> dict: if self._local_env is not None: return self._local_env.state() resp = requests.get(f"{self.base_url}/state", timeout=60) resp.raise_for_status() return resp.json() def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--remote", default=None, help="Optional base URL (e.g. http://127.0.0.1:8000)") parser.add_argument("--task", default=None) parser.add_argument("--seed", type=int, default=0) return parser.parse_args() def main() -> None: args = parse_args() client = SoftmaxSurrogateEnvClient(base_url=args.remote, seed=args.seed) print(client.reset(task=args.task)) if __name__ == "__main__": main()