Spaces:
Runtime error
Runtime error
File size: 6,126 Bytes
52c82e4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 | """FastAPI server exposing DASH-JSP inference.
Endpoints
---------
POST /solve — solve a complete JSP instance (returns schedule + metrics)
POST /dispatch_event — single-decision inference (stateless WMS integration)
GET /health — liveness probe
GET /version — version + arm names
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Optional
import numpy as np
from fastapi import FastAPI, HTTPException
import dash_jsp
from dash_jsp.api.schemas import (
DispatchEventRequest,
DispatchEventResponse,
JSPInstanceSchema,
SolveRequest,
SolveResponse,
)
from dash_jsp.bandit.linucb import LinUCBDispatcher
from dash_jsp.bandit.thompson import ThompsonDispatcher
from dash_jsp.bandit.ensemble import EnsembleDispatcher
from dash_jsp.benchmarks.format import JSPInstance
from dash_jsp.heuristics.rules import ALL_RULES, get_rule
from dash_jsp.simulator.jsp_sim import (
simulate,
simulate_with_dispatcher,
)
# ---------------------------------------------------------------------------
# Lazy model loading
# ---------------------------------------------------------------------------
_MODELS: dict = {}
def _models_dir() -> Path:
env = os.environ.get("DASH_JSP_MODELS_DIR")
if env:
return Path(env)
return Path(__file__).resolve().parent.parent.parent / "models"
def _load_linucb() -> LinUCBDispatcher:
if "linucb" not in _MODELS:
path = _models_dir() / "bandit_linucb.npz"
if path.exists():
_MODELS["linucb"] = LinUCBDispatcher.load(str(path))
else:
# Fall back to a fresh, untrained bandit (still usable but will explore)
_MODELS["linucb"] = LinUCBDispatcher()
return _MODELS["linucb"]
def _load_thompson() -> ThompsonDispatcher:
if "thompson" not in _MODELS:
path = _models_dir() / "bandit_thompson.npz"
if path.exists():
_MODELS["thompson"] = ThompsonDispatcher.load(str(path))
else:
_MODELS["thompson"] = ThompsonDispatcher()
return _MODELS["thompson"]
def _load_ensemble() -> EnsembleDispatcher:
if "ensemble" not in _MODELS:
_MODELS["ensemble"] = EnsembleDispatcher(
linucb=_load_linucb(),
thompson=_load_thompson(),
)
return _MODELS["ensemble"]
# ---------------------------------------------------------------------------
# FastAPI app
# ---------------------------------------------------------------------------
app = FastAPI(
title="DASH-JSP",
version=dash_jsp.__version__,
description=(
"Bandit-based dynamic dispatch-rule selection for job-shop scheduling. "
"Public benchmarks (Taillard / Lawrence / Brandimarte / DMU) supported "
"out of the box."
),
)
def _instance_from_schema(s: JSPInstanceSchema) -> JSPInstance:
return JSPInstance(
name=s.name,
family="api",
n_jobs=s.n_jobs,
n_machines=s.n_machines,
ops=[[(m, p) for m, p in row] for row in s.ops],
optimum=s.optimum,
due_dates=s.due_dates,
weights=s.weights,
)
@app.get("/health")
def health() -> dict:
return {"status": "ok"}
@app.get("/version")
def version() -> dict:
bandit = _load_linucb()
return {
"dash_jsp_version": dash_jsp.__version__,
"linucb_arms": [a.name for a in bandit.arms],
"linucb_pulls": bandit.n_pulls,
}
@app.post("/solve", response_model=SolveResponse)
def solve(req: SolveRequest) -> SolveResponse:
inst = _instance_from_schema(req.instance)
method = req.method.lower()
if method in ALL_RULES:
result = simulate(inst, get_rule(method))
log = None
elif method == "linucb":
dispatcher = _load_linucb()
result = simulate_with_dispatcher(inst, dispatcher, record_choices=True)
log = result.rule_choice_log
elif method == "thompson":
dispatcher = _load_thompson()
result = simulate_with_dispatcher(inst, dispatcher, record_choices=True)
log = result.rule_choice_log
elif method == "ensemble":
dispatcher = _load_ensemble()
result = simulate_with_dispatcher(inst, dispatcher, record_choices=True)
log = result.rule_choice_log
else:
raise HTTPException(400, f"Unknown method: {method!r}")
gap = (
100.0 * (result.makespan - inst.optimum) / inst.optimum
if inst.optimum
else None
)
return SolveResponse(
method=method,
makespan=result.makespan,
total_tardiness=result.total_tardiness,
avg_cycle_time=result.avg_cycle_time,
machine_utilization=result.machine_utilization,
n_dispatch_decisions=result.n_dispatch_decisions,
runtime_ms=result.runtime_ms,
optimality_gap_pct=gap,
rule_choice_log=log,
)
@app.post("/dispatch_event", response_model=DispatchEventResponse)
def dispatch_event(req: DispatchEventRequest) -> DispatchEventResponse:
"""Single-decision inference for live WMS integration."""
if len(req.context) != 32:
raise HTTPException(400, f"context must be 32-dim, got {len(req.context)}")
ctx = np.asarray(req.context, dtype=np.float64)
method = req.method.lower()
if method == "linucb":
b = _load_linucb()
arm = b.select(ctx)
elif method == "thompson":
b = _load_thompson()
arm = b.select(ctx)
elif method == "ensemble":
b = _load_ensemble()
arm = b.linucb.select(ctx) # majority via LinUCB UCB margin
else:
raise HTTPException(400, f"Unknown method: {method!r}")
return DispatchEventResponse(
chosen_rule=b.arms[arm].name if hasattr(b, "arms") else b.linucb.arms[arm].name,
arm_index=int(arm),
)
def main() -> None: # pragma: no cover
import uvicorn
uvicorn.run(
"dash_jsp.api.server:app",
host="0.0.0.0",
port=int(os.environ.get("DASH_JSP_PORT", "8000")),
reload=False,
)
if __name__ == "__main__": # pragma: no cover
main()
|