vajeeda commited on
Commit
258783b
Β·
1 Parent(s): 0290d7a

phase2 implemented

Browse files
viral_script_engine/agents/baseline_arbitrator.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from viral_script_engine.agents.llm_backend import LLMBackend
4
+
5
+ SYSTEM_PROMPT = """You are helping improve a short-form video script.
6
+ You have observed a debate between a Critic and a Defender about the script.
7
+ Choose ONE action to take to improve the script.
8
+
9
+ Available actions: hook_rewrite, section_reorder, cultural_ref_sub, cta_placement
10
+
11
+ Respond ONLY with valid JSON:
12
+ {
13
+ "action_type": "hook_rewrite",
14
+ "target_section": "hook",
15
+ "instruction": "specific instruction for the rewriter",
16
+ "critique_claim_id": "C1",
17
+ "reasoning": "brief explanation"
18
+ }"""
19
+
20
+ STRICT_RETRY_SUFFIX = (
21
+ "\n\nIMPORTANT: Your previous response was not valid JSON. "
22
+ "Respond ONLY with the raw JSON object. No markdown fences, no explanation, no preamble."
23
+ )
24
+
25
+ _FALLBACK_ACTION = {
26
+ "action_type": "hook_rewrite",
27
+ "target_section": "hook",
28
+ "instruction": "Rewrite the hook to be more engaging and direct.",
29
+ "critique_claim_id": "C1",
30
+ "reasoning": "Default fallback action.",
31
+ }
32
+
33
+
34
+ class BaselineArbitratorAgent:
35
+ """
36
+ Untrained Arbitrator for the pre-training baseline.
37
+ Uses zero-shot instruction β€” no chain-of-thought, no few-shot examples.
38
+ This ensures the comparison is fair: trained model learns through RL, not prompting.
39
+ """
40
+
41
+ def __init__(self, backend: str = "groq", model_name: str = "llama-3.3-70b-versatile"):
42
+ self.llm = LLMBackend(backend=backend, model_name=model_name)
43
+
44
+ def _build_user_prompt(self, observation: dict) -> str:
45
+ script = observation.get("current_script", "")
46
+ debate = observation.get("debate_history", [])
47
+ last_claims = []
48
+ last_defense = None
49
+ if debate:
50
+ last_round = debate[-1]
51
+ last_claims = last_round.get("critic_claims", [])
52
+ last_defense = last_round.get("defender_response")
53
+
54
+ claims_text = ""
55
+ for c in last_claims:
56
+ claims_text += f"- [{c.get('claim_id','?')}] {c.get('claim_text','')} (severity: {c.get('severity','')})\n"
57
+
58
+ defense_text = ""
59
+ if last_defense:
60
+ defense_text = (
61
+ f"Defender preserved: {last_defense.get('core_strength_quote','')}\n"
62
+ f"Flagged claims: {last_defense.get('flagged_critic_claims', [])}\n"
63
+ )
64
+
65
+ return (
66
+ f"SCRIPT:\n{script}\n\n"
67
+ f"CRITIC CLAIMS:\n{claims_text or 'None'}\n"
68
+ f"DEFENDER:\n{defense_text or 'None'}\n\n"
69
+ "Choose one action to improve the script."
70
+ )
71
+
72
+ def act(self, observation: dict) -> dict:
73
+ user_prompt = self._build_user_prompt(observation)
74
+ raw = self.llm.generate(SYSTEM_PROMPT, user_prompt, max_tokens=512)
75
+ try:
76
+ return json.loads(raw)
77
+ except Exception:
78
+ strict_prompt = user_prompt + STRICT_RETRY_SUFFIX
79
+ raw2 = self.llm.generate(SYSTEM_PROMPT, strict_prompt, max_tokens=512)
80
+ try:
81
+ return json.loads(raw2)
82
+ except Exception:
83
+ return _FALLBACK_ACTION.copy()
viral_script_engine/agents/defender.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from typing import List
3
+
4
+ from pydantic import BaseModel
5
+
6
+ from viral_script_engine.agents.llm_backend import LLMBackend
7
+ from viral_script_engine.agents.critic import CritiqueClaim
8
+
9
+ SYSTEM_PROMPT = """You are a script defender for short-form video content. Your job is NOT to say the script is perfect.
10
+ Your job is to identify what is genuinely working β€” and protect it from being edited away.
11
+
12
+ Specifically:
13
+ 1. Find the single most powerful element of the script. Quote it exactly.
14
+ 2. Explain why a viewer would respond positively to this element.
15
+ 3. Review the Critic's claims. Flag any that would destroy the script's core strength or strip its regional authenticity if acted on.
16
+ 4. List any phrases, idioms, or references that are intentionally regional β€” these must not be "corrected" away.
17
+
18
+ OUTPUT (JSON only, no preamble):
19
+ {
20
+ "core_strength": "one sentence describing the strongest element",
21
+ "core_strength_quote": "exact verbatim quote from the script",
22
+ "defense_argument": "why this element should be preserved",
23
+ "flagged_critic_claims": ["C2", "C3"],
24
+ "regional_voice_elements": ["specific phrase 1", "specific phrase 2"]
25
+ }"""
26
+
27
+ STRICT_RETRY_SUFFIX = (
28
+ "\n\nIMPORTANT: Your previous response was not valid JSON. "
29
+ "Respond ONLY with the raw JSON object. No markdown fences, no explanation, no preamble."
30
+ )
31
+
32
+
33
+ class DefenderParseError(Exception):
34
+ pass
35
+
36
+
37
+ class DefenderOutput(BaseModel):
38
+ core_strength: str
39
+ core_strength_quote: str
40
+ defense_argument: str
41
+ flagged_critic_claims: List[str]
42
+ regional_voice_elements: List[str]
43
+
44
+
45
+ class DefenderAgent:
46
+ def __init__(self, backend: str = "groq", model_name: str = "llama-3.3-70b-versatile"):
47
+ self.llm = LLMBackend(backend=backend, model_name=model_name)
48
+
49
+ def _build_user_prompt(
50
+ self,
51
+ script: str,
52
+ critic_claims: List[CritiqueClaim],
53
+ region: str,
54
+ platform: str,
55
+ ) -> str:
56
+ claims_lines = []
57
+ for i, claim in enumerate(critic_claims, start=1):
58
+ claims_lines.append(
59
+ f"{i}. [{claim.claim_id}] ({claim.critique_class}) {claim.claim_text} | Evidence: {claim.evidence}"
60
+ )
61
+ claims_block = "\n".join(claims_lines) if claims_lines else "No critic claims provided."
62
+
63
+ return (
64
+ f"SCRIPT:\n{script}\n\n"
65
+ f"CRITIC CLAIMS:\n{claims_block}\n\n"
66
+ f"REGION: {region}\n"
67
+ f"PLATFORM: {platform}\n\n"
68
+ "Defend the script now."
69
+ )
70
+
71
+ def _parse_response(self, raw: str, user_prompt: str) -> DefenderOutput:
72
+ try:
73
+ data = json.loads(raw)
74
+ return DefenderOutput(**data)
75
+ except Exception:
76
+ strict_prompt = user_prompt + STRICT_RETRY_SUFFIX
77
+ raw2 = self.llm.generate(SYSTEM_PROMPT, strict_prompt, max_tokens=1024)
78
+ try:
79
+ data = json.loads(raw2)
80
+ return DefenderOutput(**data)
81
+ except Exception as e:
82
+ raise DefenderParseError(f"Failed to parse defender output after 2 attempts: {e}")
83
+
84
+ def defend(
85
+ self,
86
+ script: str,
87
+ critic_claims: List[CritiqueClaim],
88
+ region: str,
89
+ platform: str,
90
+ ) -> DefenderOutput:
91
+ user_prompt = self._build_user_prompt(script, critic_claims, region, platform)
92
+ raw = self.llm.generate(SYSTEM_PROMPT, user_prompt, max_tokens=1024)
93
+ return self._parse_response(raw, user_prompt)
viral_script_engine/data/cultural_kb.json ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "mumbai_gen_z": {
3
+ "valid_refs": [
4
+ "Bandra", "CSMT", "dabba", "auto", "local train", "startup scene",
5
+ "Zomato", "Swiggy", "IPL", "sea link", "Juhu", "Versova", "Aarey",
6
+ "SoBo", "Powai", "Andheri", "Dharavi", "Bandstand", "Colaba",
7
+ "Bollywood 2020s"
8
+ ],
9
+ "invalid_signals": [
10
+ "trunk call", "STD booth", "pager", "cassette", "VHS", "walkman",
11
+ "disco", "beeper", "doordarshan only", "telegram boy",
12
+ "ration card queue", "old delhi", "tonga", "coolie",
13
+ "chawl gossip pre-2000"
14
+ ],
15
+ "correct_idioms": [
16
+ "ek dum solid", "full on", "kya scene hai", "bindaas", "jugaad nahi",
17
+ "yaar sun", "bhai chill maar", "ekdum mast", "kya baat hai",
18
+ "ekdum lit", "arey yaar", "full on vibe", "kya scene hai bhai",
19
+ "timepass nahi", "mast hai re", "solid plan", "chal be"
20
+ ],
21
+ "anachronistic_signals": []
22
+ },
23
+ "tier2_hindi_belt": {
24
+ "valid_refs": [
25
+ "kirana store", "mandap", "jugaad", "sabzi mandi", "panchayat",
26
+ "mela", "dal-chawal", "halwai", "tehsil", "chaupal", "gaon",
27
+ "kheti", "bori", "dhaba", "cycle", "tubewell", "Kanpur",
28
+ "Lucknow", "Patna", "Varanasi"
29
+ ],
30
+ "invalid_signals": [
31
+ "unicorn startup", "runway", "pivot", "SaaS", "venture capital",
32
+ "IPO", "fintech", "coworking space", "metro commute", "craft beer",
33
+ "avocado toast", "brunch", "penthouse", "valet parking",
34
+ "rooftop lounge", "angel investor", "Series A", "growth hacking",
35
+ "disruptive"
36
+ ],
37
+ "correct_idioms": [
38
+ "bilkul sahi", "arey bhai", "baat pakki", "suno ji", "haan ji",
39
+ "seedha baat", "desi style", "apna kaam", "ek number",
40
+ "bahut badhiya", "kya baat", "sahi mein", "dhamaal", "arre wah",
41
+ "mast jugaad", "thoda adjust karo", "kal milte hain"
42
+ ],
43
+ "anachronistic_signals": []
44
+ },
45
+ "pan_india_english": {
46
+ "valid_refs": [
47
+ "India", "Indian", "subcontinent", "festivals", "cricket", "desi",
48
+ "masala", "biryani", "chai", "temple", "monsoon", "Diwali",
49
+ "Holi", "market", "college"
50
+ ],
51
+ "invalid_signals": [
52
+ "kachcha road", "pind", "naali", "akhada", "maidan ke paas",
53
+ "chakki", "chulha", "dung cake", "lathicharge", "tehsildar",
54
+ "patwari", "sarpanch only", "gaon waale", "pradhan ji", "khet mein"
55
+ ],
56
+ "correct_idioms": [
57
+ "absolutely", "for sure", "totally agree", "makes sense",
58
+ "of course", "right", "exactly", "fair point", "well said",
59
+ "spot on", "valid point", "true that", "no doubt", "solid",
60
+ "good point", "great insight", "well done"
61
+ ],
62
+ "anachronistic_signals": []
63
+ },
64
+ "hinglish": {
65
+ "valid_refs": [
66
+ "yaar", "bhai", "dost", "scene", "vibe", "timepass", "jugaad",
67
+ "mast", "bindaas", "desi", "ekdum", "kal", "aaj", "kya", "kaise"
68
+ ],
69
+ "invalid_signals": [
70
+ "I am going to the marketplace", "Please be advised", "Pursuant to",
71
+ "Herewith", "As per our discussion", "In reference to",
72
+ "Kindly note", "I beg to inform", "With due respect",
73
+ "at your earliest convenience", "enclosed herewith",
74
+ "please find attached", "as discussed", "further to my email",
75
+ "I hope this email finds you"
76
+ ],
77
+ "correct_idioms": [
78
+ "yaar sun", "bhai ekdum", "kya baat hai", "arey chill maar",
79
+ "solid plan hai", "mast scene hai", "full on enjoy",
80
+ "ekdum bindaas", "bhai sahi bol raha hai", "kya jugaad lagaya",
81
+ "ek dum mast", "timepass nahi yaar", "kya vibe hai",
82
+ "full on mast", "ekdum sahi", "arey yaar kya scene",
83
+ "bas ek chance"
84
+ ],
85
+ "anachronistic_signals": []
86
+ }
87
+ }
viral_script_engine/data/golden_fixtures/fixture_S01.json CHANGED
@@ -8,59 +8,59 @@
8
  {
9
  "claim_id": "C1",
10
  "critique_class": "hook_weakness",
11
- "claim_text": "The hook at 0:00-0:10 promises a life-changing secret but delays its reveal until 0:45, by which time some viewers may have lost interest",
12
- "timestamp_range": "0:00-0:10",
13
- "evidence": "Okay so real talk \u00e2\u20ac\u201d I've been broke my whole life. Like actually broke. Not the aesthetic broke, the can't-pay-rent broke. And then one day I found this one trick that changed everything.",
14
  "is_falsifiable": true,
15
  "severity": "medium"
16
  },
