RL_Surrogate_ENV / client.py
wlan0's picture
Surrogate Discovery vs. Pytorch.compile vs. Triton.autotune
5000a45 unverified
from __future__ import annotations
import argparse
import sys
from pathlib import Path
from typing import Any, Optional
import requests
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
from server.softmax_surrogate_environment import DEFAULT_BUDGET, SoftmaxSurrogateEnvironment
class SoftmaxSurrogateEnvClient:
def __init__(
self,
base_url: Optional[str] = None,
measurement_path: str = "data/autotune_measurements.csv",
budget: int = DEFAULT_BUDGET,
seed: int = 0,
) -> None:
self.base_url = base_url
self._local_env = None
if base_url is None:
self._local_env = SoftmaxSurrogateEnvironment(
measurement_path=measurement_path,
budget=budget,
seed=seed,
)
def reset(self, task: Optional[str] = None, seed: Optional[int] = None) -> dict:
if self._local_env is not None:
return self._local_env.reset(task=task, seed=seed)
payload = {}
if task is not None:
payload["task"] = task
if seed is not None:
payload["seed"] = seed
resp = requests.post(f"{self.base_url}/reset", json=payload, timeout=60)
resp.raise_for_status()
return resp.json()
def step(self, action: Any) -> dict:
if self._local_env is not None:
return self._local_env.step(action)
payload = action if isinstance(action, dict) else {"x": action}
resp = requests.post(f"{self.base_url}/step", json=payload, timeout=120)
resp.raise_for_status()
return resp.json()
def state(self) -> dict:
if self._local_env is not None:
return self._local_env.state()
resp = requests.get(f"{self.base_url}/state", timeout=60)
resp.raise_for_status()
return resp.json()
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("--remote", default=None, help="Optional base URL (e.g. http://127.0.0.1:8000)")
parser.add_argument("--task", default=None)
parser.add_argument("--seed", type=int, default=0)
return parser.parse_args()
def main() -> None:
args = parse_args()
client = SoftmaxSurrogateEnvClient(base_url=args.remote, seed=args.seed)
print(client.reset(task=args.task))
if __name__ == "__main__":
main()