Anirudhreddy19 commited on
Commit
4022eb2
·
verified ·
1 Parent(s): 4f3dd97

Add root endpoint (/) returning environment metadata

Browse files
Files changed (1) hide show
  1. server.py +196 -171
server.py CHANGED
@@ -1,171 +1,196 @@
1
- """FastAPI server exposing the Scheduling Optimisation Environment as an HTTP API."""
2
-
3
- from __future__ import annotations
4
-
5
- from typing import Any
6
-
7
- from fastapi import FastAPI, HTTPException
8
- from pydantic import BaseModel
9
-
10
- from environment import SchedulingOptEnv
11
- from graders.grader_classification import ConflictGrader
12
- from graders.grader_detection import FeasibilityGrader
13
- from graders.grader_fix import RepairGrader
14
- from models import Action, Observation
15
-
16
- app = FastAPI(
17
- title="Scheduling Optimisation Environment",
18
- description=(
19
- "OpenEnv-compatible environment for training AI agents on combinatorial "
20
- "scheduling optimisation problems."
21
- ),
22
- version="1.0.0",
23
- )
24
-
25
- # Single shared environment instance.
26
- env = SchedulingOptEnv()
27
-
28
-
29
- # ---------------------------------------------------------------------------
30
- # Request / response schemas
31
- # ---------------------------------------------------------------------------
32
-
33
-
34
- class ResetRequest(BaseModel):
35
- task_id: str = "feasibility_check"
36
-
37
-
38
- class StepResponse(BaseModel):
39
- observation: Observation
40
- reward: float
41
- done: bool
42
- info: dict[str, Any]
43
-
44
-
45
- class GradeRequest(BaseModel):
46
- action: Action
47
- ground_truth: dict[str, Any]
48
-
49
-
50
- class GradeResponse(BaseModel):
51
- score: float
52
-
53
-
54
- # ---------------------------------------------------------------------------
55
- # Endpoints
56
- # ---------------------------------------------------------------------------
57
-
58
-
59
- @app.get("/health")
60
- def health() -> dict[str, str]:
61
- """Health check for Hugging Face Spaces liveness probe."""
62
- return {"status": "ok"}
63
-
64
-
65
- @app.post("/reset", response_model=Observation)
66
- def reset(req: ResetRequest) -> Observation:
67
- """Reset the environment and start a new episode.
68
-
69
- Body: {"task_id": "feasibility_check" | "conflict_classification" | "schedule_repair"}
70
- """
71
- valid_tasks = {"feasibility_check", "conflict_classification", "schedule_repair"}
72
- if req.task_id not in valid_tasks:
73
- raise HTTPException(
74
- status_code=400,
75
- detail=f"Invalid task_id. Choose from: {sorted(valid_tasks)}",
76
- )
77
- return env.reset(task_id=req.task_id)
78
-
79
-
80
- @app.post("/step", response_model=StepResponse)
81
- def step(action: Action) -> StepResponse:
82
- """Submit an action and advance the environment by one step.
83
-
84
- Body: {"response": "<answer>", "task_id": "<task_id>"}
85
- """
86
- obs, reward, done, info = env.step(action)
87
- return StepResponse(observation=obs, reward=reward, done=done, info=info)
88
-
89
-
90
- @app.get("/state")
91
- def state() -> dict[str, Any]:
92
- """Return the full current environment state."""
93
- return env.state()
94
-
95
-
96
- @app.get("/tasks")
97
- def tasks() -> list[dict[str, Any]]:
98
- """List available tasks with their action schemas."""
99
- return [
100
- {
101
- "task_id": "feasibility_check",
102
- "name": "Feasibility Check",
103
- "difficulty": "easy",
104
- "max_steps": 3,
105
- "action_schema": {
106
- "response": "feasible | infeasible",
107
- "task_id": "feasibility_check",
108
- },
109
- },
110
- {
111
- "task_id": "conflict_classification",
112
- "name": "Conflict Classification",
113
- "difficulty": "medium",
114
- "max_steps": 5,
115
- "action_schema": {
116
- "response": (
117
- "resource_overload | deadline_violation | precedence_violation | "
118
- "availability_conflict | capacity_exceeded"
119
- ),
120
- "task_id": "conflict_classification",
121
- },
122
- },
123
- {
124
- "task_id": "schedule_repair",
125
- "name": "Schedule Repair",
126
- "difficulty": "hard",
127
- "max_steps": 8,
128
- "action_schema": {
129
- "response": '{"assignments": [{"job_id": "J1", "machine_id": "M1", "start_time": 0}, ...]}',
130
- "task_id": "schedule_repair",
131
- },
132
- },
133
- ]
134
-
135
-
136
- @app.post("/grader", response_model=GradeResponse)
137
- def grader(req: GradeRequest) -> GradeResponse:
138
- """Directly invoke a grader with an action and ground truth.
139
-
140
- Body: {"action": {"response": "...", "task_id": "..."}, "ground_truth": {...}}
141
- """
142
- task_id = req.action.task_id
143
- grader_map = {
144
- "feasibility_check": FeasibilityGrader(),
145
- "conflict_classification": ConflictGrader(),
146
- "schedule_repair": RepairGrader(),
147
- }
148
- g = grader_map.get(task_id)
149
- if g is None:
150
- raise HTTPException(
151
- status_code=400, detail=f"No grader for task_id={task_id}"
152
- )
153
- score = g.grade(req.action, req.ground_truth)
154
- return GradeResponse(score=max(0.0, min(1.0, score)))
155
-
156
-
157
- @app.get("/baseline")
158
- def baseline() -> dict[str, Any]:
159
- """Trigger the baseline inference agent and return per-task scores.
160
-
161
- Falls back to mock oracle responses when OPENAI_API_KEY is not set,
162
- so this endpoint always returns a valid result.
163
- """
164
- try:
165
- from baseline import run_baseline
166
- return run_baseline()
167
- except Exception as exc:
168
- raise HTTPException(
169
- status_code=500,
170
- detail=f"Baseline run failed: {exc}",
171
- ) from exc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI server exposing the Scheduling Optimisation Environment as an HTTP API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from fastapi import FastAPI, HTTPException
8
+ from pydantic import BaseModel
9
+
10
+ from environment import SchedulingOptEnv
11
+ from graders.grader_classification import ConflictGrader
12
+ from graders.grader_detection import FeasibilityGrader
13
+ from graders.grader_fix import RepairGrader
14
+ from models import Action, Observation
15
+
16
+ app = FastAPI(
17
+ title="Scheduling Optimisation Environment",
18
+ description=(
19
+ "OpenEnv-compatible environment for training AI agents on combinatorial "
20
+ "scheduling optimisation problems."
21
+ ),
22
+ version="1.0.0",
23
+ )
24
+
25
+ # Single shared environment instance.
26
+ env = SchedulingOptEnv()
27
+
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Request / response schemas
31
+ # ---------------------------------------------------------------------------
32
+
33
+
34
+ class ResetRequest(BaseModel):
35
+ task_id: str = "feasibility_check"
36
+
37
+
38
+ class StepResponse(BaseModel):
39
+ observation: Observation
40
+ reward: float
41
+ done: bool
42
+ info: dict[str, Any]
43
+
44
+
45
+ class GradeRequest(BaseModel):
46
+ action: Action
47
+ ground_truth: dict[str, Any]
48
+
49
+
50
+ class GradeResponse(BaseModel):
51
+ score: float
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Endpoints
56
+ # ---------------------------------------------------------------------------
57
+
58
+
59
+ @app.get("/health")
60
+ def health() -> dict[str, str]:
61
+ """Health check for Hugging Face Spaces liveness probe."""
62
+ return {"status": "ok"}
63
+
64
+
65
+ @app.post("/reset", response_model=Observation)
66
+ def reset(req: ResetRequest) -> Observation:
67
+ """Reset the environment and start a new episode.
68
+
69
+ Body: {"task_id": "feasibility_check" | "conflict_classification" | "schedule_repair"}
70
+ """
71
+ valid_tasks = {"feasibility_check", "conflict_classification", "schedule_repair"}
72
+ if req.task_id not in valid_tasks:
73
+ raise HTTPException(
74
+ status_code=400,
75
+ detail=f"Invalid task_id. Choose from: {sorted(valid_tasks)}",
76
+ )
77
+ return env.reset(task_id=req.task_id)
78
+
79
+
80
+ @app.post("/step", response_model=StepResponse)
81
+ def step(action: Action) -> StepResponse:
82
+ """Submit an action and advance the environment by one step.
83
+
84
+ Body: {"response": "<answer>", "task_id": "<task_id>"}
85
+ """
86
+ obs, reward, done, info = env.step(action)
87
+ return StepResponse(observation=obs, reward=reward, done=done, info=info)
88
+
89
+
90
+ @app.get("/state")
91
+ def state() -> dict[str, Any]:
92
+ """Return the full current environment state."""
93
+ return env.state()
94
+
95
+
96
+ @app.get("/tasks")
97
+ def tasks() -> list[dict[str, Any]]:
98
+ """List available tasks with their action schemas."""
99
+ return [
100
+ {
101
+ "task_id": "feasibility_check",
102
+ "name": "Feasibility Check",
103
+ "difficulty": "easy",
104
+ "max_steps": 3,
105
+ "action_schema": {
106
+ "response": "feasible | infeasible",
107
+ "task_id": "feasibility_check",
108
+ },
109
+ },
110
+ {
111
+ "task_id": "conflict_classification",
112
+ "name": "Conflict Classification",
113
+ "difficulty": "medium",
114
+ "max_steps": 5,
115
+ "action_schema": {
116
+ "response": (
117
+ "resource_overload | deadline_violation | precedence_violation | "
118
+ "availability_conflict | capacity_exceeded"
119
+ ),
120
+ "task_id": "conflict_classification",
121
+ },
122
+ },
123
+ {
124
+ "task_id": "schedule_repair",
125
+ "name": "Schedule Repair",
126
+ "difficulty": "hard",
127
+ "max_steps": 8,
128
+ "action_schema": {
129
+ "response": '{"assignments": [{"job_id": "J1", "machine_id": "M1", "start_time": 0}, ...]}',
130
+ "task_id": "schedule_repair",
131
+ },
132
+ },
133
+ ]
134
+
135
+
136
+ @app.post("/grader", response_model=GradeResponse)
137
+ def grader(req: GradeRequest) -> GradeResponse:
138
+ """Directly invoke a grader with an action and ground truth.
139
+
140
+ Body: {"action": {"response": "...", "task_id": "..."}, "ground_truth": {...}}
141
+ """
142
+ task_id = req.action.task_id
143
+ grader_map = {
144
+ "feasibility_check": FeasibilityGrader(),
145
+ "conflict_classification": ConflictGrader(),
146
+ "schedule_repair": RepairGrader(),
147
+ }
148
+ g = grader_map.get(task_id)
149
+ if g is None:
150
+ raise HTTPException(
151
+ status_code=400, detail=f"No grader for task_id={task_id}"
152
+ )
153
+ score = g.grade(req.action, req.ground_truth)
154
+ return GradeResponse(score=max(0.0, min(1.0, score)))
155
+
156
+
157
+ @app.get("/baseline")
158
+ def baseline() -> dict[str, Any]:
159
+ """Trigger the baseline inference agent and return per-task scores.
160
+
161
+ Falls back to mock oracle responses when OPENAI_API_KEY is not set,
162
+ so this endpoint always returns a valid result.
163
+ """
164
+ try:
165
+ from baseline import run_baseline
166
+ return run_baseline()
167
+ except Exception as exc:
168
+ raise HTTPException(
169
+ status_code=500,
170
+ detail=f"Baseline run failed: {exc}",
171
+ ) from exc
172
+
173
+
174
+
175
+ # ---
176
+ # Root endpoint
177
+ # ---
178
+
179
+
180
+ @app.get("/")
181
+ def root() -> dict[str, Any]:
182
+ """Root endpoint returning environment metadata."""
183
+ return {
184
+ "name": "Scheduling Optimisation Environment",
185
+ "version": "1.0.0",
186
+ "description": "OpenEnv-compatible environment for training AI agents on combinatorial scheduling optimisation problems.",
187
+ "endpoints": {
188
+ "health": "GET /health",
189
+ "reset": "POST /reset",
190
+ "step": "POST /step",
191
+ "state": "GET /state",
192
+ "tasks": "GET /tasks",
193
+ "grader": "POST /grader",
194
+ "baseline": "GET /baseline",
195
+ },
196
+ }