17
  {
18
  "claim_id": "C2",
19
  "critique_class": "pacing_issue",
20
- "claim_text": "The script spends 15 seconds showing off the apartment at 0:20-0:35, which may slow down the narrative and cause some viewers to lose interest",
21
- "timestamp_range": "0:20-0:35",
22
- "evidence": "But first, let me show you my apartment. Pretty nice right? Took me three years to get here.",
23
  "is_falsifiable": true,
24
- "severity": "low"
25
  },
26
  {
27
  "claim_id": "C3",
28
- "critique_class": "coherence_break",
29
- "claim_text": "The transition from 'I've been broke my whole life' to 'my apartment is pretty nice' at 0:10-0:20 is abrupt and may confuse some viewers",
30
- "timestamp_range": "0:10-0:20",
31
- "evidence": "Like actually broke. Not the aesthetic broke, the can't-pay-rent broke. And then one day I found this one trick that changed everything. But first, let me show you my apartment.",
32
  "is_falsifiable": true,
33
  "severity": "medium"
34
  },
35
  {
36
  "claim_id": "C4",
37
- "critique_class": "cta_buried",
38
- "claim_text": "The call-to-action to follow the creator for more information is buried at the end of the script at 1:00-1:05 and may be missed by viewers who don't watch until the end",
39
- "timestamp_range": "1:00-1:05",
40
  "evidence": "If you want to know which funds I use, follow me and I'll post the list tomorrow.",
41
  "is_falsifiable": true,
42
  "severity": "high"
43
  },
44
  {
45
  "claim_id": "C5",
46
- "critique_class": "retention_risk",
47
- "claim_text": "The script assumes the viewer is familiar with mutual funds and SIPs at 0:40-0:50, which may cause some viewers to feel lost or disconnected from the content",
48
- "timestamp_range": "0:40-0:50",
49
- "evidence": "The secret? Mutual funds. Just SIPs.",
50
  "is_falsifiable": true,
51
  "severity": "medium"
52
  },
53
  {
54
  "claim_id": "C6",
55
- "critique_class": "cultural_mismatch",
56
- "claim_text": "The script uses a very casual tone at 0:00-0:10, which may not resonate with all viewers in the target region, particularly those who prefer more formal or professional content",
57
- "timestamp_range": "0:00-0:10",
58
- "evidence": "Okay so real talk \u00e2\u20ac\u201d I've been broke my whole life.",
59
  "is_falsifiable": true,
60
- "severity": "low"
61
  }
62
  ],
63
- "overall_severity": "medium",
64
- "raw_response": "{\n \"claims\": [\n {\n \"claim_id\": \"C1\",\n \"critique_class\": \"hook_weakness\",\n \"claim_text\": \"The hook at 0:00-0:10 promises a life-changing secret but delays its reveal until 0:45, by which time some viewers may have lost interest\",\n \"timestamp_range\": \"0:00-0:10\",\n \"evidence\": \"Okay so real talk \u00e2\u20ac\u201d I've been broke my whole life. Like actually broke. Not the aesthetic broke, the can't-pay-rent broke. And then one day I found this one trick that changed everything.\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C2\",\n \"critique_class\": \"pacing_issue\",\n \"claim_text\": \"The script spends 15 seconds showing off the apartment at 0:20-0:35, which may slow down the narrative and cause some viewers to lose interest\",\n \"timestamp_range\": \"0:20-0:35\",\n \"evidence\": \"But first, let me show you my apartment. Pretty nice right? Took me three years to get here.\",\n \"is_falsifiable\": true,\n \"severity\": \"low\"\n },\n {\n \"claim_id\": \"C3\",\n \"critique_class\": \"coherence_break\",\n \"claim_text\": \"The transition from 'I've been broke my whole life' to 'my apartment is pretty nice' at 0:10-0:20 is abrupt and may confuse some viewers\",\n \"timestamp_range\": \"0:10-0:20\",\n \"evidence\": \"Like actually broke. Not the aesthetic broke, the can't-pay-rent broke. And then one day I found this one trick that changed everything. But first, let me show you my apartment.\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C4\",\n \"critique_class\": \"cta_buried\",\n \"claim_text\": \"The call-to-action to follow the creator for more information is buried at the end of the script at 1:00-1:05 and may be missed by viewers who don't watch until the end\",\n \"timestamp_range\": \"1:00-1:05\",\n \"evidence\": \"If you want to know which funds I use, follow me and I'll post the list tomorrow.\",\n \"is_falsifiable\": true,\n \"severity\": \"high\"\n },\n {\n \"claim_id\": \"C5\",\n \"critique_class\": \"retention_risk\",\n \"claim_text\": \"The script assumes the viewer is familiar with mutual funds and SIPs at 0:40-0:50, which may cause some viewers to feel lost or disconnected from the content\",\n \"timestamp_range\": \"0:40-0:50\",\n \"evidence\": \"The secret? Mutual funds. Just SIPs.\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C6\",\n \"critique_class\": \"cultural_mismatch\",\n \"claim_text\": \"The script uses a very casual tone at 0:00-0:10, which may not resonate with all viewers in the target region, particularly those who prefer more formal or professional content\",\n \"timestamp_range\": \"0:00-0:10\",\n \"evidence\": \"Okay so real talk \u00e2\u20ac\u201d I've been broke my whole life.\",\n \"is_falsifiable\": true,\n \"severity\": \"low\"\n }\n ],\n \"overall_severity\": \"medium\"\n}"
65
  }
66
  }
 
8
  {
9
  "claim_id": "C1",
10
  "critique_class": "hook_weakness",
11
+ "claim_text": "The hook promises a life-changing trick but takes 0:22 seconds to reveal, by which time viewers might have lost interest",
12
+ "timestamp_range": "0:00-0:03",
13
+ "evidence": "I've been broke my whole life. Like actually broke. Not the aesthetic broke, the can't-pay-rent broke. And then one day I found this one trick that changed everything.",
14
  "is_falsifiable": true,
15
  "severity": "medium"
16
  },
17
  {
18
  "claim_id": "C2",
19
  "critique_class": "pacing_issue",
20
+ "claim_text": "The script switches abruptly from the personal story to a promotional message at 0:45, disrupting the narrative flow",
21
+ "timestamp_range": "0:45-0:55",
22
+ "evidence": "The secret? Mutual funds. Just SIPs. I'm serious.",
23
  "is_falsifiable": true,
24
+ "severity": "high"
25
  },
26
  {
27
  "claim_id": "C3",
28
+ "critique_class": "cta_buried",
29
+ "claim_text": "The call-to-action to follow the creator for the list of funds is buried at the end of the script and might be missed by viewers",
30
+ "timestamp_range": "1:10-1:15",
31
+ "evidence": "If you want to know which funds I use, follow me and I'll post the list tomorrow.",
32
  "is_falsifiable": true,
33
  "severity": "medium"
34
  },
35
  {
36
  "claim_id": "C4",
37
+ "critique_class": "retention_risk",
38
+ "claim_text": "The script assumes viewers will wait for the next day's post for the list of funds, which might lead to a drop in retention",
39
+ "timestamp_range": "1:10-1:15",
40
  "evidence": "If you want to know which funds I use, follow me and I'll post the list tomorrow.",
41
  "is_falsifiable": true,
42
  "severity": "high"
43
  },
44
  {
45
  "claim_id": "C5",
46
+ "critique_class": "coherence_break",
47
+ "claim_text": "The script's tone shifts from relatable and personal to overly promotional and sales-y, potentially alienating the target audience",
48
+ "timestamp_range": "N/A",
49
+ "evidence": "The entire script's tone and language",
50
  "is_falsifiable": true,
51
  "severity": "medium"
52
  },
53
  {
54
  "claim_id": "C6",
55
+ "critique_class": "pacing_issue",
56
+ "claim_text": "The script tries to cram too much information and calls-to-action in the last 10 seconds, which might overwhelm viewers",
57
+ "timestamp_range": "1:15-1:20",
58
+ "evidence": "Like and save this video before Instagram hides it.",
59
  "is_falsifiable": true,
60
+ "severity": "high"
61
  }
62
  ],
