Spaces:
Sleeping
feat(phase1): OpenEnv scaffold + R1/R2 rewards — PHASE 1 GATE PASS
Browse files- environment/actions.py: ActionType enum + ArbitratorAction model
- environment/observations.py: RewardComponents (weighted, normalised), DebateRound, Observation
- environment/episode_state.py: mutable episode state dataclass
- environment/env.py: ViralScriptEnv — Gymnasium-compatible reset/step/state, difficulty tiers, anti-gaming wired
- rewards/r1_hook_strength.py: 5-check rule-based hook scorer (promise, curiosity, specificity, front-load, anti-filler)
- rewards/r2_coherence.py: sentence-transformers cosine similarity with 4-range score mapping + embedding cache
- rewards/reward_aggregator.py: catastrophic-drop (>0.2) + action-diversity anti-gaming rules
- agents/rewriter.py: RewriterAgent wrapping LLMBackend with unified diff output
- scripts/run_dummy_episode.py: demo runner with rich output and gate check
- tests/test_rewards.py: 10 tests — R1 (5), R2 (2), aggregator (3)
- tests/test_environment.py: 6 tests — all LLM calls mocked via golden fixtures
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- viral_script_engine/agents/rewriter.py +44 -0
- viral_script_engine/environment/__init__.py +0 -0
- viral_script_engine/environment/actions.py +17 -0
- viral_script_engine/environment/env.py +151 -0
- viral_script_engine/environment/episode_state.py +48 -0
- viral_script_engine/environment/observations.py +60 -0
- viral_script_engine/rewards/__init__.py +0 -0
- viral_script_engine/rewards/base.py +7 -0
- viral_script_engine/rewards/r1_hook_strength.py +107 -0
- viral_script_engine/rewards/r2_coherence.py +48 -0
- viral_script_engine/rewards/reward_aggregator.py +42 -0
- viral_script_engine/scripts/run_dummy_episode.py +149 -0
- viral_script_engine/tests/test_environment.py +107 -0
- viral_script_engine/tests/test_rewards.py +117 -0
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import difflib
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
|
| 4 |
+
from viral_script_engine.agents.llm_backend import LLMBackend
|
| 5 |
+
from viral_script_engine.environment.actions import ArbitratorAction
|
| 6 |
+
|
| 7 |
+
_SYSTEM_PROMPT = (
|
| 8 |
+
"You are a professional script editor for short-form social media video. "
|
| 9 |
+
"Apply ONLY the instruction given. Do not make any other changes. "
|
| 10 |
+
"Do not add new ideas. Do not change the creator's voice or regional language patterns. "
|
| 11 |
+
"Return ONLY the rewritten script text, no commentary."
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RewriteResult(BaseModel):
|
| 16 |
+
rewritten_script: str
|
| 17 |
+
diff: str
|
| 18 |
+
word_count_delta: int
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class RewriterAgent:
|
| 22 |
+
def __init__(self, backend: str = "groq", model_name: str = "llama-3.3-70b-versatile"):
|
| 23 |
+
self.llm = LLMBackend(backend=backend, model_name=model_name)
|
| 24 |
+
|
| 25 |
+
def rewrite(self, current_script: str, action: ArbitratorAction) -> RewriteResult:
|
| 26 |
+
user_prompt = (
|
| 27 |
+
f"CURRENT SCRIPT:\n{current_script}\n\n"
|
| 28 |
+
f"ACTION TYPE: {action.action_type.value}\n"
|
| 29 |
+
f"TARGET SECTION: {action.target_section}\n"
|
| 30 |
+
f"INSTRUCTION: {action.instruction}\n\n"
|
| 31 |
+
"Apply the instruction and return ONLY the rewritten script."
|
| 32 |
+
)
|
| 33 |
+
rewritten = self.llm.generate(_SYSTEM_PROMPT, user_prompt, max_tokens=2048)
|
| 34 |
+
diff_lines = list(difflib.unified_diff(
|
| 35 |
+
current_script.splitlines(keepends=True),
|
| 36 |
+
rewritten.splitlines(keepends=True),
|
| 37 |
+
fromfile="original",
|
| 38 |
+
tofile="rewritten",
|
| 39 |
+
))
|
| 40 |
+
return RewriteResult(
|
| 41 |
+
rewritten_script=rewritten,
|
| 42 |
+
diff="".join(diff_lines),
|
| 43 |
+
word_count_delta=len(rewritten.split()) - len(current_script.split()),
|
| 44 |
+
)
|
|
File without changes
|
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ActionType(str, Enum):
|
| 6 |
+
HOOK_REWRITE = "hook_rewrite"
|
| 7 |
+
SECTION_REORDER = "section_reorder"
|
| 8 |
+
CULTURAL_REF_SUB = "cultural_ref_sub"
|
| 9 |
+
CTA_PLACEMENT = "cta_placement"
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ArbitratorAction(BaseModel):
|
| 13 |
+
action_type: ActionType
|
| 14 |
+
target_section: str # "hook" | "body" | "cta" | "full"
|
| 15 |
+
instruction: str # natural language instruction to the Rewriter
|
| 16 |
+
critique_claim_id: str # which CritiqueClaim this responds to, e.g. "C2"
|
| 17 |
+
reasoning: str # why this action was chosen
|
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import random
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
from viral_script_engine.agents.critic import CriticAgent
|
| 6 |
+
from viral_script_engine.agents.rewriter import RewriterAgent
|
| 7 |
+
from viral_script_engine.environment.actions import ArbitratorAction
|
| 8 |
+
from viral_script_engine.environment.episode_state import EpisodeState
|
| 9 |
+
from viral_script_engine.environment.observations import (
|
| 10 |
+
DebateRound, Observation, RewardComponents,
|
| 11 |
+
)
|
| 12 |
+
from viral_script_engine.rewards.r1_hook_strength import HookStrengthReward
|
| 13 |
+
from viral_script_engine.rewards.r2_coherence import CoherenceReward
|
| 14 |
+
from viral_script_engine.rewards.reward_aggregator import RewardAggregator
|
| 15 |
+
|
| 16 |
+
_TIERS = {
|
| 17 |
+
"easy": ["S01", "S02", "S03", "S04"],
|
| 18 |
+
"medium": ["S05", "S06", "S07"],
|
| 19 |
+
"hard": ["S08", "S09", "S10"],
|
| 20 |
+
"self_generated": [],
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ViralScriptEnv:
|
| 25 |
+
def __init__(
|
| 26 |
+
self,
|
| 27 |
+
scripts_path: str = "data/test_scripts/scripts.json",
|
| 28 |
+
max_steps: int = 5,
|
| 29 |
+
difficulty: str = "easy",
|
| 30 |
+
use_anti_gaming: bool = True,
|
| 31 |
+
):
|
| 32 |
+
self.max_steps = max_steps
|
| 33 |
+
self.difficulty = difficulty
|
| 34 |
+
self.use_anti_gaming = use_anti_gaming
|
| 35 |
+
|
| 36 |
+
with open(scripts_path) as f:
|
| 37 |
+
all_scripts = json.load(f)
|
| 38 |
+
|
| 39 |
+
tier_ids = _TIERS[difficulty]
|
| 40 |
+
self._scripts = [s for s in all_scripts if s["script_id"] in tier_ids]
|
| 41 |
+
self.critic = CriticAgent()
|
| 42 |
+
self.rewriter = RewriterAgent()
|
| 43 |
+
self.r1 = HookStrengthReward()
|
| 44 |
+
self.r2 = CoherenceReward()
|
| 45 |
+
self.aggregator = RewardAggregator()
|
| 46 |
+
self._state: Optional[EpisodeState] = None
|
| 47 |
+
|
| 48 |
+
def reset(self, seed=None, options=None) -> Tuple[dict, dict]:
|
| 49 |
+
if seed is not None:
|
| 50 |
+
random.seed(seed)
|
| 51 |
+
script = random.choice(self._scripts)
|
| 52 |
+
|
| 53 |
+
r1_result = self.r1.score(script["script_text"])
|
| 54 |
+
r2_result = self.r2.score(script["script_text"], script["script_text"])
|
| 55 |
+
initial_rewards = RewardComponents(
|
| 56 |
+
r1_hook_strength=r1_result.score,
|
| 57 |
+
r2_coherence=r2_result.score,
|
| 58 |
+
)
|
| 59 |
+
initial_rewards.compute_total()
|
| 60 |
+
|
| 61 |
+
self._state = EpisodeState.new(
|
| 62 |
+
script=script,
|
| 63 |
+
max_steps=self.max_steps,
|
| 64 |
+
difficulty_level=self.difficulty,
|
| 65 |
+
initial_rewards=initial_rewards,
|
| 66 |
+
)
|
| 67 |
+
return self._build_observation().model_dump(), {}
|
| 68 |
+
|
| 69 |
+
def step(self, action: dict) -> Tuple[dict, float, bool, bool, dict]:
|
| 70 |
+
if self._state is None:
|
| 71 |
+
raise RuntimeError("Call reset() before step()")
|
| 72 |
+
|
| 73 |
+
arb_action = ArbitratorAction(**action)
|
| 74 |
+
|
| 75 |
+
critique = self.critic.critique(
|
| 76 |
+
script=self._state.current_script,
|
| 77 |
+
region=self._state.region,
|
| 78 |
+
platform=self._state.platform,
|
| 79 |
+
niche=self._state.niche,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
rewrite_result = self.rewriter.rewrite(self._state.current_script, arb_action)
|
| 83 |
+
new_script = rewrite_result.rewritten_script
|
| 84 |
+
|
| 85 |
+
r1_result = self.r1.score(new_script)
|
| 86 |
+
r2_result = self.r2.score(self._state.original_script, new_script)
|
| 87 |
+
components = RewardComponents(
|
| 88 |
+
r1_hook_strength=r1_result.score,
|
| 89 |
+
r2_coherence=r2_result.score,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self._state.action_history.append(arb_action.action_type)
|
| 93 |
+
if self.use_anti_gaming:
|
| 94 |
+
components = self.aggregator.compute(
|
| 95 |
+
components, self._state.episode_start_rewards, self._state.action_history
|
| 96 |
+
)
|
| 97 |
+
else:
|
| 98 |
+
components.compute_total()
|
| 99 |
+
|
| 100 |
+
round_ = DebateRound(
|
| 101 |
+
step_num=self._state.step_num,
|
| 102 |
+
critic_claims=critique.claims,
|
| 103 |
+
arbitrator_action=arb_action,
|
| 104 |
+
rewrite_diff=rewrite_result.diff,
|
| 105 |
+
reward_components=components,
|
| 106 |
+
)
|
| 107 |
+
self._state.debate_history.append(round_)
|
| 108 |
+
self._state.current_script = new_script
|
| 109 |
+
self._state.last_reward_components = components
|
| 110 |
+
self._state.step_num += 1
|
| 111 |
+
|
| 112 |
+
terminated = (
|
| 113 |
+
self._state.step_num >= self._state.max_steps
|
| 114 |
+
or components.total >= 0.9
|
| 115 |
+
)
|
| 116 |
+
info = {
|
| 117 |
+
"reward_components": components.model_dump(),
|
| 118 |
+
"anti_gaming_triggered": components.anti_gaming_penalty > 0,
|
| 119 |
+
"penalty_reason": "anti_gaming" if components.anti_gaming_penalty > 0 else None,
|
| 120 |
+
}
|
| 121 |
+
return self._build_observation().model_dump(), components.total, terminated, False, info
|
| 122 |
+
|
| 123 |
+
def state(self) -> dict:
|
| 124 |
+
if self._state is None:
|
| 125 |
+
return {}
|
| 126 |
+
s = self._state
|
| 127 |
+
return {
|
| 128 |
+
"current_script": s.current_script,
|
| 129 |
+
"original_script": s.original_script,
|
| 130 |
+
"debate_history": [r.model_dump() for r in s.debate_history],
|
| 131 |
+
"reward_components": s.last_reward_components.model_dump(),
|
| 132 |
+
"step_num": s.step_num,
|
| 133 |
+
"difficulty_level": s.difficulty_level,
|
| 134 |
+
"episode_id": s.episode_id,
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
def _build_observation(self) -> Observation:
|
| 138 |
+
s = self._state
|
| 139 |
+
return Observation(
|
| 140 |
+
current_script=s.current_script,
|
| 141 |
+
original_script=s.original_script,
|
| 142 |
+
region=s.region,
|
| 143 |
+
platform=s.platform,
|
| 144 |
+
niche=s.niche,
|
| 145 |
+
step_num=s.step_num,
|
| 146 |
+
max_steps=s.max_steps,
|
| 147 |
+
debate_history=s.debate_history,
|
| 148 |
+
reward_components=s.last_reward_components,
|
| 149 |
+
difficulty_level=s.difficulty_level,
|
| 150 |
+
episode_id=s.episode_id,
|
| 151 |
+
)
|
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import uuid
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
from viral_script_engine.environment.actions import ActionType
|
| 7 |
+
from viral_script_engine.environment.observations import DebateRound, RewardComponents
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class EpisodeState:
|
| 12 |
+
episode_id: str
|
| 13 |
+
original_script: str
|
| 14 |
+
current_script: str
|
| 15 |
+
region: str
|
| 16 |
+
platform: str
|
| 17 |
+
niche: str
|
| 18 |
+
step_num: int
|
| 19 |
+
max_steps: int
|
| 20 |
+
debate_history: List[DebateRound]
|
| 21 |
+
episode_start_rewards: RewardComponents
|
| 22 |
+
last_reward_components: RewardComponents
|
| 23 |
+
difficulty_level: str
|
| 24 |
+
action_history: List[ActionType]
|
| 25 |
+
|
| 26 |
+
@classmethod
|
| 27 |
+
def new(
|
| 28 |
+
cls,
|
| 29 |
+
script: dict,
|
| 30 |
+
max_steps: int,
|
| 31 |
+
difficulty_level: str,
|
| 32 |
+
initial_rewards: RewardComponents,
|
| 33 |
+
) -> EpisodeState:
|
| 34 |
+
return cls(
|
| 35 |
+
episode_id=str(uuid.uuid4()),
|
| 36 |
+
original_script=script["script_text"],
|
| 37 |
+
current_script=script["script_text"],
|
| 38 |
+
region=script["region"],
|
| 39 |
+
platform=script["platform"],
|
| 40 |
+
niche=script["niche"],
|
| 41 |
+
step_num=0,
|
| 42 |
+
max_steps=max_steps,
|
| 43 |
+
debate_history=[],
|
| 44 |
+
episode_start_rewards=initial_rewards,
|
| 45 |
+
last_reward_components=initial_rewards,
|
| 46 |
+
difficulty_level=difficulty_level,
|
| 47 |
+
action_history=[],
|
| 48 |
+
)
|
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
from typing import Any, Dict, List, Optional
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
|
| 5 |
+
from viral_script_engine.agents.critic import CritiqueClaim
|
| 6 |
+
from viral_script_engine.environment.actions import ArbitratorAction
|
| 7 |
+
|
| 8 |
+
_WEIGHTS: Dict[str, float] = {
|
| 9 |
+
"r1": 0.25, "r2": 0.20, "r3": 0.20, "r4": 0.20, "r5": 0.15
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class RewardComponents(BaseModel):
|
| 14 |
+
r1_hook_strength: Optional[float] = None
|
| 15 |
+
r2_coherence: Optional[float] = None
|
| 16 |
+
r3_cultural_alignment: Optional[float] = None
|
| 17 |
+
r4_debate_resolution: Optional[float] = None
|
| 18 |
+
r5_defender_preservation: Optional[float] = None
|
| 19 |
+
anti_gaming_penalty: float = 0.0
|
| 20 |
+
total: float = 0.0
|
| 21 |
+
|
| 22 |
+
def compute_total(self) -> float:
|
| 23 |
+
vals = {
|
| 24 |
+
"r1": self.r1_hook_strength,
|
| 25 |
+
"r2": self.r2_coherence,
|
| 26 |
+
"r3": self.r3_cultural_alignment,
|
| 27 |
+
"r4": self.r4_debate_resolution,
|
| 28 |
+
"r5": self.r5_defender_preservation,
|
| 29 |
+
}
|
| 30 |
+
active = {k: v for k, v in vals.items() if v is not None}
|
| 31 |
+
if not active:
|
| 32 |
+
self.total = 0.0
|
| 33 |
+
return 0.0
|
| 34 |
+
norm = sum(_WEIGHTS[k] for k in active)
|
| 35 |
+
weighted = sum(_WEIGHTS[k] * v for k, v in active.items()) / norm
|
| 36 |
+
self.total = max(0.0, min(1.0, weighted - self.anti_gaming_penalty))
|
| 37 |
+
return self.total
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class DebateRound(BaseModel):
|
| 41 |
+
step_num: int
|
| 42 |
+
critic_claims: List[CritiqueClaim]
|
| 43 |
+
defender_response: Optional[Any] = None
|
| 44 |
+
arbitrator_action: Optional[ArbitratorAction] = None
|
| 45 |
+
rewrite_diff: Optional[str] = None
|
| 46 |
+
reward_components: Optional[RewardComponents] = None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Observation(BaseModel):
|
| 50 |
+
current_script: str
|
| 51 |
+
original_script: str
|
| 52 |
+
region: str
|
| 53 |
+
platform: str
|
| 54 |
+
niche: str
|
| 55 |
+
step_num: int
|
| 56 |
+
max_steps: int
|
| 57 |
+
debate_history: List[DebateRound]
|
| 58 |
+
reward_components: RewardComponents
|
| 59 |
+
difficulty_level: str
|
| 60 |
+
episode_id: str
|
|
File without changes
|
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BaseReward(ABC):
|
| 5 |
+
@abstractmethod
|
| 6 |
+
def score(self, *args, **kwargs):
|
| 7 |
+
pass
|
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
from viral_script_engine.rewards.base import BaseReward
|
| 6 |
+
|
| 7 |
+
_DEAD_OPENERS = [
|
| 8 |
+
"hey guys", "welcome back", "today i want to", "so today",
|
| 9 |
+
"in this video", "what's up everyone", "hey everyone",
|
| 10 |
+
"guys today", "hello everyone", "so basically",
|
| 11 |
+
]
|
| 12 |
+
|
| 13 |
+
_COMMON_WORDS = {
|
| 14 |
+
'i', 'the', 'a', 'an', 'my', 'your', 'its', 'it', 'is', 'are',
|
| 15 |
+
'was', 'were', 'be', 'been', "i've", "i'm", "it's", "here's",
|
| 16 |
+
'today', 'and', 'but', 'so', 'that', 'this', 'these', 'those',
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class HookRewardResult:
|
| 22 |
+
score: float
|
| 23 |
+
checks_passed: int
|
| 24 |
+
check_details: Dict[str, bool] = field(default_factory=dict)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _extract_hook(text: str) -> str:
|
| 28 |
+
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
|
| 29 |
+
hook = " ".join(sentences[:3]) if len(sentences) >= 3 else text
|
| 30 |
+
words = hook.split()
|
| 31 |
+
return " ".join(words[:50]) if len(words) > 50 else hook
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class HookStrengthReward(BaseReward):
|
| 35 |
+
def score(self, script: str) -> HookRewardResult:
|
| 36 |
+
hook = _extract_hook(script)
|
| 37 |
+
hook_lower = hook.lower()
|
| 38 |
+
first_sentence = re.split(r'(?<=[.!?])\s+', hook.strip())[0].lower()
|
| 39 |
+
|
| 40 |
+
checks = {
|
| 41 |
+
"promise": self._check_promise(hook_lower),
|
| 42 |
+
"curiosity": self._check_curiosity(hook_lower),
|
| 43 |
+
"specificity": self._check_specificity(hook),
|
| 44 |
+
"front_load": self._check_front_load(first_sentence),
|
| 45 |
+
"anti_filler": self._check_anti_filler(hook_lower),
|
| 46 |
+
}
|
| 47 |
+
passed = sum(checks.values())
|
| 48 |
+
return HookRewardResult(
|
| 49 |
+
score=min(1.0, max(0.0, passed / 5)),
|
| 50 |
+
checks_passed=passed,
|
| 51 |
+
check_details=checks,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def _check_promise(self, hook: str) -> bool:
|
| 55 |
+
bad = ["hey guys", "welcome back", "today we're talking about"]
|
| 56 |
+
if any(b in hook for b in bad):
|
| 57 |
+
return False
|
| 58 |
+
patterns = [
|
| 59 |
+
r'\d',
|
| 60 |
+
r'\bhow to\b',
|
| 61 |
+
r'\bwhy\b',
|
| 62 |
+
r'\bwhat happens when\b',
|
| 63 |
+
r'\bi made\b',
|
| 64 |
+
]
|
| 65 |
+
return any(re.search(p, hook) for p in patterns)
|
| 66 |
+
|
| 67 |
+
def _check_curiosity(self, hook: str) -> bool:
|
| 68 |
+
patterns = [
|
| 69 |
+
r'\?',
|
| 70 |
+
r"but here'?s the thing",
|
| 71 |
+
r"most \w+ don'?t know",
|
| 72 |
+
r"the secret is",
|
| 73 |
+
r"nobody tells you",
|
| 74 |
+
r"most people don'?t",
|
| 75 |
+
]
|
| 76 |
+
if not any(re.search(p, hook) for p in patterns):
|
| 77 |
+
return False
|
| 78 |
+
first = re.split(r'(?<=[.!?])\s+', hook)[0]
|
| 79 |
+
if re.search(r'\?', first) and re.search(r'\b(is|are|was|were|means|equals)\b', first):
|
| 80 |
+
return False
|
| 81 |
+
return True
|
| 82 |
+
|
| 83 |
+
def _check_specificity(self, hook: str) -> bool:
|
| 84 |
+
if re.search(r'\d', hook):
|
| 85 |
+
return True
|
| 86 |
+
sentences = re.split(r'(?<=[.!?])\s+', hook)
|
| 87 |
+
for sentence in sentences:
|
| 88 |
+
words = sentence.split()[1:]
|
| 89 |
+
for w in words:
|
| 90 |
+
clean = w.strip('.,!?;:\'"')
|
| 91 |
+
if clean and clean[0].isupper() and clean.lower() not in _COMMON_WORDS:
|
| 92 |
+
return True
|
| 93 |
+
return False
|
| 94 |
+
|
| 95 |
+
def _check_front_load(self, first_sentence: str) -> bool:
|
| 96 |
+
signals = 0
|
| 97 |
+
if re.search(r'\d', first_sentence):
|
| 98 |
+
signals += 1
|
| 99 |
+
promise_patterns = [r'\bhow to\b', r'\bwhy\b', r'\bwhat happens when\b', r'\bi made\b']
|
| 100 |
+
if any(re.search(p, first_sentence) for p in promise_patterns):
|
| 101 |
+
signals += 1
|
| 102 |
+
if re.search(r'\?', first_sentence):
|
| 103 |
+
signals += 1
|
| 104 |
+
return signals >= 2
|
| 105 |
+
|
| 106 |
+
def _check_anti_filler(self, hook: str) -> bool:
|
| 107 |
+
return not any(hook.startswith(opener) for opener in _DEAD_OPENERS)
|
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
from viral_script_engine.rewards.base import BaseReward
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class CoherenceRewardResult:
|
| 9 |
+
score: float
|
| 10 |
+
raw_similarity: float
|
| 11 |
+
interpretation: str
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class CoherenceReward(BaseReward):
|
| 15 |
+
_cache: dict = {}
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self._model = None
|
| 19 |
+
|
| 20 |
+
def _get_model(self):
|
| 21 |
+
if self._model is None:
|
| 22 |
+
from sentence_transformers import SentenceTransformer
|
| 23 |
+
self._model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 24 |
+
return self._model
|
| 25 |
+
|
| 26 |
+
def _embed(self, text: str):
|
| 27 |
+
key = hashlib.sha256(text.encode()).hexdigest()
|
| 28 |
+
if key not in self._cache:
|
| 29 |
+
self._cache[key] = self._get_model().encode(text, convert_to_tensor=True)
|
| 30 |
+
return self._cache[key]
|
| 31 |
+
|
| 32 |
+
def _cosine_sim(self, a, b) -> float:
|
| 33 |
+
from sentence_transformers.util import cos_sim
|
| 34 |
+
return float(cos_sim(a, b)[0][0])
|
| 35 |
+
|
| 36 |
+
def score(self, original: str, rewritten: str) -> CoherenceRewardResult:
|
| 37 |
+
sim = self._cosine_sim(self._embed(original), self._embed(rewritten))
|
| 38 |
+
if sim > 0.95:
|
| 39 |
+
score, interpretation = 0.8, "barely_changed"
|
| 40 |
+
elif sim >= 0.80:
|
| 41 |
+
score = 0.5 + (sim - 0.80) / 0.15 * 0.5
|
| 42 |
+
interpretation = "good_coherence"
|
| 43 |
+
elif sim >= 0.65:
|
| 44 |
+
score = (sim - 0.65) / 0.15 * 0.5
|
| 45 |
+
interpretation = "moderate_drift"
|
| 46 |
+
else:
|
| 47 |
+
score, interpretation = 0.0, "drifted_too_far"
|
| 48 |
+
return CoherenceRewardResult(score=score, raw_similarity=sim, interpretation=interpretation)
|
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from viral_script_engine.environment.actions import ActionType
|
| 5 |
+
from viral_script_engine.environment.observations import RewardComponents
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
_COMPONENT_FIELDS = [
|
| 10 |
+
"r1_hook_strength", "r2_coherence", "r3_cultural_alignment",
|
| 11 |
+
"r4_debate_resolution", "r5_defender_preservation",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RewardAggregator:
|
| 16 |
+
def compute(
|
| 17 |
+
self,
|
| 18 |
+
components: RewardComponents,
|
| 19 |
+
episode_start_components: RewardComponents,
|
| 20 |
+
action_history: List[ActionType],
|
| 21 |
+
) -> RewardComponents:
|
| 22 |
+
components.compute_total()
|
| 23 |
+
|
| 24 |
+
# Anti-gaming rule 1: catastrophic drop (>0.2 drop in any component)
|
| 25 |
+
for field in _COMPONENT_FIELDS:
|
| 26 |
+
curr = getattr(components, field)
|
| 27 |
+
start = getattr(episode_start_components, field)
|
| 28 |
+
if curr is not None and start is not None and curr < start - 0.2:
|
| 29 |
+
logger.warning("Catastrophic drop in %s: %.3f -> %.3f", field, start, curr)
|
| 30 |
+
components.total = 0.0
|
| 31 |
+
components.anti_gaming_penalty = start - curr
|
| 32 |
+
return components
|
| 33 |
+
|
| 34 |
+
# Anti-gaming rule 2: action diversity (last 3 same ActionType)
|
| 35 |
+
penalty = 0.0
|
| 36 |
+
if len(action_history) >= 3 and len(set(action_history[-3:])) == 1:
|
| 37 |
+
penalty = 0.15
|
| 38 |
+
logger.warning("Action diversity penalty: last 3 actions all %s", action_history[-1])
|
| 39 |
+
|
| 40 |
+
components.anti_gaming_penalty = penalty
|
| 41 |
+
components.total = max(0.0, min(1.0, components.total - penalty))
|
| 42 |
+
return components
|
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import argparse
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
import sys
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
from rich.console import Console
|
| 10 |
+
from rich.panel import Panel
|
| 11 |
+
from rich.table import Table
|
| 12 |
+
from rich import box
|
| 13 |
+
|
| 14 |
+
load_dotenv()
|
| 15 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 16 |
+
|
| 17 |
+
from viral_script_engine.environment.actions import ActionType
|
| 18 |
+
from viral_script_engine.environment.env import ViralScriptEnv
|
| 19 |
+
|
| 20 |
+
console = Console()
|
| 21 |
+
BASE_DIR = Path(__file__).parent.parent
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def build_random_action(action_type: ActionType) -> dict:
|
| 25 |
+
labels = {
|
| 26 |
+
ActionType.HOOK_REWRITE: ("hook", "Rewrite the hook to open with a specific number or bold claim."),
|
| 27 |
+
ActionType.SECTION_REORDER: ("body", "Move the strongest point to immediately follow the hook."),
|
| 28 |
+
ActionType.CULTURAL_REF_SUB: ("full", "Replace any generic references with locally relevant ones."),
|
| 29 |
+
ActionType.CTA_PLACEMENT: ("cta", "Move the call-to-action earlier, before the 80% mark."),
|
| 30 |
+
}
|
| 31 |
+
section, instruction = labels[action_type]
|
| 32 |
+
return {
|
| 33 |
+
"action_type": action_type.value,
|
| 34 |
+
"target_section": section,
|
| 35 |
+
"instruction": instruction,
|
| 36 |
+
"critique_claim_id": "C1",
|
| 37 |
+
"reasoning": f"Demo run: applying {action_type.value}",
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def run_episode(difficulty: str, steps: int, verbose: bool) -> dict:
|
| 42 |
+
scripts_path = str(BASE_DIR / "data" / "test_scripts" / "scripts.json")
|
| 43 |
+
env = ViralScriptEnv(scripts_path=scripts_path, max_steps=steps, difficulty=difficulty)
|
| 44 |
+
|
| 45 |
+
obs, _ = env.reset()
|
| 46 |
+
console.print(Panel(
|
| 47 |
+
f"[bold]Episode started[/bold]\n"
|
| 48 |
+
f"Difficulty: {difficulty} | Max steps: {steps}\n"
|
| 49 |
+
f"Region: {obs['region']} | Platform: {obs['platform']} | Niche: {obs['niche']}\n"
|
| 50 |
+
f"Episode ID: {obs['episode_id']}",
|
| 51 |
+
title="[bold blue]Phase 1 Demo Episode[/bold blue]",
|
| 52 |
+
border_style="blue",
|
| 53 |
+
))
|
| 54 |
+
|
| 55 |
+
episode_log = {
|
| 56 |
+
"episode_id": obs["episode_id"],
|
| 57 |
+
"difficulty": difficulty,
|
| 58 |
+
"steps": [],
|
| 59 |
+
"final_state": None,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
for step_num in range(steps):
|
| 63 |
+
action_type = random.choice(list(ActionType))
|
| 64 |
+
action = build_random_action(action_type)
|
| 65 |
+
|
| 66 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 67 |
+
rc = info["reward_components"]
|
| 68 |
+
|
| 69 |
+
if verbose:
|
| 70 |
+
t = Table(title=f"Step {step_num + 1} — {action_type.value}", box=box.SIMPLE_HEAD)
|
| 71 |
+
t.add_column("Metric", style="cyan", min_width=22)
|
| 72 |
+
t.add_column("Value", min_width=12)
|
| 73 |
+
r1_val = rc.get("r1_hook_strength")
|
| 74 |
+
r2_val = rc.get("r2_coherence")
|
| 75 |
+
t.add_row("R1 Hook Strength", f"{r1_val:.3f}" if r1_val is not None else "N/A")
|
| 76 |
+
t.add_row("R2 Coherence", f"{r2_val:.3f}" if r2_val is not None else "N/A")
|
| 77 |
+
t.add_row("Total Reward", f"[bold]{reward:.3f}[/bold]")
|
| 78 |
+
if info.get("anti_gaming_triggered"):
|
| 79 |
+
t.add_row("Anti-Gaming Penalty", f"[red]{rc.get('anti_gaming_penalty', 0):.3f}[/red]")
|
| 80 |
+
t.add_row("Penalty Reason", f"[red]{info.get('penalty_reason', '')}[/red]")
|
| 81 |
+
t.add_row("Terminated", str(terminated))
|
| 82 |
+
console.print(t)
|
| 83 |
+
|
| 84 |
+
if obs.get("debate_history"):
|
| 85 |
+
latest = obs["debate_history"][-1]
|
| 86 |
+
if latest.get("rewrite_diff"):
|
| 87 |
+
console.print(Panel(
|
| 88 |
+
latest["rewrite_diff"][:600] or "(no diff)",
|
| 89 |
+
title="Script Diff",
|
| 90 |
+
border_style="yellow",
|
| 91 |
+
))
|
| 92 |
+
|
| 93 |
+
episode_log["steps"].append({
|
| 94 |
+
"step": step_num + 1,
|
| 95 |
+
"action": action,
|
| 96 |
+
"reward": reward,
|
| 97 |
+
"reward_components": rc,
|
| 98 |
+
"anti_gaming": info.get("anti_gaming_triggered", False),
|
| 99 |
+
"terminated": terminated,
|
| 100 |
+
})
|
| 101 |
+
|
| 102 |
+
if terminated:
|
| 103 |
+
break
|
| 104 |
+
|
| 105 |
+
final_state = env.state()
|
| 106 |
+
episode_log["final_state"] = final_state
|
| 107 |
+
final_rc = final_state["reward_components"]
|
| 108 |
+
|
| 109 |
+
console.print(Panel(
|
| 110 |
+
f"[bold green]Final Reward:[/bold green] {final_rc.get('total', 0):.3f}\n"
|
| 111 |
+
f"R1 Hook Strength: {final_rc.get('r1_hook_strength', 'N/A')}\n"
|
| 112 |
+
f"R2 Coherence: {final_rc.get('r2_coherence', 'N/A')}\n"
|
| 113 |
+
f"Steps completed: {final_state['step_num']}",
|
| 114 |
+
title="Episode Summary",
|
| 115 |
+
border_style="green",
|
| 116 |
+
))
|
| 117 |
+
|
| 118 |
+
return episode_log
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def main():
|
| 122 |
+
parser = argparse.ArgumentParser(description="Run Phase 1 dummy episode")
|
| 123 |
+
parser.add_argument("--difficulty", default="easy", choices=["easy", "medium", "hard"])
|
| 124 |
+
parser.add_argument("--steps", type=int, default=3)
|
| 125 |
+
parser.add_argument("--verbose", action="store_true")
|
| 126 |
+
args = parser.parse_args()
|
| 127 |
+
|
| 128 |
+
episode_log = run_episode(args.difficulty, args.steps, args.verbose)
|
| 129 |
+
|
| 130 |
+
logs_dir = BASE_DIR / "logs"
|
| 131 |
+
logs_dir.mkdir(exist_ok=True)
|
| 132 |
+
log_path = logs_dir / f"episode_{episode_log['episode_id']}.json"
|
| 133 |
+
with open(log_path, "w") as f:
|
| 134 |
+
json.dump(episode_log, f, indent=2, default=str)
|
| 135 |
+
console.print(f"[dim]Episode log saved -> {log_path}[/dim]")
|
| 136 |
+
|
| 137 |
+
final_rc = episode_log["final_state"]["reward_components"]
|
| 138 |
+
gate_pass = (
|
| 139 |
+
final_rc.get("r1_hook_strength") is not None
|
| 140 |
+
and final_rc.get("r2_coherence") is not None
|
| 141 |
+
and log_path.exists()
|
| 142 |
+
)
|
| 143 |
+
style = "bold green" if gate_pass else "bold red"
|
| 144 |
+
label = f"PHASE 1 GATE: {'PASS' if gate_pass else 'FAIL'}"
|
| 145 |
+
console.print(Panel(f"[{style}]{label}[/{style}]", border_style="green" if gate_pass else "red"))
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
main()
|
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from unittest.mock import MagicMock, patch
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from viral_script_engine.agents.critic import CritiqueOutput, CritiqueClaim
|
| 8 |
+
from viral_script_engine.agents.rewriter import RewriteResult
|
| 9 |
+
from viral_script_engine.environment.actions import ActionType, ArbitratorAction
|
| 10 |
+
|
| 11 |
+
FIXTURE_DIR = Path(__file__).parent.parent / "data" / "golden_fixtures"
|
| 12 |
+
SCRIPTS_PATH = str(Path(__file__).parent.parent / "data" / "test_scripts" / "scripts.json")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def load_fixture(script_id: str) -> dict:
|
| 16 |
+
with open(FIXTURE_DIR / f"fixture_{script_id}.json") as f:
|
| 17 |
+
return json.load(f)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def make_mock_critique() -> CritiqueOutput:
|
| 21 |
+
fixture = load_fixture("S01")
|
| 22 |
+
claims = [CritiqueClaim(**c) for c in fixture["critique"]["claims"]]
|
| 23 |
+
return CritiqueOutput(
|
| 24 |
+
claims=claims,
|
| 25 |
+
overall_severity=fixture["critique"]["overall_severity"],
|
| 26 |
+
raw_response=fixture["critique"]["raw_response"],
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def make_mock_rewrite(current_script: str, action: ArbitratorAction) -> RewriteResult:
|
| 31 |
+
return RewriteResult(
|
| 32 |
+
rewritten_script=current_script + " [REWRITTEN]",
|
| 33 |
+
diff="@@ diff @@",
|
| 34 |
+
word_count_delta=1,
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
SAMPLE_ACTION = {
|
| 39 |
+
"action_type": ActionType.HOOK_REWRITE.value,
|
| 40 |
+
"target_section": "hook",
|
| 41 |
+
"instruction": "Make the hook more attention-grabbing with a specific number.",
|
| 42 |
+
"critique_claim_id": "C1",
|
| 43 |
+
"reasoning": "Hook is weak per C1",
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@pytest.fixture
|
| 48 |
+
def env():
|
| 49 |
+
with (
|
| 50 |
+
patch("viral_script_engine.environment.env.CriticAgent") as mock_critic_cls,
|
| 51 |
+
patch("viral_script_engine.environment.env.RewriterAgent") as mock_rewriter_cls,
|
| 52 |
+
):
|
| 53 |
+
mock_critic = MagicMock()
|
| 54 |
+
mock_critic.critique.return_value = make_mock_critique()
|
| 55 |
+
mock_critic_cls.return_value = mock_critic
|
| 56 |
+
|
| 57 |
+
mock_rewriter = MagicMock()
|
| 58 |
+
mock_rewriter.rewrite.side_effect = make_mock_rewrite
|
| 59 |
+
mock_rewriter_cls.return_value = mock_rewriter
|
| 60 |
+
|
| 61 |
+
from viral_script_engine.environment.env import ViralScriptEnv
|
| 62 |
+
yield ViralScriptEnv(scripts_path=SCRIPTS_PATH, max_steps=5, difficulty="easy")
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def test_reset_returns_valid_observation(env):
|
| 66 |
+
obs, info = env.reset(seed=42)
|
| 67 |
+
assert "current_script" in obs
|
| 68 |
+
assert obs["step_num"] == 0
|
| 69 |
+
assert obs["max_steps"] == 5
|
| 70 |
+
assert obs["reward_components"]["r1_hook_strength"] is not None
|
| 71 |
+
assert obs["reward_components"]["r2_coherence"] is not None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def test_step_completes_without_error(env):
|
| 75 |
+
env.reset(seed=42)
|
| 76 |
+
obs, reward, terminated, truncated, info = env.step(SAMPLE_ACTION)
|
| 77 |
+
assert isinstance(reward, float)
|
| 78 |
+
assert "reward_components" in info
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def test_step_increments_step_num(env):
|
| 82 |
+
env.reset(seed=42)
|
| 83 |
+
obs, *_ = env.step(SAMPLE_ACTION)
|
| 84 |
+
assert obs["step_num"] == 1
|
| 85 |
+
obs, *_ = env.step(SAMPLE_ACTION)
|
| 86 |
+
assert obs["step_num"] == 2
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def test_anti_gaming_penalty_fires_on_repeated_action(env):
|
| 90 |
+
env.reset(seed=42)
|
| 91 |
+
for _ in range(3):
|
| 92 |
+
obs, reward, _, _, info = env.step(SAMPLE_ACTION)
|
| 93 |
+
assert info["anti_gaming_triggered"]
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def test_episode_terminates_at_max_steps(env):
|
| 97 |
+
env.reset(seed=42)
|
| 98 |
+
terminated = False
|
| 99 |
+
for _ in range(5):
|
| 100 |
+
obs, reward, terminated, truncated, info = env.step(SAMPLE_ACTION)
|
| 101 |
+
assert terminated
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def test_reward_clipped_to_0_1(env):
|
| 105 |
+
env.reset(seed=42)
|
| 106 |
+
_, reward, _, _, _ = env.step(SAMPLE_ACTION)
|
| 107 |
+
assert 0.0 <= reward <= 1.0
|
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from viral_script_engine.rewards.r1_hook_strength import HookStrengthReward
|
| 3 |
+
from viral_script_engine.rewards.r2_coherence import CoherenceReward
|
| 4 |
+
from viral_script_engine.rewards.reward_aggregator import RewardAggregator
|
| 5 |
+
from viral_script_engine.environment.observations import RewardComponents
|
| 6 |
+
from viral_script_engine.environment.actions import ActionType
|
| 7 |
+
|
| 8 |
+
# ── R1 test hooks ─────────────────────────────────────────────────────────────
|
| 9 |
+
HOOK_HIGH_1 = (
|
| 10 |
+
"I made $10,000 in 30 days with 3 crypto strategies. "
|
| 11 |
+
"Here's the secret most people don't know. "
|
| 12 |
+
"This completely changed how I invest."
|
| 13 |
+
)
|
| 14 |
+
HOOK_HIGH_2 = (
|
| 15 |
+
"Why 95% of people fail at losing weight in 2024. "
|
| 16 |
+
"Most people don't know this simple truth. "
|
| 17 |
+
"It's not about calories at all."
|
| 18 |
+
)
|
| 19 |
+
HOOK_LOW_1 = (
|
| 20 |
+
"Hey guys, welcome back to my channel! "
|
| 21 |
+
"Today I want to talk about some stuff. "
|
| 22 |
+
"It's going to be super interesting!"
|
| 23 |
+
)
|
| 24 |
+
HOOK_LOW_2 = (
|
| 25 |
+
"So basically today I'm going to talk about fitness. "
|
| 26 |
+
"It's really important for everyone. "
|
| 27 |
+
"Let's get started with some tips."
|
| 28 |
+
)
|
| 29 |
+
HOOK_EDGE = (
|
| 30 |
+
"What nobody tells you about starting a business in India. "
|
| 31 |
+
"I found out the hard way. "
|
| 32 |
+
"Here's my experience."
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@pytest.fixture
|
| 37 |
+
def r1():
|
| 38 |
+
return HookStrengthReward()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@pytest.fixture
|
| 42 |
+
def r2():
|
| 43 |
+
return CoherenceReward()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
@pytest.fixture
|
| 47 |
+
def aggregator():
|
| 48 |
+
return RewardAggregator()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
# ── R1 tests ──────────────────────────────────────────────────────────────────
|
| 52 |
+
def test_r1_high_score_1(r1):
|
| 53 |
+
result = r1.score(HOOK_HIGH_1)
|
| 54 |
+
assert result.score > 0.8
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def test_r1_high_score_2(r1):
|
| 58 |
+
result = r1.score(HOOK_HIGH_2)
|
| 59 |
+
assert result.score > 0.8
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def test_r1_low_score_1(r1):
|
| 63 |
+
result = r1.score(HOOK_LOW_1)
|
| 64 |
+
assert result.score < 0.3
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def test_r1_low_score_2(r1):
|
| 68 |
+
result = r1.score(HOOK_LOW_2)
|
| 69 |
+
assert result.score < 0.3
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def test_r1_edge_case(r1):
|
| 73 |
+
result = r1.score(HOOK_EDGE)
|
| 74 |
+
assert 0.3 <= result.score <= 0.7
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ── R2 tests ──────────────────────────────────────────────────────────────────
|
| 78 |
+
def test_r2_identical_strings(r2):
|
| 79 |
+
text = "This is a test script for the viral script engine."
|
| 80 |
+
result = r2.score(text, text)
|
| 81 |
+
assert result.score == 0.8
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def test_r2_different_strings(r2):
|
| 85 |
+
orig = "I made $10,000 with crypto in 30 days using these 3 strategies."
|
| 86 |
+
diff = "The history of ancient Rome spans over a thousand years of conquest."
|
| 87 |
+
result = r2.score(orig, diff)
|
| 88 |
+
assert result.score == 0.0
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ── Aggregator tests ──────────────────────────────────────────────────────────
|
| 92 |
+
def test_aggregator_catastrophic_drop(aggregator):
|
| 93 |
+
start = RewardComponents(r1_hook_strength=0.8, r2_coherence=0.7)
|
| 94 |
+
start.compute_total()
|
| 95 |
+
current = RewardComponents(r1_hook_strength=0.3, r2_coherence=0.7)
|
| 96 |
+
result = aggregator.compute(current, start, [ActionType.HOOK_REWRITE])
|
| 97 |
+
assert result.total == 0.0
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def test_aggregator_diversity_penalty(aggregator):
|
| 101 |
+
start = RewardComponents(r1_hook_strength=0.6, r2_coherence=0.6)
|
| 102 |
+
start.compute_total()
|
| 103 |
+
current = RewardComponents(r1_hook_strength=0.7, r2_coherence=0.7)
|
| 104 |
+
history = [ActionType.HOOK_REWRITE, ActionType.HOOK_REWRITE, ActionType.HOOK_REWRITE]
|
| 105 |
+
result = aggregator.compute(current, start, history)
|
| 106 |
+
assert result.anti_gaming_penalty == 0.15
|
| 107 |
+
assert result.total < 0.7
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def test_aggregator_no_penalty(aggregator):
|
| 111 |
+
start = RewardComponents(r1_hook_strength=0.6, r2_coherence=0.6)
|
| 112 |
+
start.compute_total()
|
| 113 |
+
current = RewardComponents(r1_hook_strength=0.7, r2_coherence=0.7)
|
| 114 |
+
history = [ActionType.HOOK_REWRITE, ActionType.CTA_PLACEMENT, ActionType.SECTION_REORDER]
|
| 115 |
+
result = aggregator.compute(current, start, history)
|
| 116 |
+
assert result.anti_gaming_penalty == 0.0
|
| 117 |
+
assert result.total > 0
|