File size: 5,155 Bytes
85d3923
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""FastAPI server exposing the Scheduling Optimisation Environment as an HTTP API."""

from __future__ import annotations

from typing import Any

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

from environment import SchedulingOptEnv
from graders.grader_classification import ConflictGrader
from graders.grader_detection import FeasibilityGrader
from graders.grader_fix import RepairGrader
from models import Action, Observation

app = FastAPI(
    title="Scheduling Optimisation Environment",
    description=(
        "OpenEnv-compatible environment for training AI agents on combinatorial "
        "scheduling optimisation problems."
    ),
    version="1.0.0",
)

# Single shared environment instance.
env = SchedulingOptEnv()


# ---------------------------------------------------------------------------
# Request / response schemas
# ---------------------------------------------------------------------------


class ResetRequest(BaseModel):
    task_id: str = "feasibility_check"


class StepResponse(BaseModel):
    observation: Observation
    reward: float
    done: bool
    info: dict[str, Any]


class GradeRequest(BaseModel):
    action: Action
    ground_truth: dict[str, Any]


class GradeResponse(BaseModel):
    score: float


# ---------------------------------------------------------------------------
# Endpoints
# ---------------------------------------------------------------------------


@app.get("/health")
def health() -> dict[str, str]:
    """Health check for Hugging Face Spaces liveness probe."""
    return {"status": "ok"}


@app.post("/reset", response_model=Observation)
def reset(req: ResetRequest) -> Observation:
    """Reset the environment and start a new episode.

    Body: {"task_id": "feasibility_check" | "conflict_classification" | "schedule_repair"}
    """
    valid_tasks = {"feasibility_check", "conflict_classification", "schedule_repair"}
    if req.task_id not in valid_tasks:
        raise HTTPException(
            status_code=400,
            detail=f"Invalid task_id. Choose from: {sorted(valid_tasks)}",
        )
    return env.reset(task_id=req.task_id)


@app.post("/step", response_model=StepResponse)
def step(action: Action) -> StepResponse:
    """Submit an action and advance the environment by one step.

    Body: {"response": "<answer>", "task_id": "<task_id>"}
    """
    obs, reward, done, info = env.step(action)
    return StepResponse(observation=obs, reward=reward, done=done, info=info)


@app.get("/state")
def state() -> dict[str, Any]:
    """Return the full current environment state."""
    return env.state()


@app.get("/tasks")
def tasks() -> list[dict[str, Any]]:
    """List available tasks with their action schemas."""
    return [
        {
            "task_id": "feasibility_check",
            "name": "Feasibility Check",
            "difficulty": "easy",
            "max_steps": 3,
            "action_schema": {
                "response": "feasible | infeasible",
                "task_id": "feasibility_check",
            },
        },
        {
            "task_id": "conflict_classification",
            "name": "Conflict Classification",
            "difficulty": "medium",
            "max_steps": 5,
            "action_schema": {
                "response": (
                    "resource_overload | deadline_violation | precedence_violation | "
                    "availability_conflict | capacity_exceeded"
                ),
                "task_id": "conflict_classification",
            },
        },
        {
            "task_id": "schedule_repair",
            "name": "Schedule Repair",
            "difficulty": "hard",
            "max_steps": 8,
            "action_schema": {
                "response": '{"assignments": [{"job_id": "J1", "machine_id": "M1", "start_time": 0}, ...]}',
                "task_id": "schedule_repair",
            },
        },
    ]


@app.post("/grader", response_model=GradeResponse)
def grader(req: GradeRequest) -> GradeResponse:
    """Directly invoke a grader with an action and ground truth.

    Body: {"action": {"response": "...", "task_id": "..."}, "ground_truth": {...}}
    """
    task_id = req.action.task_id
    grader_map = {
        "feasibility_check": FeasibilityGrader(),
        "conflict_classification": ConflictGrader(),
        "schedule_repair": RepairGrader(),
    }
    g = grader_map.get(task_id)
    if g is None:
        raise HTTPException(
            status_code=400, detail=f"No grader for task_id={task_id}"
        )
    score = g.grade(req.action, req.ground_truth)
    return GradeResponse(score=max(0.0, min(1.0, score)))


@app.get("/baseline")
def baseline() -> dict[str, Any]:
    """Trigger the baseline inference agent and return per-task scores.

    Falls back to mock oracle responses when OPENAI_API_KEY is not set,
    so this endpoint always returns a valid result.
    """
    try:
        from baseline import run_baseline
        return run_baseline()
    except Exception as exc:
        raise HTTPException(
            status_code=500,
            detail=f"Baseline run failed: {exc}",
        ) from exc