63
+ "overall_severity": "high",
64
+ "raw_response": "{\n \"claims\": [\n {\n \"claim_id\": \"C1\",\n \"critique_class\": \"hook_weakness\",\n \"claim_text\": \"The hook promises a life-changing trick but takes 0:22 seconds to reveal, by which time viewers might have lost interest\",\n \"timestamp_range\": \"0:00-0:03\",\n \"evidence\": \"I've been broke my whole life. Like actually broke. Not the aesthetic broke, the can't-pay-rent broke. And then one day I found this one trick that changed everything.\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C2\",\n \"critique_class\": \"pacing_issue\",\n \"claim_text\": \"The script switches abruptly from the personal story to a promotional message at 0:45, disrupting the narrative flow\",\n \"timestamp_range\": \"0:45-0:55\",\n \"evidence\": \"The secret? Mutual funds. Just SIPs. I'm serious.\",\n \"is_falsifiable\": true,\n \"severity\": \"high\"\n },\n {\n \"claim_id\": \"C3\",\n \"critique_class\": \"cta_buried\",\n \"claim_text\": \"The call-to-action to follow the creator for the list of funds is buried at the end of the script and might be missed by viewers\",\n \"timestamp_range\": \"1:10-1:15\",\n \"evidence\": \"If you want to know which funds I use, follow me and I'll post the list tomorrow.\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C4\",\n \"critique_class\": \"retention_risk\",\n \"claim_text\": \"The script assumes viewers will wait for the next day's post for the list of funds, which might lead to a drop in retention\",\n \"timestamp_range\": \"1:10-1:15\",\n \"evidence\": \"If you want to know which funds I use, follow me and I'll post the list tomorrow.\",\n \"is_falsifiable\": true,\n \"severity\": \"high\"\n },\n {\n \"claim_id\": \"C5\",\n \"critique_class\": \"coherence_break\",\n \"claim_text\": \"The script's tone shifts from relatable and personal to overly promotional and sales-y, potentially alienating the target audience\",\n \"timestamp_range\": \"N/A\",\n \"evidence\": \"The entire script's tone and language\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C6\",\n \"critique_class\": \"pacing_issue\",\n \"claim_text\": \"The script tries to cram too much information and calls-to-action in the last 10 seconds, which might overwhelm viewers\",\n \"timestamp_range\": \"1:15-1:20\",\n \"evidence\": \"Like and save this video before Instagram hides it.\",\n \"is_falsifiable\": true,\n \"severity\": \"high\"\n }\n ],\n \"overall_severity\": \"high\"\n}"
65
  }
66
  }
viral_script_engine/data/golden_fixtures/fixture_S02.json CHANGED
@@ -8,59 +8,59 @@
8
  {
9
  "claim_id": "C1",
10
  "critique_class": "hook_weakness",
11
- "claim_text": "The hook at 0:00-0:03 promises five outfits for one thousand rupees, but the reveal at 2:04-2:10 shows the total cost is five hundred rupees, which may be confusing for viewers who were expecting to see how to create outfits for one thousand rupees",
12
- "timestamp_range": "0:00-0:03, 2:04-2:10",
13
- "evidence": "Five outfits, one thousand rupees. ... Grand total \u00e2\u20ac\u201d five hundred rupees for five outfits.",
14
  "is_falsifiable": true,
15
  "severity": "medium"
16
  },
17
  {
18
  "claim_id": "C2",
19
  "critique_class": "pacing_issue",
20
- "claim_text": "The script at 0:44-0:56 has a pacing issue where the creator says 'wait I need to find it. Okay found it.' which disrupts the flow of showcasing outfits",
21
- "timestamp_range": "0:44-0:56",
22
- "evidence": "Outfit three \u00e2\u20ac\u201d wait I need to find it. Okay found it.",
23
  "is_falsifiable": true,
24
  "severity": "low"
25
  },
26
  {
27
  "claim_id": "C3",
28
  "critique_class": "retention_risk",
29
- "claim_text": "The script at 2:20-2:30 has a retention risk where the creator asks viewers to comment their city, which may not be engaging enough to retain viewers' interest",
30
- "timestamp_range": "2:20-2:30",
31
- "evidence": "Comment your city and I'll do a version for your local markets.",
32
  "is_falsifiable": true,
33
  "severity": "medium"
34
  },
35
  {
36
  "claim_id": "C4",
37
- "critique_class": "coherence_break",
38
- "claim_text": "The script at 1:46-1:54 has a coherence break where the creator abruptly shifts from showcasing outfits to talking about a saree drape tutorial",
39
- "timestamp_range": "1:46-1:54",
40
- "evidence": "Outfit five \u00e2\u20ac\u201d this entire saree drape tutorial took me two hours",
41
  "is_falsifiable": true,
42
  "severity": "high"
43
  },
44
  {
45
  "claim_id": "C5",
46
- "critique_class": "cta_buried",
47
- "claim_text": "The script buries the call-to-action at 2:20-2:30, where the creator asks viewers to comment their city, but it's not prominent enough",
48
- "timestamp_range": "2:20-2:30",
49
- "evidence": "Comment your city and I'll do a version for your local markets.",
50
  "is_falsifiable": true,
51
  "severity": "medium"
52
  },
53
  {
54
  "claim_id": "C6",
55
  "critique_class": "cultural_mismatch",
56
- "claim_text": "The script assumes all viewers are familiar with Mumbai-based markets like Linking Road and Sarojini Nagar, which may not be relatable to viewers from other regions",
57
  "timestamp_range": "N/A",
58
- "evidence": "Linking Road, Sarojini Nagar",
59
  "is_falsifiable": true,
60
  "severity": "low"
61
  }
62
  ],
63
  "overall_severity": "medium",
64
- "raw_response": "{\n \"claims\": [\n {\n \"claim_id\": \"C1\",\n \"critique_class\": \"hook_weakness\",\n \"claim_text\": \"The hook at 0:00-0:03 promises five outfits for one thousand rupees, but the reveal at 2:04-2:10 shows the total cost is five hundred rupees, which may be confusing for viewers who were expecting to see how to create outfits for one thousand rupees\",\n \"timestamp_range\": \"0:00-0:03, 2:04-2:10\",\n \"evidence\": \"Five outfits, one thousand rupees. ... Grand total \u00e2\u20ac\u201d five hundred rupees for five outfits.\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C2\",\n \"critique_class\": \"pacing_issue\",\n \"claim_text\": \"The script at 0:44-0:56 has a pacing issue where the creator says 'wait I need to find it. Okay found it.' which disrupts the flow of showcasing outfits\",\n \"timestamp_range\": \"0:44-0:56\",\n \"evidence\": \"Outfit three \u00e2\u20ac\u201d wait I need to find it. Okay found it.\",\n \"is_falsifiable\": true,\n \"severity\": \"low\"\n },\n {\n \"claim_id\": \"C3\",\n \"critique_class\": \"retention_risk\",\n \"claim_text\": \"The script at 2:20-2:30 has a retention risk where the creator asks viewers to comment their city, which may not be engaging enough to retain viewers' interest\",\n \"timestamp_range\": \"2:20-2:30\",\n \"evidence\": \"Comment your city and I'll do a version for your local markets.\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C4\",\n \"critique_class\": \"coherence_break\",\n \"claim_text\": \"The script at 1:46-1:54 has a coherence break where the creator abruptly shifts from showcasing outfits to talking about a saree drape tutorial\",\n \"timestamp_range\": \"1:46-1:54\",\n \"evidence\": \"Outfit five \u00e2\u20ac\u201d this entire saree drape tutorial took me two hours\",\n \"is_falsifiable\": true,\n \"severity\": \"high\"\n },\n {\n \"claim_id\": \"C5\",\n \"critique_class\": \"cta_buried\",\n \"claim_text\": \"The script buries the call-to-action at 2:20-2:30, where the creator asks viewers to comment their city, but it's not prominent enough\",\n \"timestamp_range\": \"2:20-2:30\",\n \"evidence\": \"Comment your city and I'll do a version for your local markets.\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C6\",\n \"critique_class\": \"cultural_mismatch\",\n \"claim_text\": \"The script assumes all viewers are familiar with Mumbai-based markets like Linking Road and Sarojini Nagar, which may not be relatable to viewers from other regions\",\n \"timestamp_range\": \"N/A\",\n \"evidence\": \"Linking Road, Sarojini Nagar\",\n \"is_falsifiable\": true,\n \"severity\": \"low\"\n }\n ],\n \"overall_severity\": \"medium\"\n}"
65
  }
66
  }
 
8
  {
9
  "claim_id": "C1",
10
  "critique_class": "hook_weakness",
11
+ "claim_text": "The hook at 0:00-0:03 promises five outfits for one thousand rupees, but the reveal at 1:45 shows the total cost is five hundred rupees, which may confuse viewers",
12
+ "timestamp_range": "0:00-0:03, 1:45",
13
+ "evidence": "Five outfits, one thousand rupees",
14
  "is_falsifiable": true,
15
  "severity": "medium"
16
  },
17
  {
18
  "claim_id": "C2",
19
  "critique_class": "pacing_issue",
20
+ "claim_text": "The script has an abrupt pause at 0:45-0:50 where the creator says 'wait I need to find it. Okay found it.', which disrupts the flow of the video",
21
+ "timestamp_range": "0:45-0:50",
22
+ "evidence": "wait I need to find it. Okay found it.",
23
  "is_falsifiable": true,
24
  "severity": "low"
25
  },
26
  {
27
  "claim_id": "C3",
28
  "critique_class": "retention_risk",
29
+ "claim_text": "The creator asks viewers to 'save this video' at 1:20-1:25, but this call-to-action may not be compelling enough to encourage viewers to take action",
30
+ "timestamp_range": "1:20-1:25",
31
+ "evidence": "this entire saree drape tutorial took me two hours so please save this video",
32
  "is_falsifiable": true,
33
  "severity": "medium"
34
  },
35
  {
36
  "claim_id": "C4",
37
+ "critique_class": "cta_buried",
38
+ "claim_text": "The call-to-action 'Comment your city and I'll do a version for your local markets' is buried at the end of the script at 1:50-2:00, which may be missed by viewers who drop off early",
39
+ "timestamp_range": "1:50-2:00",
40
+ "evidence": "Comment your city and I'll do a version for your local markets",
41
  "is_falsifiable": true,
42
  "severity": "high"
43
  },
44
  {
45
  "claim_id": "C5",
46
+ "critique_class": "coherence_break",
47
+ "claim_text": "The script jumps abruptly from showcasing outfits to providing a grand total at 1:45, without a clear transition or summary of the previous outfits",
48
+ "timestamp_range": "1:40-1:50",
49
+ "evidence": "Grand total \u00e2\u20ac\u201d five hundred rupees for five outfits",
50
  "is_falsifiable": true,
51
  "severity": "medium"
52
  },
53
  {
54
  "claim_id": "C6",
55
  "critique_class": "cultural_mismatch",
56
+ "claim_text": "The script assumes that all viewers are familiar with Mumbai's local markets and terms like 'Linking Road' and 'Sarojini Nagar', which may not resonate with viewers from other regions",
57
  "timestamp_range": "N/A",
58
+ "evidence": "thrifted kurta from Linking Road, forty rupees, styled with mom's old dupatta, zero rupees",
59
  "is_falsifiable": true,
60
  "severity": "low"
61
  }
