Spaces:
Runtime error
Runtime error
| """FastAPI server exposing DASH-JSP inference. | |
| Endpoints | |
| --------- | |
| POST /solve — solve a complete JSP instance (returns schedule + metrics) | |
| POST /dispatch_event — single-decision inference (stateless WMS integration) | |
| GET /health — liveness probe | |
| GET /version — version + arm names | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| from fastapi import FastAPI, HTTPException | |
| import dash_jsp | |
| from dash_jsp.api.schemas import ( | |
| DispatchEventRequest, | |
| DispatchEventResponse, | |
| JSPInstanceSchema, | |
| SolveRequest, | |
| SolveResponse, | |
| ) | |
| from dash_jsp.bandit.linucb import LinUCBDispatcher | |
| from dash_jsp.bandit.thompson import ThompsonDispatcher | |
| from dash_jsp.bandit.ensemble import EnsembleDispatcher | |
| from dash_jsp.benchmarks.format import JSPInstance | |
| from dash_jsp.heuristics.rules import ALL_RULES, get_rule | |
| from dash_jsp.simulator.jsp_sim import ( | |
| simulate, | |
| simulate_with_dispatcher, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Lazy model loading | |
| # --------------------------------------------------------------------------- | |
| _MODELS: dict = {} | |
| def _models_dir() -> Path: | |
| env = os.environ.get("DASH_JSP_MODELS_DIR") | |
| if env: | |
| return Path(env) | |
| return Path(__file__).resolve().parent.parent.parent / "models" | |
| def _load_linucb() -> LinUCBDispatcher: | |
| if "linucb" not in _MODELS: | |
| path = _models_dir() / "bandit_linucb.npz" | |
| if path.exists(): | |
| _MODELS["linucb"] = LinUCBDispatcher.load(str(path)) | |
| else: | |
| # Fall back to a fresh, untrained bandit (still usable but will explore) | |
| _MODELS["linucb"] = LinUCBDispatcher() | |
| return _MODELS["linucb"] | |
| def _load_thompson() -> ThompsonDispatcher: | |
| if "thompson" not in _MODELS: | |
| path = _models_dir() / "bandit_thompson.npz" | |
| if path.exists(): | |
| _MODELS["thompson"] = ThompsonDispatcher.load(str(path)) | |
| else: | |
| _MODELS["thompson"] = ThompsonDispatcher() | |
| return _MODELS["thompson"] | |
| def _load_ensemble() -> EnsembleDispatcher: | |
| if "ensemble" not in _MODELS: | |
| _MODELS["ensemble"] = EnsembleDispatcher( | |
| linucb=_load_linucb(), | |
| thompson=_load_thompson(), | |
| ) | |
| return _MODELS["ensemble"] | |
| # --------------------------------------------------------------------------- | |
| # FastAPI app | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI( | |
| title="DASH-JSP", | |
| version=dash_jsp.__version__, | |
| description=( | |
| "Bandit-based dynamic dispatch-rule selection for job-shop scheduling. " | |
| "Public benchmarks (Taillard / Lawrence / Brandimarte / DMU) supported " | |
| "out of the box." | |
| ), | |
| ) | |
| def _instance_from_schema(s: JSPInstanceSchema) -> JSPInstance: | |
| return JSPInstance( | |
| name=s.name, | |
| family="api", | |
| n_jobs=s.n_jobs, | |
| n_machines=s.n_machines, | |
| ops=[[(m, p) for m, p in row] for row in s.ops], | |
| optimum=s.optimum, | |
| due_dates=s.due_dates, | |
| weights=s.weights, | |
| ) | |
| def health() -> dict: | |
| return {"status": "ok"} | |
| def version() -> dict: | |
| bandit = _load_linucb() | |
| return { | |
| "dash_jsp_version": dash_jsp.__version__, | |
| "linucb_arms": [a.name for a in bandit.arms], | |
| "linucb_pulls": bandit.n_pulls, | |
| } | |
| def solve(req: SolveRequest) -> SolveResponse: | |
| inst = _instance_from_schema(req.instance) | |
| method = req.method.lower() | |
| if method in ALL_RULES: | |
| result = simulate(inst, get_rule(method)) | |
| log = None | |
| elif method == "linucb": | |
| dispatcher = _load_linucb() | |
| result = simulate_with_dispatcher(inst, dispatcher, record_choices=True) | |
| log = result.rule_choice_log | |
| elif method == "thompson": | |
| dispatcher = _load_thompson() | |
| result = simulate_with_dispatcher(inst, dispatcher, record_choices=True) | |
| log = result.rule_choice_log | |
| elif method == "ensemble": | |
| dispatcher = _load_ensemble() | |
| result = simulate_with_dispatcher(inst, dispatcher, record_choices=True) | |
| log = result.rule_choice_log | |
| else: | |
| raise HTTPException(400, f"Unknown method: {method!r}") | |
| gap = ( | |
| 100.0 * (result.makespan - inst.optimum) / inst.optimum | |
| if inst.optimum | |
| else None | |
| ) | |
| return SolveResponse( | |
| method=method, | |
| makespan=result.makespan, | |
| total_tardiness=result.total_tardiness, | |
| avg_cycle_time=result.avg_cycle_time, | |
| machine_utilization=result.machine_utilization, | |
| n_dispatch_decisions=result.n_dispatch_decisions, | |
| runtime_ms=result.runtime_ms, | |
| optimality_gap_pct=gap, | |
| rule_choice_log=log, | |
| ) | |
| def dispatch_event(req: DispatchEventRequest) -> DispatchEventResponse: | |
| """Single-decision inference for live WMS integration.""" | |
| if len(req.context) != 32: | |
| raise HTTPException(400, f"context must be 32-dim, got {len(req.context)}") | |
| ctx = np.asarray(req.context, dtype=np.float64) | |
| method = req.method.lower() | |
| if method == "linucb": | |
| b = _load_linucb() | |
| arm = b.select(ctx) | |
| elif method == "thompson": | |
| b = _load_thompson() | |
| arm = b.select(ctx) | |
| elif method == "ensemble": | |
| b = _load_ensemble() | |
| arm = b.linucb.select(ctx) # majority via LinUCB UCB margin | |
| else: | |
| raise HTTPException(400, f"Unknown method: {method!r}") | |
| return DispatchEventResponse( | |
| chosen_rule=b.arms[arm].name if hasattr(b, "arms") else b.linucb.arms[arm].name, | |
| arm_index=int(arm), | |
| ) | |
| def main() -> None: # pragma: no cover | |
| import uvicorn | |
| uvicorn.run( | |
| "dash_jsp.api.server:app", | |
| host="0.0.0.0", | |
| port=int(os.environ.get("DASH_JSP_PORT", "8000")), | |
| reload=False, | |
| ) | |
| if __name__ == "__main__": # pragma: no cover | |
| main() | |