File size: 5,155 Bytes
85d3923 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | """FastAPI server exposing the Scheduling Optimisation Environment as an HTTP API."""
from __future__ import annotations
from typing import Any
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from environment import SchedulingOptEnv
from graders.grader_classification import ConflictGrader
from graders.grader_detection import FeasibilityGrader
from graders.grader_fix import RepairGrader
from models import Action, Observation
app = FastAPI(
title="Scheduling Optimisation Environment",
description=(
"OpenEnv-compatible environment for training AI agents on combinatorial "
"scheduling optimisation problems."
),
version="1.0.0",
)
# Single shared environment instance.
env = SchedulingOptEnv()
# ---------------------------------------------------------------------------
# Request / response schemas
# ---------------------------------------------------------------------------
class ResetRequest(BaseModel):
task_id: str = "feasibility_check"
class StepResponse(BaseModel):
observation: Observation
reward: float
done: bool
info: dict[str, Any]
class GradeRequest(BaseModel):
action: Action
ground_truth: dict[str, Any]
class GradeResponse(BaseModel):
score: float
# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------
@app.get("/health")
def health() -> dict[str, str]:
"""Health check for Hugging Face Spaces liveness probe."""
return {"status": "ok"}
@app.post("/reset", response_model=Observation)
def reset(req: ResetRequest) -> Observation:
"""Reset the environment and start a new episode.
Body: {"task_id": "feasibility_check" | "conflict_classification" | "schedule_repair"}
"""
valid_tasks = {"feasibility_check", "conflict_classification", "schedule_repair"}
if req.task_id not in valid_tasks:
raise HTTPException(
status_code=400,
detail=f"Invalid task_id. Choose from: {sorted(valid_tasks)}",
)
return env.reset(task_id=req.task_id)
@app.post("/step", response_model=StepResponse)
def step(action: Action) -> StepResponse:
"""Submit an action and advance the environment by one step.
Body: {"response": "<answer>", "task_id": "<task_id>"}
"""
obs, reward, done, info = env.step(action)
return StepResponse(observation=obs, reward=reward, done=done, info=info)
@app.get("/state")
def state() -> dict[str, Any]:
"""Return the full current environment state."""
return env.state()
@app.get("/tasks")
def tasks() -> list[dict[str, Any]]:
"""List available tasks with their action schemas."""
return [
{
"task_id": "feasibility_check",
"name": "Feasibility Check",
"difficulty": "easy",
"max_steps": 3,
"action_schema": {
"response": "feasible | infeasible",
"task_id": "feasibility_check",
},
},
{
"task_id": "conflict_classification",
"name": "Conflict Classification",
"difficulty": "medium",
"max_steps": 5,
"action_schema": {
"response": (
"resource_overload | deadline_violation | precedence_violation | "
"availability_conflict | capacity_exceeded"
),
"task_id": "conflict_classification",
},
},
{
"task_id": "schedule_repair",
"name": "Schedule Repair",
"difficulty": "hard",
"max_steps": 8,
"action_schema": {
"response": '{"assignments": [{"job_id": "J1", "machine_id": "M1", "start_time": 0}, ...]}',
"task_id": "schedule_repair",
},
},
]
@app.post("/grader", response_model=GradeResponse)
def grader(req: GradeRequest) -> GradeResponse:
"""Directly invoke a grader with an action and ground truth.
Body: {"action": {"response": "...", "task_id": "..."}, "ground_truth": {...}}
"""
task_id = req.action.task_id
grader_map = {
"feasibility_check": FeasibilityGrader(),
"conflict_classification": ConflictGrader(),
"schedule_repair": RepairGrader(),
}
g = grader_map.get(task_id)
if g is None:
raise HTTPException(
status_code=400, detail=f"No grader for task_id={task_id}"
)
score = g.grade(req.action, req.ground_truth)
return GradeResponse(score=max(0.0, min(1.0, score)))
@app.get("/baseline")
def baseline() -> dict[str, Any]:
"""Trigger the baseline inference agent and return per-task scores.
Falls back to mock oracle responses when OPENAI_API_KEY is not set,
so this endpoint always returns a valid result.
"""
try:
from baseline import run_baseline
return run_baseline()
except Exception as exc:
raise HTTPException(
status_code=500,
detail=f"Baseline run failed: {exc}",
) from exc
|