62
  ],
63
  "overall_severity": "medium",
64
+ "raw_response": "{\n \"claims\": [\n {\n \"claim_id\": \"C1\",\n \"critique_class\": \"hook_weakness\",\n \"claim_text\": \"The hook at 0:00-0:03 promises five outfits for one thousand rupees, but the reveal at 1:45 shows the total cost is five hundred rupees, which may confuse viewers\",\n \"timestamp_range\": \"0:00-0:03, 1:45\",\n \"evidence\": \"Five outfits, one thousand rupees\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C2\",\n \"critique_class\": \"pacing_issue\",\n \"claim_text\": \"The script has an abrupt pause at 0:45-0:50 where the creator says 'wait I need to find it. Okay found it.', which disrupts the flow of the video\",\n \"timestamp_range\": \"0:45-0:50\",\n \"evidence\": \"wait I need to find it. Okay found it.\",\n \"is_falsifiable\": true,\n \"severity\": \"low\"\n },\n {\n \"claim_id\": \"C3\",\n \"critique_class\": \"retention_risk\",\n \"claim_text\": \"The creator asks viewers to 'save this video' at 1:20-1:25, but this call-to-action may not be compelling enough to encourage viewers to take action\",\n \"timestamp_range\": \"1:20-1:25\",\n \"evidence\": \"this entire saree drape tutorial took me two hours so please save this video\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C4\",\n \"critique_class\": \"cta_buried\",\n \"claim_text\": \"The call-to-action 'Comment your city and I'll do a version for your local markets' is buried at the end of the script at 1:50-2:00, which may be missed by viewers who drop off early\",\n \"timestamp_range\": \"1:50-2:00\",\n \"evidence\": \"Comment your city and I'll do a version for your local markets\",\n \"is_falsifiable\": true,\n \"severity\": \"high\"\n },\n {\n \"claim_id\": \"C5\",\n \"critique_class\": \"coherence_break\",\n \"claim_text\": \"The script jumps abruptly from showcasing outfits to providing a grand total at 1:45, without a clear transition or summary of the previous outfits\",\n \"timestamp_range\": \"1:40-1:50\",\n \"evidence\": \"Grand total \u00e2\u20ac\u201d five hundred rupees for five outfits\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C6\",\n \"critique_class\": \"cultural_mismatch\",\n \"claim_text\": \"The script assumes that all viewers are familiar with Mumbai's local markets and terms like 'Linking Road' and 'Sarojini Nagar', which may not resonate with viewers from other regions\",\n \"timestamp_range\": \"N/A\",\n \"evidence\": \"thrifted kurta from Linking Road, forty rupees, styled with mom's old dupatta, zero rupees\",\n \"is_falsifiable\": true,\n \"severity\": \"low\"\n }\n ],\n \"overall_severity\": \"medium\"\n}"
65
  }
66
  }
viral_script_engine/environment/env.py CHANGED
@@ -3,6 +3,7 @@ import random
3
  from typing import Optional, Tuple
4
 
5
  from viral_script_engine.agents.critic import CriticAgent
 
6
  from viral_script_engine.agents.rewriter import RewriterAgent
7
  from viral_script_engine.environment.actions import ArbitratorAction
8
  from viral_script_engine.environment.episode_state import EpisodeState
@@ -11,6 +12,9 @@ from viral_script_engine.environment.observations import (
11
  )
12
  from viral_script_engine.rewards.r1_hook_strength import HookStrengthReward
13
  from viral_script_engine.rewards.r2_coherence import CoherenceReward
 
 
 
14
  from viral_script_engine.rewards.reward_aggregator import RewardAggregator
15
 
16
  _TIERS = {
@@ -28,6 +32,7 @@ class ViralScriptEnv:
28
  max_steps: int = 5,
29
  difficulty: str = "easy",
30
  use_anti_gaming: bool = True,
 
31
  ):
32
  self.max_steps = max_steps
33
  self.difficulty = difficulty
@@ -39,9 +44,13 @@ class ViralScriptEnv:
39
  tier_ids = _TIERS[difficulty]
40
  self._scripts = [s for s in all_scripts if s["script_id"] in tier_ids]
41
  self.critic = CriticAgent()
 
42
  self.rewriter = RewriterAgent()
43
  self.r1 = HookStrengthReward()
44
  self.r2 = CoherenceReward()
 
 
 
45
  self.aggregator = RewardAggregator()
46
  self._state: Optional[EpisodeState] = None
47
 
@@ -52,9 +61,11 @@ class ViralScriptEnv:
52
 
53
  r1_result = self.r1.score(script["script_text"])
54
  r2_result = self.r2.score(script["script_text"], script["script_text"])
 
55
  initial_rewards = RewardComponents(
56
  r1_hook_strength=r1_result.score,
57
  r2_coherence=r2_result.score,
 
58
  )
59
  initial_rewards.compute_total()
60
 
@@ -79,27 +90,68 @@ class ViralScriptEnv:
79
  niche=self._state.niche,
80
  )
81
 
 
 
 
 
 
 
 
82
  rewrite_result = self.rewriter.rewrite(self._state.current_script, arb_action)
83
  new_script = rewrite_result.rewritten_script
84
 
85
  r1_result = self.r1.score(new_script)
86
  r2_result = self.r2.score(self._state.original_script, new_script)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  components = RewardComponents(
88
  r1_hook_strength=r1_result.score,
89
  r2_coherence=r2_result.score,
 
 
 
90
  )
91
 
92
  self._state.action_history.append(arb_action.action_type)
93
  if self.use_anti_gaming:
94
- components = self.aggregator.compute(
95
- components, self._state.episode_start_rewards, self._state.action_history
 
 
 
 
96
  )
97
  else:
98
  components.compute_total()
 
 
 
 
 
 
 
 
 
99
 
100
  round_ = DebateRound(
101
  step_num=self._state.step_num,
102
  critic_claims=critique.claims,
 
103
  arbitrator_action=arb_action,
104
  rewrite_diff=rewrite_result.diff,
105
  reward_components=components,
@@ -109,14 +161,19 @@ class ViralScriptEnv:
109
  self._state.last_reward_components = components
110
  self._state.step_num += 1
111
 
 
 
 
 
112
  terminated = (
113
  self._state.step_num >= self._state.max_steps
114
  or components.total >= 0.9
115
  )
116
  info = {
117
  "reward_components": components.model_dump(),
118
- "anti_gaming_triggered": components.anti_gaming_penalty > 0,
119
- "penalty_reason": "anti_gaming" if components.anti_gaming_penalty > 0 else None,
 
120
  }
121
  return self._build_observation().model_dump(), components.total, terminated, False, info
122
 
@@ -132,6 +189,7 @@ class ViralScriptEnv:
132
  "step_num": s.step_num,
133
  "difficulty_level": s.difficulty_level,
134
  "episode_id": s.episode_id,
 
135
  }
136
 
137
  def _build_observation(self) -> Observation:
 
3
  from typing import Optional, Tuple
4
 
5
  from viral_script_engine.agents.critic import CriticAgent
6
+ from viral_script_engine.agents.defender import DefenderAgent
7
  from viral_script_engine.agents.rewriter import RewriterAgent
8
  from viral_script_engine.environment.actions import ArbitratorAction
9
  from viral_script_engine.environment.episode_state import EpisodeState
 
12
  )
13
  from viral_script_engine.rewards.r1_hook_strength import HookStrengthReward
14
  from viral_script_engine.rewards.r2_coherence import CoherenceReward
15
+ from viral_script_engine.rewards.r3_cultural_alignment import CulturalAlignmentReward
16
+ from viral_script_engine.rewards.r4_debate_resolution import DebateResolutionReward
17
+ from viral_script_engine.rewards.r5_defender_preservation import DefenderPreservationReward
18
  from viral_script_engine.rewards.reward_aggregator import RewardAggregator
19
 
20
  _TIERS = {
 
32
  max_steps: int = 5,
33
  difficulty: str = "easy",
34
  use_anti_gaming: bool = True,
35
+ cultural_kb_path: str = "data/cultural_kb.json",
36
  ):
37
  self.max_steps = max_steps
38
  self.difficulty = difficulty
 
44
  tier_ids = _TIERS[difficulty]
45
  self._scripts = [s for s in all_scripts if s["script_id"] in tier_ids]
46
  self.critic = CriticAgent()
47
+ self.defender = DefenderAgent()
48
  self.rewriter = RewriterAgent()
49
  self.r1 = HookStrengthReward()
50
  self.r2 = CoherenceReward()
51
+ self.r3 = CulturalAlignmentReward(knowledge_base_path=cultural_kb_path)
52
+ self.r4 = DebateResolutionReward(critic_agent=self.critic)
53
+ self.r5 = DefenderPreservationReward()
54
  self.aggregator = RewardAggregator()
55
  self._state: Optional[EpisodeState] = None
56
 
 
61
 
62
  r1_result = self.r1.score(script["script_text"])
63
  r2_result = self.r2.score(script["script_text"], script["script_text"])
64
+ r3_result = self.r3.score(script["script_text"], script.get("region", "pan_india_english"))
65
  initial_rewards = RewardComponents(
66
  r1_hook_strength=r1_result.score,
67
  r2_coherence=r2_result.score,
68
+ r3_cultural_alignment=r3_result.score,
69
  )
70
  initial_rewards.compute_total()
71
 
 
90
  niche=self._state.niche,
91
  )
92
 
93
+ defender_output = self.defender.defend(
94
+ script=self._state.current_script,
95
+ critic_claims=critique.claims,
96
+ region=self._state.region,
97
+ platform=self._state.platform,
98
+ )
99
+
100
  rewrite_result = self.rewriter.rewrite(self._state.current_script, arb_action)
101
  new_script = rewrite_result.rewritten_script
102
 
103
  r1_result = self.r1.score(new_script)
104
  r2_result = self.r2.score(self._state.original_script, new_script)
105
+ r3_result = self.r3.score(new_script, self._state.region)
106
+
107
+ targeted_claim = next(
108
+ (c for c in critique.claims if c.claim_id == arb_action.critique_claim_id),
109
+ critique.claims[0] if critique.claims else None,
110
+ )
111
+ r4_result = self.r4.score(
112
+ new_script=new_script,
113
+ original_action=arb_action,
114
+ original_claim=targeted_claim,
115
+ region=self._state.region,
116
+ platform=self._state.platform,
117
+ niche=self._state.niche,
118
+ ) if targeted_claim else None
119
+
120
+ r5_result = self.r5.score(defender_output, new_script)
121
+
122
  components = RewardComponents(
123
  r1_hook_strength=r1_result.score,
124
  r2_coherence=r2_result.score,
125
+ r3_cultural_alignment=r3_result.score,
126
+ r4_debate_resolution=r4_result.score if r4_result else None,
127
+ r5_defender_preservation=r5_result.score,
128
  )
129
 
