Spaces:
Running
Running
Commit Β·
126110b
1
Parent(s): ca3950c
fix: default to DQN mode in inference.py to prevent 30min timeout - Switch default from llm to dqn (0.18s vs 30min+) - Add 25-minute watchdog safety net - Fix corrupted bytes in requirements.txt - LLM mode still available via --mode llm
Browse files- inference.py +113 -106
- requirements.txt +1 -0
inference.py
CHANGED
|
@@ -1,23 +1,28 @@
|
|
| 1 |
"""
|
| 2 |
OpenEnv baseline inference script.
|
| 3 |
|
| 4 |
-
Runs an
|
| 5 |
-
|
| 6 |
|
| 7 |
Usage:
|
| 8 |
-
#
|
| 9 |
-
set OPENAI_API_KEY=sk-...
|
| 10 |
python inference.py
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
python inference.py
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
# Use
|
| 16 |
-
python inference.py --mode
|
| 17 |
|
| 18 |
Environment variables:
|
| 19 |
-
OPENAI_API_KEY β
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
| 21 |
"""
|
| 22 |
|
| 23 |
from __future__ import annotations
|
|
@@ -25,19 +30,20 @@ from __future__ import annotations
|
|
| 25 |
import argparse
|
| 26 |
import json
|
| 27 |
import os
|
|
|
|
| 28 |
import sys
|
|
|
|
| 29 |
import time
|
| 30 |
from typing import Callable, Dict, Optional
|
| 31 |
|
| 32 |
import numpy as np
|
| 33 |
|
| 34 |
-
# ---
|
| 35 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://openrouter.ai/api/v1")
|
| 36 |
MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b:free")
|
| 37 |
-
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
|
| 38 |
-
# Optional - if you use from_docker_image():
|
| 39 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 40 |
-
#
|
| 41 |
|
| 42 |
from environment import BusRoutingEnv, Observation, Action
|
| 43 |
from tasks import TASKS, TaskConfig, get_task
|
|
@@ -45,7 +51,7 @@ from grader import grade_all_tasks, grade_task_1, grade_task_2, grade_task_3
|
|
| 45 |
|
| 46 |
|
| 47 |
# ---------------------------------------------------------------------------
|
| 48 |
-
#
|
| 49 |
# ---------------------------------------------------------------------------
|
| 50 |
|
| 51 |
def log_start(**kwargs):
|
|
@@ -56,15 +62,12 @@ def log_start(**kwargs):
|
|
| 56 |
|
| 57 |
def log_step(**kwargs):
|
| 58 |
"""Emit [STEP] log with key-value pairs."""
|
| 59 |
-
# Convert potential None or complex types to strings
|
| 60 |
vals = " ".join(f"{k}={v if v is not None else 'null'}" for k, v in kwargs.items())
|
| 61 |
print(f"[STEP] {vals}", flush=True)
|
| 62 |
|
| 63 |
|
| 64 |
def log_end(**kwargs):
|
| 65 |
"""Emit [END] log with key-value pairs."""
|
| 66 |
-
import json
|
| 67 |
-
# Special handling for rewards list to keep it as a JSON string in the log
|
| 68 |
payload = []
|
| 69 |
for k, v in kwargs.items():
|
| 70 |
if isinstance(v, (list, np.ndarray)):
|
|
@@ -77,84 +80,71 @@ def log_end(**kwargs):
|
|
| 77 |
|
| 78 |
|
| 79 |
# ---------------------------------------------------------------------------
|
| 80 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
# ---------------------------------------------------------------------------
|
| 82 |
|
| 83 |
class MockLLMAgent:
|
| 84 |
-
"""
|
| 85 |
-
A deterministic heuristic agent that mimics what a reasonable LLM
|
| 86 |
-
would output given the observation description. Used as a fallback
|
| 87 |
-
when ``OPENAI_API_KEY`` is not set.
|
| 88 |
-
"""
|
| 89 |
|
| 90 |
def __init__(self, seed: int = 42):
|
| 91 |
self.rng = np.random.default_rng(seed)
|
| 92 |
|
| 93 |
def __call__(self, obs: np.ndarray) -> int:
|
| 94 |
-
# obs = [pos, fuel, onboard, q0, q1, q2, time]
|
| 95 |
fuel = float(obs[1])
|
| 96 |
q0, q1, q2 = float(obs[3]), float(obs[4]), float(obs[5])
|
| 97 |
-
|
| 98 |
-
# If fuel is critically low, wait (cheapest action)
|
| 99 |
if fuel < 10.0:
|
| 100 |
return 2
|
| 101 |
-
|
| 102 |
-
# Serve the largest nearby queue
|
| 103 |
if q0 >= max(q1, q2) and q0 > 2:
|
| 104 |
-
return 2
|
| 105 |
if q1 >= q2:
|
| 106 |
-
return 0
|
| 107 |
-
return 0
|
| 108 |
|
| 109 |
|
| 110 |
# ---------------------------------------------------------------------------
|
| 111 |
-
# OpenAI LLM agent
|
| 112 |
# ---------------------------------------------------------------------------
|
| 113 |
|
| 114 |
class OpenAIAgent:
|
| 115 |
-
"""
|
| 116 |
-
Agent that queries the OpenAI Chat Completions API to decide actions.
|
| 117 |
-
|
| 118 |
-
The prompt describes the observation space, valid actions, and asks the
|
| 119 |
-
model to return a JSON object ``{"action": 0|1|2}``.
|
| 120 |
-
"""
|
| 121 |
|
| 122 |
SYSTEM_PROMPT = (
|
| 123 |
-
"
|
| 124 |
-
"
|
| 125 |
-
"
|
| 126 |
-
"
|
| 127 |
-
"queue_at_current_stop, queue_at_next_stop, queue_at_stop_after_next, "
|
| 128 |
-
"time_step]\n\n"
|
| 129 |
-
"ACTIONS:\n"
|
| 130 |
-
" 0 = move to next stop AND pick up passengers\n"
|
| 131 |
-
" 1 = move to next stop but SKIP pickup\n"
|
| 132 |
-
" 2 = wait at current stop AND pick up passengers\n\n"
|
| 133 |
-
"GOALS:\n"
|
| 134 |
-
" - Minimise passenger wait time\n"
|
| 135 |
-
" - Maximise passengers picked up\n"
|
| 136 |
-
" - Conserve fuel (moving costs 1.0, waiting costs 0.2)\n"
|
| 137 |
-
" - Visit all stops evenly (don't camp at one stop)\n\n"
|
| 138 |
-
"Respond ONLY with a JSON object: {\"action\": <0, 1, or 2>}"
|
| 139 |
)
|
| 140 |
|
| 141 |
-
def __init__(
|
| 142 |
-
self,
|
| 143 |
-
temperature: float = 0.0,
|
| 144 |
-
):
|
| 145 |
try:
|
| 146 |
from openai import OpenAI
|
| 147 |
except ImportError:
|
| 148 |
-
raise ImportError(
|
| 149 |
-
|
| 150 |
-
)
|
| 151 |
-
# All LLM calls use the OpenAI client configured via these variables
|
| 152 |
self.client = OpenAI(
|
| 153 |
base_url=API_BASE_URL,
|
| 154 |
api_key=HF_TOKEN,
|
| 155 |
)
|
| 156 |
self.model = MODEL_NAME
|
| 157 |
self.temperature = temperature
|
|
|
|
| 158 |
|
| 159 |
def __call__(self, obs: np.ndarray) -> int:
|
| 160 |
user_msg = (
|
|
@@ -170,6 +160,7 @@ class OpenAIAgent:
|
|
| 170 |
],
|
| 171 |
temperature=self.temperature,
|
| 172 |
max_tokens=20,
|
|
|
|
| 173 |
)
|
| 174 |
text = response.choices[0].message.content.strip()
|
| 175 |
data = json.loads(text)
|
|
@@ -178,71 +169,87 @@ class OpenAIAgent:
|
|
| 178 |
action = 0
|
| 179 |
return action
|
| 180 |
except Exception as e:
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
return 0
|
| 184 |
|
| 185 |
|
| 186 |
# ---------------------------------------------------------------------------
|
| 187 |
-
#
|
| 188 |
# ---------------------------------------------------------------------------
|
| 189 |
|
| 190 |
def build_agent(mode: str, model_path: Optional[str] = None) -> Callable[[np.ndarray], int]:
|
| 191 |
"""
|
| 192 |
-
Build the agent callable
|
| 193 |
|
| 194 |
Modes:
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
"""
|
| 199 |
if mode == "dqn":
|
| 200 |
from agent import DQNAgent
|
| 201 |
|
| 202 |
if model_path is None:
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
agent = DQNAgent.load(model_path)
|
| 208 |
return lambda obs: agent.act(obs, greedy=True)
|
| 209 |
|
| 210 |
if mode == "llm":
|
| 211 |
if HF_TOKEN or API_BASE_URL != "<your-active-api-url>":
|
| 212 |
-
print("[INFO] Using
|
| 213 |
return OpenAIAgent()
|
| 214 |
else:
|
| 215 |
-
print("[WARN]
|
| 216 |
return MockLLMAgent()
|
| 217 |
|
| 218 |
# Default: mock
|
| 219 |
-
print("[INFO] Using mock (heuristic) agent.")
|
| 220 |
return MockLLMAgent()
|
| 221 |
|
| 222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
def run_inference(mode: str, model_path: Optional[str], episodes: int) -> Dict:
|
| 224 |
"""Run inference across all three tasks and return the grade report."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
agent = build_agent(mode, model_path)
|
| 226 |
-
|
| 227 |
-
print("
|
| 228 |
-
print(
|
| 229 |
-
print(f"
|
| 230 |
-
print(f"
|
| 231 |
-
print(f"
|
|
|
|
|
|
|
| 232 |
|
| 233 |
t0 = time.time()
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
# We run the report and log its high-level outcome in the END block
|
| 239 |
-
# Note: the sample script logs every step during a simulation,
|
| 240 |
-
# but since our grader runs multiple episodes, we will log the aggregate results.
|
| 241 |
report = grade_all_tasks(agent, episodes=episodes)
|
| 242 |
|
| 243 |
-
# Simplified step log for aggregate progress
|
| 244 |
log_step(step=episodes, action="evaluate_all", reward=report["aggregate_score"], done="true", error="null")
|
| 245 |
-
|
| 246 |
log_end(
|
| 247 |
success=bool(report["aggregate_score"] > 0.7),
|
| 248 |
steps=episodes,
|
|
@@ -255,20 +262,20 @@ def run_inference(mode: str, model_path: Optional[str], episodes: int) -> Dict:
|
|
| 255 |
# Pretty print
|
| 256 |
for task_key in ("task_easy", "task_medium", "task_hard"):
|
| 257 |
tr = report[task_key]
|
| 258 |
-
print(f"{'-' * 55}")
|
| 259 |
-
print(f" {tr['task']} ({tr['difficulty']}) -> score: {tr['score']:.4f}")
|
| 260 |
-
print(f"{'-' * 55}")
|
| 261 |
for section in ("rl_agent", "baseline_greedy"):
|
| 262 |
-
print(f" [{section}]")
|
| 263 |
for k, v in tr[section].items():
|
| 264 |
-
print(f" {k}: {v:.4f}")
|
| 265 |
-
print()
|
| 266 |
|
| 267 |
-
print(f"{'=' * 55}")
|
| 268 |
-
print(f" AGGREGATE SCORE : {report['aggregate_score']:.4f}")
|
| 269 |
-
print(f" Task weights : {report['weights']}")
|
| 270 |
-
print(f" Time elapsed : {elapsed:.2f}s")
|
| 271 |
-
print(f"{'=' * 55}")
|
| 272 |
|
| 273 |
return report
|
| 274 |
|
|
@@ -284,8 +291,8 @@ def main() -> None:
|
|
| 284 |
p.add_argument(
|
| 285 |
"--mode",
|
| 286 |
choices=["llm", "mock", "dqn"],
|
| 287 |
-
default="
|
| 288 |
-
help="Agent mode: '
|
| 289 |
)
|
| 290 |
p.add_argument(
|
| 291 |
"--model-path",
|
|
@@ -296,7 +303,7 @@ def main() -> None:
|
|
| 296 |
p.add_argument(
|
| 297 |
"--episodes",
|
| 298 |
type=int,
|
| 299 |
-
default=
|
| 300 |
help="Number of evaluation episodes per task.",
|
| 301 |
)
|
| 302 |
args = p.parse_args()
|
|
|
|
| 1 |
"""
|
| 2 |
OpenEnv baseline inference script.
|
| 3 |
|
| 4 |
+
Runs an agent on all three task difficulty tiers and prints reproducible
|
| 5 |
+
scores with structured logging.
|
| 6 |
|
| 7 |
Usage:
|
| 8 |
+
# Default: use pre-trained DQN model (completes in ~30 seconds):
|
|
|
|
| 9 |
python inference.py
|
| 10 |
|
| 11 |
+
# Explicitly use DQN with a specific checkpoint:
|
| 12 |
+
python inference.py --mode dqn --model-path models/dqn_bus_v6_best.pt
|
| 13 |
+
|
| 14 |
+
# Use LLM via API (requires API key, slower):
|
| 15 |
+
python inference.py --mode llm
|
| 16 |
|
| 17 |
+
# Use deterministic mock heuristic:
|
| 18 |
+
python inference.py --mode mock
|
| 19 |
|
| 20 |
Environment variables:
|
| 21 |
+
OPENAI_API_KEY β API key for LLM mode (optional)
|
| 22 |
+
MODEL_NAME β LLM model name (default: openai/gpt-oss-120b:free)
|
| 23 |
+
API_BASE_URL β API endpoint (default: https://openrouter.ai/api/v1)
|
| 24 |
+
MAX_EVAL_EPISODES β Episodes per task (default: 2)
|
| 25 |
+
EVAL_TIMEOUT β Global timeout in seconds (default: 1500 = 25 min)
|
| 26 |
"""
|
| 27 |
|
| 28 |
from __future__ import annotations
|
|
|
|
| 30 |
import argparse
|
| 31 |
import json
|
| 32 |
import os
|
| 33 |
+
import signal
|
| 34 |
import sys
|
| 35 |
+
import threading
|
| 36 |
import time
|
| 37 |
from typing import Callable, Dict, Optional
|
| 38 |
|
| 39 |
import numpy as np
|
| 40 |
|
| 41 |
+
# --- Configuration ---
|
| 42 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://openrouter.ai/api/v1")
|
| 43 |
MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b:free")
|
| 44 |
+
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY")
|
|
|
|
| 45 |
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
|
| 46 |
+
GLOBAL_TIMEOUT = int(os.getenv("EVAL_TIMEOUT", "1500")) # 25 minutes
|
| 47 |
|
| 48 |
from environment import BusRoutingEnv, Observation, Action
|
| 49 |
from tasks import TASKS, TaskConfig, get_task
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
# ---------------------------------------------------------------------------
|
| 54 |
+
# Structured Logging (Mandatory Hackathon Requirement)
|
| 55 |
# ---------------------------------------------------------------------------
|
| 56 |
|
| 57 |
def log_start(**kwargs):
|
|
|
|
| 62 |
|
| 63 |
def log_step(**kwargs):
|
| 64 |
"""Emit [STEP] log with key-value pairs."""
|
|
|
|
| 65 |
vals = " ".join(f"{k}={v if v is not None else 'null'}" for k, v in kwargs.items())
|
| 66 |
print(f"[STEP] {vals}", flush=True)
|
| 67 |
|
| 68 |
|
| 69 |
def log_end(**kwargs):
|
| 70 |
"""Emit [END] log with key-value pairs."""
|
|
|
|
|
|
|
| 71 |
payload = []
|
| 72 |
for k, v in kwargs.items():
|
| 73 |
if isinstance(v, (list, np.ndarray)):
|
|
|
|
| 80 |
|
| 81 |
|
| 82 |
# ---------------------------------------------------------------------------
|
| 83 |
+
# Watchdog timer β kills process if evaluation exceeds global timeout
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
|
| 86 |
+
def _start_watchdog(timeout_seconds: int) -> None:
|
| 87 |
+
"""Start a background thread that kills the process after timeout."""
|
| 88 |
+
def _watchdog():
|
| 89 |
+
time.sleep(timeout_seconds)
|
| 90 |
+
print(f"\n[TIMEOUT] Global timeout of {timeout_seconds}s reached. Exiting.", flush=True)
|
| 91 |
+
log_end(success=False, steps=0, score=0.0, rewards=[0, 0, 0], reason="global_timeout")
|
| 92 |
+
os._exit(1)
|
| 93 |
+
|
| 94 |
+
t = threading.Thread(target=_watchdog, daemon=True)
|
| 95 |
+
t.start()
|
| 96 |
+
print(f"[INFO] Watchdog armed: {timeout_seconds}s global deadline.", flush=True)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# ---------------------------------------------------------------------------
|
| 100 |
+
# Mock LLM agent (deterministic fallback)
|
| 101 |
# ---------------------------------------------------------------------------
|
| 102 |
|
| 103 |
class MockLLMAgent:
|
| 104 |
+
"""Deterministic heuristic agent β fallback when API is unavailable."""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
def __init__(self, seed: int = 42):
|
| 107 |
self.rng = np.random.default_rng(seed)
|
| 108 |
|
| 109 |
def __call__(self, obs: np.ndarray) -> int:
|
|
|
|
| 110 |
fuel = float(obs[1])
|
| 111 |
q0, q1, q2 = float(obs[3]), float(obs[4]), float(obs[5])
|
|
|
|
|
|
|
| 112 |
if fuel < 10.0:
|
| 113 |
return 2
|
|
|
|
|
|
|
| 114 |
if q0 >= max(q1, q2) and q0 > 2:
|
| 115 |
+
return 2
|
| 116 |
if q1 >= q2:
|
| 117 |
+
return 0
|
| 118 |
+
return 0
|
| 119 |
|
| 120 |
|
| 121 |
# ---------------------------------------------------------------------------
|
| 122 |
+
# OpenAI LLM agent (with strict per-call timeout)
|
| 123 |
# ---------------------------------------------------------------------------
|
| 124 |
|
| 125 |
class OpenAIAgent:
|
| 126 |
+
"""Agent that queries an LLM API β used only when --mode llm is explicit."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
SYSTEM_PROMPT = (
|
| 129 |
+
"RL bus agent. Obs: [pos (0-11), fuel (0-100), pax_onboard, q_curr, q_next, q_after, step].\n"
|
| 130 |
+
"Actions: 0=move+pickup, 1=move+skip, 2=wait+pickup.\n"
|
| 131 |
+
"Goals: Max pickups, min wait, save fuel.\n"
|
| 132 |
+
"Respond ONLY: {\"action\": 0|1|2}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
)
|
| 134 |
|
| 135 |
+
def __init__(self, temperature: float = 0.0):
|
|
|
|
|
|
|
|
|
|
| 136 |
try:
|
| 137 |
from openai import OpenAI
|
| 138 |
except ImportError:
|
| 139 |
+
raise ImportError("openai package not installed. Run: pip install openai")
|
| 140 |
+
|
|
|
|
|
|
|
| 141 |
self.client = OpenAI(
|
| 142 |
base_url=API_BASE_URL,
|
| 143 |
api_key=HF_TOKEN,
|
| 144 |
)
|
| 145 |
self.model = MODEL_NAME
|
| 146 |
self.temperature = temperature
|
| 147 |
+
self._fallback = MockLLMAgent()
|
| 148 |
|
| 149 |
def __call__(self, obs: np.ndarray) -> int:
|
| 150 |
user_msg = (
|
|
|
|
| 160 |
],
|
| 161 |
temperature=self.temperature,
|
| 162 |
max_tokens=20,
|
| 163 |
+
timeout=8.0, # Strict 8s timeout per call
|
| 164 |
)
|
| 165 |
text = response.choices[0].message.content.strip()
|
| 166 |
data = json.loads(text)
|
|
|
|
| 169 |
action = 0
|
| 170 |
return action
|
| 171 |
except Exception as e:
|
| 172 |
+
print(f"[WARN] LLM call failed ({type(e).__name__}), using heuristic fallback", flush=True)
|
| 173 |
+
return self._fallback(obs)
|
|
|
|
| 174 |
|
| 175 |
|
| 176 |
# ---------------------------------------------------------------------------
|
| 177 |
+
# Agent builder
|
| 178 |
# ---------------------------------------------------------------------------
|
| 179 |
|
| 180 |
def build_agent(mode: str, model_path: Optional[str] = None) -> Callable[[np.ndarray], int]:
|
| 181 |
"""
|
| 182 |
+
Build the agent callable.
|
| 183 |
|
| 184 |
Modes:
|
| 185 |
+
dqn β Pre-trained DQN checkpoint (DEFAULT β fast, local, reliable)
|
| 186 |
+
llm β OpenAI-compatible API
|
| 187 |
+
mock β Deterministic heuristic
|
| 188 |
"""
|
| 189 |
if mode == "dqn":
|
| 190 |
from agent import DQNAgent
|
| 191 |
|
| 192 |
if model_path is None:
|
| 193 |
+
# Try multiple known model paths
|
| 194 |
+
candidates = [
|
| 195 |
+
"models/dqn_bus_v6_best.pt",
|
| 196 |
+
"models/dqn_bus_v6.pt",
|
| 197 |
+
"models/dqn_bus.pt",
|
| 198 |
+
]
|
| 199 |
+
for candidate in candidates:
|
| 200 |
+
if os.path.isfile(candidate):
|
| 201 |
+
model_path = candidate
|
| 202 |
+
break
|
| 203 |
+
|
| 204 |
+
if model_path is None or not os.path.isfile(model_path):
|
| 205 |
+
print(f"[WARN] No DQN model found. Falling back to mock agent.", flush=True)
|
| 206 |
+
return MockLLMAgent()
|
| 207 |
+
|
| 208 |
+
print(f"[INFO] Loading DQN model from '{model_path}'", flush=True)
|
| 209 |
agent = DQNAgent.load(model_path)
|
| 210 |
return lambda obs: agent.act(obs, greedy=True)
|
| 211 |
|
| 212 |
if mode == "llm":
|
| 213 |
if HF_TOKEN or API_BASE_URL != "<your-active-api-url>":
|
| 214 |
+
print("[INFO] Using LLM API agent.", flush=True)
|
| 215 |
return OpenAIAgent()
|
| 216 |
else:
|
| 217 |
+
print("[WARN] No API key set β falling back to mock agent.", flush=True)
|
| 218 |
return MockLLMAgent()
|
| 219 |
|
| 220 |
# Default: mock
|
| 221 |
+
print("[INFO] Using mock (heuristic) agent.", flush=True)
|
| 222 |
return MockLLMAgent()
|
| 223 |
|
| 224 |
|
| 225 |
+
# ---------------------------------------------------------------------------
|
| 226 |
+
# Inference runner
|
| 227 |
+
# ---------------------------------------------------------------------------
|
| 228 |
+
|
| 229 |
def run_inference(mode: str, model_path: Optional[str], episodes: int) -> Dict:
|
| 230 |
"""Run inference across all three tasks and return the grade report."""
|
| 231 |
+
|
| 232 |
+
# Start the watchdog timer
|
| 233 |
+
_start_watchdog(GLOBAL_TIMEOUT)
|
| 234 |
+
|
| 235 |
agent = build_agent(mode, model_path)
|
| 236 |
+
|
| 237 |
+
print(f"\n{'=' * 60}", flush=True)
|
| 238 |
+
print(" OpenEnv Bus Routing - Inference", flush=True)
|
| 239 |
+
print(f"{'=' * 60}", flush=True)
|
| 240 |
+
print(f" Mode : {mode}", flush=True)
|
| 241 |
+
print(f" Episodes : {episodes}", flush=True)
|
| 242 |
+
print(f" Timeout : {GLOBAL_TIMEOUT}s", flush=True)
|
| 243 |
+
print(f"{'=' * 60}\n", flush=True)
|
| 244 |
|
| 245 |
t0 = time.time()
|
| 246 |
|
| 247 |
+
log_start(task=mode, env="rl-bus-optimization", model=MODEL_NAME if mode == "llm" else f"dqn-local")
|
| 248 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
report = grade_all_tasks(agent, episodes=episodes)
|
| 250 |
|
|
|
|
| 251 |
log_step(step=episodes, action="evaluate_all", reward=report["aggregate_score"], done="true", error="null")
|
| 252 |
+
|
| 253 |
log_end(
|
| 254 |
success=bool(report["aggregate_score"] > 0.7),
|
| 255 |
steps=episodes,
|
|
|
|
| 262 |
# Pretty print
|
| 263 |
for task_key in ("task_easy", "task_medium", "task_hard"):
|
| 264 |
tr = report[task_key]
|
| 265 |
+
print(f"{'-' * 55}", flush=True)
|
| 266 |
+
print(f" {tr['task']} ({tr['difficulty']}) -> score: {tr['score']:.4f}", flush=True)
|
| 267 |
+
print(f"{'-' * 55}", flush=True)
|
| 268 |
for section in ("rl_agent", "baseline_greedy"):
|
| 269 |
+
print(f" [{section}]", flush=True)
|
| 270 |
for k, v in tr[section].items():
|
| 271 |
+
print(f" {k}: {v:.4f}", flush=True)
|
| 272 |
+
print(flush=True)
|
| 273 |
|
| 274 |
+
print(f"{'=' * 55}", flush=True)
|
| 275 |
+
print(f" AGGREGATE SCORE : {report['aggregate_score']:.4f}", flush=True)
|
| 276 |
+
print(f" Task weights : {report['weights']}", flush=True)
|
| 277 |
+
print(f" Time elapsed : {elapsed:.2f}s", flush=True)
|
| 278 |
+
print(f"{'=' * 55}", flush=True)
|
| 279 |
|
| 280 |
return report
|
| 281 |
|
|
|
|
| 291 |
p.add_argument(
|
| 292 |
"--mode",
|
| 293 |
choices=["llm", "mock", "dqn"],
|
| 294 |
+
default="dqn", # DEFAULT: DQN β fast, local, no API needed
|
| 295 |
+
help="Agent mode: 'dqn' (pre-trained model, DEFAULT), 'llm' (API), or 'mock' (heuristic).",
|
| 296 |
)
|
| 297 |
p.add_argument(
|
| 298 |
"--model-path",
|
|
|
|
| 303 |
p.add_argument(
|
| 304 |
"--episodes",
|
| 305 |
type=int,
|
| 306 |
+
default=int(os.getenv("MAX_EVAL_EPISODES", 2)),
|
| 307 |
help="Number of evaluation episodes per task.",
|
| 308 |
)
|
| 309 |
args = p.parse_args()
|
requirements.txt
CHANGED
|
@@ -9,3 +9,4 @@ pandas>=2.0
|
|
| 9 |
uvicorn>=0.20.0
|
| 10 |
openenv-core>=0.2.0
|
| 11 |
huggingface-hub>=0.20.0
|
|
|
|
|
|
| 9 |
uvicorn>=0.20.0
|
| 10 |
openenv-core>=0.2.0
|
| 11 |
huggingface-hub>=0.20.0
|
| 12 |
+
python-dotenv
|