Spaces:
Sleeping
Sleeping
feat(phase10): ABScriptEnv, ContrastiveReward, A/B rollout fn, 25 tests PASS, gate PASS
Browse files- demo/run_demo.py +115 -0
- docs/progress.md +11 -0
- session/context.md +11 -10
- session/phase-log.md +1 -0
- viral_script_engine/environment/ab_env.py +256 -0
- viral_script_engine/environment/trajectory.py +146 -0
- viral_script_engine/rewards/contrastive_reward.py +67 -0
- viral_script_engine/scripts/run_ab_episode.py +201 -0
- viral_script_engine/tests/test_phase10.py +418 -0
- viral_script_engine/training/rollout_function.py +91 -0
demo/run_demo.py
CHANGED
|
@@ -622,11 +622,122 @@ def run_interactive():
|
|
| 622 |
# Entry point
|
| 623 |
# ---------------------------------------------------------------------------
|
| 624 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
def main():
|
| 626 |
parser = argparse.ArgumentParser(description="Viral Script Debugging Engine — 5-Act Demo")
|
| 627 |
parser.add_argument("--script", default="S03", help="Script ID to demo (default: S03)")
|
| 628 |
parser.add_argument("--compare", action="store_true", help="Show untrained vs trained arbitrator side-by-side")
|
| 629 |
parser.add_argument("--interactive", action="store_true", help="Human acts as Arbitrator")
|
|
|
|
| 630 |
args = parser.parse_args()
|
| 631 |
|
| 632 |
console.print(Panel(
|
|
@@ -638,6 +749,10 @@ def main():
|
|
| 638 |
|
| 639 |
if args.interactive:
|
| 640 |
run_interactive()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 641 |
else:
|
| 642 |
run_compare(args.script)
|
| 643 |
|
|
|
|
| 622 |
# Entry point
|
| 623 |
# ---------------------------------------------------------------------------
|
| 624 |
|
| 625 |
+
def run_ab_mode(script_id: str):
|
| 626 |
+
"""
|
| 627 |
+
Act 4 — 'Two Paths': run both A/B trajectories in parallel and show
|
| 628 |
+
the contrastive reward at the end. Phase 10 addition.
|
| 629 |
+
"""
|
| 630 |
+
from viral_script_engine.environment.ab_env import ABScriptEnv
|
| 631 |
+
from viral_script_engine.rewards.contrastive_reward import ContrastiveReward
|
| 632 |
+
|
| 633 |
+
console.print(Rule(
|
| 634 |
+
"[bold yellow]ACT 4 — TWO PATHS (A/B Mode)[/bold yellow]", style="yellow"
|
| 635 |
+
))
|
| 636 |
+
console.print(
|
| 637 |
+
"[dim]Running two parallel trajectories from the same script...[/dim]\n"
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
difficulty_map = {
|
| 641 |
+
"S01": "easy", "S02": "easy", "S03": "easy", "S04": "easy",
|
| 642 |
+
"S05": "medium", "S06": "medium", "S07": "medium",
|
| 643 |
+
"S08": "hard", "S09": "hard", "S10": "hard",
|
| 644 |
+
}
|
| 645 |
+
difficulty = difficulty_map.get(script_id, "hard")
|
| 646 |
+
|
| 647 |
+
ab_env = ABScriptEnv(
|
| 648 |
+
scripts_path=_SCRIPTS_PATH,
|
| 649 |
+
cultural_kb_path=_CULTURAL_KB_PATH,
|
| 650 |
+
max_steps=4,
|
| 651 |
+
difficulty=difficulty,
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
try:
|
| 655 |
+
state = ab_env.reset_from_script_id(script_id, _SCRIPTS_PATH)
|
| 656 |
+
except Exception as exc:
|
| 657 |
+
console.print(f"[red]A/B reset failed: {exc}[/red]")
|
| 658 |
+
return
|
| 659 |
+
|
| 660 |
+
# Show step 1 forced actions
|
| 661 |
+
forced_a = ab_env._forced_action_a or {}
|
| 662 |
+
forced_b = ab_env._forced_action_b or {}
|
| 663 |
+
traj_a = state["trajectory_a"]
|
| 664 |
+
traj_b = state["trajectory_b"]
|
| 665 |
+
|
| 666 |
+
table = Table(box=box.SIMPLE_HEAD, show_header=True, padding=(0, 1))
|
| 667 |
+
table.add_column("", style="yellow", min_width=22)
|
| 668 |
+
table.add_column("Trajectory A (Critic-first)", style="cyan", min_width=30)
|
| 669 |
+
table.add_column("Trajectory B (Defender-first)", style="green", min_width=30)
|
| 670 |
+
|
| 671 |
+
table.add_row(
|
| 672 |
+
"Step 1 action",
|
| 673 |
+
forced_a.get("action_type", "?"),
|
| 674 |
+
forced_b.get("action_type", "?"),
|
| 675 |
+
)
|
| 676 |
+
table.add_row(
|
| 677 |
+
"Cumulative reward",
|
| 678 |
+
f"{traj_a.get('cumulative_reward', 0.0):.3f}",
|
| 679 |
+
f"{traj_b.get('cumulative_reward', 0.0):.3f}",
|
| 680 |
+
)
|
| 681 |
+
console.print(Panel(table, title="[yellow]STEP 1 — FORCED[/yellow]", border_style="yellow"))
|
| 682 |
+
|
| 683 |
+
# Run one free step with a simple baseline action
|
| 684 |
+
baseline = BaselineArbitratorAgent()
|
| 685 |
+
obs_for_arb = {
|
| 686 |
+
"current_script": traj_a.get("current_script", ""),
|
| 687 |
+
"debate_history": traj_a.get("debate_history", []),
|
| 688 |
+
"reward_components": traj_a.get("reward_components", {}),
|
| 689 |
+
}
|
| 690 |
+
free_action = baseline.act(obs_for_arb)
|
| 691 |
+
|
| 692 |
+
try:
|
| 693 |
+
state, _, terminated, _, _ = ab_env.step(free_action)
|
| 694 |
+
except Exception as exc:
|
| 695 |
+
console.print(f"[dim]Free step failed: {exc}[/dim]")
|
| 696 |
+
|
| 697 |
+
# Episode end — contrastive reward
|
| 698 |
+
contrastive_result = ab_env.contrastive_reward_calc.compute(
|
| 699 |
+
ab_env._traj_a, ab_env._traj_b
|
| 700 |
+
)
|
| 701 |
+
traj_a_f = state["trajectory_a"]
|
| 702 |
+
traj_b_f = state["trajectory_b"]
|
| 703 |
+
|
| 704 |
+
winner_label = {
|
| 705 |
+
"A": "[cyan]A (critic-first)[/cyan]",
|
| 706 |
+
"B": "[green]B (defender-first)[/green]",
|
| 707 |
+
"tie": "[dim]tie[/dim]",
|
| 708 |
+
}.get(contrastive_result.winning_trajectory, contrastive_result.winning_trajectory)
|
| 709 |
+
|
| 710 |
+
lesson_map = {
|
| 711 |
+
"critic_first": "Act on the Critic's highest-severity claim first for maximum early gains.",
|
| 712 |
+
"defender_first": "Preserve the Defender's core voice first on culturally-rich scripts.",
|
| 713 |
+
"tie": "Both orderings performed similarly — action type matters more than sequence here.",
|
| 714 |
+
}
|
| 715 |
+
|
| 716 |
+
summary_body = (
|
| 717 |
+
f"Trajectory A final cumulative: {traj_a_f.get('cumulative_reward', 0.0):.3f}\n"
|
| 718 |
+
f"Trajectory B final cumulative: {traj_b_f.get('cumulative_reward', 0.0):.3f}\n\n"
|
| 719 |
+
f"Winner: {winner_label}\n"
|
| 720 |
+
f"Delta (A−B): {contrastive_result.delta:+.3f}\n"
|
| 721 |
+
f"Base reward: {contrastive_result.base_reward:.4f}\n"
|
| 722 |
+
f"Contrast bonus: {contrastive_result.contrast_bonus:+.4f}\n"
|
| 723 |
+
f"[bold]Contrastive reward: {contrastive_result.final_reward:.4f}[/bold]\n\n"
|
| 724 |
+
f"[italic dim]Lesson: {lesson_map.get(contrastive_result.winning_trajectory_type, '')}[/italic dim]"
|
| 725 |
+
)
|
| 726 |
+
console.print(Panel(
|
| 727 |
+
summary_body,
|
| 728 |
+
title="[yellow]EPISODE END — CONTRASTIVE REWARD[/yellow]",
|
| 729 |
+
border_style="yellow",
|
| 730 |
+
padding=(1, 2),
|
| 731 |
+
))
|
| 732 |
+
console.print()
|
| 733 |
+
|
| 734 |
+
|
| 735 |
def main():
|
| 736 |
parser = argparse.ArgumentParser(description="Viral Script Debugging Engine — 5-Act Demo")
|
| 737 |
parser.add_argument("--script", default="S03", help="Script ID to demo (default: S03)")
|
| 738 |
parser.add_argument("--compare", action="store_true", help="Show untrained vs trained arbitrator side-by-side")
|
| 739 |
parser.add_argument("--interactive", action="store_true", help="Human acts as Arbitrator")
|
| 740 |
+
parser.add_argument("--ab-mode", action="store_true", help="Phase 10: run A/B two-path demo")
|
| 741 |
args = parser.parse_args()
|
| 742 |
|
| 743 |
console.print(Panel(
|
|
|
|
| 749 |
|
| 750 |
if args.interactive:
|
| 751 |
run_interactive()
|
| 752 |
+
elif args.ab_mode:
|
| 753 |
+
script = _load_script(args.script)
|
| 754 |
+
act1_raw_script(script)
|
| 755 |
+
run_ab_mode(args.script)
|
| 756 |
else:
|
| 757 |
run_compare(args.script)
|
| 758 |
|
docs/progress.md
CHANGED
|
@@ -128,6 +128,17 @@ Do not read entire codebase to understand progress — read this file.
|
|
| 128 |
✅ scripts/run_dummy_episode.py — LLM-stubbed gate check, Phase 9 GATE: PASS
|
| 129 |
✅ scripts/run_platform_comparison.py — cross-platform comparison, R1/R2/R9 diverge on S03, GATE: PASS
|
| 130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
## Blocked Items
|
| 132 |
❌ GRPOConfig test — blocked by: pyarrow DLL blocked by Windows App Control (works on Linux/Colab)
|
| 133 |
❌ Full GRPO training — blocked by: no local GPU (requires Colab or cloud compute)
|
|
|
|
| 128 |
✅ scripts/run_dummy_episode.py — LLM-stubbed gate check, Phase 9 GATE: PASS
|
| 129 |
✅ scripts/run_platform_comparison.py — cross-platform comparison, R1/R2/R9 diverge on S03, GATE: PASS
|
| 130 |
|
| 131 |
+
## Phase 10 — A/B Testing Environment Layer
|
| 132 |
+
✅ Trajectory + TrajectoryType — pydantic model; forced first-action logic (critic_first / defender_first)
|
| 133 |
+
✅ ABScriptEnv — two parallel ViralScriptEnvs; forced step 1; free steps 2+; state() with delta
|
| 134 |
+
✅ ContrastiveReward — delta-based reward: base_reward + tanh(delta*3)*0.2, clipped to [0,1]
|
| 135 |
+
✅ ContrastiveRewardResult — pydantic result with final_reward, contrast_bonus, winning_trajectory
|
| 136 |
+
✅ training/rollout_function.py — build_ab_rollout_fn() with dual-trajectory prompt format added
|
| 137 |
+
✅ scripts/run_ab_episode.py — gate check script; side-by-side step output; lesson printed at end
|
| 138 |
+
✅ demo/run_demo.py — --ab-mode flag; Act 4 "Two Paths" shows both trajectories + contrastive reward
|
| 139 |
+
✅ test_phase10.py — 25 tests, all passing
|
| 140 |
+
✅ Phase 10 gate — PHASE 10 GATE: PASS, delta=-0.078, contrastive reward active
|
| 141 |
+
|
| 142 |
## Blocked Items
|
| 143 |
❌ GRPOConfig test — blocked by: pyarrow DLL blocked by Windows App Control (works on Linux/Colab)
|
| 144 |
❌ Full GRPO training — blocked by: no local GPU (requires Colab or cloud compute)
|
session/context.md
CHANGED
|
@@ -1,40 +1,41 @@
|
|
| 1 |
# Context — Carry Over for Next Session
|
| 2 |
|
| 3 |
## Current Phase
|
| 4 |
-
Phase:
|
| 5 |
-
Prompt file: prompts/phase-
|
| 6 |
Status: complete
|
| 7 |
|
| 8 |
---
|
| 9 |
|
| 10 |
## Currently Working On
|
| 11 |
-
Feature: Phase
|
| 12 |
File(s): N/A
|
| 13 |
-
Status:
|
| 14 |
|
| 15 |
---
|
| 16 |
|
| 17 |
## Open Questions
|
| 18 |
-
|
| 19 |
-
Should full GRPO training run before Phase 5?
|
| 20 |
|
| 21 |
---
|
| 22 |
|
| 23 |
## Known Blockers
|
| 24 |
pyarrow DLL blocked on Windows — all training must run on Linux/Colab
|
| 25 |
Escalation mastery requires trained model (r4 >= 0.8 x3 consecutive) — untrained baseline won't trigger
|
|
|
|
| 26 |
|
| 27 |
---
|
| 28 |
|
| 29 |
## Last Commit Message
|
| 30 |
-
feat(
|
| 31 |
|
| 32 |
---
|
| 33 |
|
| 34 |
## Do Not Forget
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
| 38 |
|
| 39 |
---
|
| 40 |
|
|
|
|
| 1 |
# Context — Carry Over for Next Session
|
| 2 |
|
| 3 |
## Current Phase
|
| 4 |
+
Phase: 10
|
| 5 |
+
Prompt file: prompts/phase-10.md
|
| 6 |
Status: complete
|
| 7 |
|
| 8 |
---
|
| 9 |
|
| 10 |
## Currently Working On
|
| 11 |
+
Feature: Phase 10 complete. Awaiting user confirmation to proceed to next phase (if any).
|
| 12 |
File(s): N/A
|
| 13 |
+
Status: All 25 tests pass. Gate script prints PHASE 10 GATE: PASS.
|
| 14 |
|
| 15 |
---
|
| 16 |
|
| 17 |
## Open Questions
|
| 18 |
+
Is there a Phase 11? Check if prompts/phase-11.md exists.
|
|
|
|
| 19 |
|
| 20 |
---
|
| 21 |
|
| 22 |
## Known Blockers
|
| 23 |
pyarrow DLL blocked on Windows — all training must run on Linux/Colab
|
| 24 |
Escalation mastery requires trained model (r4 >= 0.8 x3 consecutive) — untrained baseline won't trigger
|
| 25 |
+
Full GRPO training requires Colab or cloud GPU
|
| 26 |
|
| 27 |
---
|
| 28 |
|
| 29 |
## Last Commit Message
|
| 30 |
+
feat(phase10): ABScriptEnv, ContrastiveReward, A/B rollout, 25 tests PASS, gate PASS
|
| 31 |
|
| 32 |
---
|
| 33 |
|
| 34 |
## Do Not Forget
|
| 35 |
+
ABScriptEnv.reset() runs forced step 1 automatically — step 2+ are free choice
|
| 36 |
+
Contrastive reward formula: base_reward + tanh(delta*3)*0.2, clipped [0,1]
|
| 37 |
+
Cumulative reward is sum of per-step totals — clips to 1.0 with 4+ steps at high score
|
| 38 |
+
Gate check: python scripts/run_ab_episode.py --script S08 --steps 4 --verbose
|
| 39 |
|
| 40 |
---
|
| 41 |
|
session/phase-log.md
CHANGED
|
@@ -28,6 +28,7 @@ ROLLED BACK — changes reverted, reason in line
|
|
| 28 |
[2026-04-26] [Phase 7] COMPLETE — ReasoningParser, ProcessVerifier, ProcessReward, 21 tests PASS, gate PASS
|
| 29 |
[2026-04-26] [Phase 8] COMPLETE — CreatorProfile, ProfileGenerator, R8 PersonaFit, 25 tests PASS, gate PASS
|
| 30 |
[2026-04-26] [Phase 9] COMPLETE — PlatformRegistry, R9 PlatformPacing, R1/R2 platform-aware, 20 tests PASS, gate PASS
|
|
|
|
| 31 |
|
| 32 |
---
|
| 33 |
|
|
|
|
| 28 |
[2026-04-26] [Phase 7] COMPLETE — ReasoningParser, ProcessVerifier, ProcessReward, 21 tests PASS, gate PASS
|
| 29 |
[2026-04-26] [Phase 8] COMPLETE — CreatorProfile, ProfileGenerator, R8 PersonaFit, 25 tests PASS, gate PASS
|
| 30 |
[2026-04-26] [Phase 9] COMPLETE — PlatformRegistry, R9 PlatformPacing, R1/R2 platform-aware, 20 tests PASS, gate PASS
|
| 31 |
+
[2026-04-26] [Phase 10] COMPLETE — ABScriptEnv, ContrastiveReward, A/B rollout fn, 25 tests PASS, gate PASS
|
| 32 |
|
| 33 |
---
|
| 34 |
|
viral_script_engine/environment/ab_env.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import json
|
| 3 |
+
import random
|
| 4 |
+
import uuid
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from viral_script_engine.environment.env import ViralScriptEnv
|
| 8 |
+
from viral_script_engine.environment.trajectory import Trajectory, TrajectoryType
|
| 9 |
+
from viral_script_engine.rewards.contrastive_reward import ContrastiveReward
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ABScriptEnv:
|
| 13 |
+
"""
|
| 14 |
+
A/B Testing wrapper around ViralScriptEnv.
|
| 15 |
+
|
| 16 |
+
Each episode runs TWO parallel trajectories from the same starting script:
|
| 17 |
+
- Trajectory A (critic_first): forced to act on Critic's top claim in step 1
|
| 18 |
+
- Trajectory B (defender_first): forced to act on Defender's concern in step 1
|
| 19 |
+
- Steps 2+ are free — the Arbitrator makes its own decisions in both
|
| 20 |
+
|
| 21 |
+
The Arbitrator observes BOTH trajectories in the state() output.
|
| 22 |
+
The contrastive reward fires at episode end based on the delta.
|
| 23 |
+
|
| 24 |
+
This teaches the Arbitrator: "I could have done X first or Y first.
|
| 25 |
+
One led to a better outcome. Learn which one."
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
scripts_path: str = "data/test_scripts/scripts.json",
|
| 31 |
+
cultural_kb_path: str = "data/cultural_kb.json",
|
| 32 |
+
max_steps: int = 5,
|
| 33 |
+
difficulty: str = "easy",
|
| 34 |
+
):
|
| 35 |
+
self.env_a = ViralScriptEnv(
|
| 36 |
+
scripts_path=scripts_path,
|
| 37 |
+
cultural_kb_path=cultural_kb_path,
|
| 38 |
+
max_steps=max_steps,
|
| 39 |
+
difficulty=difficulty,
|
| 40 |
+
use_escalation=False,
|
| 41 |
+
use_anti_gaming=False,
|
| 42 |
+
)
|
| 43 |
+
self.env_b = ViralScriptEnv(
|
| 44 |
+
scripts_path=scripts_path,
|
| 45 |
+
cultural_kb_path=cultural_kb_path,
|
| 46 |
+
max_steps=max_steps,
|
| 47 |
+
difficulty=difficulty,
|
| 48 |
+
use_escalation=False,
|
| 49 |
+
use_anti_gaming=False,
|
| 50 |
+
)
|
| 51 |
+
self.contrastive_reward_calc = ContrastiveReward()
|
| 52 |
+
self._traj_a: Optional[Trajectory] = None
|
| 53 |
+
self._traj_b: Optional[Trajectory] = None
|
| 54 |
+
self._episode_id: Optional[str] = None
|
| 55 |
+
self._step_num: int = 0
|
| 56 |
+
self._forced_action_a: Optional[dict] = None
|
| 57 |
+
self._forced_action_b: Optional[dict] = None
|
| 58 |
+
|
| 59 |
+
# ------------------------------------------------------------------
|
| 60 |
+
# Public API
|
| 61 |
+
# ------------------------------------------------------------------
|
| 62 |
+
|
| 63 |
+
def reset(self, seed=None, options=None) -> dict:
|
| 64 |
+
"""
|
| 65 |
+
Reset BOTH environments with the SAME script and seed.
|
| 66 |
+
Run step 1 automatically with the forced actions.
|
| 67 |
+
Return the state after forced step 1.
|
| 68 |
+
"""
|
| 69 |
+
if seed is None:
|
| 70 |
+
seed = random.randint(0, 2 ** 31)
|
| 71 |
+
|
| 72 |
+
self._episode_id = str(uuid.uuid4())
|
| 73 |
+
self._step_num = 0
|
| 74 |
+
|
| 75 |
+
obs_a, _ = self.env_a.reset(seed=seed)
|
| 76 |
+
obs_b, _ = self.env_b.reset(seed=seed)
|
| 77 |
+
|
| 78 |
+
return self._run_forced_step_1(obs_a, obs_b)
|
| 79 |
+
|
| 80 |
+
def reset_from_script_id(self, script_id: str, scripts_path: str) -> dict:
|
| 81 |
+
"""Reset both environments to a specific script by ID."""
|
| 82 |
+
with open(scripts_path) as f:
|
| 83 |
+
all_scripts = json.load(f)
|
| 84 |
+
script = next((s for s in all_scripts if s["script_id"] == script_id), None)
|
| 85 |
+
if script is None:
|
| 86 |
+
raise ValueError(f"Script {script_id!r} not found in {scripts_path}")
|
| 87 |
+
|
| 88 |
+
self._episode_id = str(uuid.uuid4())
|
| 89 |
+
self._step_num = 0
|
| 90 |
+
|
| 91 |
+
episode_config = {
|
| 92 |
+
"script_id": script["script_id"],
|
| 93 |
+
"script_text": script["script_text"],
|
| 94 |
+
"region": script["region"],
|
| 95 |
+
"platform": script["platform"],
|
| 96 |
+
"niche": script["niche"],
|
| 97 |
+
"difficulty": script.get("difficulty", "hard"),
|
| 98 |
+
}
|
| 99 |
+
obs_a, _ = self.env_a.reset_from_config(episode_config)
|
| 100 |
+
obs_b, _ = self.env_b.reset_from_config(episode_config)
|
| 101 |
+
|
| 102 |
+
return self._run_forced_step_1(obs_a, obs_b)
|
| 103 |
+
|
| 104 |
+
def step(self, action: dict) -> Tuple[dict, float, bool, bool, dict]:
|
| 105 |
+
"""
|
| 106 |
+
Execute the action in BOTH environments simultaneously (step 2+).
|
| 107 |
+
Same action applied to both trajectories.
|
| 108 |
+
Returns combined observation with both trajectory states.
|
| 109 |
+
Terminated when BOTH trajectories have reached max_steps.
|
| 110 |
+
"""
|
| 111 |
+
if self._traj_a is None or self._traj_b is None:
|
| 112 |
+
raise RuntimeError("Call reset() before step()")
|
| 113 |
+
|
| 114 |
+
if not self._traj_a.terminated:
|
| 115 |
+
obs_a, r_a, done_a, _, info_a = self.env_a.step(action)
|
| 116 |
+
self._traj_a.current_script = obs_a.get(
|
| 117 |
+
"current_script", self._traj_a.current_script
|
| 118 |
+
)
|
| 119 |
+
self._traj_a.cumulative_reward += r_a
|
| 120 |
+
self._traj_a.step_count += 1
|
| 121 |
+
self._traj_a.terminated = done_a
|
| 122 |
+
self._traj_a.final_reward_components = info_a.get("reward_components")
|
| 123 |
+
|
| 124 |
+
if not self._traj_b.terminated:
|
| 125 |
+
obs_b, r_b, done_b, _, info_b = self.env_b.step(action)
|
| 126 |
+
self._traj_b.current_script = obs_b.get(
|
| 127 |
+
"current_script", self._traj_b.current_script
|
| 128 |
+
)
|
| 129 |
+
self._traj_b.cumulative_reward += r_b
|
| 130 |
+
self._traj_b.step_count += 1
|
| 131 |
+
self._traj_b.terminated = done_b
|
| 132 |
+
self._traj_b.final_reward_components = info_b.get("reward_components")
|
| 133 |
+
|
| 134 |
+
self._step_num += 1
|
| 135 |
+
terminated = self._traj_a.terminated and self._traj_b.terminated
|
| 136 |
+
|
| 137 |
+
episode_reward = 0.0
|
| 138 |
+
if terminated:
|
| 139 |
+
result = self.contrastive_reward_calc.compute(self._traj_a, self._traj_b)
|
| 140 |
+
episode_reward = result.final_reward
|
| 141 |
+
|
| 142 |
+
return self.state(), episode_reward, terminated, False, {}
|
| 143 |
+
|
| 144 |
+
def state(self) -> dict:
|
| 145 |
+
"""
|
| 146 |
+
Returns state showing both trajectories:
|
| 147 |
+
{
|
| 148 |
+
"trajectory_a": { current_script, reward_components, debate_history,
|
| 149 |
+
cumulative_reward, step_count, terminated, trajectory_type },
|
| 150 |
+
"trajectory_b": { ... },
|
| 151 |
+
"delta": traj_a.cumulative_reward - traj_b.cumulative_reward,
|
| 152 |
+
"leading_trajectory": "A" or "B",
|
| 153 |
+
"step_num": current step,
|
| 154 |
+
"episode_id": ...
|
| 155 |
+
}
|
| 156 |
+
"""
|
| 157 |
+
if self._traj_a is None or self._traj_b is None:
|
| 158 |
+
return {}
|
| 159 |
+
|
| 160 |
+
delta = self._traj_a.cumulative_reward - self._traj_b.cumulative_reward
|
| 161 |
+
leading = "A" if delta >= 0 else "B"
|
| 162 |
+
|
| 163 |
+
return {
|
| 164 |
+
"trajectory_a": self._traj_state(self.env_a, self._traj_a),
|
| 165 |
+
"trajectory_b": self._traj_state(self.env_b, self._traj_b),
|
| 166 |
+
"delta": delta,
|
| 167 |
+
"leading_trajectory": leading,
|
| 168 |
+
"step_num": self._step_num,
|
| 169 |
+
"episode_id": self._episode_id,
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
def reward(self) -> float:
|
| 173 |
+
"""Called at episode end — returns the contrastive reward."""
|
| 174 |
+
if self._traj_a is None or self._traj_b is None:
|
| 175 |
+
return 0.0
|
| 176 |
+
result = self.contrastive_reward_calc.compute(self._traj_a, self._traj_b)
|
| 177 |
+
return result.final_reward
|
| 178 |
+
|
| 179 |
+
# ------------------------------------------------------------------
|
| 180 |
+
# Internal helpers
|
| 181 |
+
# ------------------------------------------------------------------
|
| 182 |
+
|
| 183 |
+
def _run_forced_step_1(self, obs_a: dict, obs_b: dict) -> dict:
|
| 184 |
+
"""
|
| 185 |
+
After both envs are reset, run step 1 with forced actions and
|
| 186 |
+
initialise the Trajectory objects.
|
| 187 |
+
"""
|
| 188 |
+
initial_script = obs_a.get("current_script", "")
|
| 189 |
+
region = obs_a.get("region", "pan_india_english")
|
| 190 |
+
platform = obs_a.get("platform", "Reels")
|
| 191 |
+
niche = obs_a.get("niche", "personal finance")
|
| 192 |
+
|
| 193 |
+
self._traj_a = Trajectory(
|
| 194 |
+
trajectory_id=f"{self._episode_id}_A",
|
| 195 |
+
trajectory_type=TrajectoryType.CRITIC_FIRST,
|
| 196 |
+
initial_script=initial_script,
|
| 197 |
+
current_script=initial_script,
|
| 198 |
+
)
|
| 199 |
+
self._traj_b = Trajectory(
|
| 200 |
+
trajectory_id=f"{self._episode_id}_B",
|
| 201 |
+
trajectory_type=TrajectoryType.DEFENDER_FIRST,
|
| 202 |
+
initial_script=initial_script,
|
| 203 |
+
current_script=initial_script,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Run critic and defender once to determine forced actions
|
| 207 |
+
critique = self.env_a.critic.critique(
|
| 208 |
+
script=initial_script,
|
| 209 |
+
region=region,
|
| 210 |
+
platform=platform,
|
| 211 |
+
niche=niche,
|
| 212 |
+
)
|
| 213 |
+
defender_out = self.env_a.defender.defend(
|
| 214 |
+
script=initial_script,
|
| 215 |
+
critic_claims=critique.claims,
|
| 216 |
+
region=region,
|
| 217 |
+
platform=platform,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
forced_a = self._traj_a.get_forced_first_action(critique.claims, defender_out)
|
| 221 |
+
forced_b = self._traj_b.get_forced_first_action(critique.claims, defender_out)
|
| 222 |
+
|
| 223 |
+
self._forced_action_a = forced_a
|
| 224 |
+
self._forced_action_b = forced_b
|
| 225 |
+
|
| 226 |
+
# Execute forced step 1 in each environment
|
| 227 |
+
obs_a_new, r_a, done_a, _, info_a = self.env_a.step(forced_a)
|
| 228 |
+
obs_b_new, r_b, done_b, _, info_b = self.env_b.step(forced_b)
|
| 229 |
+
|
| 230 |
+
self._traj_a.current_script = obs_a_new.get("current_script", initial_script)
|
| 231 |
+
self._traj_a.cumulative_reward = r_a
|
| 232 |
+
self._traj_a.step_count = 1
|
| 233 |
+
self._traj_a.terminated = done_a
|
| 234 |
+
self._traj_a.final_reward_components = info_a.get("reward_components")
|
| 235 |
+
|
| 236 |
+
self._traj_b.current_script = obs_b_new.get("current_script", initial_script)
|
| 237 |
+
self._traj_b.cumulative_reward = r_b
|
| 238 |
+
self._traj_b.step_count = 1
|
| 239 |
+
self._traj_b.terminated = done_b
|
| 240 |
+
self._traj_b.final_reward_components = info_b.get("reward_components")
|
| 241 |
+
|
| 242 |
+
self._step_num = 1
|
| 243 |
+
|
| 244 |
+
return self.state()
|
| 245 |
+
|
| 246 |
+
def _traj_state(self, env: ViralScriptEnv, traj: Trajectory) -> dict:
|
| 247 |
+
s = env.state()
|
| 248 |
+
return {
|
| 249 |
+
"current_script": traj.current_script,
|
| 250 |
+
"reward_components": s.get("reward_components", {}),
|
| 251 |
+
"debate_history": s.get("debate_history", []),
|
| 252 |
+
"cumulative_reward": traj.cumulative_reward,
|
| 253 |
+
"step_count": traj.step_count,
|
| 254 |
+
"terminated": traj.terminated,
|
| 255 |
+
"trajectory_type": traj.trajectory_type,
|
| 256 |
+
}
|
viral_script_engine/environment/trajectory.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import Any, List, Optional
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
|
| 5 |
+
from viral_script_engine.environment.observations import DebateRound, RewardComponents
|
| 6 |
+
|
| 7 |
+
_SEVERITY_ORDER = {"high": 3, "medium": 2, "low": 1}
|
| 8 |
+
|
| 9 |
+
_CRITIQUE_TO_ACTION = {
|
| 10 |
+
"hook_weakness": "hook_rewrite",
|
| 11 |
+
"pacing_issue": "section_reorder",
|
| 12 |
+
"cultural_mismatch": "cultural_ref_sub",
|
| 13 |
+
"cta_buried": "cta_placement",
|
| 14 |
+
"coherence_break": "section_reorder",
|
| 15 |
+
"retention_risk": "hook_rewrite",
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
_ACTION_TO_TARGET = {
|
| 19 |
+
"hook_rewrite": "hook",
|
| 20 |
+
"section_reorder": "body",
|
| 21 |
+
"cultural_ref_sub": "body",
|
| 22 |
+
"cta_placement": "cta",
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class TrajectoryType:
|
| 27 |
+
CRITIC_FIRST = "critic_first" # Trajectory A: act on Critic's top claim first
|
| 28 |
+
DEFENDER_FIRST = "defender_first" # Trajectory B: act on Defender's concern first
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Trajectory(BaseModel):
|
| 32 |
+
trajectory_id: str
|
| 33 |
+
trajectory_type: str
|
| 34 |
+
initial_script: str
|
| 35 |
+
current_script: str
|
| 36 |
+
steps: List[Any] = []
|
| 37 |
+
cumulative_reward: float = 0.0
|
| 38 |
+
final_reward_components: Optional[Any] = None
|
| 39 |
+
terminated: bool = False
|
| 40 |
+
step_count: int = 0
|
| 41 |
+
|
| 42 |
+
def get_forced_first_action(
|
| 43 |
+
self,
|
| 44 |
+
critic_claims: List[Any],
|
| 45 |
+
defender_output: Any,
|
| 46 |
+
) -> dict:
|
| 47 |
+
"""
|
| 48 |
+
Returns the forced first action based on trajectory type.
|
| 49 |
+
|
| 50 |
+
CRITIC_FIRST: pick the action that addresses the highest-severity CritiqueClaim.
|
| 51 |
+
DEFENDER_FIRST: pick the action that preserves the core_strength_quote.
|
| 52 |
+
If core_strength is in hook → hook_rewrite is risky → pick cta_placement first.
|
| 53 |
+
"""
|
| 54 |
+
if self.trajectory_type == TrajectoryType.CRITIC_FIRST:
|
| 55 |
+
return self._critic_first_action(critic_claims)
|
| 56 |
+
return self._defender_first_action(critic_claims, defender_output)
|
| 57 |
+
|
| 58 |
+
def _critic_first_action(self, critic_claims: List[Any]) -> dict:
|
| 59 |
+
if not critic_claims:
|
| 60 |
+
return _fallback_action("C1")
|
| 61 |
+
sorted_claims = sorted(
|
| 62 |
+
critic_claims,
|
| 63 |
+
key=lambda c: _SEVERITY_ORDER.get(getattr(c, "severity", "low"), 0),
|
| 64 |
+
reverse=True,
|
| 65 |
+
)
|
| 66 |
+
top = sorted_claims[0]
|
| 67 |
+
action_type = _CRITIQUE_TO_ACTION.get(
|
| 68 |
+
getattr(top, "critique_class", ""), "hook_rewrite"
|
| 69 |
+
)
|
| 70 |
+
return {
|
| 71 |
+
"action_type": action_type,
|
| 72 |
+
"target_section": _ACTION_TO_TARGET.get(action_type, "hook"),
|
| 73 |
+
"instruction": (
|
| 74 |
+
f"Address the top critic concern: "
|
| 75 |
+
f"{getattr(top, 'claim_text', '')[:100]}"
|
| 76 |
+
),
|
| 77 |
+
"critique_claim_id": getattr(top, "claim_id", "C1"),
|
| 78 |
+
"reasoning": (
|
| 79 |
+
f"CRITIC_FIRST: targeting highest-severity "
|
| 80 |
+
f"{getattr(top, 'critique_class', '')} claim ({getattr(top, 'severity', '')})."
|
| 81 |
+
),
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
def _defender_first_action(self, critic_claims: List[Any], defender_output: Any) -> dict:
|
| 85 |
+
core_quote = ""
|
| 86 |
+
flagged: set = set()
|
| 87 |
+
|
| 88 |
+
if defender_output is not None:
|
| 89 |
+
if hasattr(defender_output, "core_strength_quote"):
|
| 90 |
+
core_quote = defender_output.core_strength_quote or ""
|
| 91 |
+
flagged = set(getattr(defender_output, "flagged_critic_claims", []))
|
| 92 |
+
elif isinstance(defender_output, dict):
|
| 93 |
+
core_quote = defender_output.get("core_strength_quote", "")
|
| 94 |
+
flagged = set(defender_output.get("flagged_critic_claims", []))
|
| 95 |
+
|
| 96 |
+
# Core strength is "in the hook" if its first 20 chars appear in the leading 100 chars
|
| 97 |
+
hook_portion = self.current_script[:100].lower()
|
| 98 |
+
core_in_hook = bool(core_quote) and core_quote.lower()[:20] in hook_portion
|
| 99 |
+
|
| 100 |
+
if core_in_hook:
|
| 101 |
+
# Hook is precious — choose a safe non-hook action first
|
| 102 |
+
action_type = "cta_placement"
|
| 103 |
+
target = "cta"
|
| 104 |
+
instruction = (
|
| 105 |
+
"Improve CTA positioning to boost completion rate "
|
| 106 |
+
"without altering the hook."
|
| 107 |
+
)
|
| 108 |
+
claim_id = (
|
| 109 |
+
getattr(critic_claims[0], "claim_id", "C1")
|
| 110 |
+
if critic_claims else "C1"
|
| 111 |
+
)
|
| 112 |
+
else:
|
| 113 |
+
# Core is in body — safe to improve the hook
|
| 114 |
+
action_type = "hook_rewrite"
|
| 115 |
+
target = "hook"
|
| 116 |
+
instruction = (
|
| 117 |
+
"Rewrite the hook for stronger attention capture "
|
| 118 |
+
"while preserving the core body voice."
|
| 119 |
+
)
|
| 120 |
+
unflagged = [
|
| 121 |
+
c for c in critic_claims
|
| 122 |
+
if getattr(c, "claim_id", "") not in flagged
|
| 123 |
+
]
|
| 124 |
+
claim = unflagged[0] if unflagged else (critic_claims[0] if critic_claims else None)
|
| 125 |
+
claim_id = getattr(claim, "claim_id", "C1") if claim else "C1"
|
| 126 |
+
|
| 127 |
+
return {
|
| 128 |
+
"action_type": action_type,
|
| 129 |
+
"target_section": target,
|
| 130 |
+
"instruction": instruction,
|
| 131 |
+
"critique_claim_id": claim_id,
|
| 132 |
+
"reasoning": (
|
| 133 |
+
"DEFENDER_FIRST: preserving Defender's core strength "
|
| 134 |
+
"and regional voice before addressing critic claims."
|
| 135 |
+
),
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _fallback_action(claim_id: str = "C1") -> dict:
|
| 140 |
+
return {
|
| 141 |
+
"action_type": "hook_rewrite",
|
| 142 |
+
"target_section": "hook",
|
| 143 |
+
"instruction": "Rewrite the hook to open with a strong immediate claim.",
|
| 144 |
+
"critique_claim_id": claim_id,
|
| 145 |
+
"reasoning": "Fallback: no critic claims available.",
|
| 146 |
+
}
|
viral_script_engine/rewards/contrastive_reward.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import math
|
| 3 |
+
from typing import TYPE_CHECKING
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
from viral_script_engine.environment.trajectory import Trajectory
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ContrastiveRewardResult(BaseModel):
|
| 12 |
+
final_reward: float
|
| 13 |
+
base_reward: float
|
| 14 |
+
contrast_bonus: float
|
| 15 |
+
delta: float
|
| 16 |
+
winning_trajectory: str # "A" | "B" | "tie"
|
| 17 |
+
winning_trajectory_type: str # "critic_first" | "defender_first" | "tie"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ContrastiveReward:
|
| 21 |
+
"""
|
| 22 |
+
Computes a reward based on the delta between two parallel trajectories.
|
| 23 |
+
|
| 24 |
+
The key insight: the Arbitrator is rewarded not just for doing well,
|
| 25 |
+
but for doing BETTER than the counterfactual alternative.
|
| 26 |
+
|
| 27 |
+
Reward formula:
|
| 28 |
+
- delta = traj_a.cumulative_reward - traj_b.cumulative_reward
|
| 29 |
+
- base_reward = max(traj_a.cumulative_reward, traj_b.cumulative_reward)
|
| 30 |
+
(reward the better trajectory's absolute performance)
|
| 31 |
+
- contrast_bonus = tanh(delta * 3) * 0.2
|
| 32 |
+
(add up to +0.2 bonus when one trajectory clearly dominates)
|
| 33 |
+
- final = base_reward + contrast_bonus, clipped to [0, 1]
|
| 34 |
+
|
| 35 |
+
When delta is near zero, contrast_bonus → 0 — no extra credit for
|
| 36 |
+
a coin-flip decision. When delta is large, contrast_bonus is maximised —
|
| 37 |
+
this is the signal that matters most for learning action ordering.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def compute(
|
| 41 |
+
self,
|
| 42 |
+
traj_a: "Trajectory",
|
| 43 |
+
traj_b: "Trajectory",
|
| 44 |
+
) -> ContrastiveRewardResult:
|
| 45 |
+
delta = traj_a.cumulative_reward - traj_b.cumulative_reward
|
| 46 |
+
base_reward = max(traj_a.cumulative_reward, traj_b.cumulative_reward)
|
| 47 |
+
contrast_bonus = math.tanh(delta * 3) * 0.2
|
| 48 |
+
final = max(0.0, min(1.0, base_reward + contrast_bonus))
|
| 49 |
+
|
| 50 |
+
if abs(delta) < 1e-6:
|
| 51 |
+
winning = "tie"
|
| 52 |
+
winning_type = "tie"
|
| 53 |
+
elif delta > 0:
|
| 54 |
+
winning = "A"
|
| 55 |
+
winning_type = traj_a.trajectory_type
|
| 56 |
+
else:
|
| 57 |
+
winning = "B"
|
| 58 |
+
winning_type = traj_b.trajectory_type
|
| 59 |
+
|
| 60 |
+
return ContrastiveRewardResult(
|
| 61 |
+
final_reward=final,
|
| 62 |
+
base_reward=base_reward,
|
| 63 |
+
contrast_bonus=contrast_bonus,
|
| 64 |
+
delta=delta,
|
| 65 |
+
winning_trajectory=winning,
|
| 66 |
+
winning_trajectory_type=winning_type,
|
| 67 |
+
)
|
viral_script_engine/scripts/run_ab_episode.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A/B Episode Runner — Phase 10 Gate Check Script
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python scripts/run_ab_episode.py --script S08 --steps 4 --verbose
|
| 6 |
+
python scripts/run_ab_episode.py --script S03 --steps 3
|
| 7 |
+
"""
|
| 8 |
+
import argparse
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
if hasattr(sys.stdout, "reconfigure"):
|
| 13 |
+
sys.stdout.reconfigure(encoding="utf-8", errors="replace")
|
| 14 |
+
if hasattr(sys.stderr, "reconfigure"):
|
| 15 |
+
sys.stderr.reconfigure(encoding="utf-8", errors="replace")
|
| 16 |
+
|
| 17 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 18 |
+
|
| 19 |
+
from dotenv import load_dotenv
|
| 20 |
+
load_dotenv(dotenv_path=Path(__file__).parent.parent / ".env")
|
| 21 |
+
load_dotenv(dotenv_path=Path(__file__).parent.parent.parent / ".env", override=False)
|
| 22 |
+
|
| 23 |
+
from viral_script_engine.environment.ab_env import ABScriptEnv
|
| 24 |
+
from viral_script_engine.rewards.contrastive_reward import ContrastiveReward
|
| 25 |
+
from viral_script_engine.agents.baseline_arbitrator import BaselineArbitratorAgent
|
| 26 |
+
|
| 27 |
+
_ROOT = Path(__file__).parent.parent
|
| 28 |
+
_SCRIPTS_PATH = str(_ROOT / "data" / "test_scripts" / "scripts.json")
|
| 29 |
+
_CULTURAL_KB_PATH = str(_ROOT / "data" / "cultural_kb.json")
|
| 30 |
+
|
| 31 |
+
_DIFFICULTY_FOR_SCRIPT = {
|
| 32 |
+
"S01": "easy", "S02": "easy", "S03": "easy", "S04": "easy",
|
| 33 |
+
"S05": "medium", "S06": "medium", "S07": "medium",
|
| 34 |
+
"S08": "hard", "S09": "hard", "S10": "hard",
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
SEP = "═" * 70
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _rc_row(label: str, before: float, after: float) -> str:
|
| 41 |
+
delta = after - before
|
| 42 |
+
sign = "+" if delta >= 0 else ""
|
| 43 |
+
warn = " ⚠" if delta < -0.05 else ""
|
| 44 |
+
return f" {label}: {before:.2f} → {after:.2f} ({sign}{delta:.2f}){warn}"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _traj_summary(traj: dict, label: str) -> str:
|
| 48 |
+
rc = traj.get("reward_components") or {}
|
| 49 |
+
r1 = rc.get("r1_hook_strength") or 0.0
|
| 50 |
+
r3 = rc.get("r3_cultural_alignment") or 0.0
|
| 51 |
+
total = rc.get("total") or traj.get("cumulative_reward", 0.0)
|
| 52 |
+
return (
|
| 53 |
+
f" [{label}] script[:60]: {traj.get('current_script', '')[:60]!r}\n"
|
| 54 |
+
f" R1={r1:.2f} R3={r3:.2f} Cumulative={traj.get('cumulative_reward', 0.0):.3f}"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def run_ab_episode(script_id: str, num_steps: int, verbose: bool):
|
| 59 |
+
difficulty = _DIFFICULTY_FOR_SCRIPT.get(script_id, "hard")
|
| 60 |
+
ab_env = ABScriptEnv(
|
| 61 |
+
scripts_path=_SCRIPTS_PATH,
|
| 62 |
+
cultural_kb_path=_CULTURAL_KB_PATH,
|
| 63 |
+
max_steps=num_steps + 1, # +1 because step 1 is forced
|
| 64 |
+
difficulty=difficulty,
|
| 65 |
+
)
|
| 66 |
+
arbitrator = BaselineArbitratorAgent()
|
| 67 |
+
|
| 68 |
+
print(f"\n{SEP}")
|
| 69 |
+
print(f" A/B EPISODE — Script: {script_id} Steps: {num_steps} Difficulty: {difficulty}")
|
| 70 |
+
print(SEP)
|
| 71 |
+
|
| 72 |
+
# Reset — forced step 1 runs automatically
|
| 73 |
+
state = ab_env.reset_from_script_id(script_id, _SCRIPTS_PATH)
|
| 74 |
+
|
| 75 |
+
traj_a = state["trajectory_a"]
|
| 76 |
+
traj_b = state["trajectory_b"]
|
| 77 |
+
forced_a = ab_env._forced_action_a
|
| 78 |
+
forced_b = ab_env._forced_action_b
|
| 79 |
+
|
| 80 |
+
print(f"\n{SEP}")
|
| 81 |
+
print(" STEP 1 (FORCED)")
|
| 82 |
+
print(SEP)
|
| 83 |
+
col_w = 34
|
| 84 |
+
print(
|
| 85 |
+
f" {'TRAJECTORY A (Critic-first)':<{col_w}}"
|
| 86 |
+
f" {'TRAJECTORY B (Defender-first)'}"
|
| 87 |
+
)
|
| 88 |
+
print(
|
| 89 |
+
f" Action: {forced_a.get('action_type','?'):<{col_w-8}}"
|
| 90 |
+
f" Action: {forced_b.get('action_type','?')}"
|
| 91 |
+
)
|
| 92 |
+
print(
|
| 93 |
+
f" Cumulative: {traj_a['cumulative_reward']:.3f}{'':<{col_w-20}}"
|
| 94 |
+
f" Cumulative: {traj_b['cumulative_reward']:.3f}"
|
| 95 |
+
)
|
| 96 |
+
if verbose:
|
| 97 |
+
print(f" Reasoning A: {forced_a.get('reasoning','')[:60]}")
|
| 98 |
+
print(f" Reasoning B: {forced_b.get('reasoning','')[:60]}")
|
| 99 |
+
|
| 100 |
+
print(f"\n Delta after step 1: {state['delta']:+.3f} (leading: Trajectory {state['leading_trajectory']})")
|
| 101 |
+
|
| 102 |
+
# Free steps (2+)
|
| 103 |
+
for step_idx in range(2, num_steps + 1):
|
| 104 |
+
if traj_a.get("terminated") and traj_b.get("terminated"):
|
| 105 |
+
break
|
| 106 |
+
|
| 107 |
+
# Arbitrator acts based on current trajectory_a state (simplification for demo)
|
| 108 |
+
obs_for_arb = {
|
| 109 |
+
"current_script": traj_a.get("current_script", ""),
|
| 110 |
+
"debate_history": traj_a.get("debate_history", []),
|
| 111 |
+
"reward_components": traj_a.get("reward_components", {}),
|
| 112 |
+
}
|
| 113 |
+
action = arbitrator.act(obs_for_arb)
|
| 114 |
+
|
| 115 |
+
print(f"\n{SEP}")
|
| 116 |
+
print(f" STEP {step_idx} (FREE CHOICE)")
|
| 117 |
+
print(SEP)
|
| 118 |
+
print(f" Arbitrator action: {action.get('action_type')} → {action.get('critique_claim_id')}")
|
| 119 |
+
|
| 120 |
+
prev_a_cum = traj_a["cumulative_reward"]
|
| 121 |
+
prev_b_cum = traj_b["cumulative_reward"]
|
| 122 |
+
|
| 123 |
+
state, ep_reward, terminated, _, _ = ab_env.step(action)
|
| 124 |
+
traj_a = state["trajectory_a"]
|
| 125 |
+
traj_b = state["trajectory_b"]
|
| 126 |
+
|
| 127 |
+
print(
|
| 128 |
+
f" Traj A cumulative: {prev_a_cum:.3f} → {traj_a['cumulative_reward']:.3f}"
|
| 129 |
+
f" ({traj_a['cumulative_reward'] - prev_a_cum:+.3f})"
|
| 130 |
+
)
|
| 131 |
+
print(
|
| 132 |
+
f" Traj B cumulative: {prev_b_cum:.3f} → {traj_b['cumulative_reward']:.3f}"
|
| 133 |
+
f" ({traj_b['cumulative_reward'] - prev_b_cum:+.3f})"
|
| 134 |
+
)
|
| 135 |
+
print(f" Delta: {state['delta']:+.3f} Leading: Trajectory {state['leading_trajectory']}")
|
| 136 |
+
|
| 137 |
+
if terminated:
|
| 138 |
+
break
|
| 139 |
+
|
| 140 |
+
# Episode end
|
| 141 |
+
traj_a_final = state["trajectory_a"]
|
| 142 |
+
traj_b_final = state["trajectory_b"]
|
| 143 |
+
final_delta = state["delta"]
|
| 144 |
+
|
| 145 |
+
contrastive = ab_env.contrastive_reward_calc.compute(
|
| 146 |
+
ab_env._traj_a, ab_env._traj_b
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
winner_label = {
|
| 150 |
+
"A": "A (critic-first was better)",
|
| 151 |
+
"B": "B (defender-first was better)",
|
| 152 |
+
"tie": "tie",
|
| 153 |
+
}.get(contrastive.winning_trajectory, contrastive.winning_trajectory)
|
| 154 |
+
|
| 155 |
+
lesson_map = {
|
| 156 |
+
"critic_first": "Act on the Critic's top severity claim first to maximise early gains.",
|
| 157 |
+
"defender_first": "On scripts with strong core voice, preserve the Defender's concern first.",
|
| 158 |
+
"tie": "Both orderings performed similarly — action choice matters more than sequence.",
|
| 159 |
+
}
|
| 160 |
+
lesson = lesson_map.get(contrastive.winning_trajectory_type, "")
|
| 161 |
+
|
| 162 |
+
print(f"\n{SEP}")
|
| 163 |
+
print(" EPISODE END")
|
| 164 |
+
print(SEP)
|
| 165 |
+
print(f" Trajectory A final cumulative: {traj_a_final['cumulative_reward']:.3f}")
|
| 166 |
+
print(f" Trajectory B final cumulative: {traj_b_final['cumulative_reward']:.3f}")
|
| 167 |
+
print(f" Winner: {winner_label}")
|
| 168 |
+
print(f" Delta: {final_delta:+.3f}")
|
| 169 |
+
print(f" Base reward: {contrastive.base_reward:.4f}")
|
| 170 |
+
print(f" Contrast bonus: {contrastive.contrast_bonus:+.4f}")
|
| 171 |
+
print(f" Contrastive reward: {contrastive.final_reward:.4f}")
|
| 172 |
+
print(f" Lesson: {lesson}")
|
| 173 |
+
print()
|
| 174 |
+
|
| 175 |
+
gate_pass = (
|
| 176 |
+
abs(final_delta) > 1e-6
|
| 177 |
+
and 0.0 <= contrastive.final_reward <= 1.0
|
| 178 |
+
)
|
| 179 |
+
if gate_pass:
|
| 180 |
+
print(
|
| 181 |
+
f"PHASE 10 GATE: PASS — A/B environment running. "
|
| 182 |
+
f"Contrastive reward active. Delta: {final_delta:.3f}."
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
print(
|
| 186 |
+
f"PHASE 10 GATE: FAIL — delta={final_delta:.6f}, "
|
| 187 |
+
f"reward={contrastive.final_reward:.4f}"
|
| 188 |
+
)
|
| 189 |
+
sys.exit(1)
|
| 190 |
+
|
| 191 |
+
return contrastive
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
parser = argparse.ArgumentParser(description="Run an A/B episode (Phase 10)")
|
| 196 |
+
parser.add_argument("--script", default="S08", help="Script ID (default: S08)")
|
| 197 |
+
parser.add_argument("--steps", type=int, default=4, help="Total steps including forced step 1")
|
| 198 |
+
parser.add_argument("--verbose", action="store_true", help="Show reasoning details")
|
| 199 |
+
args = parser.parse_args()
|
| 200 |
+
|
| 201 |
+
run_ab_episode(args.script, args.steps, args.verbose)
|
viral_script_engine/tests/test_phase10.py
ADDED
|
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Phase 10 tests — A/B Testing Environment Layer."""
|
| 2 |
+
import math
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from unittest.mock import MagicMock, patch
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
|
| 9 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 10 |
+
|
| 11 |
+
from viral_script_engine.environment.trajectory import Trajectory, TrajectoryType
|
| 12 |
+
from viral_script_engine.rewards.contrastive_reward import ContrastiveReward, ContrastiveRewardResult
|
| 13 |
+
|
| 14 |
+
_SCRIPTS_PATH = str(
|
| 15 |
+
Path(__file__).parent.parent / "data" / "test_scripts" / "scripts.json"
|
| 16 |
+
)
|
| 17 |
+
_CULTURAL_KB_PATH = str(
|
| 18 |
+
Path(__file__).parent.parent / "data" / "cultural_kb.json"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
# Helpers
|
| 23 |
+
# ---------------------------------------------------------------------------
|
| 24 |
+
|
| 25 |
+
def _make_claim(claim_id: str, severity: str, critique_class: str) -> MagicMock:
|
| 26 |
+
c = MagicMock()
|
| 27 |
+
c.claim_id = claim_id
|
| 28 |
+
c.severity = severity
|
| 29 |
+
c.critique_class = critique_class
|
| 30 |
+
c.claim_text = f"Test claim {claim_id}"
|
| 31 |
+
return c
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _make_defender(core_quote: str, flagged: list = None) -> MagicMock:
|
| 35 |
+
d = MagicMock()
|
| 36 |
+
d.core_strength_quote = core_quote
|
| 37 |
+
d.flagged_critic_claims = flagged or []
|
| 38 |
+
return d
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _make_trajectory(
|
| 42 |
+
traj_type: str,
|
| 43 |
+
cumulative: float,
|
| 44 |
+
script: str = "Test script body content here.",
|
| 45 |
+
) -> Trajectory:
|
| 46 |
+
return Trajectory(
|
| 47 |
+
trajectory_id=f"test_{traj_type}",
|
| 48 |
+
trajectory_type=traj_type,
|
| 49 |
+
initial_script=script,
|
| 50 |
+
current_script=script,
|
| 51 |
+
cumulative_reward=cumulative,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
# Trajectory: forced first action — CRITIC_FIRST
|
| 57 |
+
# ---------------------------------------------------------------------------
|
| 58 |
+
|
| 59 |
+
class TestTrajectoryForcedActionCriticFirst:
|
| 60 |
+
def setup_method(self):
|
| 61 |
+
self.traj = _make_trajectory(TrajectoryType.CRITIC_FIRST, 0.0)
|
| 62 |
+
|
| 63 |
+
def test_picks_highest_severity_claim(self):
|
| 64 |
+
claims = [
|
| 65 |
+
_make_claim("C1", "low", "pacing_issue"),
|
| 66 |
+
_make_claim("C2", "high", "hook_weakness"),
|
| 67 |
+
_make_claim("C3", "medium", "cta_buried"),
|
| 68 |
+
]
|
| 69 |
+
action = self.traj.get_forced_first_action(claims, None)
|
| 70 |
+
# highest severity is C2 (high, hook_weakness → hook_rewrite)
|
| 71 |
+
assert action["action_type"] == "hook_rewrite"
|
| 72 |
+
assert action["critique_claim_id"] == "C2"
|
| 73 |
+
|
| 74 |
+
def test_maps_cta_buried_to_cta_placement(self):
|
| 75 |
+
claims = [_make_claim("C1", "high", "cta_buried")]
|
| 76 |
+
action = self.traj.get_forced_first_action(claims, None)
|
| 77 |
+
assert action["action_type"] == "cta_placement"
|
| 78 |
+
assert action["target_section"] == "cta"
|
| 79 |
+
|
| 80 |
+
def test_maps_cultural_mismatch_to_cultural_ref_sub(self):
|
| 81 |
+
claims = [_make_claim("C1", "high", "cultural_mismatch")]
|
| 82 |
+
action = self.traj.get_forced_first_action(claims, None)
|
| 83 |
+
assert action["action_type"] == "cultural_ref_sub"
|
| 84 |
+
|
| 85 |
+
def test_fallback_when_no_claims(self):
|
| 86 |
+
action = self.traj.get_forced_first_action([], None)
|
| 87 |
+
assert action["action_type"] == "hook_rewrite"
|
| 88 |
+
assert "CRITIC_FIRST" in action["reasoning"] or action["reasoning"]
|
| 89 |
+
|
| 90 |
+
def test_reasoning_mentions_critic_first(self):
|
| 91 |
+
claims = [_make_claim("C1", "high", "hook_weakness")]
|
| 92 |
+
action = self.traj.get_forced_first_action(claims, None)
|
| 93 |
+
assert "CRITIC_FIRST" in action["reasoning"]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ---------------------------------------------------------------------------
|
| 97 |
+
# Trajectory: forced first action — DEFENDER_FIRST
|
| 98 |
+
# ---------------------------------------------------------------------------
|
| 99 |
+
|
| 100 |
+
class TestTrajectoryForcedActionDefenderFirst:
|
| 101 |
+
def test_picks_cta_when_core_strength_in_hook(self):
|
| 102 |
+
# Script starts with the core quote → hook is precious
|
| 103 |
+
script = "Why does your phone battery lie? Charge to eighty. Never below twenty."
|
| 104 |
+
traj = _make_trajectory(TrajectoryType.DEFENDER_FIRST, 0.0, script=script)
|
| 105 |
+
defender = _make_defender(core_quote="Why does your phone battery lie?")
|
| 106 |
+
claims = [_make_claim("C1", "high", "hook_weakness")]
|
| 107 |
+
|
| 108 |
+
action = traj.get_forced_first_action(claims, defender)
|
| 109 |
+
assert action["action_type"] == "cta_placement", (
|
| 110 |
+
f"Expected cta_placement when core strength is in hook, got {action['action_type']}"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
def test_picks_hook_rewrite_when_core_strength_in_body(self):
|
| 114 |
+
# Script hook is entirely generic; core quote only appears after the first 100 chars
|
| 115 |
+
filler = "Stop wasting your money on things that do not matter at all. " * 2 # >100 chars
|
| 116 |
+
core = "UNIQUE_CORE_PHRASE_XYZ"
|
| 117 |
+
script = filler + core
|
| 118 |
+
traj = _make_trajectory(TrajectoryType.DEFENDER_FIRST, 0.0, script=script)
|
| 119 |
+
# Core quote appears after position 100 — NOT in hook
|
| 120 |
+
defender = _make_defender(core_quote=core)
|
| 121 |
+
claims = [_make_claim("C1", "high", "hook_weakness")]
|
| 122 |
+
|
| 123 |
+
action = traj.get_forced_first_action(claims, defender)
|
| 124 |
+
assert action["action_type"] == "hook_rewrite", (
|
| 125 |
+
f"Expected hook_rewrite when core is NOT in hook, got {action['action_type']}"
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
def test_reasoning_mentions_defender_first(self):
|
| 129 |
+
traj = _make_trajectory(TrajectoryType.DEFENDER_FIRST, 0.0)
|
| 130 |
+
action = traj.get_forced_first_action([], None)
|
| 131 |
+
assert "DEFENDER_FIRST" in action["reasoning"]
|
| 132 |
+
|
| 133 |
+
def test_skips_flagged_claims_in_defender_first(self):
|
| 134 |
+
script = "Body content only. No hook magic here at all whatsoever for testing purposes."
|
| 135 |
+
traj = _make_trajectory(TrajectoryType.DEFENDER_FIRST, 0.0, script=script)
|
| 136 |
+
defender = _make_defender(
|
| 137 |
+
core_quote="definitely not in the hook portion of this script",
|
| 138 |
+
flagged=["C1"],
|
| 139 |
+
)
|
| 140 |
+
claims = [
|
| 141 |
+
_make_claim("C1", "high", "hook_weakness"),
|
| 142 |
+
_make_claim("C2", "medium", "pacing_issue"),
|
| 143 |
+
]
|
| 144 |
+
action = traj.get_forced_first_action(claims, defender)
|
| 145 |
+
# C1 is flagged, so should pick C2
|
| 146 |
+
assert action["critique_claim_id"] == "C2"
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ---------------------------------------------------------------------------
|
| 150 |
+
# ContrastiveReward
|
| 151 |
+
# ---------------------------------------------------------------------------
|
| 152 |
+
|
| 153 |
+
class TestContrastiveReward:
|
| 154 |
+
def setup_method(self):
|
| 155 |
+
self.cr = ContrastiveReward()
|
| 156 |
+
|
| 157 |
+
def test_delta_computed_correctly(self):
|
| 158 |
+
traj_a = _make_trajectory(TrajectoryType.CRITIC_FIRST, 0.7)
|
| 159 |
+
traj_b = _make_trajectory(TrajectoryType.DEFENDER_FIRST, 0.5)
|
| 160 |
+
result = self.cr.compute(traj_a, traj_b)
|
| 161 |
+
assert abs(result.delta - 0.2) < 1e-9
|
| 162 |
+
|
| 163 |
+
def test_winning_trajectory_is_a_when_a_higher(self):
|
| 164 |
+
traj_a = _make_trajectory(TrajectoryType.CRITIC_FIRST, 0.8)
|
| 165 |
+
traj_b = _make_trajectory(TrajectoryType.DEFENDER_FIRST, 0.5)
|
| 166 |
+
result = self.cr.compute(traj_a, traj_b)
|
| 167 |
+
assert result.winning_trajectory == "A"
|
| 168 |
+
assert result.winning_trajectory_type == TrajectoryType.CRITIC_FIRST
|
| 169 |
+
|
| 170 |
+
def test_winning_trajectory_is_b_when_b_higher(self):
|
| 171 |
+
traj_a = _make_trajectory(TrajectoryType.CRITIC_FIRST, 0.4)
|
| 172 |
+
traj_b = _make_trajectory(TrajectoryType.DEFENDER_FIRST, 0.7)
|
| 173 |
+
result = self.cr.compute(traj_a, traj_b)
|
| 174 |
+
assert result.winning_trajectory == "B"
|
| 175 |
+
assert result.winning_trajectory_type == TrajectoryType.DEFENDER_FIRST
|
| 176 |
+
|
| 177 |
+
def test_tie_when_delta_is_zero(self):
|
| 178 |
+
traj_a = _make_trajectory(TrajectoryType.CRITIC_FIRST, 0.6)
|
| 179 |
+
traj_b = _make_trajectory(TrajectoryType.DEFENDER_FIRST, 0.6)
|
| 180 |
+
result = self.cr.compute(traj_a, traj_b)
|
| 181 |
+
assert result.winning_trajectory == "tie"
|
| 182 |
+
|
| 183 |
+
def test_contrast_bonus_near_zero_when_delta_small(self):
|
| 184 |
+
# delta = 0.01 → tanh(0.01 * 3) * 0.2 ≈ 0.006 — near zero
|
| 185 |
+
traj_a = _make_trajectory(TrajectoryType.CRITIC_FIRST, 0.51)
|
| 186 |
+
traj_b = _make_trajectory(TrajectoryType.DEFENDER_FIRST, 0.50)
|
| 187 |
+
result = self.cr.compute(traj_a, traj_b)
|
| 188 |
+
assert abs(result.contrast_bonus) < 0.02, (
|
| 189 |
+
f"contrast_bonus should be near 0 for delta=0.01, got {result.contrast_bonus}"
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def test_contrast_bonus_positive_when_delta_large(self):
|
| 193 |
+
# delta = 0.3 → tanh(0.9) * 0.2 ≈ 0.156 — clearly positive
|
| 194 |
+
traj_a = _make_trajectory(TrajectoryType.CRITIC_FIRST, 0.7)
|
| 195 |
+
traj_b = _make_trajectory(TrajectoryType.DEFENDER_FIRST, 0.4)
|
| 196 |
+
result = self.cr.compute(traj_a, traj_b)
|
| 197 |
+
assert result.contrast_bonus > 0.1, (
|
| 198 |
+
f"contrast_bonus should be > 0.1 for delta=0.3, got {result.contrast_bonus}"
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
def test_final_reward_clipped_to_0_1_upper(self):
|
| 202 |
+
# Very high cumulative rewards should still clip to 1.0
|
| 203 |
+
traj_a = _make_trajectory(TrajectoryType.CRITIC_FIRST, 5.0)
|
| 204 |
+
traj_b = _make_trajectory(TrajectoryType.DEFENDER_FIRST, 0.1)
|
| 205 |
+
result = self.cr.compute(traj_a, traj_b)
|
| 206 |
+
assert result.final_reward <= 1.0
|
| 207 |
+
|
| 208 |
+
def test_final_reward_clipped_to_0_1_lower(self):
|
| 209 |
+
# Negative cumulative rewards should clip to 0.0
|
| 210 |
+
traj_a = _make_trajectory(TrajectoryType.CRITIC_FIRST, -1.0)
|
| 211 |
+
traj_b = _make_trajectory(TrajectoryType.DEFENDER_FIRST, -2.0)
|
| 212 |
+
result = self.cr.compute(traj_a, traj_b)
|
| 213 |
+
assert result.final_reward >= 0.0
|
| 214 |
+
|
| 215 |
+
def test_final_reward_always_in_0_1(self):
|
| 216 |
+
for cum_a, cum_b in [(0.3, 0.3), (0.9, 0.1), (0.0, 0.0), (0.5, 0.5), (1.0, 0.0)]:
|
| 217 |
+
traj_a = _make_trajectory(TrajectoryType.CRITIC_FIRST, cum_a)
|
| 218 |
+
traj_b = _make_trajectory(TrajectoryType.DEFENDER_FIRST, cum_b)
|
| 219 |
+
result = self.cr.compute(traj_a, traj_b)
|
| 220 |
+
assert 0.0 <= result.final_reward <= 1.0, (
|
| 221 |
+
f"final_reward={result.final_reward} out of [0,1] for "
|
| 222 |
+
f"cum_a={cum_a}, cum_b={cum_b}"
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
def test_base_reward_is_max(self):
|
| 226 |
+
traj_a = _make_trajectory(TrajectoryType.CRITIC_FIRST, 0.7)
|
| 227 |
+
traj_b = _make_trajectory(TrajectoryType.DEFENDER_FIRST, 0.5)
|
| 228 |
+
result = self.cr.compute(traj_a, traj_b)
|
| 229 |
+
assert abs(result.base_reward - 0.7) < 1e-9
|
| 230 |
+
|
| 231 |
+
def test_result_is_contrastive_reward_result_instance(self):
|
| 232 |
+
traj_a = _make_trajectory(TrajectoryType.CRITIC_FIRST, 0.6)
|
| 233 |
+
traj_b = _make_trajectory(TrajectoryType.DEFENDER_FIRST, 0.4)
|
| 234 |
+
result = self.cr.compute(traj_a, traj_b)
|
| 235 |
+
assert isinstance(result, ContrastiveRewardResult)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# ---------------------------------------------------------------------------
|
| 239 |
+
# ABScriptEnv — integration tests using mocked env.step() and reset()
|
| 240 |
+
# ---------------------------------------------------------------------------
|
| 241 |
+
|
| 242 |
+
def _fake_obs(script: str = "Test script.", reward: float = 0.5) -> dict:
|
| 243 |
+
return {
|
| 244 |
+
"current_script": script,
|
| 245 |
+
"original_script": script,
|
| 246 |
+
"region": "pan_india_english",
|
| 247 |
+
"platform": "Reels",
|
| 248 |
+
"niche": "personal finance",
|
| 249 |
+
"step_num": 1,
|
| 250 |
+
"max_steps": 3,
|
| 251 |
+
"debate_history": [],
|
| 252 |
+
"reward_components": {"r1_hook_strength": reward, "total": reward},
|
| 253 |
+
"difficulty_level": "easy",
|
| 254 |
+
"episode_id": "ep_test",
|
| 255 |
+
"current_moderation_flags": [],
|
| 256 |
+
"current_originality_flags": [],
|
| 257 |
+
"creator_profile": None,
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def _fake_step_result(script: str = "Test script.", reward: float = 0.5, done: bool = False):
|
| 262 |
+
obs = _fake_obs(script, reward)
|
| 263 |
+
info = {"reward_components": {"r1_hook_strength": reward, "total": reward}}
|
| 264 |
+
return obs, reward, done, False, info
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def _make_real_critique(claim_id="C1", severity="high", critique_class="hook_weakness"):
|
| 268 |
+
"""Return a MagicMock with real CritiqueClaim objects so pydantic validation passes."""
|
| 269 |
+
from viral_script_engine.agents.critic import CritiqueClaim, CritiqueOutput
|
| 270 |
+
claim = CritiqueClaim(
|
| 271 |
+
claim_id=claim_id,
|
| 272 |
+
critique_class=critique_class,
|
| 273 |
+
claim_text=f"Test {critique_class} claim",
|
| 274 |
+
timestamp_range="0:00-0:05",
|
| 275 |
+
evidence="test evidence here",
|
| 276 |
+
is_falsifiable=True,
|
| 277 |
+
severity=severity,
|
| 278 |
+
)
|
| 279 |
+
mock_crit = MagicMock()
|
| 280 |
+
mock_crit.claims = [claim]
|
| 281 |
+
mock_crit.overall_severity = severity
|
| 282 |
+
return mock_crit
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def _make_real_defender(core_quote="hook content here"):
|
| 286 |
+
from viral_script_engine.agents.defender import DefenderOutput
|
| 287 |
+
return DefenderOutput(
|
| 288 |
+
core_strength="Strong hook",
|
| 289 |
+
core_strength_quote=core_quote,
|
| 290 |
+
defense_argument="Preserve this element.",
|
| 291 |
+
flagged_critic_claims=[],
|
| 292 |
+
regional_voice_elements=[],
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class TestABScriptEnvMocked:
|
| 297 |
+
"""Test ABScriptEnv behaviour with env.step() mocked at the env level."""
|
| 298 |
+
|
| 299 |
+
def _make_ab_env(self):
|
| 300 |
+
from viral_script_engine.environment.ab_env import ABScriptEnv
|
| 301 |
+
return ABScriptEnv(
|
| 302 |
+
scripts_path=_SCRIPTS_PATH,
|
| 303 |
+
cultural_kb_path=_CULTURAL_KB_PATH,
|
| 304 |
+
max_steps=3,
|
| 305 |
+
difficulty="easy",
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
def _reset_with_mocks(self, ab_env, core_quote="body content deep here", seed=42):
|
| 309 |
+
"""
|
| 310 |
+
Reset ab_env with mocked critic, defender, and step calls.
|
| 311 |
+
Uses real CritiqueClaim/DefenderOutput to pass pydantic validation.
|
| 312 |
+
Returns the state dict.
|
| 313 |
+
"""
|
| 314 |
+
mock_critique = _make_real_critique("C1", "high", "hook_weakness")
|
| 315 |
+
mock_defender = _make_real_defender(core_quote)
|
| 316 |
+
|
| 317 |
+
with patch.object(ab_env.env_a.critic, "critique", return_value=mock_critique), \
|
| 318 |
+
patch.object(ab_env.env_a.defender, "defend", return_value=mock_defender), \
|
| 319 |
+
patch.object(ab_env.env_b.step, "__call__", side_effect=None) if False else \
|
| 320 |
+
patch.object(ab_env.env_a, "step",
|
| 321 |
+
side_effect=lambda action, **kw: _fake_step_result("Script A.", 0.65)), \
|
| 322 |
+
patch.object(ab_env.env_b, "step",
|
| 323 |
+
side_effect=lambda action, **kw: _fake_step_result("Script B.", 0.55)):
|
| 324 |
+
state = ab_env.reset(seed=seed)
|
| 325 |
+
return state
|
| 326 |
+
|
| 327 |
+
def test_reset_gives_both_trajectory_states(self):
|
| 328 |
+
ab_env = self._make_ab_env()
|
| 329 |
+
state = self._reset_with_mocks(ab_env)
|
| 330 |
+
|
| 331 |
+
assert "trajectory_a" in state
|
| 332 |
+
assert "trajectory_b" in state
|
| 333 |
+
assert "delta" in state
|
| 334 |
+
assert "leading_trajectory" in state
|
| 335 |
+
assert "episode_id" in state
|
| 336 |
+
|
| 337 |
+
def test_both_envs_start_from_same_script(self):
|
| 338 |
+
ab_env = self._make_ab_env()
|
| 339 |
+
self._reset_with_mocks(ab_env, seed=42)
|
| 340 |
+
|
| 341 |
+
# Both trajectories must share the same initial_script (same reset seed)
|
| 342 |
+
assert ab_env._traj_a.initial_script == ab_env._traj_b.initial_script
|
| 343 |
+
|
| 344 |
+
def test_step_1_forced_actions_differ(self):
|
| 345 |
+
"""
|
| 346 |
+
Traj A (critic_first, hook_weakness claim) → hook_rewrite.
|
| 347 |
+
Traj B (defender_first, core in hook) → cta_placement.
|
| 348 |
+
"""
|
| 349 |
+
import json as _json
|
| 350 |
+
scripts = _json.loads(open(_SCRIPTS_PATH).read())
|
| 351 |
+
easy_script = next(s for s in scripts if s["script_id"] == "S01")
|
| 352 |
+
# Use first 30 chars of the real script as the "core quote" so it appears in the hook
|
| 353 |
+
hook_text = easy_script["script_text"][:30]
|
| 354 |
+
|
| 355 |
+
ab_env = self._make_ab_env()
|
| 356 |
+
mock_critique = _make_real_critique("C1", "high", "hook_weakness")
|
| 357 |
+
mock_defender = _make_real_defender(core_quote=hook_text)
|
| 358 |
+
|
| 359 |
+
with patch.object(ab_env.env_a.critic, "critique", return_value=mock_critique), \
|
| 360 |
+
patch.object(ab_env.env_a.defender, "defend", return_value=mock_defender), \
|
| 361 |
+
patch.object(ab_env.env_a, "step",
|
| 362 |
+
side_effect=lambda action, **kw: _fake_step_result()), \
|
| 363 |
+
patch.object(ab_env.env_b, "step",
|
| 364 |
+
side_effect=lambda action, **kw: _fake_step_result()):
|
| 365 |
+
ab_env.reset(seed=42)
|
| 366 |
+
|
| 367 |
+
action_a = ab_env._forced_action_a.get("action_type")
|
| 368 |
+
action_b = ab_env._forced_action_b.get("action_type")
|
| 369 |
+
assert action_a == "hook_rewrite", (
|
| 370 |
+
f"CRITIC_FIRST: expected hook_rewrite, got {action_a}"
|
| 371 |
+
)
|
| 372 |
+
assert action_b == "cta_placement", (
|
| 373 |
+
f"DEFENDER_FIRST (core in hook): expected cta_placement, got {action_b}"
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
def test_step_applies_same_action_to_both(self):
|
| 377 |
+
ab_env = self._make_ab_env()
|
| 378 |
+
self._reset_with_mocks(ab_env)
|
| 379 |
+
|
| 380 |
+
step_calls_a: list = []
|
| 381 |
+
step_calls_b: list = []
|
| 382 |
+
|
| 383 |
+
def track_a(action, **kw):
|
| 384 |
+
step_calls_a.append(action)
|
| 385 |
+
return _fake_step_result("A after free step", 0.7, done=True)
|
| 386 |
+
|
| 387 |
+
def track_b(action, **kw):
|
| 388 |
+
step_calls_b.append(action)
|
| 389 |
+
return _fake_step_result("B after free step", 0.6, done=True)
|
| 390 |
+
|
| 391 |
+
free_action = {
|
| 392 |
+
"action_type": "cta_placement",
|
| 393 |
+
"target_section": "cta",
|
| 394 |
+
"instruction": "Move CTA to end.",
|
| 395 |
+
"critique_claim_id": "C1",
|
| 396 |
+
"reasoning": "test",
|
| 397 |
+
}
|
| 398 |
+
with patch.object(ab_env.env_a, "step", side_effect=track_a), \
|
| 399 |
+
patch.object(ab_env.env_b, "step", side_effect=track_b):
|
| 400 |
+
ab_env.step(free_action)
|
| 401 |
+
|
| 402 |
+
assert len(step_calls_a) == 1
|
| 403 |
+
assert len(step_calls_b) == 1
|
| 404 |
+
assert step_calls_a[0]["action_type"] == step_calls_b[0]["action_type"] == "cta_placement"
|
| 405 |
+
|
| 406 |
+
def test_state_returns_correct_delta(self):
|
| 407 |
+
ab_env = self._make_ab_env()
|
| 408 |
+
self._reset_with_mocks(ab_env)
|
| 409 |
+
|
| 410 |
+
# Manually set cumulative rewards to known values
|
| 411 |
+
ab_env._traj_a.cumulative_reward = 0.7
|
| 412 |
+
ab_env._traj_b.cumulative_reward = 0.5
|
| 413 |
+
|
| 414 |
+
state = ab_env.state()
|
| 415 |
+
assert abs(state["delta"] - 0.2) < 1e-9
|
| 416 |
+
assert state["leading_trajectory"] == "A"
|
| 417 |
+
assert "trajectory_a" in state
|
| 418 |
+
assert "trajectory_b" in state
|
viral_script_engine/training/rollout_function.py
CHANGED
|
@@ -272,3 +272,94 @@ def _config_to_prompt(config: dict) -> str:
|
|
| 272 |
f"CURRICULUM NOTES: {config.get('curriculum_notes', '')}\n\n"
|
| 273 |
"Choose your action:\n<|end|>"
|
| 274 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
f"CURRICULUM NOTES: {config.get('curriculum_notes', '')}\n\n"
|
| 273 |
"Choose your action:\n<|end|>"
|
| 274 |
)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
# ---------------------------------------------------------------------------
|
| 278 |
+
# Phase 10 — A/B rollout function
|
| 279 |
+
# ---------------------------------------------------------------------------
|
| 280 |
+
|
| 281 |
+
def _format_ab_observation_prompt(state: dict, max_steps: int) -> str:
|
| 282 |
+
"""Format the A/B observation for the Arbitrator prompt."""
|
| 283 |
+
traj_a = state.get("trajectory_a", {})
|
| 284 |
+
traj_b = state.get("trajectory_b", {})
|
| 285 |
+
delta = state.get("delta", 0.0)
|
| 286 |
+
step_num = state.get("step_num", 1)
|
| 287 |
+
|
| 288 |
+
def _rc_summary(rc: dict) -> str:
|
| 289 |
+
return (
|
| 290 |
+
f"R1={rc.get('r1_hook_strength') or 0.0:.2f} "
|
| 291 |
+
f"R2={rc.get('r2_coherence') or 0.0:.2f} "
|
| 292 |
+
f"R3={rc.get('r3_cultural_alignment') or 0.0:.2f} "
|
| 293 |
+
f"Total={rc.get('total') or 0.0:.2f}"
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
rc_a = traj_a.get("reward_components", {})
|
| 297 |
+
rc_b = traj_b.get("reward_components", {})
|
| 298 |
+
|
| 299 |
+
return (
|
| 300 |
+
f"<|system|>\n{ARBITRATOR_SYSTEM}\n<|end|>\n\n"
|
| 301 |
+
f"<|user|>\n"
|
| 302 |
+
f"TRAJECTORY A (Critic-first approach):\n"
|
| 303 |
+
f"Current script: {traj_a.get('current_script', '')}\n"
|
| 304 |
+
f"Rewards so far: {_rc_summary(rc_a)} Cumulative={traj_a.get('cumulative_reward', 0.0):.3f}\n\n"
|
| 305 |
+
f"TRAJECTORY B (Defender-first approach):\n"
|
| 306 |
+
f"Current script: {traj_b.get('current_script', '')}\n"
|
| 307 |
+
f"Rewards so far: {_rc_summary(rc_b)} Cumulative={traj_b.get('cumulative_reward', 0.0):.3f}\n\n"
|
| 308 |
+
f"Delta (A - B): {delta:.3f}\n"
|
| 309 |
+
f"Step: {step_num}/{max_steps}\n\n"
|
| 310 |
+
"Choose your next action (applied to BOTH trajectories):\n<|end|>"
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def build_ab_rollout_fn(
|
| 315 |
+
ab_env,
|
| 316 |
+
max_steps: int = 5,
|
| 317 |
+
max_new_tokens: int = 256,
|
| 318 |
+
):
|
| 319 |
+
"""
|
| 320 |
+
Rollout function for the A/B environment.
|
| 321 |
+
|
| 322 |
+
The prompt includes both trajectory states so the Arbitrator can see
|
| 323 |
+
how the two paths diverge and learn which starting action leads to
|
| 324 |
+
better cumulative outcomes.
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
def rollout_fn(
|
| 328 |
+
prompts: List[str],
|
| 329 |
+
model,
|
| 330 |
+
tokenizer,
|
| 331 |
+
) -> Tuple[List[str], List[float]]:
|
| 332 |
+
completions: List[str] = []
|
| 333 |
+
rewards: List[float] = []
|
| 334 |
+
|
| 335 |
+
for prompt in prompts:
|
| 336 |
+
state = ab_env.reset()
|
| 337 |
+
episode_parts: List[str] = []
|
| 338 |
+
episode_reward = 0.0
|
| 339 |
+
terminated = False
|
| 340 |
+
|
| 341 |
+
for step in range(max_steps - 1): # step 1 is forced; free steps = max_steps-1
|
| 342 |
+
obs_prompt = _format_ab_observation_prompt(state, max_steps)
|
| 343 |
+
full_prompt = prompt + "\n\n" + obs_prompt
|
| 344 |
+
|
| 345 |
+
raw_output = _model_generate(model, tokenizer, full_prompt, max_new_tokens)
|
| 346 |
+
action = _extract_json_action(raw_output)
|
| 347 |
+
episode_parts.append(raw_output)
|
| 348 |
+
|
| 349 |
+
try:
|
| 350 |
+
state, episode_reward, terminated, _, _ = ab_env.step(action)
|
| 351 |
+
except Exception:
|
| 352 |
+
terminated = True
|
| 353 |
+
|
| 354 |
+
if terminated:
|
| 355 |
+
break
|
| 356 |
+
|
| 357 |
+
if not terminated:
|
| 358 |
+
episode_reward = ab_env.reward()
|
| 359 |
+
|
| 360 |
+
completions.append("\n".join(episode_parts))
|
| 361 |
+
rewards.append(episode_reward)
|
| 362 |
+
|
| 363 |
+
return completions, rewards
|
| 364 |
+
|
| 365 |
+
return rollout_fn
|