nihalaninihal Claude Opus 4.6 commited on
Commit
fcf34b9
·
1 Parent(s): af292c9

Add run_demo_episode wrapper to demo.py for dict-based episode results

Browse files

The verification test suite expects run_demo_episode(seed, trained) to
return a dict with 'scores' and 'trajectory' keys. The existing
run_episode() returned a plain tuple. Added run_demo_episode() as a thin
wrapper that calls run_episode() and repacks the result into the expected
dict format, enabling tests to access r['scores'] and r['trajectory']
without changing the internals of run_episode or run_comparison.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. sentinelops_arena/demo.py +25 -0
sentinelops_arena/demo.py CHANGED
@@ -431,6 +431,31 @@ def run_episode(
431
  return replay_log, final_scores
432
 
433
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  def run_comparison(seed: int = 42, attacker_type: str = "randomized") -> Dict:
435
  """Run untrained vs trained worker comparison.
436
 
 
431
  return replay_log, final_scores
432
 
433
 
434
+ def run_demo_episode(
435
+ trained: bool = False,
436
+ seed: int = 42,
437
+ attacker_type: str = "randomized",
438
+ ) -> Dict:
439
+ """Run a single demo episode and return a dict with ``scores`` and ``trajectory``.
440
+
441
+ This is a convenience wrapper around :func:`run_episode` that returns a
442
+ dictionary instead of a tuple so callers can use ``r["scores"]`` and
443
+ ``r["trajectory"]`` directly.
444
+
445
+ Args:
446
+ trained: Whether the worker agent uses trained (resilient) behaviour.
447
+ seed: Random seed for the environment and the randomised attacker.
448
+ attacker_type: ``"randomized"`` (default) or ``"scripted"`` (legacy).
449
+
450
+ Returns:
451
+ dict with keys:
452
+ - ``"scores"`` – final per-agent score dict
453
+ - ``"trajectory"`` – list of step dicts (the replay log)
454
+ """
455
+ trajectory, scores = run_episode(trained=trained, seed=seed, attacker_type=attacker_type)
456
+ return {"scores": scores, "trajectory": trajectory}
457
+
458
+
459
  def run_comparison(seed: int = 42, attacker_type: str = "randomized") -> Dict:
460
  """Run untrained vs trained worker comparison.
461