130
  self._state.action_history.append(arb_action.action_type)
131
  if self.use_anti_gaming:
132
+ components, anti_log = self.aggregator.compute(
133
+ components,
134
+ self._state.episode_start_rewards,
135
+ self._state.action_history,
136
+ episode_id=self._state.episode_id,
137
+ step_num=self._state.step_num,
138
  )
139
  else:
140
  components.compute_total()
141
+ from viral_script_engine.rewards.reward_aggregator import AntiGamingLog
142
+ anti_log = AntiGamingLog(
143
+ episode_id=self._state.episode_id,
144
+ step_num=self._state.step_num,
145
+ triggered=False,
146
+ penalty_applied=0.0,
147
+ pre_penalty_total=components.total,
148
+ post_penalty_total=components.total,
149
+ )
150
 
151
  round_ = DebateRound(
152
  step_num=self._state.step_num,
153
  critic_claims=critique.claims,
154
+ defender_response=defender_output.model_dump(),
155
  arbitrator_action=arb_action,
156
  rewrite_diff=rewrite_result.diff,
157
  reward_components=components,
 
161
  self._state.last_reward_components = components
162
  self._state.step_num += 1
163
 
164
+ if not hasattr(self._state, "anti_gaming_logs"):
165
+ self._state.anti_gaming_logs = []
166
+ self._state.anti_gaming_logs.append(anti_log.model_dump())
167
+
168
  terminated = (
169
  self._state.step_num >= self._state.max_steps
170
  or components.total >= 0.9
171
  )
172
  info = {
173
  "reward_components": components.model_dump(),
174
+ "anti_gaming_triggered": anti_log.triggered,
175
+ "penalty_reason": anti_log.rule_triggered,
176
+ "anti_gaming_log": anti_log.model_dump(),
177
  }
178
  return self._build_observation().model_dump(), components.total, terminated, False, info
179
 
 
189
  "step_num": s.step_num,
190
  "difficulty_level": s.difficulty_level,
191
  "episode_id": s.episode_id,
192
+ "anti_gaming_logs": getattr(s, "anti_gaming_logs", []),
193
  }
194
 
195
  def _build_observation(self) -> Observation:
viral_script_engine/rewards/r3_cultural_alignment.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+ import json
4
+ import re
5
+
6
+
7
+ @dataclass
8
+ class CulturalRewardResult:
9
+ score: float
10
+ valid_refs_found: List[str]
11
+ correct_idioms_found: List[str]
12
+ invalid_signals_found: List[str]
13
+ anachronistic_signals_found: List[str]
14
+ region: str
15
+
16
+
17
+ class CulturalAlignmentReward:
18
+ def __init__(self, knowledge_base_path: str = "data/cultural_kb.json"):
19
+ with open(knowledge_base_path, "r", encoding="utf-8") as f:
20
+ self._kb = json.load(f)
21
+
22
+ def score(self, script: str, region: str) -> CulturalRewardResult:
23
+ if region not in self._kb:
24
+ return CulturalRewardResult(
25
+ score=0.5,
26
+ valid_refs_found=[],
27
+ correct_idioms_found=[],
28
+ invalid_signals_found=[],
29
+ anachronistic_signals_found=[],
30
+ region=region,
31
+ )
32
+
33
+ kb = self._kb[region]
34
+ script_lower = script.lower()
35
+
36
+ valid_refs_found = [r for r in kb["valid_refs"] if r.lower() in script_lower]
37
+ correct_idioms_found = [i for i in kb["correct_idioms"] if i.lower() in script_lower]
38
+ invalid_signals_found = [s for s in kb["invalid_signals"] if s.lower() in script_lower]
39
+ anachronistic_signals_found = [a for a in kb["anachronistic_signals"] if a.lower() in script_lower]
40
+
41
+ numerator = (
42
+ len(valid_refs_found)
43
+ + len(correct_idioms_found)
44
+ - len(invalid_signals_found)
45
+ - len(anachronistic_signals_found)
46
+ )
47
+ denominator = max(len(kb["valid_refs"]) + len(kb["correct_idioms"]), 1)
48
+ raw_score = numerator / denominator
49
+ clipped_score = max(0.0, min(1.0, raw_score))
50
+
51
+ return CulturalRewardResult(
52
+ score=clipped_score,
53
+ valid_refs_found=valid_refs_found,
54
+ correct_idioms_found=correct_idioms_found,
55
+ invalid_signals_found=invalid_signals_found,
56
+ anachronistic_signals_found=anachronistic_signals_found,
57
+ region=region,
58
+ )
viral_script_engine/rewards/r4_debate_resolution.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from viral_script_engine.agents.critic import CriticAgent, CritiqueClaim
4
+ from viral_script_engine.environment.actions import ArbitratorAction
5
+
6
+
7
+ @dataclass
8
+ class DebateResolutionResult:
9
+ score: float
10
+ resolution_status: str
11
+ original_claim_id: str
12
+ original_claim_class: str
13
+ new_claims_count: int
14
+
15
+
16
+ def _parse_seconds(ts: str) -> float:
17
+ if not ts or ts == "N/A":
18
+ return -1.0
19
+ try:
20
+ start = ts.split("-")[0].strip()
21
+ parts = start.split(":")
22
+ return float(parts[0]) * 60 + float(parts[1])
23
+ except Exception:
24
+ return -1.0
25
+
26
+
27
+ class DebateResolutionReward:
28
+ def __init__(self, critic_agent: CriticAgent):
29
+ self.critic = critic_agent
30
+
31
+ def score(
32
+ self,
33
+ new_script: str,
34
+ original_action: ArbitratorAction,
35
+ original_claim: CritiqueClaim,
36
+ region: str,
37
+ platform: str,
38
+ niche: str,
39
+ ) -> DebateResolutionResult:
40
+ new_critique = self.critic.critique(new_script, region, platform, niche)
41
+ new_claims = new_critique.claims
42
+
43
+ orig_ts = _parse_seconds(original_claim.timestamp_range)
44
+ orig_class = original_claim.critique_class
45
+
46
+ matching = []
47
+ for c in new_claims:
48
+ if c.critique_class != orig_class:
49
+ continue
50
+ if orig_ts == -1.0:
51
+ matching.append(c)
52
+ else:
53
+ new_ts = _parse_seconds(c.timestamp_range)
54
+ if new_ts != -1.0 and abs(new_ts - orig_ts) <= 5:
55
+ matching.append(c)
56
+
57
+ if not matching:
58
+ return DebateResolutionResult(
59
+ score=1.0,
60
+ resolution_status="resolved",
61
+ original_claim_id=original_claim.claim_id,
62
+ original_claim_class=orig_class,
63
+ new_claims_count=len(new_claims),
64
+ )
65
+
66
+ worst = max(matching, key=lambda c: {"high": 2, "medium": 1, "low": 0}.get(c.severity, 1))
67
+ if worst.severity == "low":
68
+ return DebateResolutionResult(
69
+ score=0.5,
70
+ resolution_status="partially_resolved",
71
+ original_claim_id=original_claim.claim_id,
72
+ original_claim_class=orig_class,
73
+ new_claims_count=len(new_claims),
74
+ )
75
+
76
+ return DebateResolutionResult(
77
+ score=0.0,
78
+ resolution_status="persists",
79
+ original_claim_id=original_claim.claim_id,
80
+ original_claim_class=orig_class,
81
+ new_claims_count=len(new_claims),
82
+ )
viral_script_engine/rewards/r5_defender_preservation.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List
3
+
4
+ from viral_script_engine.agents.defender import DefenderOutput
5
+
6
+
7
+ @dataclass
8
+ class DefenderPreservationResult:
9
+ score: float
10
+ max_similarity: float
11
+ best_matching_sentence: str
12
+
13
+
14
+ def _sentence_split(text: str) -> List[str]:
15
+ import re
16
+ sentences = re.split(r"(?<=[.!?])\s+", text.strip())
17
+ return [s for s in sentences if s.strip()]
18
+
19
+
20
+ class DefenderPreservationReward:
21
+ _model = None
22
+ _cache: dict = {}
23
+
24
+ def _get_model(self):
25
+ if DefenderPreservationReward._model is None:
26
+ from sentence_transformers import SentenceTransformer
27
+ DefenderPreservationReward._model = SentenceTransformer("all-MiniLM-L6-v2")
28
+ return DefenderPreservationReward._model
29
+
30
+ def _embed(self, text: str):
31
+ import hashlib
32
+ key = hashlib.sha256(text.encode()).hexdigest()
33
+ if key not in DefenderPreservationReward._cache:
34
+ DefenderPreservationReward._cache[key] = self._get_model().encode(
35
+ text, convert_to_tensor=True
36
+ )
37
+ return DefenderPreservationReward._cache[key]
38
+
39
+ def score(self, defender_output: DefenderOutput, rewritten_script: str) -> DefenderPreservationResult:
40
+ from sentence_transformers.util import cos_sim
41
+
42
+ quote_emb = self._embed(defender_output.core_strength_quote)
43
+ sentences = _sentence_split(rewritten_script)
44
+
45
+ if not sentences:
46
+ return DefenderPreservationResult(score=0.0, max_similarity=0.0, best_matching_sentence="")
47
+
48
+ sims = [(float(cos_sim(quote_emb, self._embed(s))[0][0]), s) for s in sentences]
49
+ max_sim, best_sent = max(sims, key=lambda x: x[0])
50
+
51
+ if max_sim >= 0.85:
52
+ final_score = 1.0
53
+ elif max_sim >= 0.65:
54
+ final_score = max_sim
55
+ else:
56
+ final_score = 0.0
57
+
58
+ return DefenderPreservationResult(
59
+ score=final_score,
60
+ max_similarity=max_sim,
61
+ best_matching_sentence=best_sent,
62
+ )
viral_script_engine/rewards/reward_aggregator.py CHANGED
@@ -1,5 +1,7 @@
1
  import logging
2
- from typing import List
 
 
3
 
4
  from viral_script_engine.environment.actions import ActionType
5
  from viral_script_engine.environment.observations import RewardComponents
@@ -11,6 +13,19 @@ _COMPONENT_FIELDS = [
11
  "r4_debate_resolution", "r5_defender_preservation",
12
  ]
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  class RewardAggregator:
16
  def compute(
@@ -18,20 +33,31 @@ class RewardAggregator:
18
  components: RewardComponents,
19
  episode_start_components: RewardComponents,
20
  action_history: List[ActionType],
21
- ) -> RewardComponents:
 
 
22
  components.compute_total()
 
23
 
24
- # Anti-gaming rule 1: catastrophic drop (>0.2 drop in any component)
25
  for field in _COMPONENT_FIELDS:
26
  curr = getattr(components, field)
27
  start = getattr(episode_start_components, field)
28
- if curr is not None and start is not None and curr < start - 0.2:
29
  logger.warning("Catastrophic drop in %s: %.3f -> %.3f", field, start, curr)
30
  components.total = 0.0
