voldemort6996 commited on
Commit
30bf3bb
·
1 Parent(s): 61203a1

fix: move info prints to stderr and use comma-separated rewards in [END] tag for validator compliance

Browse files
Files changed (1) hide show
  1. inference.py +19 -13
inference.py CHANGED
@@ -79,8 +79,9 @@ def log_end(**kwargs):
79
  """Emit [END] log with key-value pairs."""
80
  payload = []
81
  for k, v in kwargs.items():
82
- if isinstance(v, (list, np.ndarray)):
83
- v_str = json.dumps(list(v))
 
84
  else:
85
  v_str = str(v)
86
  payload.append(f"{k}={v_str}")
@@ -97,7 +98,7 @@ def _start_watchdog(timeout_seconds: int) -> None:
97
  def _watchdog():
98
  time.sleep(timeout_seconds)
99
  print(f"\n[TIMEOUT] Global timeout of {timeout_seconds}s reached. Exiting.", flush=True)
100
- log_end(success=False, steps=0, score=0.0, rewards=[0, 0, 0], reason="global_timeout")
101
  os._exit(1)
102
 
103
  t = threading.Thread(target=_watchdog, daemon=True)
@@ -243,20 +244,26 @@ def run_inference(mode: str, model_path: Optional[str], episodes: int) -> Dict:
243
 
244
  agent = build_agent(mode, model_path)
245
 
246
- print(f"\n{'=' * 60}", flush=True)
247
- print(" OpenEnv Bus Routing - Inference", flush=True)
248
- print(f"{'=' * 60}", flush=True)
249
- print(f" Mode : {mode}", flush=True)
250
- print(f" Episodes : {episodes}", flush=True)
251
- print(f" Timeout : {GLOBAL_TIMEOUT}s", flush=True)
252
- print(f"{'=' * 60}\n", flush=True)
253
 
254
  t0 = time.time()
255
 
256
  all_rewards = []
257
  total_steps = 0
258
  results = {}
259
- task_keys = [("task_1", "easy"), ("task_2", "medium"), ("task_3", "hard"), ("task_4", "medium"), ("task_5", "hard")]
 
 
 
 
 
 
260
 
261
  # Use try...finally to guarantee [END] log
262
  try:
@@ -318,8 +325,7 @@ def run_inference(mode: str, model_path: Optional[str], episodes: int) -> Dict:
318
  log_end(
319
  success="true" if success else "false",
320
  steps=total_steps,
321
- score=f"{final_score:.4f}",
322
- rewards=f"{sum(all_rewards):.2f}"
323
  )
324
 
325
  elapsed = time.time() - t0
 
79
  """Emit [END] log with key-value pairs."""
80
  payload = []
81
  for k, v in kwargs.items():
82
+ if isinstance(v, (list, np.ndarray, tuple)):
83
+ # Format as comma-separated list WITHOUT brackets/quotes for the validator
84
+ v_str = ",".join(f"{x:.2f}" if isinstance(x, (float, np.float32)) else str(x) for x in v)
85
  else:
86
  v_str = str(v)
87
  payload.append(f"{k}={v_str}")
 
98
  def _watchdog():
99
  time.sleep(timeout_seconds)
100
  print(f"\n[TIMEOUT] Global timeout of {timeout_seconds}s reached. Exiting.", flush=True)
101
+ log_end(success="false", steps=0, rewards=[0.0], reason="global_timeout")
102
  os._exit(1)
103
 
104
  t = threading.Thread(target=_watchdog, daemon=True)
 
244
 
245
  agent = build_agent(mode, model_path)
246
 
247
+ dprint(f"\n{'=' * 60}")
248
+ dprint(" OpenEnv Bus Routing - Inference")
249
+ dprint(f"{'=' * 60}")
250
+ dprint(f" Mode : {mode}")
251
+ dprint(f" Episodes : {episodes}")
252
+ dprint(f" Timeout : {GLOBAL_TIMEOUT}s")
253
+ dprint(f"{'=' * 60}\n")
254
 
255
  t0 = time.time()
256
 
257
  all_rewards = []
258
  total_steps = 0
259
  results = {}
260
+ task_keys = [
261
+ ("task_1", "easy"),
262
+ ("task_2", "medium"),
263
+ ("task_3", "hard"),
264
+ ("task_4", "medium"),
265
+ ("task_5", "hard")
266
+ ]
267
 
268
  # Use try...finally to guarantee [END] log
269
  try:
 
325
  log_end(
326
  success="true" if success else "false",
327
  steps=total_steps,
328
+ rewards=all_rewards
 
329
  )
330
 
331
  elapsed = time.time() - t0