Spaces:
Running
Running
File size: 2,397 Bytes
73b708a 248cbb9 4e608c3 73b708a 4e608c3 248cbb9 4e608c3 248cbb9 4e608c3 73b708a 4e608c3 248cbb9 4e608c3 73b708a 248cbb9 ffda6a0 4e608c3 554c891 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | from fastapi import FastAPI, HTTPException, Body
from src.models import Action, TaskConfig, ResetRequest
from src.env import DesalEnv
from src.tasks import TASKS
import subprocess
from typing import Optional
app = FastAPI(title="Advanced Municipal Desalination Plant Env")
env = DesalEnv()
@app.get("/")
def health_check():
return {"status": "ok", "message": "Advanced DesalEnv is running", "features": ["weather", "salinity", "mechanics"]}
@app.post("/reset")
def reset_env(task_id: str = "easy_spring", req: Optional[ResetRequest] = None):
# Support both GET query params and POST JSON body for task_id
if req and req.task_id != "easy_spring":
task_id = req.task_id
if task_id not in TASKS:
raise HTTPException(status_code=404, detail="Task not found")
obs = env.reset(TASKS[task_id])
return {"observation": obs.dict()}
@app.post("/step")
def step_env(action: Action):
try:
result = env.step(action)
return result.dict()
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/state")
def get_state():
if env.state is None:
raise HTTPException(status_code=400, detail="Environment not initialized")
return {"observation": env.state.dict()}
@app.get("/tasks")
def list_tasks():
return {"tasks": list(TASKS.keys()), "action_schema": Action.schema()}
@app.get("/grader")
def grader():
if env.state is None:
return {"score": 0.001}
# Grade relative to typical maximum and minimum returns to generate a 0.0-1.0 range
baseline_offset = env.config.max_steps * 1000.0 # Compensate for penalties
scale_factor = env.config.max_steps * 1500.0
try:
raw_score = float(env.total_reward + baseline_offset) / scale_factor
import math
if math.isnan(raw_score):
score = 0.001
else:
score = float(max(0.001, min(0.999, raw_score)))
except:
score = 0.001
if score >= 1.0:
score = 0.999
elif score <= 0.0:
score = 0.001
return {"score": score}
@app.post("/baseline")
def run_baseline():
result = subprocess.run(["python", "src/baseline.py"], capture_output=True, text=True)
return {"output": result.stdout}
def main():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
main()
|