Spaces:
Configuration error
Configuration error
| from __future__ import annotations | |
| import argparse | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.append(str(ROOT)) | |
| from fastapi import FastAPI | |
| from fastapi import HTTPException | |
| from pydantic import BaseModel | |
| from models import ResetResult, StepResult | |
| from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment | |
| app = FastAPI(title="Autotune Benchmark OpenEnv Server") | |
| env = SoftmaxSurrogateEnvironment() | |
| class ResetRequest(BaseModel): | |
| task: Optional[str] = None | |
| seed: Optional[int] = None | |
| class StepRequest(BaseModel): | |
| config_id: Optional[int] = None | |
| x: Optional[list[float]] = None | |
| def health() -> Dict[str, str]: | |
| return {"ok": "true"} | |
| def reset(payload: ResetRequest) -> Dict[str, Any]: | |
| result = env.reset(task=payload.task, seed=payload.seed) | |
| return result | |
| def step(payload: StepRequest) -> Dict[str, Any]: | |
| if payload.config_id is not None: | |
| result = env.step({"config_id": payload.config_id}) | |
| return result | |
| if payload.x is not None: | |
| result = env.step({"x": payload.x}) | |
| return result | |
| raise HTTPException(status_code=400, detail="Missing config_id.") | |
| return result | |
| def state() -> Dict[str, Any]: | |
| return env.state() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Run softmax surrogate environment server.") | |
| parser.add_argument("--host", default="0.0.0.0") | |
| parser.add_argument("--port", type=int, default=8000) | |
| args = parser.parse_args() | |
| try: | |
| import uvicorn | |
| uvicorn.run("app:app", host=args.host, port=args.port, reload=False) | |
| except Exception as err: # pragma: no cover | |
| raise RuntimeError("uvicorn not available") from err | |