31
  components.anti_gaming_penalty = start - curr
32
- return components
 
 
 
 
 
 
 
 
 
 
33
 
34
- # Anti-gaming rule 2: action diversity (last 3 same ActionType)
35
  penalty = 0.0
36
  if len(action_history) >= 3 and len(set(action_history[-3:])) == 1:
37
  penalty = 0.15
@@ -39,4 +65,15 @@ class RewardAggregator:
39
 
40
  components.anti_gaming_penalty = penalty
41
  components.total = max(0.0, min(1.0, components.total - penalty))
42
- return components
 
 
 
 
 
 
 
 
 
 
 
 
1
  import logging
2
+ from typing import List, Optional, Tuple
3
+
4
+ from pydantic import BaseModel
5
 
6
  from viral_script_engine.environment.actions import ActionType
7
  from viral_script_engine.environment.observations import RewardComponents
 
13
  "r4_debate_resolution", "r5_defender_preservation",
14
  ]
15
 
16
+ _DROP_THRESHOLD = 0.25
17
+
18
+
19
+ class AntiGamingLog(BaseModel):
20
+ episode_id: str
21
+ step_num: int
22
+ triggered: bool
23
+ rule_triggered: Optional[str] = None
24
+ component_that_dropped: Optional[str] = None
25
+ penalty_applied: float
26
+ pre_penalty_total: float
27
+ post_penalty_total: float
28
+
29
 
30
  class RewardAggregator:
31
  def compute(
 
33
  components: RewardComponents,
34
  episode_start_components: RewardComponents,
35
  action_history: List[ActionType],
36
+ episode_id: str = "",
37
+ step_num: int = 0,
38
+ ) -> Tuple[RewardComponents, AntiGamingLog]:
39
  components.compute_total()
40
+ pre_penalty_total = components.total
41
 
 
42
  for field in _COMPONENT_FIELDS:
43
  curr = getattr(components, field)
44
  start = getattr(episode_start_components, field)
45
+ if curr is not None and start is not None and curr < start - _DROP_THRESHOLD:
46
  logger.warning("Catastrophic drop in %s: %.3f -> %.3f", field, start, curr)
47
  components.total = 0.0
48
  components.anti_gaming_penalty = start - curr
49
+ log = AntiGamingLog(
50
+ episode_id=episode_id,
51
+ step_num=step_num,
52
+ triggered=True,
53
+ rule_triggered="catastrophic_drop",
54
+ component_that_dropped=field,
55
+ penalty_applied=components.anti_gaming_penalty,
56
+ pre_penalty_total=pre_penalty_total,
57
+ post_penalty_total=0.0,
58
+ )
59
+ return components, log
60
 
 
61
  penalty = 0.0
62
  if len(action_history) >= 3 and len(set(action_history[-3:])) == 1:
63
  penalty = 0.15
 
65
 
66
  components.anti_gaming_penalty = penalty
67
  components.total = max(0.0, min(1.0, components.total - penalty))
