Spaces:
Running
Running
Commit ·
30bf3bb
1
Parent(s): 61203a1
fix: move info prints to stderr and use comma-separated rewards in [END] tag for validator compliance
Browse files- inference.py +19 -13
inference.py
CHANGED
|
@@ -79,8 +79,9 @@ def log_end(**kwargs):
|
|
| 79 |
"""Emit [END] log with key-value pairs."""
|
| 80 |
payload = []
|
| 81 |
for k, v in kwargs.items():
|
| 82 |
-
if isinstance(v, (list, np.ndarray)):
|
| 83 |
-
|
|
|
|
| 84 |
else:
|
| 85 |
v_str = str(v)
|
| 86 |
payload.append(f"{k}={v_str}")
|
|
@@ -97,7 +98,7 @@ def _start_watchdog(timeout_seconds: int) -> None:
|
|
| 97 |
def _watchdog():
|
| 98 |
time.sleep(timeout_seconds)
|
| 99 |
print(f"\n[TIMEOUT] Global timeout of {timeout_seconds}s reached. Exiting.", flush=True)
|
| 100 |
-
log_end(success=
|
| 101 |
os._exit(1)
|
| 102 |
|
| 103 |
t = threading.Thread(target=_watchdog, daemon=True)
|
|
@@ -243,20 +244,26 @@ def run_inference(mode: str, model_path: Optional[str], episodes: int) -> Dict:
|
|
| 243 |
|
| 244 |
agent = build_agent(mode, model_path)
|
| 245 |
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
|
| 254 |
t0 = time.time()
|
| 255 |
|
| 256 |
all_rewards = []
|
| 257 |
total_steps = 0
|
| 258 |
results = {}
|
| 259 |
-
task_keys = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
# Use try...finally to guarantee [END] log
|
| 262 |
try:
|
|
@@ -318,8 +325,7 @@ def run_inference(mode: str, model_path: Optional[str], episodes: int) -> Dict:
|
|
| 318 |
log_end(
|
| 319 |
success="true" if success else "false",
|
| 320 |
steps=total_steps,
|
| 321 |
-
|
| 322 |
-
rewards=f"{sum(all_rewards):.2f}"
|
| 323 |
)
|
| 324 |
|
| 325 |
elapsed = time.time() - t0
|
|
|
|
| 79 |
"""Emit [END] log with key-value pairs."""
|
| 80 |
payload = []
|
| 81 |
for k, v in kwargs.items():
|
| 82 |
+
if isinstance(v, (list, np.ndarray, tuple)):
|
| 83 |
+
# Format as comma-separated list WITHOUT brackets/quotes for the validator
|
| 84 |
+
v_str = ",".join(f"{x:.2f}" if isinstance(x, (float, np.float32)) else str(x) for x in v)
|
| 85 |
else:
|
| 86 |
v_str = str(v)
|
| 87 |
payload.append(f"{k}={v_str}")
|
|
|
|
| 98 |
def _watchdog():
|
| 99 |
time.sleep(timeout_seconds)
|
| 100 |
print(f"\n[TIMEOUT] Global timeout of {timeout_seconds}s reached. Exiting.", flush=True)
|
| 101 |
+
log_end(success="false", steps=0, rewards=[0.0], reason="global_timeout")
|
| 102 |
os._exit(1)
|
| 103 |
|
| 104 |
t = threading.Thread(target=_watchdog, daemon=True)
|
|
|
|
| 244 |
|
| 245 |
agent = build_agent(mode, model_path)
|
| 246 |
|
| 247 |
+
dprint(f"\n{'=' * 60}")
|
| 248 |
+
dprint(" OpenEnv Bus Routing - Inference")
|
| 249 |
+
dprint(f"{'=' * 60}")
|
| 250 |
+
dprint(f" Mode : {mode}")
|
| 251 |
+
dprint(f" Episodes : {episodes}")
|
| 252 |
+
dprint(f" Timeout : {GLOBAL_TIMEOUT}s")
|
| 253 |
+
dprint(f"{'=' * 60}\n")
|
| 254 |
|
| 255 |
t0 = time.time()
|
| 256 |
|
| 257 |
all_rewards = []
|
| 258 |
total_steps = 0
|
| 259 |
results = {}
|
| 260 |
+
task_keys = [
|
| 261 |
+
("task_1", "easy"),
|
| 262 |
+
("task_2", "medium"),
|
| 263 |
+
("task_3", "hard"),
|
| 264 |
+
("task_4", "medium"),
|
| 265 |
+
("task_5", "hard")
|
| 266 |
+
]
|
| 267 |
|
| 268 |
# Use try...finally to guarantee [END] log
|
| 269 |
try:
|
|
|
|
| 325 |
log_end(
|
| 326 |
success="true" if success else "false",
|
| 327 |
steps=total_steps,
|
| 328 |
+
rewards=all_rewards
|
|
|
|
| 329 |
)
|
| 330 |
|
| 331 |
elapsed = time.time() - t0
|