vajeeda Claude Sonnet 4.6 commited on
Commit
41ea373
·
1 Parent(s): 5fcd0ee

feat(phase1): OpenEnv scaffold + R1/R2 rewards — PHASE 1 GATE PASS

Browse files

- environment/actions.py: ActionType enum + ArbitratorAction model
- environment/observations.py: RewardComponents (weighted, normalised), DebateRound, Observation
- environment/episode_state.py: mutable episode state dataclass
- environment/env.py: ViralScriptEnv — Gymnasium-compatible reset/step/state, difficulty tiers, anti-gaming wired
- rewards/r1_hook_strength.py: 5-check rule-based hook scorer (promise, curiosity, specificity, front-load, anti-filler)
- rewards/r2_coherence.py: sentence-transformers cosine similarity with 4-range score mapping + embedding cache
- rewards/reward_aggregator.py: catastrophic-drop (>0.2) + action-diversity anti-gaming rules
- agents/rewriter.py: RewriterAgent wrapping LLMBackend with unified diff output
- scripts/run_dummy_episode.py: demo runner with rich output and gate check
- tests/test_rewards.py: 10 tests — R1 (5), R2 (2), aggregator (3)
- tests/test_environment.py: 6 tests — all LLM calls mocked via golden fixtures

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

viral_script_engine/agents/rewriter.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import difflib
2
+ from pydantic import BaseModel
3
+
4
+ from viral_script_engine.agents.llm_backend import LLMBackend
5
+ from viral_script_engine.environment.actions import ArbitratorAction
6
+
7
+ _SYSTEM_PROMPT = (
8
+ "You are a professional script editor for short-form social media video. "
9
+ "Apply ONLY the instruction given. Do not make any other changes. "
10
+ "Do not add new ideas. Do not change the creator's voice or regional language patterns. "
11
+ "Return ONLY the rewritten script text, no commentary."
12
+ )
13
+
14
+
15
+ class RewriteResult(BaseModel):
16
+ rewritten_script: str
17
+ diff: str
18
+ word_count_delta: int
19
+
20
+
21
+ class RewriterAgent:
22
+ def __init__(self, backend: str = "groq", model_name: str = "llama-3.3-70b-versatile"):
23
+ self.llm = LLMBackend(backend=backend, model_name=model_name)
24
+
25
+ def rewrite(self, current_script: str, action: ArbitratorAction) -> RewriteResult:
26
+ user_prompt = (
27
+ f"CURRENT SCRIPT:\n{current_script}\n\n"
28
+ f"ACTION TYPE: {action.action_type.value}\n"
29
+ f"TARGET SECTION: {action.target_section}\n"
30
+ f"INSTRUCTION: {action.instruction}\n\n"
31
+ "Apply the instruction and return ONLY the rewritten script."
32
+ )
33
+ rewritten = self.llm.generate(_SYSTEM_PROMPT, user_prompt, max_tokens=2048)
34
+ diff_lines = list(difflib.unified_diff(
35
+ current_script.splitlines(keepends=True),
36
+ rewritten.splitlines(keepends=True),
37
+ fromfile="original",
38
+ tofile="rewritten",
39
+ ))
40
+ return RewriteResult(
41
+ rewritten_script=rewritten,
42
+ diff="".join(diff_lines),
43
+ word_count_delta=len(rewritten.split()) - len(current_script.split()),
44
+ )
viral_script_engine/environment/__init__.py ADDED
File without changes
viral_script_engine/environment/actions.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from pydantic import BaseModel
3
+
4
+
5
+ class ActionType(str, Enum):
6
+ HOOK_REWRITE = "hook_rewrite"
7
+ SECTION_REORDER = "section_reorder"
8
+ CULTURAL_REF_SUB = "cultural_ref_sub"
9
+ CTA_PLACEMENT = "cta_placement"
10
+
11
+
12
+ class ArbitratorAction(BaseModel):
13
+ action_type: ActionType
14
+ target_section: str # "hook" | "body" | "cta" | "full"
15
+ instruction: str # natural language instruction to the Rewriter
16
+ critique_claim_id: str # which CritiqueClaim this responds to, e.g. "C2"
17
+ reasoning: str # why this action was chosen
viral_script_engine/environment/env.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from typing import Optional, Tuple
4
+
5
+ from viral_script_engine.agents.critic import CriticAgent
6
+ from viral_script_engine.agents.rewriter import RewriterAgent
7
+ from viral_script_engine.environment.actions import ArbitratorAction
8
+ from viral_script_engine.environment.episode_state import EpisodeState
9
+ from viral_script_engine.environment.observations import (
10
+ DebateRound, Observation, RewardComponents,
11
+ )
12
+ from viral_script_engine.rewards.r1_hook_strength import HookStrengthReward
13
+ from viral_script_engine.rewards.r2_coherence import CoherenceReward
14
+ from viral_script_engine.rewards.reward_aggregator import RewardAggregator
15
+
16
+ _TIERS = {
17
+ "easy": ["S01", "S02", "S03", "S04"],
18
+ "medium": ["S05", "S06", "S07"],
19
+ "hard": ["S08", "S09", "S10"],
20
+ "self_generated": [],
21
+ }
22
+
23
+
24
+ class ViralScriptEnv:
25
+ def __init__(
26
+ self,
27
+ scripts_path: str = "data/test_scripts/scripts.json",
28
+ max_steps: int = 5,
29
+ difficulty: str = "easy",
30
+ use_anti_gaming: bool = True,
31
+ ):
32
+ self.max_steps = max_steps
33
+ self.difficulty = difficulty
34
+ self.use_anti_gaming = use_anti_gaming
35
+
36
+ with open(scripts_path) as f:
37
+ all_scripts = json.load(f)
38
+
39
+ tier_ids = _TIERS[difficulty]
40
+ self._scripts = [s for s in all_scripts if s["script_id"] in tier_ids]
41
+ self.critic = CriticAgent()
42
+ self.rewriter = RewriterAgent()
43
+ self.r1 = HookStrengthReward()
44
+ self.r2 = CoherenceReward()
45
+ self.aggregator = RewardAggregator()
46
+ self._state: Optional[EpisodeState] = None
47
+
48
+ def reset(self, seed=None, options=None) -> Tuple[dict, dict]:
49
+ if seed is not None:
50
+ random.seed(seed)
51
+ script = random.choice(self._scripts)
52
+
53
+ r1_result = self.r1.score(script["script_text"])
54
+ r2_result = self.r2.score(script["script_text"], script["script_text"])
55
+ initial_rewards = RewardComponents(
56
+ r1_hook_strength=r1_result.score,
57
+ r2_coherence=r2_result.score,
58
+ )
59
+ initial_rewards.compute_total()
60
+
61
+ self._state = EpisodeState.new(
62
+ script=script,
63
+ max_steps=self.max_steps,
64
+ difficulty_level=self.difficulty,
65
+ initial_rewards=initial_rewards,
66
+ )
67
+ return self._build_observation().model_dump(), {}
68
+
69
+ def step(self, action: dict) -> Tuple[dict, float, bool, bool, dict]:
70
+ if self._state is None:
71
+ raise RuntimeError("Call reset() before step()")
72
+
73
+ arb_action = ArbitratorAction(**action)
74
+
75
+ critique = self.critic.critique(
76
+ script=self._state.current_script,
77
+ region=self._state.region,
78
+ platform=self._state.platform,
79
+ niche=self._state.niche,
80
+ )
81
+
82
+ rewrite_result = self.rewriter.rewrite(self._state.current_script, arb_action)
83
+ new_script = rewrite_result.rewritten_script
84
+
85
+ r1_result = self.r1.score(new_script)
86
+ r2_result = self.r2.score(self._state.original_script, new_script)
87
+ components = RewardComponents(
88
+ r1_hook_strength=r1_result.score,
89
+ r2_coherence=r2_result.score,
90
+ )
91
+
92
+ self._state.action_history.append(arb_action.action_type)
93
+ if self.use_anti_gaming:
94
+ components = self.aggregator.compute(
95
+ components, self._state.episode_start_rewards, self._state.action_history
96
+ )
97
+ else:
98
+ components.compute_total()
99
+
100
+ round_ = DebateRound(
101
+ step_num=self._state.step_num,
102
+ critic_claims=critique.claims,
103
+ arbitrator_action=arb_action,
104
+ rewrite_diff=rewrite_result.diff,
105
+ reward_components=components,
106
+ )
107
+ self._state.debate_history.append(round_)
108
+ self._state.current_script = new_script
109
+ self._state.last_reward_components = components
110
+ self._state.step_num += 1
111
+
112
+ terminated = (
113
+ self._state.step_num >= self._state.max_steps
114
+ or components.total >= 0.9
115
+ )
116
+ info = {
117
+ "reward_components": components.model_dump(),
118
+ "anti_gaming_triggered": components.anti_gaming_penalty > 0,
119
+ "penalty_reason": "anti_gaming" if components.anti_gaming_penalty > 0 else None,
120
+ }
121
+ return self._build_observation().model_dump(), components.total, terminated, False, info
122
+
123
+ def state(self) -> dict:
124
+ if self._state is None:
125
+ return {}
126
+ s = self._state
127
+ return {
128
+ "current_script": s.current_script,
129
+ "original_script": s.original_script,
130
+ "debate_history": [r.model_dump() for r in s.debate_history],
131
+ "reward_components": s.last_reward_components.model_dump(),
132
+ "step_num": s.step_num,
133
+ "difficulty_level": s.difficulty_level,
134
+ "episode_id": s.episode_id,
135
+ }
136
+
137
+ def _build_observation(self) -> Observation:
138
+ s = self._state
139
+ return Observation(
140
+ current_script=s.current_script,
141
+ original_script=s.original_script,
142
+ region=s.region,
143
+ platform=s.platform,
144
+ niche=s.niche,
145
+ step_num=s.step_num,
146
+ max_steps=s.max_steps,
147
+ debate_history=s.debate_history,
148
+ reward_components=s.last_reward_components,
149
+ difficulty_level=s.difficulty_level,
150
+ episode_id=s.episode_id,
151
+ )
viral_script_engine/environment/episode_state.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import uuid
3
+ from dataclasses import dataclass, field
4
+ from typing import List
5
+
6
+ from viral_script_engine.environment.actions import ActionType
7
+ from viral_script_engine.environment.observations import DebateRound, RewardComponents
8
+
9
+
10
+ @dataclass
11
+ class EpisodeState:
12
+ episode_id: str
13
+ original_script: str
14
+ current_script: str
15
+ region: str
16
+ platform: str
17
+ niche: str
18
+ step_num: int
19
+ max_steps: int
20
+ debate_history: List[DebateRound]
21
+ episode_start_rewards: RewardComponents
22
+ last_reward_components: RewardComponents
23
+ difficulty_level: str
24
+ action_history: List[ActionType]
25
+
26
+ @classmethod
27
+ def new(
28
+ cls,
29
+ script: dict,
30
+ max_steps: int,
31
+ difficulty_level: str,
32
+ initial_rewards: RewardComponents,
33
+ ) -> EpisodeState:
34
+ return cls(
35
+ episode_id=str(uuid.uuid4()),
36
+ original_script=script["script_text"],
37
+ current_script=script["script_text"],
38
+ region=script["region"],
39
+ platform=script["platform"],
40
+ niche=script["niche"],
41
+ step_num=0,
42
+ max_steps=max_steps,
43
+ debate_history=[],
44
+ episode_start_rewards=initial_rewards,
45
+ last_reward_components=initial_rewards,
46
+ difficulty_level=difficulty_level,
47
+ action_history=[],
48
+ )
viral_script_engine/environment/observations.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Any, Dict, List, Optional
3
+ from pydantic import BaseModel
4
+
5
+ from viral_script_engine.agents.critic import CritiqueClaim
6
+ from viral_script_engine.environment.actions import ArbitratorAction
7
+
8
+ _WEIGHTS: Dict[str, float] = {
9
+ "r1": 0.25, "r2": 0.20, "r3": 0.20, "r4": 0.20, "r5": 0.15
10
+ }
11
+
12
+
13
+ class RewardComponents(BaseModel):
14
+ r1_hook_strength: Optional[float] = None
15
+ r2_coherence: Optional[float] = None
16
+ r3_cultural_alignment: Optional[float] = None
17
+ r4_debate_resolution: Optional[float] = None
18
+ r5_defender_preservation: Optional[float] = None
19
+ anti_gaming_penalty: float = 0.0
20
+ total: float = 0.0
21
+
22
+ def compute_total(self) -> float:
23
+ vals = {
24
+ "r1": self.r1_hook_strength,
25
+ "r2": self.r2_coherence,
26
+ "r3": self.r3_cultural_alignment,
27
+ "r4": self.r4_debate_resolution,
28
+ "r5": self.r5_defender_preservation,
29
+ }
30
+ active = {k: v for k, v in vals.items() if v is not None}
31
+ if not active:
32
+ self.total = 0.0
33
+ return 0.0
34
+ norm = sum(_WEIGHTS[k] for k in active)
35
+ weighted = sum(_WEIGHTS[k] * v for k, v in active.items()) / norm
36
+ self.total = max(0.0, min(1.0, weighted - self.anti_gaming_penalty))
37
+ return self.total
38
+
39
+
40
+ class DebateRound(BaseModel):
41
+ step_num: int
42
+ critic_claims: List[CritiqueClaim]
43
+ defender_response: Optional[Any] = None
44
+ arbitrator_action: Optional[ArbitratorAction] = None
45
+ rewrite_diff: Optional[str] = None
46
+ reward_components: Optional[RewardComponents] = None
47
+
48
+
49
+ class Observation(BaseModel):
50
+ current_script: str
51
+ original_script: str
52
+ region: str
53
+ platform: str
54
+ niche: str
55
+ step_num: int
56
+ max_steps: int
57
+ debate_history: List[DebateRound]
58
+ reward_components: RewardComponents
59
+ difficulty_level: str
60
+ episode_id: str
viral_script_engine/rewards/__init__.py ADDED
File without changes
viral_script_engine/rewards/base.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class BaseReward(ABC):
5
+ @abstractmethod
6
+ def score(self, *args, **kwargs):
7
+ pass
viral_script_engine/rewards/r1_hook_strength.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from dataclasses import dataclass, field
3
+ from typing import Dict
4
+
5
+ from viral_script_engine.rewards.base import BaseReward
6
+
7
+ _DEAD_OPENERS = [
8
+ "hey guys", "welcome back", "today i want to", "so today",
9
+ "in this video", "what's up everyone", "hey everyone",
10
+ "guys today", "hello everyone", "so basically",
11
+ ]
12
+
13
+ _COMMON_WORDS = {
14
+ 'i', 'the', 'a', 'an', 'my', 'your', 'its', 'it', 'is', 'are',
15
+ 'was', 'were', 'be', 'been', "i've", "i'm", "it's", "here's",
16
+ 'today', 'and', 'but', 'so', 'that', 'this', 'these', 'those',
17
+ }
18
+
19
+
20
+ @dataclass
21
+ class HookRewardResult:
22
+ score: float
23
+ checks_passed: int
24
+ check_details: Dict[str, bool] = field(default_factory=dict)
25
+
26
+
27
+ def _extract_hook(text: str) -> str:
28
+ sentences = re.split(r'(?<=[.!?])\s+', text.strip())
29
+ hook = " ".join(sentences[:3]) if len(sentences) >= 3 else text
30
+ words = hook.split()
31
+ return " ".join(words[:50]) if len(words) > 50 else hook
32
+
33
+
34
+ class HookStrengthReward(BaseReward):
35
+ def score(self, script: str) -> HookRewardResult:
36
+ hook = _extract_hook(script)
37
+ hook_lower = hook.lower()
38
+ first_sentence = re.split(r'(?<=[.!?])\s+', hook.strip())[0].lower()
39
+
40
+ checks = {
41
+ "promise": self._check_promise(hook_lower),
42
+ "curiosity": self._check_curiosity(hook_lower),
43
+ "specificity": self._check_specificity(hook),
44
+ "front_load": self._check_front_load(first_sentence),
45
+ "anti_filler": self._check_anti_filler(hook_lower),
46
+ }
47
+ passed = sum(checks.values())
48
+ return HookRewardResult(
49
+ score=min(1.0, max(0.0, passed / 5)),
50
+ checks_passed=passed,
51
+ check_details=checks,
52
+ )
53
+
54
+ def _check_promise(self, hook: str) -> bool:
55
+ bad = ["hey guys", "welcome back", "today we're talking about"]
56
+ if any(b in hook for b in bad):
57
+ return False
58
+ patterns = [
59
+ r'\d',
60
+ r'\bhow to\b',
61
+ r'\bwhy\b',
62
+ r'\bwhat happens when\b',
63
+ r'\bi made\b',
64
+ ]
65
+ return any(re.search(p, hook) for p in patterns)
66
+
67
+ def _check_curiosity(self, hook: str) -> bool:
68
+ patterns = [
69
+ r'\?',
70
+ r"but here'?s the thing",
71
+ r"most \w+ don'?t know",
72
+ r"the secret is",
73
+ r"nobody tells you",
74
+ r"most people don'?t",
75
+ ]
76
+ if not any(re.search(p, hook) for p in patterns):
77
+ return False
78
+ first = re.split(r'(?<=[.!?])\s+', hook)[0]
79
+ if re.search(r'\?', first) and re.search(r'\b(is|are|was|were|means|equals)\b', first):
80
+ return False
81
+ return True
82
+
83
+ def _check_specificity(self, hook: str) -> bool:
84
+ if re.search(r'\d', hook):
85
+ return True
86
+ sentences = re.split(r'(?<=[.!?])\s+', hook)
87
+ for sentence in sentences:
88
+ words = sentence.split()[1:]
89
+ for w in words:
90
+ clean = w.strip('.,!?;:\'"')
91
+ if clean and clean[0].isupper() and clean.lower() not in _COMMON_WORDS:
92
+ return True
93
+ return False
94
+
95
+ def _check_front_load(self, first_sentence: str) -> bool:
96
+ signals = 0
97
+ if re.search(r'\d', first_sentence):
98
+ signals += 1
99
+ promise_patterns = [r'\bhow to\b', r'\bwhy\b', r'\bwhat happens when\b', r'\bi made\b']
100
+ if any(re.search(p, first_sentence) for p in promise_patterns):
101
+ signals += 1
102
+ if re.search(r'\?', first_sentence):
103
+ signals += 1
104
+ return signals >= 2
105
+
106
+ def _check_anti_filler(self, hook: str) -> bool:
107
+ return not any(hook.startswith(opener) for opener in _DEAD_OPENERS)
viral_script_engine/rewards/r2_coherence.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ from dataclasses import dataclass
3
+
4
+ from viral_script_engine.rewards.base import BaseReward
5
+
6
+
7
+ @dataclass
8
+ class CoherenceRewardResult:
9
+ score: float
10
+ raw_similarity: float
11
+ interpretation: str
12
+
13
+
14
+ class CoherenceReward(BaseReward):
15
+ _cache: dict = {}
16
+
17
+ def __init__(self):
18
+ self._model = None
19
+
20
+ def _get_model(self):
21
+ if self._model is None:
22
+ from sentence_transformers import SentenceTransformer
23
+ self._model = SentenceTransformer("all-MiniLM-L6-v2")
24
+ return self._model
25
+
26
+ def _embed(self, text: str):
27
+ key = hashlib.sha256(text.encode()).hexdigest()
28
+ if key not in self._cache:
29
+ self._cache[key] = self._get_model().encode(text, convert_to_tensor=True)
30
+ return self._cache[key]
31
+
32
+ def _cosine_sim(self, a, b) -> float:
33
+ from sentence_transformers.util import cos_sim
34
+ return float(cos_sim(a, b)[0][0])
35
+
36
+ def score(self, original: str, rewritten: str) -> CoherenceRewardResult:
37
+ sim = self._cosine_sim(self._embed(original), self._embed(rewritten))
38
+ if sim > 0.95:
39
+ score, interpretation = 0.8, "barely_changed"
40
+ elif sim >= 0.80:
41
+ score = 0.5 + (sim - 0.80) / 0.15 * 0.5
42
+ interpretation = "good_coherence"
43
+ elif sim >= 0.65:
44
+ score = (sim - 0.65) / 0.15 * 0.5
45
+ interpretation = "moderate_drift"
46
+ else:
47
+ score, interpretation = 0.0, "drifted_too_far"
48
+ return CoherenceRewardResult(score=score, raw_similarity=sim, interpretation=interpretation)
viral_script_engine/rewards/reward_aggregator.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List
3
+
4
+ from viral_script_engine.environment.actions import ActionType
5
+ from viral_script_engine.environment.observations import RewardComponents
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ _COMPONENT_FIELDS = [
10
+ "r1_hook_strength", "r2_coherence", "r3_cultural_alignment",
11
+ "r4_debate_resolution", "r5_defender_preservation",
12
+ ]
13
+
14
+
15
+ class RewardAggregator:
16
+ def compute(
17
+ self,
18
+ components: RewardComponents,
19
+ episode_start_components: RewardComponents,
20
+ action_history: List[ActionType],
21
+ ) -> RewardComponents:
22
+ components.compute_total()
23
+
24
+ # Anti-gaming rule 1: catastrophic drop (>0.2 drop in any component)
25
+ for field in _COMPONENT_FIELDS:
26
+ curr = getattr(components, field)
27
+ start = getattr(episode_start_components, field)
28
+ if curr is not None and start is not None and curr < start - 0.2:
29
+ logger.warning("Catastrophic drop in %s: %.3f -> %.3f", field, start, curr)
30
+ components.total = 0.0
31
+ components.anti_gaming_penalty = start - curr
32
+ return components
33
+
34
+ # Anti-gaming rule 2: action diversity (last 3 same ActionType)
35
+ penalty = 0.0
36
+ if len(action_history) >= 3 and len(set(action_history[-3:])) == 1:
37
+ penalty = 0.15
38
+ logger.warning("Action diversity penalty: last 3 actions all %s", action_history[-1])
39
+
40
+ components.anti_gaming_penalty = penalty
41
+ components.total = max(0.0, min(1.0, components.total - penalty))
42
+ return components
viral_script_engine/scripts/run_dummy_episode.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import argparse
3
+ import json
4
+ import random
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ from dotenv import load_dotenv
9
+ from rich.console import Console
10
+ from rich.panel import Panel
11
+ from rich.table import Table
12
+ from rich import box
13
+
14
+ load_dotenv()
15
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
16
+
17
+ from viral_script_engine.environment.actions import ActionType
18
+ from viral_script_engine.environment.env import ViralScriptEnv
19
+
20
+ console = Console()
21
+ BASE_DIR = Path(__file__).parent.parent
22
+
23
+
24
+ def build_random_action(action_type: ActionType) -> dict:
25
+ labels = {
26
+ ActionType.HOOK_REWRITE: ("hook", "Rewrite the hook to open with a specific number or bold claim."),
27
+ ActionType.SECTION_REORDER: ("body", "Move the strongest point to immediately follow the hook."),
28
+ ActionType.CULTURAL_REF_SUB: ("full", "Replace any generic references with locally relevant ones."),
29
+ ActionType.CTA_PLACEMENT: ("cta", "Move the call-to-action earlier, before the 80% mark."),
30
+ }
31
+ section, instruction = labels[action_type]
32
+ return {
33
+ "action_type": action_type.value,
34
+ "target_section": section,
35
+ "instruction": instruction,
36
+ "critique_claim_id": "C1",
37
+ "reasoning": f"Demo run: applying {action_type.value}",
38
+ }
39
+
40
+
41
+ def run_episode(difficulty: str, steps: int, verbose: bool) -> dict:
42
+ scripts_path = str(BASE_DIR / "data" / "test_scripts" / "scripts.json")
43
+ env = ViralScriptEnv(scripts_path=scripts_path, max_steps=steps, difficulty=difficulty)
44
+
45
+ obs, _ = env.reset()
46
+ console.print(Panel(
47
+ f"[bold]Episode started[/bold]\n"
48
+ f"Difficulty: {difficulty} | Max steps: {steps}\n"
49
+ f"Region: {obs['region']} | Platform: {obs['platform']} | Niche: {obs['niche']}\n"
50
+ f"Episode ID: {obs['episode_id']}",
51
+ title="[bold blue]Phase 1 Demo Episode[/bold blue]",
52
+ border_style="blue",
53
+ ))
54
+
55
+ episode_log = {
56
+ "episode_id": obs["episode_id"],
57
+ "difficulty": difficulty,
58
+ "steps": [],
59
+ "final_state": None,
60
+ }
61
+
62
+ for step_num in range(steps):
63
+ action_type = random.choice(list(ActionType))
64
+ action = build_random_action(action_type)
65
+
66
+ obs, reward, terminated, truncated, info = env.step(action)
67
+ rc = info["reward_components"]
68
+
69
+ if verbose:
70
+ t = Table(title=f"Step {step_num + 1} — {action_type.value}", box=box.SIMPLE_HEAD)
71
+ t.add_column("Metric", style="cyan", min_width=22)
72
+ t.add_column("Value", min_width=12)
73
+ r1_val = rc.get("r1_hook_strength")
74
+ r2_val = rc.get("r2_coherence")
75
+ t.add_row("R1 Hook Strength", f"{r1_val:.3f}" if r1_val is not None else "N/A")
76
+ t.add_row("R2 Coherence", f"{r2_val:.3f}" if r2_val is not None else "N/A")
77
+ t.add_row("Total Reward", f"[bold]{reward:.3f}[/bold]")
78
+ if info.get("anti_gaming_triggered"):
79
+ t.add_row("Anti-Gaming Penalty", f"[red]{rc.get('anti_gaming_penalty', 0):.3f}[/red]")
80
+ t.add_row("Penalty Reason", f"[red]{info.get('penalty_reason', '')}[/red]")
81
+ t.add_row("Terminated", str(terminated))
82
+ console.print(t)
83
+
84
+ if obs.get("debate_history"):
85
+ latest = obs["debate_history"][-1]
86
+ if latest.get("rewrite_diff"):
87
+ console.print(Panel(
88
+ latest["rewrite_diff"][:600] or "(no diff)",
89
+ title="Script Diff",
90
+ border_style="yellow",
91
+ ))
92
+
93
+ episode_log["steps"].append({
94
+ "step": step_num + 1,
95
+ "action": action,
96
+ "reward": reward,
97
+ "reward_components": rc,
98
+ "anti_gaming": info.get("anti_gaming_triggered", False),
99
+ "terminated": terminated,
100
+ })
101
+
102
+ if terminated:
103
+ break
104
+
105
+ final_state = env.state()
106
+ episode_log["final_state"] = final_state
107
+ final_rc = final_state["reward_components"]
108
+
109
+ console.print(Panel(
110
+ f"[bold green]Final Reward:[/bold green] {final_rc.get('total', 0):.3f}\n"
111
+ f"R1 Hook Strength: {final_rc.get('r1_hook_strength', 'N/A')}\n"
112
+ f"R2 Coherence: {final_rc.get('r2_coherence', 'N/A')}\n"
113
+ f"Steps completed: {final_state['step_num']}",
114
+ title="Episode Summary",
115
+ border_style="green",
116
+ ))
117
+
118
+ return episode_log
119
+
120
+
121
+ def main():
122
+ parser = argparse.ArgumentParser(description="Run Phase 1 dummy episode")
123
+ parser.add_argument("--difficulty", default="easy", choices=["easy", "medium", "hard"])
124
+ parser.add_argument("--steps", type=int, default=3)
125
+ parser.add_argument("--verbose", action="store_true")
126
+ args = parser.parse_args()
127
+
128
+ episode_log = run_episode(args.difficulty, args.steps, args.verbose)
129
+
130
+ logs_dir = BASE_DIR / "logs"
131
+ logs_dir.mkdir(exist_ok=True)
132
+ log_path = logs_dir / f"episode_{episode_log['episode_id']}.json"
133
+ with open(log_path, "w") as f:
134
+ json.dump(episode_log, f, indent=2, default=str)
135
+ console.print(f"[dim]Episode log saved -> {log_path}[/dim]")
136
+
137
+ final_rc = episode_log["final_state"]["reward_components"]
138
+ gate_pass = (
139
+ final_rc.get("r1_hook_strength") is not None
140
+ and final_rc.get("r2_coherence") is not None
141
+ and log_path.exists()
142
+ )
143
+ style = "bold green" if gate_pass else "bold red"
144
+ label = f"PHASE 1 GATE: {'PASS' if gate_pass else 'FAIL'}"
145
+ console.print(Panel(f"[{style}]{label}[/{style}]", border_style="green" if gate_pass else "red"))
146
+
147
+
148
+ if __name__ == "__main__":
149
+ main()
viral_script_engine/tests/test_environment.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ import pytest
6
+
7
+ from viral_script_engine.agents.critic import CritiqueOutput, CritiqueClaim
8
+ from viral_script_engine.agents.rewriter import RewriteResult
9
+ from viral_script_engine.environment.actions import ActionType, ArbitratorAction
10
+
11
+ FIXTURE_DIR = Path(__file__).parent.parent / "data" / "golden_fixtures"
12
+ SCRIPTS_PATH = str(Path(__file__).parent.parent / "data" / "test_scripts" / "scripts.json")
13
+
14
+
15
+ def load_fixture(script_id: str) -> dict:
16
+ with open(FIXTURE_DIR / f"fixture_{script_id}.json") as f:
17
+ return json.load(f)
18
+
19
+
20
+ def make_mock_critique() -> CritiqueOutput:
21
+ fixture = load_fixture("S01")
22
+ claims = [CritiqueClaim(**c) for c in fixture["critique"]["claims"]]
23
+ return CritiqueOutput(
24
+ claims=claims,
25
+ overall_severity=fixture["critique"]["overall_severity"],
26
+ raw_response=fixture["critique"]["raw_response"],
27
+ )
28
+
29
+
30
+ def make_mock_rewrite(current_script: str, action: ArbitratorAction) -> RewriteResult:
31
+ return RewriteResult(
32
+ rewritten_script=current_script + " [REWRITTEN]",
33
+ diff="@@ diff @@",
34
+ word_count_delta=1,
35
+ )
36
+
37
+
38
+ SAMPLE_ACTION = {
39
+ "action_type": ActionType.HOOK_REWRITE.value,
40
+ "target_section": "hook",
41
+ "instruction": "Make the hook more attention-grabbing with a specific number.",
42
+ "critique_claim_id": "C1",
43
+ "reasoning": "Hook is weak per C1",
44
+ }
45
+
46
+
47
+ @pytest.fixture
48
+ def env():
49
+ with (
50
+ patch("viral_script_engine.environment.env.CriticAgent") as mock_critic_cls,
51
+ patch("viral_script_engine.environment.env.RewriterAgent") as mock_rewriter_cls,
52
+ ):
53
+ mock_critic = MagicMock()
54
+ mock_critic.critique.return_value = make_mock_critique()
55
+ mock_critic_cls.return_value = mock_critic
56
+
57
+ mock_rewriter = MagicMock()
58
+ mock_rewriter.rewrite.side_effect = make_mock_rewrite
59
+ mock_rewriter_cls.return_value = mock_rewriter
60
+
61
+ from viral_script_engine.environment.env import ViralScriptEnv
62
+ yield ViralScriptEnv(scripts_path=SCRIPTS_PATH, max_steps=5, difficulty="easy")
63
+
64
+
65
+ def test_reset_returns_valid_observation(env):
66
+ obs, info = env.reset(seed=42)
67
+ assert "current_script" in obs
68
+ assert obs["step_num"] == 0
69
+ assert obs["max_steps"] == 5
70
+ assert obs["reward_components"]["r1_hook_strength"] is not None
71
+ assert obs["reward_components"]["r2_coherence"] is not None
72
+
73
+
74
+ def test_step_completes_without_error(env):
75
+ env.reset(seed=42)
76
+ obs, reward, terminated, truncated, info = env.step(SAMPLE_ACTION)
77
+ assert isinstance(reward, float)
78
+ assert "reward_components" in info
79
+
80
+
81
+ def test_step_increments_step_num(env):
82
+ env.reset(seed=42)
83
+ obs, *_ = env.step(SAMPLE_ACTION)
84
+ assert obs["step_num"] == 1
85
+ obs, *_ = env.step(SAMPLE_ACTION)
86
+ assert obs["step_num"] == 2
87
+
88
+
89
+ def test_anti_gaming_penalty_fires_on_repeated_action(env):
90
+ env.reset(seed=42)
91
+ for _ in range(3):
92
+ obs, reward, _, _, info = env.step(SAMPLE_ACTION)
93
+ assert info["anti_gaming_triggered"]
94
+
95
+
96
+ def test_episode_terminates_at_max_steps(env):
97
+ env.reset(seed=42)
98
+ terminated = False
99
+ for _ in range(5):
100
+ obs, reward, terminated, truncated, info = env.step(SAMPLE_ACTION)
101
+ assert terminated
102
+
103
+
104
+ def test_reward_clipped_to_0_1(env):
105
+ env.reset(seed=42)
106
+ _, reward, _, _, _ = env.step(SAMPLE_ACTION)
107
+ assert 0.0 <= reward <= 1.0
viral_script_engine/tests/test_rewards.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from viral_script_engine.rewards.r1_hook_strength import HookStrengthReward
3
+ from viral_script_engine.rewards.r2_coherence import CoherenceReward
4
+ from viral_script_engine.rewards.reward_aggregator import RewardAggregator
5
+ from viral_script_engine.environment.observations import RewardComponents
6
+ from viral_script_engine.environment.actions import ActionType
7
+
8
+ # ── R1 test hooks ─────────────────────────────────────────────────────────────
9
+ HOOK_HIGH_1 = (
10
+ "I made $10,000 in 30 days with 3 crypto strategies. "
11
+ "Here's the secret most people don't know. "
12
+ "This completely changed how I invest."
13
+ )
14
+ HOOK_HIGH_2 = (
15
+ "Why 95% of people fail at losing weight in 2024. "
16
+ "Most people don't know this simple truth. "
17
+ "It's not about calories at all."
18
+ )
19
+ HOOK_LOW_1 = (
20
+ "Hey guys, welcome back to my channel! "
21
+ "Today I want to talk about some stuff. "
22
+ "It's going to be super interesting!"
23
+ )
24
+ HOOK_LOW_2 = (
25
+ "So basically today I'm going to talk about fitness. "
26
+ "It's really important for everyone. "
27
+ "Let's get started with some tips."
28
+ )
29
+ HOOK_EDGE = (
30
+ "What nobody tells you about starting a business in India. "
31
+ "I found out the hard way. "
32
+ "Here's my experience."
33
+ )
34
+
35
+
36
+ @pytest.fixture
37
+ def r1():
38
+ return HookStrengthReward()
39
+
40
+
41
+ @pytest.fixture
42
+ def r2():
43
+ return CoherenceReward()
44
+
45
+
46
+ @pytest.fixture
47
+ def aggregator():
48
+ return RewardAggregator()
49
+
50
+
51
+ # ── R1 tests ──────────────────────────────────────────────────────────────────
52
+ def test_r1_high_score_1(r1):
53
+ result = r1.score(HOOK_HIGH_1)
54
+ assert result.score > 0.8
55
+
56
+
57
+ def test_r1_high_score_2(r1):
58
+ result = r1.score(HOOK_HIGH_2)
59
+ assert result.score > 0.8
60
+
61
+
62
+ def test_r1_low_score_1(r1):
63
+ result = r1.score(HOOK_LOW_1)
64
+ assert result.score < 0.3
65
+
66
+
67
+ def test_r1_low_score_2(r1):
68
+ result = r1.score(HOOK_LOW_2)
69
+ assert result.score < 0.3
70
+
71
+
72
+ def test_r1_edge_case(r1):
73
+ result = r1.score(HOOK_EDGE)
74
+ assert 0.3 <= result.score <= 0.7
75
+
76
+
77
+ # ── R2 tests ──────────────────────────────────────────────────────────────────
78
+ def test_r2_identical_strings(r2):
79
+ text = "This is a test script for the viral script engine."
80
+ result = r2.score(text, text)
81
+ assert result.score == 0.8
82
+
83
+
84
+ def test_r2_different_strings(r2):
85
+ orig = "I made $10,000 with crypto in 30 days using these 3 strategies."
86
+ diff = "The history of ancient Rome spans over a thousand years of conquest."
87
+ result = r2.score(orig, diff)
88
+ assert result.score == 0.0
89
+
90
+
91
+ # ── Aggregator tests ──────────────────────────────────────────────────────────
92
+ def test_aggregator_catastrophic_drop(aggregator):
93
+ start = RewardComponents(r1_hook_strength=0.8, r2_coherence=0.7)
94
+ start.compute_total()
95
+ current = RewardComponents(r1_hook_strength=0.3, r2_coherence=0.7)
96
+ result = aggregator.compute(current, start, [ActionType.HOOK_REWRITE])
97
+ assert result.total == 0.0
98
+
99
+
100
+ def test_aggregator_diversity_penalty(aggregator):
101
+ start = RewardComponents(r1_hook_strength=0.6, r2_coherence=0.6)
102
+ start.compute_total()
103
+ current = RewardComponents(r1_hook_strength=0.7, r2_coherence=0.7)
104
+ history = [ActionType.HOOK_REWRITE, ActionType.HOOK_REWRITE, ActionType.HOOK_REWRITE]
105
+ result = aggregator.compute(current, start, history)
106
+ assert result.anti_gaming_penalty == 0.15
107
+ assert result.total < 0.7
108
+
109
+
110
+ def test_aggregator_no_penalty(aggregator):
111
+ start = RewardComponents(r1_hook_strength=0.6, r2_coherence=0.6)
112
+ start.compute_total()
113
+ current = RewardComponents(r1_hook_strength=0.7, r2_coherence=0.7)
114
+ history = [ActionType.HOOK_REWRITE, ActionType.CTA_PLACEMENT, ActionType.SECTION_REORDER]
115
+ result = aggregator.compute(current, start, history)
116
+ assert result.anti_gaming_penalty == 0.0
117
+ assert result.total > 0