"""FastAPI server exposing DASH-JSP inference. Endpoints --------- POST /solve — solve a complete JSP instance (returns schedule + metrics) POST /dispatch_event — single-decision inference (stateless WMS integration) GET /health — liveness probe GET /version — version + arm names """ from __future__ import annotations import os from pathlib import Path from typing import Optional import numpy as np from fastapi import FastAPI, HTTPException import dash_jsp from dash_jsp.api.schemas import ( DispatchEventRequest, DispatchEventResponse, JSPInstanceSchema, SolveRequest, SolveResponse, ) from dash_jsp.bandit.linucb import LinUCBDispatcher from dash_jsp.bandit.thompson import ThompsonDispatcher from dash_jsp.bandit.ensemble import EnsembleDispatcher from dash_jsp.benchmarks.format import JSPInstance from dash_jsp.heuristics.rules import ALL_RULES, get_rule from dash_jsp.simulator.jsp_sim import ( simulate, simulate_with_dispatcher, ) # --------------------------------------------------------------------------- # Lazy model loading # --------------------------------------------------------------------------- _MODELS: dict = {} def _models_dir() -> Path: env = os.environ.get("DASH_JSP_MODELS_DIR") if env: return Path(env) return Path(__file__).resolve().parent.parent.parent / "models" def _load_linucb() -> LinUCBDispatcher: if "linucb" not in _MODELS: path = _models_dir() / "bandit_linucb.npz" if path.exists(): _MODELS["linucb"] = LinUCBDispatcher.load(str(path)) else: # Fall back to a fresh, untrained bandit (still usable but will explore) _MODELS["linucb"] = LinUCBDispatcher() return _MODELS["linucb"] def _load_thompson() -> ThompsonDispatcher: if "thompson" not in _MODELS: path = _models_dir() / "bandit_thompson.npz" if path.exists(): _MODELS["thompson"] = ThompsonDispatcher.load(str(path)) else: _MODELS["thompson"] = ThompsonDispatcher() return _MODELS["thompson"] def _load_ensemble() -> EnsembleDispatcher: if "ensemble" not in _MODELS: _MODELS["ensemble"] = EnsembleDispatcher( linucb=_load_linucb(), thompson=_load_thompson(), ) return _MODELS["ensemble"] # --------------------------------------------------------------------------- # FastAPI app # --------------------------------------------------------------------------- app = FastAPI( title="DASH-JSP", version=dash_jsp.__version__, description=( "Bandit-based dynamic dispatch-rule selection for job-shop scheduling. " "Public benchmarks (Taillard / Lawrence / Brandimarte / DMU) supported " "out of the box." ), ) def _instance_from_schema(s: JSPInstanceSchema) -> JSPInstance: return JSPInstance( name=s.name, family="api", n_jobs=s.n_jobs, n_machines=s.n_machines, ops=[[(m, p) for m, p in row] for row in s.ops], optimum=s.optimum, due_dates=s.due_dates, weights=s.weights, ) @app.get("/health") def health() -> dict: return {"status": "ok"} @app.get("/version") def version() -> dict: bandit = _load_linucb() return { "dash_jsp_version": dash_jsp.__version__, "linucb_arms": [a.name for a in bandit.arms], "linucb_pulls": bandit.n_pulls, } @app.post("/solve", response_model=SolveResponse) def solve(req: SolveRequest) -> SolveResponse: inst = _instance_from_schema(req.instance) method = req.method.lower() if method in ALL_RULES: result = simulate(inst, get_rule(method)) log = None elif method == "linucb": dispatcher = _load_linucb() result = simulate_with_dispatcher(inst, dispatcher, record_choices=True) log = result.rule_choice_log elif method == "thompson": dispatcher = _load_thompson() result = simulate_with_dispatcher(inst, dispatcher, record_choices=True) log = result.rule_choice_log elif method == "ensemble": dispatcher = _load_ensemble() result = simulate_with_dispatcher(inst, dispatcher, record_choices=True) log = result.rule_choice_log else: raise HTTPException(400, f"Unknown method: {method!r}") gap = ( 100.0 * (result.makespan - inst.optimum) / inst.optimum if inst.optimum else None ) return SolveResponse( method=method, makespan=result.makespan, total_tardiness=result.total_tardiness, avg_cycle_time=result.avg_cycle_time, machine_utilization=result.machine_utilization, n_dispatch_decisions=result.n_dispatch_decisions, runtime_ms=result.runtime_ms, optimality_gap_pct=gap, rule_choice_log=log, ) @app.post("/dispatch_event", response_model=DispatchEventResponse) def dispatch_event(req: DispatchEventRequest) -> DispatchEventResponse: """Single-decision inference for live WMS integration.""" if len(req.context) != 32: raise HTTPException(400, f"context must be 32-dim, got {len(req.context)}") ctx = np.asarray(req.context, dtype=np.float64) method = req.method.lower() if method == "linucb": b = _load_linucb() arm = b.select(ctx) elif method == "thompson": b = _load_thompson() arm = b.select(ctx) elif method == "ensemble": b = _load_ensemble() arm = b.linucb.select(ctx) # majority via LinUCB UCB margin else: raise HTTPException(400, f"Unknown method: {method!r}") return DispatchEventResponse( chosen_rule=b.arms[arm].name if hasattr(b, "arms") else b.linucb.arms[arm].name, arm_index=int(arm), ) def main() -> None: # pragma: no cover import uvicorn uvicorn.run( "dash_jsp.api.server:app", host="0.0.0.0", port=int(os.environ.get("DASH_JSP_PORT", "8000")), reload=False, ) if __name__ == "__main__": # pragma: no cover main()