Vittal-M's picture
Trainer Space: download -> train -> push -> sleep
52c82e4
"""Command-line entry point: `dash-jsp <command>`."""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import dash_jsp
def _cmd_solve(args: argparse.Namespace) -> int:
"""Solve a single instance file with the chosen method."""
from dash_jsp.benchmarks.taillard import parse_taillard_text
from dash_jsp.benchmarks.lawrence import parse_lawrence_text
from dash_jsp.heuristics.rules import ALL_RULES, get_rule
from dash_jsp.simulator.jsp_sim import simulate, simulate_with_dispatcher
text = Path(args.instance).read_text()
name = Path(args.instance).stem
if args.format == "taillard":
inst = parse_taillard_text(text, name)
elif args.format == "lawrence":
inst = parse_lawrence_text(text, name)
else:
raise ValueError(f"Unknown --format: {args.format}")
method = args.method.lower()
if method in ALL_RULES:
result = simulate(inst, get_rule(method))
elif method == "linucb":
from dash_jsp.bandit.linucb import LinUCBDispatcher
models_dir = Path(args.models_dir) if args.models_dir else None
if models_dir and (models_dir / "bandit_linucb.npz").exists():
dispatcher = LinUCBDispatcher.load(str(models_dir / "bandit_linucb.npz"))
else:
dispatcher = LinUCBDispatcher()
result = simulate_with_dispatcher(inst, dispatcher, record_choices=True)
elif method == "thompson":
from dash_jsp.bandit.thompson import ThompsonDispatcher
models_dir = Path(args.models_dir) if args.models_dir else None
if models_dir and (models_dir / "bandit_thompson.npz").exists():
dispatcher = ThompsonDispatcher.load(str(models_dir / "bandit_thompson.npz"))
else:
dispatcher = ThompsonDispatcher()
result = simulate_with_dispatcher(inst, dispatcher, record_choices=True)
else:
print(f"Unknown method: {method}", file=sys.stderr)
return 2
out = {
"instance": inst.name,
"method": method,
"makespan": result.makespan,
"total_tardiness": result.total_tardiness,
"avg_cycle_time": result.avg_cycle_time,
"machine_utilization": result.machine_utilization,
"n_dispatch_decisions": result.n_dispatch_decisions,
"runtime_ms": result.runtime_ms,
"optimum": inst.optimum,
"optimality_gap_pct": (
100.0 * (result.makespan - inst.optimum) / inst.optimum
if inst.optimum else None
),
}
print(json.dumps(out, indent=2))
return 0
def _cmd_serve(args: argparse.Namespace) -> int: # pragma: no cover
"""Launch the FastAPI server."""
import uvicorn
uvicorn.run(
"dash_jsp.api.server:app",
host=args.host,
port=args.port,
reload=False,
)
return 0
def _cmd_version(_args: argparse.Namespace) -> int:
print(f"dash-jsp {dash_jsp.__version__}")
return 0
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="dash-jsp",
description="DASH-JSP: bandit-based dynamic heuristic switching for JSP",
)
sub = parser.add_subparsers(dest="cmd", required=True)
# solve
p_solve = sub.add_parser("solve", help="Solve a JSP instance file")
p_solve.add_argument("instance", type=str, help="Path to instance .txt file")
p_solve.add_argument(
"--format",
default="taillard",
choices=["taillard", "lawrence"],
help="Instance format",
)
p_solve.add_argument(
"--method",
default="linucb",
help="fifo|edd|cr|atc|wspt|slack|linucb|thompson",
)
p_solve.add_argument(
"--models-dir",
default=None,
help="Directory containing bandit model artifacts",
)
p_solve.set_defaults(fn=_cmd_solve)
# serve
p_serve = sub.add_parser("serve", help="Launch the REST API server")
p_serve.add_argument("--host", default="0.0.0.0")
p_serve.add_argument("--port", default=8000, type=int)
p_serve.set_defaults(fn=_cmd_serve)
# version
p_version = sub.add_parser("version", help="Print version")
p_version.set_defaults(fn=_cmd_version)
return parser
def main(argv=None) -> int:
parser = build_parser()
args = parser.parse_args(argv)
return args.fn(args)
if __name__ == "__main__": # pragma: no cover
raise SystemExit(main())