File size: 2,397 Bytes
73b708a
 
248cbb9
4e608c3
 
73b708a
4e608c3
248cbb9
 
4e608c3
 
 
248cbb9
4e608c3
 
73b708a
 
 
 
 
4e608c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248cbb9
4e608c3
 
 
 
73b708a
248cbb9
 
 
ffda6a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e608c3
 
 
 
 
 
554c891
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from fastapi import FastAPI, HTTPException, Body
from src.models import Action, TaskConfig, ResetRequest
from src.env import DesalEnv
from src.tasks import TASKS
import subprocess
from typing import Optional

app = FastAPI(title="Advanced Municipal Desalination Plant Env")
env = DesalEnv()

@app.get("/")
def health_check():
    return {"status": "ok", "message": "Advanced DesalEnv is running", "features": ["weather", "salinity", "mechanics"]}

@app.post("/reset")
def reset_env(task_id: str = "easy_spring", req: Optional[ResetRequest] = None):
    # Support both GET query params and POST JSON body for task_id
    if req and req.task_id != "easy_spring":
        task_id = req.task_id
        
    if task_id not in TASKS:
        raise HTTPException(status_code=404, detail="Task not found")
    obs = env.reset(TASKS[task_id])
    return {"observation": obs.dict()}

@app.post("/step")
def step_env(action: Action):
    try:
        result = env.step(action)
        return result.dict()
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))

@app.get("/state")
def get_state():
    if env.state is None:
        raise HTTPException(status_code=400, detail="Environment not initialized")
    return {"observation": env.state.dict()}

@app.get("/tasks")
def list_tasks():
    return {"tasks": list(TASKS.keys()), "action_schema": Action.schema()}

@app.get("/grader")
def grader():
    if env.state is None:
        return {"score": 0.001}
    # Grade relative to typical maximum and minimum returns to generate a 0.0-1.0 range
    baseline_offset = env.config.max_steps * 1000.0 # Compensate for penalties
    scale_factor = env.config.max_steps * 1500.0 
    try:
        raw_score = float(env.total_reward + baseline_offset) / scale_factor
        import math
        if math.isnan(raw_score):
            score = 0.001
        else:
            score = float(max(0.001, min(0.999, raw_score)))
    except:
        score = 0.001
        
    if score >= 1.0:
        score = 0.999
    elif score <= 0.0:
        score = 0.001
        
    return {"score": score}

@app.post("/baseline")
def run_baseline():
    result = subprocess.run(["python", "src/baseline.py"], capture_output=True, text=True)
    return {"output": result.stdout}

def main():
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)

if __name__ == "__main__":
    main()