| """FastAPI server exposing the Scheduling Optimisation Environment as an HTTP API.""" |
|
|
| from __future__ import annotations |
|
|
| from typing import Any |
|
|
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
|
|
| from environment import SchedulingOptEnv |
| from graders.grader_classification import ConflictGrader |
| from graders.grader_detection import FeasibilityGrader |
| from graders.grader_fix import RepairGrader |
| from models import Action, Observation |
|
|
| app = FastAPI( |
| title="Scheduling Optimisation Environment", |
| description=( |
| "OpenEnv-compatible environment for training AI agents on combinatorial " |
| "scheduling optimisation problems." |
| ), |
| version="1.0.0", |
| ) |
|
|
| |
| env = SchedulingOptEnv() |
|
|
|
|
| |
| |
| |
|
|
|
|
| class ResetRequest(BaseModel): |
| task_id: str = "feasibility_check" |
|
|
|
|
| class StepResponse(BaseModel): |
| observation: Observation |
| reward: float |
| done: bool |
| info: dict[str, Any] |
|
|
|
|
| class GradeRequest(BaseModel): |
| action: Action |
| ground_truth: dict[str, Any] |
|
|
|
|
| class GradeResponse(BaseModel): |
| score: float |
|
|
|
|
| |
| |
| |
|
|
|
|
| @app.get("/health") |
| def health() -> dict[str, str]: |
| """Health check for Hugging Face Spaces liveness probe.""" |
| return {"status": "ok"} |
|
|
|
|
| @app.post("/reset", response_model=Observation) |
| def reset(req: ResetRequest) -> Observation: |
| """Reset the environment and start a new episode. |
| |
| Body: {"task_id": "feasibility_check" | "conflict_classification" | "schedule_repair"} |
| """ |
| valid_tasks = {"feasibility_check", "conflict_classification", "schedule_repair"} |
| if req.task_id not in valid_tasks: |
| raise HTTPException( |
| status_code=400, |
| detail=f"Invalid task_id. Choose from: {sorted(valid_tasks)}", |
| ) |
| return env.reset(task_id=req.task_id) |
|
|
|
|
| @app.post("/step", response_model=StepResponse) |
| def step(action: Action) -> StepResponse: |
| """Submit an action and advance the environment by one step. |
| |
| Body: {"response": "<answer>", "task_id": "<task_id>"} |
| """ |
| obs, reward, done, info = env.step(action) |
| return StepResponse(observation=obs, reward=reward, done=done, info=info) |
|
|
|
|
| @app.get("/state") |
| def state() -> dict[str, Any]: |
| """Return the full current environment state.""" |
| return env.state() |
|
|
|
|
| @app.get("/tasks") |
| def tasks() -> list[dict[str, Any]]: |
| """List available tasks with their action schemas.""" |
| return [ |
| { |
| "task_id": "feasibility_check", |
| "name": "Feasibility Check", |
| "difficulty": "easy", |
| "max_steps": 3, |
| "action_schema": { |
| "response": "feasible | infeasible", |
| "task_id": "feasibility_check", |
| }, |
| }, |
| { |
| "task_id": "conflict_classification", |
| "name": "Conflict Classification", |
| "difficulty": "medium", |
| "max_steps": 5, |
| "action_schema": { |
| "response": ( |
| "resource_overload | deadline_violation | precedence_violation | " |
| "availability_conflict | capacity_exceeded" |
| ), |
| "task_id": "conflict_classification", |
| }, |
| }, |
| { |
| "task_id": "schedule_repair", |
| "name": "Schedule Repair", |
| "difficulty": "hard", |
| "max_steps": 8, |
| "action_schema": { |
| "response": '{"assignments": [{"job_id": "J1", "machine_id": "M1", "start_time": 0}, ...]}', |
| "task_id": "schedule_repair", |
| }, |
| }, |
| ] |
|
|
|
|
| @app.post("/grader", response_model=GradeResponse) |
| def grader(req: GradeRequest) -> GradeResponse: |
| """Directly invoke a grader with an action and ground truth. |
| |
| Body: {"action": {"response": "...", "task_id": "..."}, "ground_truth": {...}} |
| """ |
| task_id = req.action.task_id |
| grader_map = { |
| "feasibility_check": FeasibilityGrader(), |
| "conflict_classification": ConflictGrader(), |
| "schedule_repair": RepairGrader(), |
| } |
| g = grader_map.get(task_id) |
| if g is None: |
| raise HTTPException( |
| status_code=400, detail=f"No grader for task_id={task_id}" |
| ) |
| score = g.grade(req.action, req.ground_truth) |
| return GradeResponse(score=max(0.0, min(1.0, score))) |
|
|
|
|
| @app.get("/baseline") |
| def baseline() -> dict[str, Any]: |
| """Trigger the baseline inference agent and return per-task scores. |
| |
| Falls back to mock oracle responses when OPENAI_API_KEY is not set, |
| so this endpoint always returns a valid result. |
| """ |
| try: |
| from baseline import run_baseline |
| return run_baseline() |
| except Exception as exc: |
| raise HTTPException( |
| status_code=500, |
| detail=f"Baseline run failed: {exc}", |
| ) from exc |
|
|