68
+
69
+ log = AntiGamingLog(
70
+ episode_id=episode_id,
71
+ step_num=step_num,
72
+ triggered=penalty > 0,
73
+ rule_triggered="action_repetition" if penalty > 0 else None,
74
+ component_that_dropped=None,
75
+ penalty_applied=penalty,
76
+ pre_penalty_total=pre_penalty_total,
77
+ post_penalty_total=components.total,
78
+ )
79
+ return components, log
viral_script_engine/scripts/run_baseline.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import json
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ from dotenv import load_dotenv
8
+ from rich.console import Console
9
+ from rich.table import Table
10
+ from rich import box
11
+
12
+ load_dotenv()
13
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
14
+
15
+ from viral_script_engine.agents.baseline_arbitrator import BaselineArbitratorAgent
16
+ from viral_script_engine.environment.env import ViralScriptEnv
17
+
18
+ console = Console()
19
+ BASE_DIR = Path(__file__).parent.parent
20
+ LOGS_DIR = BASE_DIR / "logs"
21
+ LOGS_DIR.mkdir(exist_ok=True)
22
+
23
+ _SCHEDULE = (
24
+ [(i, "easy") for i in range(1, 9)]
25
+ + [(i, "medium") for i in range(9, 17)]
26
+ + [(i, "hard") for i in range(17, 21)]
27
+ )
28
+
29
+ _REWARD_KEYS = ["r1_hook_strength", "r2_coherence", "r3_cultural_alignment",
30
+ "r4_debate_resolution", "r5_defender_preservation"]
31
+
32
+
33
+ def _make_env(difficulty: str) -> ViralScriptEnv:
34
+ return ViralScriptEnv(
35
+ scripts_path=str(BASE_DIR / "data" / "test_scripts" / "scripts.json"),
36
+ cultural_kb_path=str(BASE_DIR / "data" / "cultural_kb.json"),
37
+ max_steps=5,
38
+ difficulty=difficulty,
39
+ )
40
+
41
+
42
+ def run_episode(ep_num: int, difficulty: str, agent: BaselineArbitratorAgent) -> dict:
43
+ env = _make_env(difficulty)
44
+ obs, _ = env.reset()
45
+
46
+ episode_id = obs["episode_id"]
47
+ script_id = "unknown"
48
+ state = env.state()
49
+ original_script = state.get("original_script", "")
50
+
51
+ steps_log = []
52
+ total_reward = 0.0
53
+
54
+ for _ in range(env.max_steps):
55
+ action = agent.act(obs)
56
+ obs, reward, terminated, truncated, info = env.step(action)
57
+ rc = info["reward_components"]
58
+ anti_log = info.get("anti_gaming_log", {})
59
+
60
+ step_entry = {
61
+ "r1": rc.get("r1_hook_strength"),
62
+ "r2": rc.get("r2_coherence"),
63
+ "r3": rc.get("r3_cultural_alignment"),
64
+ "r4": rc.get("r4_debate_resolution"),
65
+ "r5": rc.get("r5_defender_preservation"),
66
+ "total": reward,
67
+ "anti_gaming_triggered": anti_log.get("triggered", False),
68
+ "penalty": anti_log.get("penalty_applied", 0.0),
69
+ }
70
+ steps_log.append(step_entry)
71
+ total_reward = reward
72
+
73
+ if terminated or truncated:
74
+ break
75
+
76
+ final_state = env.state()
77
+ final_script = final_state.get("current_script", "")
78
+
79
+ return {
80
+ "episode_num": ep_num,
81
+ "episode_id": episode_id,
82
+ "difficulty": difficulty,
83
+ "script_id": script_id,
84
+ "steps": steps_log,
85
+ "total_reward": total_reward,
86
+ "anti_gaming_logs": final_state.get("anti_gaming_logs", []),
87
+ "original_script": original_script,
88
+ "final_script": final_script,
89
+ }
90
+
91
+
92
+ def main():
93
+ agent = BaselineArbitratorAgent()
94
+ all_episodes = []
95
+
96
+ for ep_num, difficulty in _SCHEDULE:
97
+ console.print(f"[dim]Episode {ep_num:02d}/20 ({difficulty})...[/dim]")
98
+ try:
99
+ result = run_episode(ep_num, difficulty, agent)
100
+ all_episodes.append(result)
101
+ console.print(
102
+ f" -> total_reward={result['total_reward']:.3f} "
103
+ f"steps={len(result['steps'])}"
104
+ )
105
+ except Exception as e:
106
+ console.print(f" [red]ERROR episode {ep_num}: {e}[/red]")
107
+ all_episodes.append({
108
+ "episode_num": ep_num,
109
+ "episode_id": "",
110
+ "difficulty": difficulty,
111
+ "script_id": "error",
112
+ "steps": [],
113
+ "total_reward": 0.0,
114
+ "anti_gaming_logs": [],
115
+ "original_script": "",
116
+ "final_script": "",
117
+ "error": str(e),
118
+ })
119
+
120
+ results_path = LOGS_DIR / "baseline_results.json"
121
+ with open(results_path, "w", encoding="utf-8") as f:
122
+ json.dump(all_episodes, f, indent=2, default=str)
123
+
124
+ _save_plots(all_episodes)
125
+ _print_summary(all_episodes)
126
+
127
+ mean_total = float(np.mean([e["total_reward"] for e in all_episodes]))
128
+ console.print(
129
+ f"\n[bold green]PHASE 2 GATE: PASS β€” Baseline curves saved. "
130
+ f"Pre-training mean total reward: {mean_total:.2f}[/bold green]"
131
+ )
132
+
133
+
134
+ def _collect_reward_series(episodes: list, key: str):
135
+ series = []
136
+ for ep in episodes:
137
+ vals = [s.get(key) for s in ep.get("steps", []) if s.get(key) is not None]
138
+ series.append(vals[-1] if vals else 0.0)
139
+ return series
140
+
141
+
142
+ def _save_plots(episodes: list):
143
+ import matplotlib
144
+ matplotlib.use("Agg")
145
+ import matplotlib.pyplot as plt
146
+
147
+ labels = {
148
+ "r1": "R1 Hook Strength",
149
+ "r2": "R2 Coherence",
150
+ "r3": "R3 Cultural Alignment",
151
+ "r4": "R4 Debate Resolution",
152
+ "r5": "R5 Defender Preservation",
153
+ "total": "Total Reward",
154
+ }
155
+ keys = list(labels.keys())
156
+ ep_nums = [e["episode_num"] for e in episodes]
157
+
158
+ fig, axes = plt.subplots(2, 3, figsize=(14, 8), dpi=150)
159
+ fig.suptitle(
160
+ "Baseline (Untrained) Arbitrator β€” Pre-Training Reward Curves",
161
+ fontsize=13,
162
+ )
163
+
164
+ for idx, key in enumerate(keys):
165
+ ax = axes[idx // 3][idx % 3]
166
+ series = _collect_reward_series(episodes, key) if key != "total" else [e["total_reward"] for e in episodes]
167
+ ax.plot(ep_nums, series, marker="o", linewidth=1.5, markersize=4)
168
+ ax.set_title(labels[key], fontsize=10)
169
+ ax.set_xlabel("Episode", fontsize=8)
170
+ ax.set_ylabel("Reward", fontsize=8)
171
+ ax.set_ylim(0, 1)
172
+ ax.set_xlim(min(ep_nums) - 0.5, max(ep_nums) + 0.5)
173
+ ax.tick_params(labelsize=7)
174
+ ax.grid(True, alpha=0.3)
175
+
176
+ plt.tight_layout()
177
+ plot_path = LOGS_DIR / "baseline_reward_curves.png"
178
+ plt.savefig(str(plot_path), dpi=150)
179
+ plt.close()
180
+ console.print(f"[dim]Curves saved -> {plot_path}[/dim]")
181
+
182
+
183
+ def _print_summary(episodes: list):
184
+ table = Table(title="Baseline Results β€” Mean +/- Std (20 episodes)", box=box.SIMPLE_HEAD)
185
+ table.add_column("Reward", style="cyan", min_width=28)
186
+ table.add_column("Mean", min_width=8)
187
+ table.add_column("Std", min_width=8)
188
+ table.add_column("Min", min_width=8)
189
+ table.add_column("Max", min_width=8)
190
+
191
+ label_map = {
192
+ "r1": "R1 Hook Strength",
193
+ "r2": "R2 Coherence",
194
+ "r3": "R3 Cultural Alignment",
195
+ "r4": "R4 Debate Resolution",
196
+ "r5": "R5 Defender Preservation",
197
+ "total": "Total Reward",
198
+ }
199
+ for key, label in label_map.items():
200
+ if key == "total":
201
+ vals = [e["total_reward"] for e in episodes]
202
+ else:
203
+ vals = _collect_reward_series(episodes, key)
204
+ arr = np.array(vals, dtype=float)
205
+ table.add_row(
206
+ label,
207
+ f"{arr.mean():.3f}",
208
+ f"{arr.std():.3f}",
209
+ f"{arr.min():.3f}",
210
+ f"{arr.max():.3f}",
211
+ )
212
+
213
+ console.print(table)
214
+
215
+
216
+ if __name__ == "__main__":
217
+ main()
viral_script_engine/tests/test_environment.py CHANGED
@@ -49,6 +49,10 @@ def env():
49
  with (
50
  patch("viral_script_engine.environment.env.CriticAgent") as mock_critic_cls,
51
  patch("viral_script_engine.environment.env.RewriterAgent") as mock_rewriter_cls,
 
 
 
 
52
  ):
53
  mock_critic = MagicMock()
54
  mock_critic.critique.return_value = make_mock_critique()
@@ -58,6 +62,39 @@ def env():
58
  mock_rewriter.rewrite.side_effect = make_mock_rewrite
59
  mock_rewriter_cls.return_value = mock_rewriter
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  from viral_script_engine.environment.env import ViralScriptEnv
62
  yield ViralScriptEnv(scripts_path=SCRIPTS_PATH, max_steps=5, difficulty="easy")
63
 
 
49
  with (
50
  patch("viral_script_engine.environment.env.CriticAgent") as mock_critic_cls,
51
  patch("viral_script_engine.environment.env.RewriterAgent") as mock_rewriter_cls,
52
+ patch("viral_script_engine.environment.env.DefenderAgent") as mock_defender_cls,
53
+ patch("viral_script_engine.environment.env.CulturalAlignmentReward") as mock_r3_cls,
54
+ patch("viral_script_engine.environment.env.DebateResolutionReward") as mock_r4_cls,
55
+ patch("viral_script_engine.environment.env.DefenderPreservationReward") as mock_r5_cls,
56
  ):
57
  mock_critic = MagicMock()
58
  mock_critic.critique.return_value = make_mock_critique()
 
62
  mock_rewriter.rewrite.side_effect = make_mock_rewrite
63
  mock_rewriter_cls.return_value = mock_rewriter
64
 
65
+ mock_defender = MagicMock()
66
+ from viral_script_engine.agents.defender import DefenderOutput
67
+ mock_defender.defend.return_value = DefenderOutput(
68
+ core_strength="strong hook",
69
+ core_strength_quote="test quote",
70
+ defense_argument="preserve it",
71
+ flagged_critic_claims=[],
72
+ regional_voice_elements=[],
73
+ )
74
+ mock_defender_cls.return_value = mock_defender
75
+
76
+ mock_r3 = MagicMock()
77
+ mock_r3.score.return_value = MagicMock(score=0.6)
78
+ mock_r3_cls.return_value = mock_r3
79
+
80
+ mock_r4 = MagicMock()
81
+ from viral_script_engine.rewards.r4_debate_resolution import DebateResolutionResult
82
+ mock_r4.score.return_value = DebateResolutionResult(
83
+ score=0.8,
84
+ resolution_status="resolved",
85
+ original_claim_id="C1",
86
+ original_claim_class="hook_weakness",
87
+ new_claims_count=2,
88
+ )
89
+ mock_r4_cls.return_value = mock_r4
90
+
91
+ mock_r5 = MagicMock()
92
+ from viral_script_engine.rewards.r5_defender_preservation import DefenderPreservationResult
93
+ mock_r5.score.return_value = DefenderPreservationResult(
94
+ score=0.9, max_similarity=0.9, best_matching_sentence="test quote"
95
+ )
96
+ mock_r5_cls.return_value = mock_r5
97
+
98
  from viral_script_engine.environment.env import ViralScriptEnv
99
  yield ViralScriptEnv(scripts_path=SCRIPTS_PATH, max_steps=5, difficulty="easy")
100
 
viral_script_engine/tests/test_phase2.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import pytest
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ from viral_script_engine.agents.defender import DefenderAgent, DefenderOutput, DefenderParseError
6
+ from viral_script_engine.agents.critic import CritiqueClaim
7
+ from viral_script_engine.environment.actions import ActionType, ArbitratorAction
8
+ from viral_script_engine.rewards.r3_cultural_alignment import CulturalAlignmentReward
9
+ from viral_script_engine.rewards.r4_debate_resolution import DebateResolutionReward, DebateResolutionResult
10
+ from viral_script_engine.rewards.r5_defender_preservation import DefenderPreservationReward
11
+ from viral_script_engine.rewards.reward_aggregator import RewardAggregator, AntiGamingLog
12
+ from viral_script_engine.environment.observations import RewardComponents
13
+
14
+
15
+ # ─── fixtures ────────────────────────────────────────────────────────────────
16
+
17
+ MOCK_DEFENDER_RESPONSE = json.dumps({
18
+ "core_strength": "Relatable hook about saving money",
19
+ "core_strength_quote": "Let me tell you a secret",
20
+ "defense_argument": "This creates immediate viewer curiosity and should not be changed.",
21
+ "flagged_critic_claims": ["C2"],
22
+ "regional_voice_elements": ["yaar", "ek dum solid"],
23
+ })
24
+
25
+ MOCK_CRITIQUE_CLAIMS = [
26
+ CritiqueClaim(
27
+ claim_id="C1",
28
+ critique_class="hook_weakness",
29
+ claim_text="Weak hook.",
30
+ timestamp_range="0:00-0:03",
31
+ evidence="Let me tell you a secret",
32
+ is_falsifiable=True,
33
+ severity="high",
34
+ ),
35
+ CritiqueClaim(
36
+ claim_id="C2",
37
+ critique_class="cta_buried",
38
+ claim_text="CTA at end.",
39
+ timestamp_range="0:45-0:50",
40
+ evidence="Like and save this video",
41
+ is_falsifiable=True,
42
+ severity="medium",
43
+ ),
44
+ ]
45
+
46
+
47
+ @pytest.fixture
48
+ def mock_defender_llm(monkeypatch):
49
+ monkeypatch.setattr(
50
+ "viral_script_engine.agents.llm_backend.LLMBackend.generate",
51
+ lambda self, sys_prompt, usr_prompt, **kw: MOCK_DEFENDER_RESPONSE,
52
+ )
53
+
54
+
55
+ # ─── Step 1: DefenderAgent ────────────────────────────────────────────────────
56
+
57
+ def test_defender_parses_output(mock_defender_llm):
58
+ agent = DefenderAgent()
59
+ result = agent.defend(
60
+ script="Let me tell you a secret about saving money. yaar, ek dum solid plan.",
61
+ critic_claims=MOCK_CRITIQUE_CLAIMS,
62
+ region="mumbai_gen_z",
63
+ platform="instagram",
64
+ )
65
+ assert isinstance(result, DefenderOutput)
66
+ assert result.core_strength_quote == "Let me tell you a secret"
67
+ assert "C2" in result.flagged_critic_claims
68
+ assert "yaar" in result.regional_voice_elements
69
+
70
+
71
+ def test_defender_retries_on_invalid_json(monkeypatch):
72
+ call_count = {"n": 0}
73
+
74
+ def fake_generate(self, sys_prompt, usr_prompt, **kw):
75
+ call_count["n"] += 1
76
+ if call_count["n"] == 1:
77
+ return "NOT JSON AT ALL"
78
+ return MOCK_DEFENDER_RESPONSE
79
+
80
+ monkeypatch.setattr(
81
+ "viral_script_engine.agents.llm_backend.LLMBackend.generate",
82
+ fake_generate,
83
+ )
84
+ agent = DefenderAgent()
85
+ result = agent.defend("script", MOCK_CRITIQUE_CLAIMS, "mumbai_gen_z", "instagram")
86
+ assert isinstance(result, DefenderOutput)
87
+ assert call_count["n"] == 2
88
+
89
+
90
+ def test_defender_raises_after_two_failures(monkeypatch):
91
+ monkeypatch.setattr(
92
+ "viral_script_engine.agents.llm_backend.LLMBackend.generate",
93
+ lambda self, sys_prompt, usr_prompt, **kw: "BAD JSON",
94
+ )
95
+ agent = DefenderAgent()
96
+ with pytest.raises(DefenderParseError):
97
+ agent.defend("script", MOCK_CRITIQUE_CLAIMS, "mumbai_gen_z", "instagram")
98
+
99
+
100
+ # ─── Step 2: R3 CulturalAlignmentReward ──────────────────────────────────────
101
+
102
+ @pytest.fixture
103
+ def r3(tmp_path):
104
+ kb = {
105
+ "mumbai_gen_z": {
106
+ "valid_refs": ["Bandra", "CSMT", "local train", "Swiggy", "IPL"],
107
+ "correct_idioms": ["ek dum solid", "kya scene hai", "full on"],
108
+ "invalid_signals": ["trunk call", "VHS", "walkman"],
109
+ "anachronistic_signals": [],
110
+ },
111
+ "tier2_hindi_belt": {
112
+ "valid_refs": ["kirana store", "sabzi mandi", "jugaad", "panchayat", "mela"],
113
+ "correct_idioms": ["bilkul sahi", "arey bhai", "seedha baat"],
114
+ "invalid_signals": ["SaaS", "venture capital", "coworking space"],
115
+ "anachronistic_signals": [],
116
+ },
117
+ }
118
+ kb_path = tmp_path / "test_kb.json"
119
+ kb_path.write_text(json.dumps(kb), encoding="utf-8")
120
+ return CulturalAlignmentReward(knowledge_base_path=str(kb_path))
121
+
122
+
123
+ def test_r3_scores_regional_script(r3):
124
+ script = "Take the local train to Bandra. IPL is on at night. ek dum solid plan yaar."
125
+ result = r3.score(script, "mumbai_gen_z")
126
+ assert result.score > 0.0
127
+ assert "local train" in result.valid_refs_found or "Bandra" in result.valid_refs_found
128
+
129
+
130
+ def test_r3_scores_non_regional_script_lower(r3):
131
+ script = "Buy on Amazon. Use your credit card. Free delivery available nationwide."
132
+ regional = r3.score(
133
+ "Take local train to Bandra. IPL is on. ek dum solid.", "mumbai_gen_z"
134
+ )
135
+ non_regional = r3.score(script, "mumbai_gen_z")
136
+ assert regional.score >= non_regional.score
137
+
138
+
139
+ def test_r3_penalises_invalid_signals(r3):
140
+ script = "This is like an old VHS walkman trunk call era."
141
+ result = r3.score(script, "mumbai_gen_z")
142
+ assert result.score == 0.0
143
+ assert len(result.invalid_signals_found) > 0
144
+
145
+
146
+ def test_r3_neutral_for_unknown_region(r3):
147
+ result = r3.score("any script", "unknown_region_xyz")
148
+ assert result.score == 0.5
149
+
150
+
151
+ def test_r3_tier2_valid(r3):
152
+ script = "Went to kirana store, met at sabzi mandi, pure jugaad. bilkul sahi plan."
153
+ result = r3.score(script, "tier2_hindi_belt")
154
+ assert result.score > 0.0
155
+
156
+
157
+ def test_r3_tier2_penalises_metro_jargon(r3):
158
+ script = "We raised SaaS venture capital at a coworking space."
159
+ result = r3.score(script, "tier2_hindi_belt")
160
+ assert len(result.invalid_signals_found) > 0
161
+
162
+
163
+ # ─── Step 3: R4 DebateResolutionReward ───────────────────────────────────────
164
+
165
+ def _make_critique_output(claims):
166
+ from viral_script_engine.agents.critic import CritiqueOutput
167
+ return CritiqueOutput(claims=claims, overall_severity="medium", raw_response="")
168
+
169
+
170
+ def _make_claim(claim_id, critique_class, timestamp_range, severity="high"):
171
+ return CritiqueClaim(
172
+ claim_id=claim_id,
173
+ critique_class=critique_class,
174
+ claim_text="test claim",
175
+ timestamp_range=timestamp_range,
176
+ evidence="evidence text",
177
+ is_falsifiable=True,
178
+ severity=severity,
179
+ )
180
+
181
+
182
+ def _make_action():
183
+ return ArbitratorAction(
184
+ action_type=ActionType.HOOK_REWRITE,
185
+ target_section="hook",
186
+ instruction="fix hook",
187
+ critique_claim_id="C1",
188
+ reasoning="test",
189
+ )
190
+
191
+
192
+ def test_r4_resolved_when_no_matching_claim():
193
+ mock_critic = MagicMock()
194
+ mock_critic.critique.return_value = _make_critique_output([
195
+ _make_claim("C1", "cta_buried", "0:45-0:50"),
196
+ ])
197
+ r4 = DebateResolutionReward(critic_agent=mock_critic)
198
+ original_claim = _make_claim("C1", "hook_weakness", "0:00-0:03", "high")
199
+ result = r4.score("new script", _make_action(), original_claim,
200
+ "mumbai_gen_z", "instagram", "finance")
201
+ assert result.score == 1.0
202
+ assert result.resolution_status == "resolved"
203
+
204
+
205
+ def test_r4_partially_resolved_when_severity_drops():
206
+ mock_critic = MagicMock()
207
+ mock_critic.critique.return_value = _make_critique_output([
208
+ _make_claim("C1", "hook_weakness", "0:01-0:04", "low"),
209
+ ])
210
+ r4 = DebateResolutionReward(critic_agent=mock_critic)
211
+ original_claim = _make_claim("C1", "hook_weakness", "0:00-0:03", "high")
212
+ result = r4.score("new script", _make_action(), original_claim,
213
+ "mumbai_gen_z", "instagram", "finance")
214
+ assert result.score == 0.5
215
+ assert result.resolution_status == "partially_resolved"
216
+
217
+
218
+ def test_r4_persists_when_same_severity():
219
+ mock_critic = MagicMock()
220
+ mock_critic.critique.return_value = _make_critique_output([
221
+ _make_claim("C1", "hook_weakness", "0:01-0:03", "high"),
222
+ ])
223
+ r4 = DebateResolutionReward(critic_agent=mock_critic)
224
+ original_claim = _make_claim("C1", "hook_weakness", "0:00-0:03", "high")
225
+ result = r4.score("new script", _make_action(), original_claim,
226
+ "mumbai_gen_z", "instagram", "finance")
227
+ assert result.score == 0.0
228
+ assert result.resolution_status == "persists"
229
+
230
+
231
+ # ─── Step 4: R5 DefenderPreservationReward ───────────────────────────────────
232
+
233
+ def _make_defender_output(quote: str) -> DefenderOutput:
234
+ return DefenderOutput(
235
+ core_strength="Strong opening",
236
+ core_strength_quote=quote,
237
+ defense_argument="Should be preserved.",
238
+ flagged_critic_claims=[],
239
+ regional_voice_elements=[],
240
+ )
241
+
242
+
243
+ def test_r5_high_score_when_quote_present():
244
+ r5 = DefenderPreservationReward()
245
+ quote = "Let me tell you a secret about saving money every month."
246
+ script = "Let me tell you a secret about saving money every month. This is the key insight."
247
+ defender_out = _make_defender_output(quote)
248
+ result = r5.score(defender_out, script)
249
+ assert result.score >= 0.85
250
+
251
+
252
+ def test_r5_zero_score_when_quote_absent():
253
+ r5 = DefenderPreservationReward()
254
+ quote = "Completely different text that shares nothing with rewrite."
255
+ script = "Today we discuss quantum physics and neutron stars in distant galaxies."
256
+ defender_out = _make_defender_output(quote)
257
+ result = r5.score(defender_out, script)
258
+ assert result.score < 0.65
259
+
260
+
261
+ # ─── Step 5/6: AntiGamingLog and RewardAggregator ────────────────────────────
262
+
263
+ def _make_components(**kwargs) -> RewardComponents:
264
+ defaults = dict(
265
+ r1_hook_strength=0.7, r2_coherence=0.7,
266
+ r3_cultural_alignment=0.7,
267
+ r4_debate_resolution=None,
268
+ r5_defender_preservation=None,
269
+ )
270
+ defaults.update(kwargs)
271
+ rc = RewardComponents(**defaults)
272
+ rc.compute_total()
273
+ return rc
274
+
275
+
276
+ def test_anti_gaming_catastrophic_drop_zeroes_reward():
277
+ aggregator = RewardAggregator()
278
+ start = _make_components(r2_coherence=0.8)
279
+ current = _make_components(r2_coherence=0.4)
280
+
281
+ result, log = aggregator.compute(current, start, [], episode_id="ep1", step_num=1)
282
+
283
+ assert result.total == 0.0
284
+ assert log.triggered is True
285
+ assert log.rule_triggered == "catastrophic_drop"
286
+ assert log.component_that_dropped == "r2_coherence"
287
+ assert log.post_penalty_total == 0.0
288
+
289
+
290
+ def test_anti_gaming_diversity_penalty_fires_on_3x_same():
291
+ aggregator = RewardAggregator()
292
+ start = _make_components()
293
+ current = _make_components()
294
+ history = [ActionType.HOOK_REWRITE] * 3
295
+
296
+ result, log = aggregator.compute(current, start, history, episode_id="ep2", step_num=2)
297
+
298
+ assert log.triggered is True
299
+ assert log.rule_triggered == "action_repetition"
300
+ assert log.penalty_applied == 0.15
301
+
302
+
303
+ def test_anti_gaming_log_not_triggered_clean():
304
+ aggregator = RewardAggregator()
305
+ start = _make_components()
306
+ current = _make_components()
307
+ history = [ActionType.HOOK_REWRITE, ActionType.CTA_PLACEMENT, ActionType.SECTION_REORDER]
308
+
309
+ result, log = aggregator.compute(current, start, history, episode_id="ep3", step_num=1)
310
+
311
+ assert log.triggered is False
312
+ assert log.rule_triggered is None
313
+ assert log.penalty_applied == 0.0
314
+
315
+
316
+ def test_anti_gaming_log_fields_populated():
317
+ aggregator = RewardAggregator()
318
+ start = _make_components(r1_hook_strength=0.9)
319
+ current = _make_components(r1_hook_strength=0.5)
320
+
321
+ _, log = aggregator.compute(current, start, [], episode_id="myep", step_num=3)
322
+
323
+ assert log.episode_id == "myep"
324
+ assert log.step_num == 3
325
+ assert isinstance(log, AntiGamingLog)
viral_script_engine/tests/test_rewards.py CHANGED
@@ -93,7 +93,7 @@ def test_aggregator_catastrophic_drop(aggregator):
93
  start = RewardComponents(r1_hook_strength=0.8, r2_coherence=0.7)
94
  start.compute_total()
95
  current = RewardComponents(r1_hook_strength=0.3, r2_coherence=0.7)
96
- result = aggregator.compute(current, start, [ActionType.HOOK_REWRITE])
97
  assert result.total == 0.0
98
 
99
 
@@ -102,7 +102,7 @@ def test_aggregator_diversity_penalty(aggregator):
102
  start.compute_total()
103
  current = RewardComponents(r1_hook_strength=0.7, r2_coherence=0.7)
104
  history = [ActionType.HOOK_REWRITE, ActionType.HOOK_REWRITE, ActionType.HOOK_REWRITE]
105
- result = aggregator.compute(current, start, history)
106
  assert result.anti_gaming_penalty == 0.15
107
  assert result.total < 0.7
108
 
@@ -112,6 +112,6 @@ def test_aggregator_no_penalty(aggregator):
112
  start.compute_total()
113
  current = RewardComponents(r1_hook_strength=0.7, r2_coherence=0.7)
114
  history = [ActionType.HOOK_REWRITE, ActionType.CTA_PLACEMENT, ActionType.SECTION_REORDER]
115
- result = aggregator.compute(current, start, history)
116
  assert result.anti_gaming_penalty == 0.0
117
  assert result.total > 0
 
93
  start = RewardComponents(r1_hook_strength=0.8, r2_coherence=0.7)
94
  start.compute_total()
95
  current = RewardComponents(r1_hook_strength=0.3, r2_coherence=0.7)
96
+ result, log = aggregator.compute(current, start, [ActionType.HOOK_REWRITE])
97
  assert result.total == 0.0
98
 
99
 
 
102
  start.compute_total()
103
  current = RewardComponents(r1_hook_strength=0.7, r2_coherence=0.7)
104
  history = [ActionType.HOOK_REWRITE, ActionType.HOOK_REWRITE, ActionType.HOOK_REWRITE]
105
+ result, log = aggregator.compute(current, start, history)
106
  assert result.anti_gaming_penalty == 0.15
107
  assert result.total < 0.7
108
 
 
112
  start.compute_total()
113
  current = RewardComponents(r1_hook_strength=0.7, r2_coherence=0.7)
114
  history = [ActionType.HOOK_REWRITE, ActionType.CTA_PLACEMENT, ActionType.SECTION_REORDER]
115
+ result, log = aggregator.compute(current, start, history)
116
  assert result.anti_gaming_penalty == 0.0
117
  assert result.total > 0