Vittal-M commited on
Commit
cbe9d96
·
verified ·
1 Parent(s): 5e5933e

Upload server.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. server.py +171 -0
server.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI server exposing the Scheduling Optimisation Environment as an HTTP API."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ from fastapi import FastAPI, HTTPException
8
+ from pydantic import BaseModel
9
+
10
+ from environment import SchedulingOptEnv
11
+ from graders.grader_classification import ConflictGrader
12
+ from graders.grader_detection import FeasibilityGrader
13
+ from graders.grader_fix import RepairGrader
14
+ from models import Action, Observation
15
+
16
+ app = FastAPI(
17
+ title="Scheduling Optimisation Environment",
18
+ description=(
19
+ "OpenEnv-compatible environment for training AI agents on combinatorial "
20
+ "scheduling optimisation problems."
21
+ ),
22
+ version="1.0.0",
23
+ )
24
+
25
+ # Single shared environment instance.
26
+ env = SchedulingOptEnv()
27
+
28
+
29
+ # ---------------------------------------------------------------------------
30
+ # Request / response schemas
31
+ # ---------------------------------------------------------------------------
32
+
33
+
34
+ class ResetRequest(BaseModel):
35
+ task_id: str = "feasibility_check"
36
+
37
+
38
+ class StepResponse(BaseModel):
39
+ observation: Observation
40
+ reward: float
41
+ done: bool
42
+ info: dict[str, Any]
43
+
44
+
45
+ class GradeRequest(BaseModel):
46
+ action: Action
47
+ ground_truth: dict[str, Any]
48
+
49
+
50
+ class GradeResponse(BaseModel):
51
+ score: float
52
+
53
+
54
+ # ---------------------------------------------------------------------------
55
+ # Endpoints
56
+ # ---------------------------------------------------------------------------
57
+
58
+
59
+ @app.get("/health")
60
+ def health() -> dict[str, str]:
61
+ """Health check for Hugging Face Spaces liveness probe."""
62
+ return {"status": "ok"}
63
+
64
+
65
+ @app.post("/reset", response_model=Observation)
66
+ def reset(req: ResetRequest) -> Observation:
67
+ """Reset the environment and start a new episode.
68
+
69
+ Body: {"task_id": "feasibility_check" | "conflict_classification" | "schedule_repair"}
70
+ """
71
+ valid_tasks = {"feasibility_check", "conflict_classification", "schedule_repair"}
72
+ if req.task_id not in valid_tasks:
73
+ raise HTTPException(
74
+ status_code=400,
75
+ detail=f"Invalid task_id. Choose from: {sorted(valid_tasks)}",
76
+ )
77
+ return env.reset(task_id=req.task_id)
78
+
79
+
80
+ @app.post("/step", response_model=StepResponse)
81
+ def step(action: Action) -> StepResponse:
82
+ """Submit an action and advance the environment by one step.
83
+
84
+ Body: {"response": "<answer>", "task_id": "<task_id>"}
85
+ """
86
+ obs, reward, done, info = env.step(action)
87
+ return StepResponse(observation=obs, reward=reward, done=done, info=info)
88
+
89
+
90
+ @app.get("/state")
91
+ def state() -> dict[str, Any]:
92
+ """Return the full current environment state."""
93
+ return env.state()
94
+
95
+
96
+ @app.get("/tasks")
97
+ def tasks() -> list[dict[str, Any]]:
98
+ """List available tasks with their action schemas."""
99
+ return [
100
+ {
101
+ "task_id": "feasibility_check",
102
+ "name": "Feasibility Check",
103
+ "difficulty": "easy",
104
+ "max_steps": 3,
105
+ "action_schema": {
106
+ "response": "feasible | infeasible",
107
+ "task_id": "feasibility_check",
108
+ },
109
+ },
110
+ {
111
+ "task_id": "conflict_classification",
112
+ "name": "Conflict Classification",
113
+ "difficulty": "medium",
114
+ "max_steps": 5,
115
+ "action_schema": {
116
+ "response": (
117
+ "resource_overload | deadline_violation | precedence_violation | "
118
+ "availability_conflict | capacity_exceeded"
119
+ ),
120
+ "task_id": "conflict_classification",
121
+ },
122
+ },
123
+ {
124
+ "task_id": "schedule_repair",
125
+ "name": "Schedule Repair",
126
+ "difficulty": "hard",
127
+ "max_steps": 8,
128
+ "action_schema": {
129
+ "response": '{"assignments": [{"job_id": "J1", "machine_id": "M1", "start_time": 0}, ...]}',
130
+ "task_id": "schedule_repair",
131
+ },
132
+ },
133
+ ]
134
+
135
+
136
+ @app.post("/grader", response_model=GradeResponse)
137
+ def grader(req: GradeRequest) -> GradeResponse:
138
+ """Directly invoke a grader with an action and ground truth.
139
+
140
+ Body: {"action": {"response": "...", "task_id": "..."}, "ground_truth": {...}}
141
+ """
142
+ task_id = req.action.task_id
143
+ grader_map = {
144
+ "feasibility_check": FeasibilityGrader(),
145
+ "conflict_classification": ConflictGrader(),
146
+ "schedule_repair": RepairGrader(),
147
+ }
148
+ g = grader_map.get(task_id)
149
+ if g is None:
150
+ raise HTTPException(
151
+ status_code=400, detail=f"No grader for task_id={task_id}"
152
+ )
153
+ score = g.grade(req.action, req.ground_truth)
154
+ return GradeResponse(score=max(0.0, min(1.0, score)))
155
+
156
+
157
+ @app.get("/baseline")
158
+ def baseline() -> dict[str, Any]:
159
+ """Trigger the baseline inference agent and return per-task scores.
160
+
161
+ Falls back to mock oracle responses when OPENAI_API_KEY is not set,
162
+ so this endpoint always returns a valid result.
163
+ """
164
+ try:
165
+ from baseline import run_baseline
166
+ return run_baseline()
167
+ except Exception as exc:
168
+ raise HTTPException(
169
+ status_code=500,
170
+ detail=f"Baseline run failed: {exc}",
171
+ ) from exc