File size: 1,909 Bytes
5000a45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
from __future__ import annotations

import argparse
import sys
from pathlib import Path
from typing import Any, Dict, Optional

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

from fastapi import FastAPI
from fastapi import HTTPException
from pydantic import BaseModel

from models import ResetResult, StepResult
from server.softmax_surrogate_environment import SoftmaxSurrogateEnvironment


app = FastAPI(title="Autotune Benchmark OpenEnv Server")
env = SoftmaxSurrogateEnvironment()


class ResetRequest(BaseModel):
    task: Optional[str] = None
    seed: Optional[int] = None


class StepRequest(BaseModel):
    config_id: Optional[int] = None
    x: Optional[list[float]] = None


@app.get("/health")
def health() -> Dict[str, str]:
    return {"ok": "true"}


@app.post("/reset")
def reset(payload: ResetRequest) -> Dict[str, Any]:
    result = env.reset(task=payload.task, seed=payload.seed)
    return result


@app.post("/step")
def step(payload: StepRequest) -> Dict[str, Any]:
    if payload.config_id is not None:
        result = env.step({"config_id": payload.config_id})
        return result
    if payload.x is not None:
        result = env.step({"x": payload.x})
        return result
    raise HTTPException(status_code=400, detail="Missing config_id.")
    return result


@app.get("/state")
def state() -> Dict[str, Any]:
    return env.state()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run softmax surrogate environment server.")
    parser.add_argument("--host", default="0.0.0.0")
    parser.add_argument("--port", type=int, default=8000)
    args = parser.parse_args()
    try:
        import uvicorn

        uvicorn.run("app:app", host=args.host, port=args.port, reload=False)
    except Exception as err:  # pragma: no cover
        raise RuntimeError("uvicorn not available") from err