Vittal-M's picture
Trainer Space: download -> train -> push -> sleep
52c82e4
"""FastAPI server exposing DASH-JSP inference.
Endpoints
---------
POST /solve — solve a complete JSP instance (returns schedule + metrics)
POST /dispatch_event — single-decision inference (stateless WMS integration)
GET /health — liveness probe
GET /version — version + arm names
"""
from __future__ import annotations
import os
from pathlib import Path
from typing import Optional
import numpy as np
from fastapi import FastAPI, HTTPException
import dash_jsp
from dash_jsp.api.schemas import (
DispatchEventRequest,
DispatchEventResponse,
JSPInstanceSchema,
SolveRequest,
SolveResponse,
)
from dash_jsp.bandit.linucb import LinUCBDispatcher
from dash_jsp.bandit.thompson import ThompsonDispatcher
from dash_jsp.bandit.ensemble import EnsembleDispatcher
from dash_jsp.benchmarks.format import JSPInstance
from dash_jsp.heuristics.rules import ALL_RULES, get_rule
from dash_jsp.simulator.jsp_sim import (
simulate,
simulate_with_dispatcher,
)
# ---------------------------------------------------------------------------
# Lazy model loading
# ---------------------------------------------------------------------------
_MODELS: dict = {}
def _models_dir() -> Path:
env = os.environ.get("DASH_JSP_MODELS_DIR")
if env:
return Path(env)
return Path(__file__).resolve().parent.parent.parent / "models"
def _load_linucb() -> LinUCBDispatcher:
if "linucb" not in _MODELS:
path = _models_dir() / "bandit_linucb.npz"
if path.exists():
_MODELS["linucb"] = LinUCBDispatcher.load(str(path))
else:
# Fall back to a fresh, untrained bandit (still usable but will explore)
_MODELS["linucb"] = LinUCBDispatcher()
return _MODELS["linucb"]
def _load_thompson() -> ThompsonDispatcher:
if "thompson" not in _MODELS:
path = _models_dir() / "bandit_thompson.npz"
if path.exists():
_MODELS["thompson"] = ThompsonDispatcher.load(str(path))
else:
_MODELS["thompson"] = ThompsonDispatcher()
return _MODELS["thompson"]
def _load_ensemble() -> EnsembleDispatcher:
if "ensemble" not in _MODELS:
_MODELS["ensemble"] = EnsembleDispatcher(
linucb=_load_linucb(),
thompson=_load_thompson(),
)
return _MODELS["ensemble"]
# ---------------------------------------------------------------------------
# FastAPI app
# ---------------------------------------------------------------------------
app = FastAPI(
title="DASH-JSP",
version=dash_jsp.__version__,
description=(
"Bandit-based dynamic dispatch-rule selection for job-shop scheduling. "
"Public benchmarks (Taillard / Lawrence / Brandimarte / DMU) supported "
"out of the box."
),
)
def _instance_from_schema(s: JSPInstanceSchema) -> JSPInstance:
return JSPInstance(
name=s.name,
family="api",
n_jobs=s.n_jobs,
n_machines=s.n_machines,
ops=[[(m, p) for m, p in row] for row in s.ops],
optimum=s.optimum,
due_dates=s.due_dates,
weights=s.weights,
)
@app.get("/health")
def health() -> dict:
return {"status": "ok"}
@app.get("/version")
def version() -> dict:
bandit = _load_linucb()
return {
"dash_jsp_version": dash_jsp.__version__,
"linucb_arms": [a.name for a in bandit.arms],
"linucb_pulls": bandit.n_pulls,
}
@app.post("/solve", response_model=SolveResponse)
def solve(req: SolveRequest) -> SolveResponse:
inst = _instance_from_schema(req.instance)
method = req.method.lower()
if method in ALL_RULES:
result = simulate(inst, get_rule(method))
log = None
elif method == "linucb":
dispatcher = _load_linucb()
result = simulate_with_dispatcher(inst, dispatcher, record_choices=True)
log = result.rule_choice_log
elif method == "thompson":
dispatcher = _load_thompson()
result = simulate_with_dispatcher(inst, dispatcher, record_choices=True)
log = result.rule_choice_log
elif method == "ensemble":
dispatcher = _load_ensemble()
result = simulate_with_dispatcher(inst, dispatcher, record_choices=True)
log = result.rule_choice_log
else:
raise HTTPException(400, f"Unknown method: {method!r}")
gap = (
100.0 * (result.makespan - inst.optimum) / inst.optimum
if inst.optimum
else None
)
return SolveResponse(
method=method,
makespan=result.makespan,
total_tardiness=result.total_tardiness,
avg_cycle_time=result.avg_cycle_time,
machine_utilization=result.machine_utilization,
n_dispatch_decisions=result.n_dispatch_decisions,
runtime_ms=result.runtime_ms,
optimality_gap_pct=gap,
rule_choice_log=log,
)
@app.post("/dispatch_event", response_model=DispatchEventResponse)
def dispatch_event(req: DispatchEventRequest) -> DispatchEventResponse:
"""Single-decision inference for live WMS integration."""
if len(req.context) != 32:
raise HTTPException(400, f"context must be 32-dim, got {len(req.context)}")
ctx = np.asarray(req.context, dtype=np.float64)
method = req.method.lower()
if method == "linucb":
b = _load_linucb()
arm = b.select(ctx)
elif method == "thompson":
b = _load_thompson()
arm = b.select(ctx)
elif method == "ensemble":
b = _load_ensemble()
arm = b.linucb.select(ctx) # majority via LinUCB UCB margin
else:
raise HTTPException(400, f"Unknown method: {method!r}")
return DispatchEventResponse(
chosen_rule=b.arms[arm].name if hasattr(b, "arms") else b.linucb.arms[arm].name,
arm_index=int(arm),
)
def main() -> None: # pragma: no cover
import uvicorn
uvicorn.run(
"dash_jsp.api.server:app",
host="0.0.0.0",
port=int(os.environ.get("DASH_JSP_PORT", "8000")),
reload=False,
)
if __name__ == "__main__": # pragma: no cover
main()