Spaces:
Sleeping
Sleeping
| """FastAPI server exposing the Scheduling Optimisation Environment as an HTTP API.""" | |
| from __future__ import annotations | |
| from typing import Any | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from environment import SchedulingOptEnv | |
| from graders.grader_classification import ConflictGrader | |
| from graders.grader_detection import FeasibilityGrader | |
| from graders.grader_fix import RepairGrader | |
| from models import Action, Observation | |
| app = FastAPI( | |
| title="Scheduling Optimisation Environment", | |
| description=( | |
| "OpenEnv-compatible environment for training AI agents on combinatorial " | |
| "scheduling optimisation problems." | |
| ), | |
| version="1.0.0", | |
| ) | |
| # Single shared environment instance. | |
| env = SchedulingOptEnv() | |
| # --------------------------------------------------------------------------- | |
| # Request / response schemas | |
| # --------------------------------------------------------------------------- | |
| class ResetRequest(BaseModel): | |
| task_id: str = "feasibility_check" | |
| class StepResponse(BaseModel): | |
| observation: Observation | |
| reward: float | |
| done: bool | |
| info: dict[str, Any] | |
| class GradeRequest(BaseModel): | |
| action: Action | |
| ground_truth: dict[str, Any] | |
| class GradeResponse(BaseModel): | |
| score: float | |
| # --------------------------------------------------------------------------- | |
| # Endpoints | |
| # --------------------------------------------------------------------------- | |
| def health() -> dict[str, str]: | |
| """Health check for Hugging Face Spaces liveness probe.""" | |
| return {"status": "ok"} | |
| def reset(req: ResetRequest) -> Observation: | |
| """Reset the environment and start a new episode. | |
| Body: {"task_id": "feasibility_check" | "conflict_classification" | "schedule_repair"} | |
| """ | |
| valid_tasks = {"feasibility_check", "conflict_classification", "schedule_repair"} | |
| if req.task_id not in valid_tasks: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid task_id. Choose from: {sorted(valid_tasks)}", | |
| ) | |
| return env.reset(task_id=req.task_id) | |
| def step(action: Action) -> StepResponse: | |
| """Submit an action and advance the environment by one step. | |
| Body: {"response": "<answer>", "task_id": "<task_id>"} | |
| """ | |
| obs, reward, done, info = env.step(action) | |
| return StepResponse(observation=obs, reward=reward, done=done, info=info) | |
| def state() -> dict[str, Any]: | |
| """Return the full current environment state.""" | |
| return env.state() | |
| def tasks() -> list[dict[str, Any]]: | |
| """List available tasks with their action schemas.""" | |
| return [ | |
| { | |
| "task_id": "feasibility_check", | |
| "name": "Feasibility Check", | |
| "difficulty": "easy", | |
| "max_steps": 3, | |
| "action_schema": { | |
| "response": "feasible | infeasible", | |
| "task_id": "feasibility_check", | |
| }, | |
| }, | |
| { | |
| "task_id": "conflict_classification", | |
| "name": "Conflict Classification", | |
| "difficulty": "medium", | |
| "max_steps": 5, | |
| "action_schema": { | |
| "response": ( | |
| "resource_overload | deadline_violation | precedence_violation | " | |
| "availability_conflict | capacity_exceeded" | |
| ), | |
| "task_id": "conflict_classification", | |
| }, | |
| }, | |
| { | |
| "task_id": "schedule_repair", | |
| "name": "Schedule Repair", | |
| "difficulty": "hard", | |
| "max_steps": 8, | |
| "action_schema": { | |
| "response": '{"assignments": [{"job_id": "J1", "machine_id": "M1", "start_time": 0}, ...]}', | |
| "task_id": "schedule_repair", | |
| }, | |
| }, | |
| ] | |
| def grader(req: GradeRequest) -> GradeResponse: | |
| """Directly invoke a grader with an action and ground truth. | |
| Body: {"action": {"response": "...", "task_id": "..."}, "ground_truth": {...}} | |
| """ | |
| task_id = req.action.task_id | |
| grader_map = { | |
| "feasibility_check": FeasibilityGrader(), | |
| "conflict_classification": ConflictGrader(), | |
| "schedule_repair": RepairGrader(), | |
| } | |
| g = grader_map.get(task_id) | |
| if g is None: | |
| raise HTTPException( | |
| status_code=400, detail=f"No grader for task_id={task_id}" | |
| ) | |
| score = g.grade(req.action, req.ground_truth) | |
| return GradeResponse(score=max(0.0, min(1.0, score))) | |
| def baseline() -> dict[str, Any]: | |
| """Trigger the baseline inference agent and return per-task scores. | |
| Falls back to mock oracle responses when OPENAI_API_KEY is not set, | |
| so this endpoint always returns a valid result. | |
| """ | |
| try: | |
| from baseline import run_baseline | |
| return run_baseline() | |
| except Exception as exc: | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Baseline run failed: {exc}", | |
| ) from exc | |