Spaces:
Sleeping
Sleeping
phase2 implemented
Browse files- viral_script_engine/agents/baseline_arbitrator.py +83 -0
- viral_script_engine/agents/defender.py +93 -0
- viral_script_engine/data/cultural_kb.json +87 -0
- viral_script_engine/data/golden_fixtures/fixture_S01.json +25 -25
- viral_script_engine/data/golden_fixtures/fixture_S02.json +20 -20
- viral_script_engine/environment/env.py +62 -4
- viral_script_engine/rewards/r3_cultural_alignment.py +58 -0
- viral_script_engine/rewards/r4_debate_resolution.py +82 -0
- viral_script_engine/rewards/r5_defender_preservation.py +62 -0
- viral_script_engine/rewards/reward_aggregator.py +44 -7
- viral_script_engine/scripts/run_baseline.py +217 -0
- viral_script_engine/tests/test_environment.py +37 -0
- viral_script_engine/tests/test_phase2.py +325 -0
- viral_script_engine/tests/test_rewards.py +3 -3
viral_script_engine/agents/baseline_arbitrator.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
from viral_script_engine.agents.llm_backend import LLMBackend
|
| 4 |
+
|
| 5 |
+
SYSTEM_PROMPT = """You are helping improve a short-form video script.
|
| 6 |
+
You have observed a debate between a Critic and a Defender about the script.
|
| 7 |
+
Choose ONE action to take to improve the script.
|
| 8 |
+
|
| 9 |
+
Available actions: hook_rewrite, section_reorder, cultural_ref_sub, cta_placement
|
| 10 |
+
|
| 11 |
+
Respond ONLY with valid JSON:
|
| 12 |
+
{
|
| 13 |
+
"action_type": "hook_rewrite",
|
| 14 |
+
"target_section": "hook",
|
| 15 |
+
"instruction": "specific instruction for the rewriter",
|
| 16 |
+
"critique_claim_id": "C1",
|
| 17 |
+
"reasoning": "brief explanation"
|
| 18 |
+
}"""
|
| 19 |
+
|
| 20 |
+
STRICT_RETRY_SUFFIX = (
|
| 21 |
+
"\n\nIMPORTANT: Your previous response was not valid JSON. "
|
| 22 |
+
"Respond ONLY with the raw JSON object. No markdown fences, no explanation, no preamble."
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
_FALLBACK_ACTION = {
|
| 26 |
+
"action_type": "hook_rewrite",
|
| 27 |
+
"target_section": "hook",
|
| 28 |
+
"instruction": "Rewrite the hook to be more engaging and direct.",
|
| 29 |
+
"critique_claim_id": "C1",
|
| 30 |
+
"reasoning": "Default fallback action.",
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class BaselineArbitratorAgent:
|
| 35 |
+
"""
|
| 36 |
+
Untrained Arbitrator for the pre-training baseline.
|
| 37 |
+
Uses zero-shot instruction β no chain-of-thought, no few-shot examples.
|
| 38 |
+
This ensures the comparison is fair: trained model learns through RL, not prompting.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
def __init__(self, backend: str = "groq", model_name: str = "llama-3.3-70b-versatile"):
|
| 42 |
+
self.llm = LLMBackend(backend=backend, model_name=model_name)
|
| 43 |
+
|
| 44 |
+
def _build_user_prompt(self, observation: dict) -> str:
|
| 45 |
+
script = observation.get("current_script", "")
|
| 46 |
+
debate = observation.get("debate_history", [])
|
| 47 |
+
last_claims = []
|
| 48 |
+
last_defense = None
|
| 49 |
+
if debate:
|
| 50 |
+
last_round = debate[-1]
|
| 51 |
+
last_claims = last_round.get("critic_claims", [])
|
| 52 |
+
last_defense = last_round.get("defender_response")
|
| 53 |
+
|
| 54 |
+
claims_text = ""
|
| 55 |
+
for c in last_claims:
|
| 56 |
+
claims_text += f"- [{c.get('claim_id','?')}] {c.get('claim_text','')} (severity: {c.get('severity','')})\n"
|
| 57 |
+
|
| 58 |
+
defense_text = ""
|
| 59 |
+
if last_defense:
|
| 60 |
+
defense_text = (
|
| 61 |
+
f"Defender preserved: {last_defense.get('core_strength_quote','')}\n"
|
| 62 |
+
f"Flagged claims: {last_defense.get('flagged_critic_claims', [])}\n"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
return (
|
| 66 |
+
f"SCRIPT:\n{script}\n\n"
|
| 67 |
+
f"CRITIC CLAIMS:\n{claims_text or 'None'}\n"
|
| 68 |
+
f"DEFENDER:\n{defense_text or 'None'}\n\n"
|
| 69 |
+
"Choose one action to improve the script."
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def act(self, observation: dict) -> dict:
|
| 73 |
+
user_prompt = self._build_user_prompt(observation)
|
| 74 |
+
raw = self.llm.generate(SYSTEM_PROMPT, user_prompt, max_tokens=512)
|
| 75 |
+
try:
|
| 76 |
+
return json.loads(raw)
|
| 77 |
+
except Exception:
|
| 78 |
+
strict_prompt = user_prompt + STRICT_RETRY_SUFFIX
|
| 79 |
+
raw2 = self.llm.generate(SYSTEM_PROMPT, strict_prompt, max_tokens=512)
|
| 80 |
+
try:
|
| 81 |
+
return json.loads(raw2)
|
| 82 |
+
except Exception:
|
| 83 |
+
return _FALLBACK_ACTION.copy()
|
viral_script_engine/agents/defender.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
|
| 6 |
+
from viral_script_engine.agents.llm_backend import LLMBackend
|
| 7 |
+
from viral_script_engine.agents.critic import CritiqueClaim
|
| 8 |
+
|
| 9 |
+
SYSTEM_PROMPT = """You are a script defender for short-form video content. Your job is NOT to say the script is perfect.
|
| 10 |
+
Your job is to identify what is genuinely working β and protect it from being edited away.
|
| 11 |
+
|
| 12 |
+
Specifically:
|
| 13 |
+
1. Find the single most powerful element of the script. Quote it exactly.
|
| 14 |
+
2. Explain why a viewer would respond positively to this element.
|
| 15 |
+
3. Review the Critic's claims. Flag any that would destroy the script's core strength or strip its regional authenticity if acted on.
|
| 16 |
+
4. List any phrases, idioms, or references that are intentionally regional β these must not be "corrected" away.
|
| 17 |
+
|
| 18 |
+
OUTPUT (JSON only, no preamble):
|
| 19 |
+
{
|
| 20 |
+
"core_strength": "one sentence describing the strongest element",
|
| 21 |
+
"core_strength_quote": "exact verbatim quote from the script",
|
| 22 |
+
"defense_argument": "why this element should be preserved",
|
| 23 |
+
"flagged_critic_claims": ["C2", "C3"],
|
| 24 |
+
"regional_voice_elements": ["specific phrase 1", "specific phrase 2"]
|
| 25 |
+
}"""
|
| 26 |
+
|
| 27 |
+
STRICT_RETRY_SUFFIX = (
|
| 28 |
+
"\n\nIMPORTANT: Your previous response was not valid JSON. "
|
| 29 |
+
"Respond ONLY with the raw JSON object. No markdown fences, no explanation, no preamble."
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class DefenderParseError(Exception):
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class DefenderOutput(BaseModel):
|
| 38 |
+
core_strength: str
|
| 39 |
+
core_strength_quote: str
|
| 40 |
+
defense_argument: str
|
| 41 |
+
flagged_critic_claims: List[str]
|
| 42 |
+
regional_voice_elements: List[str]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class DefenderAgent:
|
| 46 |
+
def __init__(self, backend: str = "groq", model_name: str = "llama-3.3-70b-versatile"):
|
| 47 |
+
self.llm = LLMBackend(backend=backend, model_name=model_name)
|
| 48 |
+
|
| 49 |
+
def _build_user_prompt(
|
| 50 |
+
self,
|
| 51 |
+
script: str,
|
| 52 |
+
critic_claims: List[CritiqueClaim],
|
| 53 |
+
region: str,
|
| 54 |
+
platform: str,
|
| 55 |
+
) -> str:
|
| 56 |
+
claims_lines = []
|
| 57 |
+
for i, claim in enumerate(critic_claims, start=1):
|
| 58 |
+
claims_lines.append(
|
| 59 |
+
f"{i}. [{claim.claim_id}] ({claim.critique_class}) {claim.claim_text} | Evidence: {claim.evidence}"
|
| 60 |
+
)
|
| 61 |
+
claims_block = "\n".join(claims_lines) if claims_lines else "No critic claims provided."
|
| 62 |
+
|
| 63 |
+
return (
|
| 64 |
+
f"SCRIPT:\n{script}\n\n"
|
| 65 |
+
f"CRITIC CLAIMS:\n{claims_block}\n\n"
|
| 66 |
+
f"REGION: {region}\n"
|
| 67 |
+
f"PLATFORM: {platform}\n\n"
|
| 68 |
+
"Defend the script now."
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def _parse_response(self, raw: str, user_prompt: str) -> DefenderOutput:
|
| 72 |
+
try:
|
| 73 |
+
data = json.loads(raw)
|
| 74 |
+
return DefenderOutput(**data)
|
| 75 |
+
except Exception:
|
| 76 |
+
strict_prompt = user_prompt + STRICT_RETRY_SUFFIX
|
| 77 |
+
raw2 = self.llm.generate(SYSTEM_PROMPT, strict_prompt, max_tokens=1024)
|
| 78 |
+
try:
|
| 79 |
+
data = json.loads(raw2)
|
| 80 |
+
return DefenderOutput(**data)
|
| 81 |
+
except Exception as e:
|
| 82 |
+
raise DefenderParseError(f"Failed to parse defender output after 2 attempts: {e}")
|
| 83 |
+
|
| 84 |
+
def defend(
|
| 85 |
+
self,
|
| 86 |
+
script: str,
|
| 87 |
+
critic_claims: List[CritiqueClaim],
|
| 88 |
+
region: str,
|
| 89 |
+
platform: str,
|
| 90 |
+
) -> DefenderOutput:
|
| 91 |
+
user_prompt = self._build_user_prompt(script, critic_claims, region, platform)
|
| 92 |
+
raw = self.llm.generate(SYSTEM_PROMPT, user_prompt, max_tokens=1024)
|
| 93 |
+
return self._parse_response(raw, user_prompt)
|
viral_script_engine/data/cultural_kb.json
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"mumbai_gen_z": {
|
| 3 |
+
"valid_refs": [
|
| 4 |
+
"Bandra", "CSMT", "dabba", "auto", "local train", "startup scene",
|
| 5 |
+
"Zomato", "Swiggy", "IPL", "sea link", "Juhu", "Versova", "Aarey",
|
| 6 |
+
"SoBo", "Powai", "Andheri", "Dharavi", "Bandstand", "Colaba",
|
| 7 |
+
"Bollywood 2020s"
|
| 8 |
+
],
|
| 9 |
+
"invalid_signals": [
|
| 10 |
+
"trunk call", "STD booth", "pager", "cassette", "VHS", "walkman",
|
| 11 |
+
"disco", "beeper", "doordarshan only", "telegram boy",
|
| 12 |
+
"ration card queue", "old delhi", "tonga", "coolie",
|
| 13 |
+
"chawl gossip pre-2000"
|
| 14 |
+
],
|
| 15 |
+
"correct_idioms": [
|
| 16 |
+
"ek dum solid", "full on", "kya scene hai", "bindaas", "jugaad nahi",
|
| 17 |
+
"yaar sun", "bhai chill maar", "ekdum mast", "kya baat hai",
|
| 18 |
+
"ekdum lit", "arey yaar", "full on vibe", "kya scene hai bhai",
|
| 19 |
+
"timepass nahi", "mast hai re", "solid plan", "chal be"
|
| 20 |
+
],
|
| 21 |
+
"anachronistic_signals": []
|
| 22 |
+
},
|
| 23 |
+
"tier2_hindi_belt": {
|
| 24 |
+
"valid_refs": [
|
| 25 |
+
"kirana store", "mandap", "jugaad", "sabzi mandi", "panchayat",
|
| 26 |
+
"mela", "dal-chawal", "halwai", "tehsil", "chaupal", "gaon",
|
| 27 |
+
"kheti", "bori", "dhaba", "cycle", "tubewell", "Kanpur",
|
| 28 |
+
"Lucknow", "Patna", "Varanasi"
|
| 29 |
+
],
|
| 30 |
+
"invalid_signals": [
|
| 31 |
+
"unicorn startup", "runway", "pivot", "SaaS", "venture capital",
|
| 32 |
+
"IPO", "fintech", "coworking space", "metro commute", "craft beer",
|
| 33 |
+
"avocado toast", "brunch", "penthouse", "valet parking",
|
| 34 |
+
"rooftop lounge", "angel investor", "Series A", "growth hacking",
|
| 35 |
+
"disruptive"
|
| 36 |
+
],
|
| 37 |
+
"correct_idioms": [
|
| 38 |
+
"bilkul sahi", "arey bhai", "baat pakki", "suno ji", "haan ji",
|
| 39 |
+
"seedha baat", "desi style", "apna kaam", "ek number",
|
| 40 |
+
"bahut badhiya", "kya baat", "sahi mein", "dhamaal", "arre wah",
|
| 41 |
+
"mast jugaad", "thoda adjust karo", "kal milte hain"
|
| 42 |
+
],
|
| 43 |
+
"anachronistic_signals": []
|
| 44 |
+
},
|
| 45 |
+
"pan_india_english": {
|
| 46 |
+
"valid_refs": [
|
| 47 |
+
"India", "Indian", "subcontinent", "festivals", "cricket", "desi",
|
| 48 |
+
"masala", "biryani", "chai", "temple", "monsoon", "Diwali",
|
| 49 |
+
"Holi", "market", "college"
|
| 50 |
+
],
|
| 51 |
+
"invalid_signals": [
|
| 52 |
+
"kachcha road", "pind", "naali", "akhada", "maidan ke paas",
|
| 53 |
+
"chakki", "chulha", "dung cake", "lathicharge", "tehsildar",
|
| 54 |
+
"patwari", "sarpanch only", "gaon waale", "pradhan ji", "khet mein"
|
| 55 |
+
],
|
| 56 |
+
"correct_idioms": [
|
| 57 |
+
"absolutely", "for sure", "totally agree", "makes sense",
|
| 58 |
+
"of course", "right", "exactly", "fair point", "well said",
|
| 59 |
+
"spot on", "valid point", "true that", "no doubt", "solid",
|
| 60 |
+
"good point", "great insight", "well done"
|
| 61 |
+
],
|
| 62 |
+
"anachronistic_signals": []
|
| 63 |
+
},
|
| 64 |
+
"hinglish": {
|
| 65 |
+
"valid_refs": [
|
| 66 |
+
"yaar", "bhai", "dost", "scene", "vibe", "timepass", "jugaad",
|
| 67 |
+
"mast", "bindaas", "desi", "ekdum", "kal", "aaj", "kya", "kaise"
|
| 68 |
+
],
|
| 69 |
+
"invalid_signals": [
|
| 70 |
+
"I am going to the marketplace", "Please be advised", "Pursuant to",
|
| 71 |
+
"Herewith", "As per our discussion", "In reference to",
|
| 72 |
+
"Kindly note", "I beg to inform", "With due respect",
|
| 73 |
+
"at your earliest convenience", "enclosed herewith",
|
| 74 |
+
"please find attached", "as discussed", "further to my email",
|
| 75 |
+
"I hope this email finds you"
|
| 76 |
+
],
|
| 77 |
+
"correct_idioms": [
|
| 78 |
+
"yaar sun", "bhai ekdum", "kya baat hai", "arey chill maar",
|
| 79 |
+
"solid plan hai", "mast scene hai", "full on enjoy",
|
| 80 |
+
"ekdum bindaas", "bhai sahi bol raha hai", "kya jugaad lagaya",
|
| 81 |
+
"ek dum mast", "timepass nahi yaar", "kya vibe hai",
|
| 82 |
+
"full on mast", "ekdum sahi", "arey yaar kya scene",
|
| 83 |
+
"bas ek chance"
|
| 84 |
+
],
|
| 85 |
+
"anachronistic_signals": []
|
| 86 |
+
}
|
| 87 |
+
}
|
viral_script_engine/data/golden_fixtures/fixture_S01.json
CHANGED
|
@@ -8,59 +8,59 @@
|
|
| 8 |
{
|
| 9 |
"claim_id": "C1",
|
| 10 |
"critique_class": "hook_weakness",
|
| 11 |
-
"claim_text": "The hook
|
| 12 |
-
"timestamp_range": "0:00-0:
|
| 13 |
-
"evidence": "
|
| 14 |
"is_falsifiable": true,
|
| 15 |
"severity": "medium"
|
| 16 |
},
|
| 17 |
{
|
| 18 |
"claim_id": "C2",
|
| 19 |
"critique_class": "pacing_issue",
|
| 20 |
-
"claim_text": "The script
|
| 21 |
-
"timestamp_range": "0:
|
| 22 |
-
"evidence": "
|
| 23 |
"is_falsifiable": true,
|
| 24 |
-
"severity": "
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"claim_id": "C3",
|
| 28 |
-
"critique_class": "
|
| 29 |
-
"claim_text": "The
|
| 30 |
-
"timestamp_range": "
|
| 31 |
-
"evidence": "
|
| 32 |
"is_falsifiable": true,
|
| 33 |
"severity": "medium"
|
| 34 |
},
|
| 35 |
{
|
| 36 |
"claim_id": "C4",
|
| 37 |
-
"critique_class": "
|
| 38 |
-
"claim_text": "The
|
| 39 |
-
"timestamp_range": "1:
|
| 40 |
"evidence": "If you want to know which funds I use, follow me and I'll post the list tomorrow.",
|
| 41 |
"is_falsifiable": true,
|
| 42 |
"severity": "high"
|
| 43 |
},
|
| 44 |
{
|
| 45 |
"claim_id": "C5",
|
| 46 |
-
"critique_class": "
|
| 47 |
-
"claim_text": "The script
|
| 48 |
-
"timestamp_range": "
|
| 49 |
-
"evidence": "The
|
| 50 |
"is_falsifiable": true,
|
| 51 |
"severity": "medium"
|
| 52 |
},
|
| 53 |
{
|
| 54 |
"claim_id": "C6",
|
| 55 |
-
"critique_class": "
|
| 56 |
-
"claim_text": "The script
|
| 57 |
-
"timestamp_range": "
|
| 58 |
-
"evidence": "
|
| 59 |
"is_falsifiable": true,
|
| 60 |
-
"severity": "
|
| 61 |
}
|
| 62 |
],
|
| 63 |
-
"overall_severity": "
|
| 64 |
-
"raw_response": "{\n \"claims\": [\n {\n \"claim_id\": \"C1\",\n \"critique_class\": \"hook_weakness\",\n \"claim_text\": \"The hook
|
| 65 |
}
|
| 66 |
}
|
|
|
|
| 8 |
{
|
| 9 |
"claim_id": "C1",
|
| 10 |
"critique_class": "hook_weakness",
|
| 11 |
+
"claim_text": "The hook promises a life-changing trick but takes 0:22 seconds to reveal, by which time viewers might have lost interest",
|
| 12 |
+
"timestamp_range": "0:00-0:03",
|
| 13 |
+
"evidence": "I've been broke my whole life. Like actually broke. Not the aesthetic broke, the can't-pay-rent broke. And then one day I found this one trick that changed everything.",
|
| 14 |
"is_falsifiable": true,
|
| 15 |
"severity": "medium"
|
| 16 |
},
|
| 17 |
{
|
| 18 |
"claim_id": "C2",
|
| 19 |
"critique_class": "pacing_issue",
|
| 20 |
+
"claim_text": "The script switches abruptly from the personal story to a promotional message at 0:45, disrupting the narrative flow",
|
| 21 |
+
"timestamp_range": "0:45-0:55",
|
| 22 |
+
"evidence": "The secret? Mutual funds. Just SIPs. I'm serious.",
|
| 23 |
"is_falsifiable": true,
|
| 24 |
+
"severity": "high"
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"claim_id": "C3",
|
| 28 |
+
"critique_class": "cta_buried",
|
| 29 |
+
"claim_text": "The call-to-action to follow the creator for the list of funds is buried at the end of the script and might be missed by viewers",
|
| 30 |
+
"timestamp_range": "1:10-1:15",
|
| 31 |
+
"evidence": "If you want to know which funds I use, follow me and I'll post the list tomorrow.",
|
| 32 |
"is_falsifiable": true,
|
| 33 |
"severity": "medium"
|
| 34 |
},
|
| 35 |
{
|
| 36 |
"claim_id": "C4",
|
| 37 |
+
"critique_class": "retention_risk",
|
| 38 |
+
"claim_text": "The script assumes viewers will wait for the next day's post for the list of funds, which might lead to a drop in retention",
|
| 39 |
+
"timestamp_range": "1:10-1:15",
|
| 40 |
"evidence": "If you want to know which funds I use, follow me and I'll post the list tomorrow.",
|
| 41 |
"is_falsifiable": true,
|
| 42 |
"severity": "high"
|
| 43 |
},
|
| 44 |
{
|
| 45 |
"claim_id": "C5",
|
| 46 |
+
"critique_class": "coherence_break",
|
| 47 |
+
"claim_text": "The script's tone shifts from relatable and personal to overly promotional and sales-y, potentially alienating the target audience",
|
| 48 |
+
"timestamp_range": "N/A",
|
| 49 |
+
"evidence": "The entire script's tone and language",
|
| 50 |
"is_falsifiable": true,
|
| 51 |
"severity": "medium"
|
| 52 |
},
|
| 53 |
{
|
| 54 |
"claim_id": "C6",
|
| 55 |
+
"critique_class": "pacing_issue",
|
| 56 |
+
"claim_text": "The script tries to cram too much information and calls-to-action in the last 10 seconds, which might overwhelm viewers",
|
| 57 |
+
"timestamp_range": "1:15-1:20",
|
| 58 |
+
"evidence": "Like and save this video before Instagram hides it.",
|
| 59 |
"is_falsifiable": true,
|
| 60 |
+
"severity": "high"
|
| 61 |
}
|
| 62 |
],
|
| 63 |
+
"overall_severity": "high",
|
| 64 |
+
"raw_response": "{\n \"claims\": [\n {\n \"claim_id\": \"C1\",\n \"critique_class\": \"hook_weakness\",\n \"claim_text\": \"The hook promises a life-changing trick but takes 0:22 seconds to reveal, by which time viewers might have lost interest\",\n \"timestamp_range\": \"0:00-0:03\",\n \"evidence\": \"I've been broke my whole life. Like actually broke. Not the aesthetic broke, the can't-pay-rent broke. And then one day I found this one trick that changed everything.\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C2\",\n \"critique_class\": \"pacing_issue\",\n \"claim_text\": \"The script switches abruptly from the personal story to a promotional message at 0:45, disrupting the narrative flow\",\n \"timestamp_range\": \"0:45-0:55\",\n \"evidence\": \"The secret? Mutual funds. Just SIPs. I'm serious.\",\n \"is_falsifiable\": true,\n \"severity\": \"high\"\n },\n {\n \"claim_id\": \"C3\",\n \"critique_class\": \"cta_buried\",\n \"claim_text\": \"The call-to-action to follow the creator for the list of funds is buried at the end of the script and might be missed by viewers\",\n \"timestamp_range\": \"1:10-1:15\",\n \"evidence\": \"If you want to know which funds I use, follow me and I'll post the list tomorrow.\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C4\",\n \"critique_class\": \"retention_risk\",\n \"claim_text\": \"The script assumes viewers will wait for the next day's post for the list of funds, which might lead to a drop in retention\",\n \"timestamp_range\": \"1:10-1:15\",\n \"evidence\": \"If you want to know which funds I use, follow me and I'll post the list tomorrow.\",\n \"is_falsifiable\": true,\n \"severity\": \"high\"\n },\n {\n \"claim_id\": \"C5\",\n \"critique_class\": \"coherence_break\",\n \"claim_text\": \"The script's tone shifts from relatable and personal to overly promotional and sales-y, potentially alienating the target audience\",\n \"timestamp_range\": \"N/A\",\n \"evidence\": \"The entire script's tone and language\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C6\",\n \"critique_class\": \"pacing_issue\",\n \"claim_text\": \"The script tries to cram too much information and calls-to-action in the last 10 seconds, which might overwhelm viewers\",\n \"timestamp_range\": \"1:15-1:20\",\n \"evidence\": \"Like and save this video before Instagram hides it.\",\n \"is_falsifiable\": true,\n \"severity\": \"high\"\n }\n ],\n \"overall_severity\": \"high\"\n}"
|
| 65 |
}
|
| 66 |
}
|
viral_script_engine/data/golden_fixtures/fixture_S02.json
CHANGED
|
@@ -8,59 +8,59 @@
|
|
| 8 |
{
|
| 9 |
"claim_id": "C1",
|
| 10 |
"critique_class": "hook_weakness",
|
| 11 |
-
"claim_text": "The hook at 0:00-0:03 promises five outfits for one thousand rupees, but the reveal at
|
| 12 |
-
"timestamp_range": "0:00-0:03,
|
| 13 |
-
"evidence": "Five outfits, one thousand rupees
|
| 14 |
"is_falsifiable": true,
|
| 15 |
"severity": "medium"
|
| 16 |
},
|
| 17 |
{
|
| 18 |
"claim_id": "C2",
|
| 19 |
"critique_class": "pacing_issue",
|
| 20 |
-
"claim_text": "The script at 0:
|
| 21 |
-
"timestamp_range": "0:
|
| 22 |
-
"evidence": "
|
| 23 |
"is_falsifiable": true,
|
| 24 |
"severity": "low"
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"claim_id": "C3",
|
| 28 |
"critique_class": "retention_risk",
|
| 29 |
-
"claim_text": "The
|
| 30 |
-
"timestamp_range": "
|
| 31 |
-
"evidence": "
|
| 32 |
"is_falsifiable": true,
|
| 33 |
"severity": "medium"
|
| 34 |
},
|
| 35 |
{
|
| 36 |
"claim_id": "C4",
|
| 37 |
-
"critique_class": "
|
| 38 |
-
"claim_text": "The
|
| 39 |
-
"timestamp_range": "1:
|
| 40 |
-
"evidence": "
|
| 41 |
"is_falsifiable": true,
|
| 42 |
"severity": "high"
|
| 43 |
},
|
| 44 |
{
|
| 45 |
"claim_id": "C5",
|
| 46 |
-
"critique_class": "
|
| 47 |
-
"claim_text": "The script
|
| 48 |
-
"timestamp_range": "
|
| 49 |
-
"evidence": "
|
| 50 |
"is_falsifiable": true,
|
| 51 |
"severity": "medium"
|
| 52 |
},
|
| 53 |
{
|
| 54 |
"claim_id": "C6",
|
| 55 |
"critique_class": "cultural_mismatch",
|
| 56 |
-
"claim_text": "The script assumes all viewers are familiar with Mumbai
|
| 57 |
"timestamp_range": "N/A",
|
| 58 |
-
"evidence": "Linking Road,
|
| 59 |
"is_falsifiable": true,
|
| 60 |
"severity": "low"
|
| 61 |
}
|
| 62 |
],
|
| 63 |
"overall_severity": "medium",
|
| 64 |
-
"raw_response": "{\n \"claims\": [\n {\n \"claim_id\": \"C1\",\n \"critique_class\": \"hook_weakness\",\n \"claim_text\": \"The hook at 0:00-0:03 promises five outfits for one thousand rupees, but the reveal at
|
| 65 |
}
|
| 66 |
}
|
|
|
|
| 8 |
{
|
| 9 |
"claim_id": "C1",
|
| 10 |
"critique_class": "hook_weakness",
|
| 11 |
+
"claim_text": "The hook at 0:00-0:03 promises five outfits for one thousand rupees, but the reveal at 1:45 shows the total cost is five hundred rupees, which may confuse viewers",
|
| 12 |
+
"timestamp_range": "0:00-0:03, 1:45",
|
| 13 |
+
"evidence": "Five outfits, one thousand rupees",
|
| 14 |
"is_falsifiable": true,
|
| 15 |
"severity": "medium"
|
| 16 |
},
|
| 17 |
{
|
| 18 |
"claim_id": "C2",
|
| 19 |
"critique_class": "pacing_issue",
|
| 20 |
+
"claim_text": "The script has an abrupt pause at 0:45-0:50 where the creator says 'wait I need to find it. Okay found it.', which disrupts the flow of the video",
|
| 21 |
+
"timestamp_range": "0:45-0:50",
|
| 22 |
+
"evidence": "wait I need to find it. Okay found it.",
|
| 23 |
"is_falsifiable": true,
|
| 24 |
"severity": "low"
|
| 25 |
},
|
| 26 |
{
|
| 27 |
"claim_id": "C3",
|
| 28 |
"critique_class": "retention_risk",
|
| 29 |
+
"claim_text": "The creator asks viewers to 'save this video' at 1:20-1:25, but this call-to-action may not be compelling enough to encourage viewers to take action",
|
| 30 |
+
"timestamp_range": "1:20-1:25",
|
| 31 |
+
"evidence": "this entire saree drape tutorial took me two hours so please save this video",
|
| 32 |
"is_falsifiable": true,
|
| 33 |
"severity": "medium"
|
| 34 |
},
|
| 35 |
{
|
| 36 |
"claim_id": "C4",
|
| 37 |
+
"critique_class": "cta_buried",
|
| 38 |
+
"claim_text": "The call-to-action 'Comment your city and I'll do a version for your local markets' is buried at the end of the script at 1:50-2:00, which may be missed by viewers who drop off early",
|
| 39 |
+
"timestamp_range": "1:50-2:00",
|
| 40 |
+
"evidence": "Comment your city and I'll do a version for your local markets",
|
| 41 |
"is_falsifiable": true,
|
| 42 |
"severity": "high"
|
| 43 |
},
|
| 44 |
{
|
| 45 |
"claim_id": "C5",
|
| 46 |
+
"critique_class": "coherence_break",
|
| 47 |
+
"claim_text": "The script jumps abruptly from showcasing outfits to providing a grand total at 1:45, without a clear transition or summary of the previous outfits",
|
| 48 |
+
"timestamp_range": "1:40-1:50",
|
| 49 |
+
"evidence": "Grand total \u00e2\u20ac\u201d five hundred rupees for five outfits",
|
| 50 |
"is_falsifiable": true,
|
| 51 |
"severity": "medium"
|
| 52 |
},
|
| 53 |
{
|
| 54 |
"claim_id": "C6",
|
| 55 |
"critique_class": "cultural_mismatch",
|
| 56 |
+
"claim_text": "The script assumes that all viewers are familiar with Mumbai's local markets and terms like 'Linking Road' and 'Sarojini Nagar', which may not resonate with viewers from other regions",
|
| 57 |
"timestamp_range": "N/A",
|
| 58 |
+
"evidence": "thrifted kurta from Linking Road, forty rupees, styled with mom's old dupatta, zero rupees",
|
| 59 |
"is_falsifiable": true,
|
| 60 |
"severity": "low"
|
| 61 |
}
|
| 62 |
],
|
| 63 |
"overall_severity": "medium",
|
| 64 |
+
"raw_response": "{\n \"claims\": [\n {\n \"claim_id\": \"C1\",\n \"critique_class\": \"hook_weakness\",\n \"claim_text\": \"The hook at 0:00-0:03 promises five outfits for one thousand rupees, but the reveal at 1:45 shows the total cost is five hundred rupees, which may confuse viewers\",\n \"timestamp_range\": \"0:00-0:03, 1:45\",\n \"evidence\": \"Five outfits, one thousand rupees\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C2\",\n \"critique_class\": \"pacing_issue\",\n \"claim_text\": \"The script has an abrupt pause at 0:45-0:50 where the creator says 'wait I need to find it. Okay found it.', which disrupts the flow of the video\",\n \"timestamp_range\": \"0:45-0:50\",\n \"evidence\": \"wait I need to find it. Okay found it.\",\n \"is_falsifiable\": true,\n \"severity\": \"low\"\n },\n {\n \"claim_id\": \"C3\",\n \"critique_class\": \"retention_risk\",\n \"claim_text\": \"The creator asks viewers to 'save this video' at 1:20-1:25, but this call-to-action may not be compelling enough to encourage viewers to take action\",\n \"timestamp_range\": \"1:20-1:25\",\n \"evidence\": \"this entire saree drape tutorial took me two hours so please save this video\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C4\",\n \"critique_class\": \"cta_buried\",\n \"claim_text\": \"The call-to-action 'Comment your city and I'll do a version for your local markets' is buried at the end of the script at 1:50-2:00, which may be missed by viewers who drop off early\",\n \"timestamp_range\": \"1:50-2:00\",\n \"evidence\": \"Comment your city and I'll do a version for your local markets\",\n \"is_falsifiable\": true,\n \"severity\": \"high\"\n },\n {\n \"claim_id\": \"C5\",\n \"critique_class\": \"coherence_break\",\n \"claim_text\": \"The script jumps abruptly from showcasing outfits to providing a grand total at 1:45, without a clear transition or summary of the previous outfits\",\n \"timestamp_range\": \"1:40-1:50\",\n \"evidence\": \"Grand total \u00e2\u20ac\u201d five hundred rupees for five outfits\",\n \"is_falsifiable\": true,\n \"severity\": \"medium\"\n },\n {\n \"claim_id\": \"C6\",\n \"critique_class\": \"cultural_mismatch\",\n \"claim_text\": \"The script assumes that all viewers are familiar with Mumbai's local markets and terms like 'Linking Road' and 'Sarojini Nagar', which may not resonate with viewers from other regions\",\n \"timestamp_range\": \"N/A\",\n \"evidence\": \"thrifted kurta from Linking Road, forty rupees, styled with mom's old dupatta, zero rupees\",\n \"is_falsifiable\": true,\n \"severity\": \"low\"\n }\n ],\n \"overall_severity\": \"medium\"\n}"
|
| 65 |
}
|
| 66 |
}
|
viral_script_engine/environment/env.py
CHANGED
|
@@ -3,6 +3,7 @@ import random
|
|
| 3 |
from typing import Optional, Tuple
|
| 4 |
|
| 5 |
from viral_script_engine.agents.critic import CriticAgent
|
|
|
|
| 6 |
from viral_script_engine.agents.rewriter import RewriterAgent
|
| 7 |
from viral_script_engine.environment.actions import ArbitratorAction
|
| 8 |
from viral_script_engine.environment.episode_state import EpisodeState
|
|
@@ -11,6 +12,9 @@ from viral_script_engine.environment.observations import (
|
|
| 11 |
)
|
| 12 |
from viral_script_engine.rewards.r1_hook_strength import HookStrengthReward
|
| 13 |
from viral_script_engine.rewards.r2_coherence import CoherenceReward
|
|
|
|
|
|
|
|
|
|
| 14 |
from viral_script_engine.rewards.reward_aggregator import RewardAggregator
|
| 15 |
|
| 16 |
_TIERS = {
|
|
@@ -28,6 +32,7 @@ class ViralScriptEnv:
|
|
| 28 |
max_steps: int = 5,
|
| 29 |
difficulty: str = "easy",
|
| 30 |
use_anti_gaming: bool = True,
|
|
|
|
| 31 |
):
|
| 32 |
self.max_steps = max_steps
|
| 33 |
self.difficulty = difficulty
|
|
@@ -39,9 +44,13 @@ class ViralScriptEnv:
|
|
| 39 |
tier_ids = _TIERS[difficulty]
|
| 40 |
self._scripts = [s for s in all_scripts if s["script_id"] in tier_ids]
|
| 41 |
self.critic = CriticAgent()
|
|
|
|
| 42 |
self.rewriter = RewriterAgent()
|
| 43 |
self.r1 = HookStrengthReward()
|
| 44 |
self.r2 = CoherenceReward()
|
|
|
|
|
|
|
|
|
|
| 45 |
self.aggregator = RewardAggregator()
|
| 46 |
self._state: Optional[EpisodeState] = None
|
| 47 |
|
|
@@ -52,9 +61,11 @@ class ViralScriptEnv:
|
|
| 52 |
|
| 53 |
r1_result = self.r1.score(script["script_text"])
|
| 54 |
r2_result = self.r2.score(script["script_text"], script["script_text"])
|
|
|
|
| 55 |
initial_rewards = RewardComponents(
|
| 56 |
r1_hook_strength=r1_result.score,
|
| 57 |
r2_coherence=r2_result.score,
|
|
|
|
| 58 |
)
|
| 59 |
initial_rewards.compute_total()
|
| 60 |
|
|
@@ -79,27 +90,68 @@ class ViralScriptEnv:
|
|
| 79 |
niche=self._state.niche,
|
| 80 |
)
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
rewrite_result = self.rewriter.rewrite(self._state.current_script, arb_action)
|
| 83 |
new_script = rewrite_result.rewritten_script
|
| 84 |
|
| 85 |
r1_result = self.r1.score(new_script)
|
| 86 |
r2_result = self.r2.score(self._state.original_script, new_script)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
components = RewardComponents(
|
| 88 |
r1_hook_strength=r1_result.score,
|
| 89 |
r2_coherence=r2_result.score,
|
|
|
|
|
|
|
|
|
|
| 90 |
)
|
| 91 |
|
| 92 |
self._state.action_history.append(arb_action.action_type)
|
| 93 |
if self.use_anti_gaming:
|
| 94 |
-
components = self.aggregator.compute(
|
| 95 |
-
components,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
)
|
| 97 |
else:
|
| 98 |
components.compute_total()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
round_ = DebateRound(
|
| 101 |
step_num=self._state.step_num,
|
| 102 |
critic_claims=critique.claims,
|
|
|
|
| 103 |
arbitrator_action=arb_action,
|
| 104 |
rewrite_diff=rewrite_result.diff,
|
| 105 |
reward_components=components,
|
|
@@ -109,14 +161,19 @@ class ViralScriptEnv:
|
|
| 109 |
self._state.last_reward_components = components
|
| 110 |
self._state.step_num += 1
|
| 111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
terminated = (
|
| 113 |
self._state.step_num >= self._state.max_steps
|
| 114 |
or components.total >= 0.9
|
| 115 |
)
|
| 116 |
info = {
|
| 117 |
"reward_components": components.model_dump(),
|
| 118 |
-
"anti_gaming_triggered":
|
| 119 |
-
"penalty_reason":
|
|
|
|
| 120 |
}
|
| 121 |
return self._build_observation().model_dump(), components.total, terminated, False, info
|
| 122 |
|
|
@@ -132,6 +189,7 @@ class ViralScriptEnv:
|
|
| 132 |
"step_num": s.step_num,
|
| 133 |
"difficulty_level": s.difficulty_level,
|
| 134 |
"episode_id": s.episode_id,
|
|
|
|
| 135 |
}
|
| 136 |
|
| 137 |
def _build_observation(self) -> Observation:
|
|
|
|
| 3 |
from typing import Optional, Tuple
|
| 4 |
|
| 5 |
from viral_script_engine.agents.critic import CriticAgent
|
| 6 |
+
from viral_script_engine.agents.defender import DefenderAgent
|
| 7 |
from viral_script_engine.agents.rewriter import RewriterAgent
|
| 8 |
from viral_script_engine.environment.actions import ArbitratorAction
|
| 9 |
from viral_script_engine.environment.episode_state import EpisodeState
|
|
|
|
| 12 |
)
|
| 13 |
from viral_script_engine.rewards.r1_hook_strength import HookStrengthReward
|
| 14 |
from viral_script_engine.rewards.r2_coherence import CoherenceReward
|
| 15 |
+
from viral_script_engine.rewards.r3_cultural_alignment import CulturalAlignmentReward
|
| 16 |
+
from viral_script_engine.rewards.r4_debate_resolution import DebateResolutionReward
|
| 17 |
+
from viral_script_engine.rewards.r5_defender_preservation import DefenderPreservationReward
|
| 18 |
from viral_script_engine.rewards.reward_aggregator import RewardAggregator
|
| 19 |
|
| 20 |
_TIERS = {
|
|
|
|
| 32 |
max_steps: int = 5,
|
| 33 |
difficulty: str = "easy",
|
| 34 |
use_anti_gaming: bool = True,
|
| 35 |
+
cultural_kb_path: str = "data/cultural_kb.json",
|
| 36 |
):
|
| 37 |
self.max_steps = max_steps
|
| 38 |
self.difficulty = difficulty
|
|
|
|
| 44 |
tier_ids = _TIERS[difficulty]
|
| 45 |
self._scripts = [s for s in all_scripts if s["script_id"] in tier_ids]
|
| 46 |
self.critic = CriticAgent()
|
| 47 |
+
self.defender = DefenderAgent()
|
| 48 |
self.rewriter = RewriterAgent()
|
| 49 |
self.r1 = HookStrengthReward()
|
| 50 |
self.r2 = CoherenceReward()
|
| 51 |
+
self.r3 = CulturalAlignmentReward(knowledge_base_path=cultural_kb_path)
|
| 52 |
+
self.r4 = DebateResolutionReward(critic_agent=self.critic)
|
| 53 |
+
self.r5 = DefenderPreservationReward()
|
| 54 |
self.aggregator = RewardAggregator()
|
| 55 |
self._state: Optional[EpisodeState] = None
|
| 56 |
|
|
|
|
| 61 |
|
| 62 |
r1_result = self.r1.score(script["script_text"])
|
| 63 |
r2_result = self.r2.score(script["script_text"], script["script_text"])
|
| 64 |
+
r3_result = self.r3.score(script["script_text"], script.get("region", "pan_india_english"))
|
| 65 |
initial_rewards = RewardComponents(
|
| 66 |
r1_hook_strength=r1_result.score,
|
| 67 |
r2_coherence=r2_result.score,
|
| 68 |
+
r3_cultural_alignment=r3_result.score,
|
| 69 |
)
|
| 70 |
initial_rewards.compute_total()
|
| 71 |
|
|
|
|
| 90 |
niche=self._state.niche,
|
| 91 |
)
|
| 92 |
|
| 93 |
+
defender_output = self.defender.defend(
|
| 94 |
+
script=self._state.current_script,
|
| 95 |
+
critic_claims=critique.claims,
|
| 96 |
+
region=self._state.region,
|
| 97 |
+
platform=self._state.platform,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
rewrite_result = self.rewriter.rewrite(self._state.current_script, arb_action)
|
| 101 |
new_script = rewrite_result.rewritten_script
|
| 102 |
|
| 103 |
r1_result = self.r1.score(new_script)
|
| 104 |
r2_result = self.r2.score(self._state.original_script, new_script)
|
| 105 |
+
r3_result = self.r3.score(new_script, self._state.region)
|
| 106 |
+
|
| 107 |
+
targeted_claim = next(
|
| 108 |
+
(c for c in critique.claims if c.claim_id == arb_action.critique_claim_id),
|
| 109 |
+
critique.claims[0] if critique.claims else None,
|
| 110 |
+
)
|
| 111 |
+
r4_result = self.r4.score(
|
| 112 |
+
new_script=new_script,
|
| 113 |
+
original_action=arb_action,
|
| 114 |
+
original_claim=targeted_claim,
|
| 115 |
+
region=self._state.region,
|
| 116 |
+
platform=self._state.platform,
|
| 117 |
+
niche=self._state.niche,
|
| 118 |
+
) if targeted_claim else None
|
| 119 |
+
|
| 120 |
+
r5_result = self.r5.score(defender_output, new_script)
|
| 121 |
+
|
| 122 |
components = RewardComponents(
|
| 123 |
r1_hook_strength=r1_result.score,
|
| 124 |
r2_coherence=r2_result.score,
|
| 125 |
+
r3_cultural_alignment=r3_result.score,
|
| 126 |
+
r4_debate_resolution=r4_result.score if r4_result else None,
|
| 127 |
+
r5_defender_preservation=r5_result.score,
|
| 128 |
)
|
| 129 |
|
| 130 |
self._state.action_history.append(arb_action.action_type)
|
| 131 |
if self.use_anti_gaming:
|
| 132 |
+
components, anti_log = self.aggregator.compute(
|
| 133 |
+
components,
|
| 134 |
+
self._state.episode_start_rewards,
|
| 135 |
+
self._state.action_history,
|
| 136 |
+
episode_id=self._state.episode_id,
|
| 137 |
+
step_num=self._state.step_num,
|
| 138 |
)
|
| 139 |
else:
|
| 140 |
components.compute_total()
|
| 141 |
+
from viral_script_engine.rewards.reward_aggregator import AntiGamingLog
|
| 142 |
+
anti_log = AntiGamingLog(
|
| 143 |
+
episode_id=self._state.episode_id,
|
| 144 |
+
step_num=self._state.step_num,
|
| 145 |
+
triggered=False,
|
| 146 |
+
penalty_applied=0.0,
|
| 147 |
+
pre_penalty_total=components.total,
|
| 148 |
+
post_penalty_total=components.total,
|
| 149 |
+
)
|
| 150 |
|
| 151 |
round_ = DebateRound(
|
| 152 |
step_num=self._state.step_num,
|
| 153 |
critic_claims=critique.claims,
|
| 154 |
+
defender_response=defender_output.model_dump(),
|
| 155 |
arbitrator_action=arb_action,
|
| 156 |
rewrite_diff=rewrite_result.diff,
|
| 157 |
reward_components=components,
|
|
|
|
| 161 |
self._state.last_reward_components = components
|
| 162 |
self._state.step_num += 1
|
| 163 |
|
| 164 |
+
if not hasattr(self._state, "anti_gaming_logs"):
|
| 165 |
+
self._state.anti_gaming_logs = []
|
| 166 |
+
self._state.anti_gaming_logs.append(anti_log.model_dump())
|
| 167 |
+
|
| 168 |
terminated = (
|
| 169 |
self._state.step_num >= self._state.max_steps
|
| 170 |
or components.total >= 0.9
|
| 171 |
)
|
| 172 |
info = {
|
| 173 |
"reward_components": components.model_dump(),
|
| 174 |
+
"anti_gaming_triggered": anti_log.triggered,
|
| 175 |
+
"penalty_reason": anti_log.rule_triggered,
|
| 176 |
+
"anti_gaming_log": anti_log.model_dump(),
|
| 177 |
}
|
| 178 |
return self._build_observation().model_dump(), components.total, terminated, False, info
|
| 179 |
|
|
|
|
| 189 |
"step_num": s.step_num,
|
| 190 |
"difficulty_level": s.difficulty_level,
|
| 191 |
"episode_id": s.episode_id,
|
| 192 |
+
"anti_gaming_logs": getattr(s, "anti_gaming_logs", []),
|
| 193 |
}
|
| 194 |
|
| 195 |
def _build_observation(self) -> Observation:
|
viral_script_engine/rewards/r3_cultural_alignment.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class CulturalRewardResult:
|
| 9 |
+
score: float
|
| 10 |
+
valid_refs_found: List[str]
|
| 11 |
+
correct_idioms_found: List[str]
|
| 12 |
+
invalid_signals_found: List[str]
|
| 13 |
+
anachronistic_signals_found: List[str]
|
| 14 |
+
region: str
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CulturalAlignmentReward:
|
| 18 |
+
def __init__(self, knowledge_base_path: str = "data/cultural_kb.json"):
|
| 19 |
+
with open(knowledge_base_path, "r", encoding="utf-8") as f:
|
| 20 |
+
self._kb = json.load(f)
|
| 21 |
+
|
| 22 |
+
def score(self, script: str, region: str) -> CulturalRewardResult:
|
| 23 |
+
if region not in self._kb:
|
| 24 |
+
return CulturalRewardResult(
|
| 25 |
+
score=0.5,
|
| 26 |
+
valid_refs_found=[],
|
| 27 |
+
correct_idioms_found=[],
|
| 28 |
+
invalid_signals_found=[],
|
| 29 |
+
anachronistic_signals_found=[],
|
| 30 |
+
region=region,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
kb = self._kb[region]
|
| 34 |
+
script_lower = script.lower()
|
| 35 |
+
|
| 36 |
+
valid_refs_found = [r for r in kb["valid_refs"] if r.lower() in script_lower]
|
| 37 |
+
correct_idioms_found = [i for i in kb["correct_idioms"] if i.lower() in script_lower]
|
| 38 |
+
invalid_signals_found = [s for s in kb["invalid_signals"] if s.lower() in script_lower]
|
| 39 |
+
anachronistic_signals_found = [a for a in kb["anachronistic_signals"] if a.lower() in script_lower]
|
| 40 |
+
|
| 41 |
+
numerator = (
|
| 42 |
+
len(valid_refs_found)
|
| 43 |
+
+ len(correct_idioms_found)
|
| 44 |
+
- len(invalid_signals_found)
|
| 45 |
+
- len(anachronistic_signals_found)
|
| 46 |
+
)
|
| 47 |
+
denominator = max(len(kb["valid_refs"]) + len(kb["correct_idioms"]), 1)
|
| 48 |
+
raw_score = numerator / denominator
|
| 49 |
+
clipped_score = max(0.0, min(1.0, raw_score))
|
| 50 |
+
|
| 51 |
+
return CulturalRewardResult(
|
| 52 |
+
score=clipped_score,
|
| 53 |
+
valid_refs_found=valid_refs_found,
|
| 54 |
+
correct_idioms_found=correct_idioms_found,
|
| 55 |
+
invalid_signals_found=invalid_signals_found,
|
| 56 |
+
anachronistic_signals_found=anachronistic_signals_found,
|
| 57 |
+
region=region,
|
| 58 |
+
)
|
viral_script_engine/rewards/r4_debate_resolution.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
from viral_script_engine.agents.critic import CriticAgent, CritiqueClaim
|
| 4 |
+
from viral_script_engine.environment.actions import ArbitratorAction
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class DebateResolutionResult:
|
| 9 |
+
score: float
|
| 10 |
+
resolution_status: str
|
| 11 |
+
original_claim_id: str
|
| 12 |
+
original_claim_class: str
|
| 13 |
+
new_claims_count: int
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _parse_seconds(ts: str) -> float:
|
| 17 |
+
if not ts or ts == "N/A":
|
| 18 |
+
return -1.0
|
| 19 |
+
try:
|
| 20 |
+
start = ts.split("-")[0].strip()
|
| 21 |
+
parts = start.split(":")
|
| 22 |
+
return float(parts[0]) * 60 + float(parts[1])
|
| 23 |
+
except Exception:
|
| 24 |
+
return -1.0
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DebateResolutionReward:
|
| 28 |
+
def __init__(self, critic_agent: CriticAgent):
|
| 29 |
+
self.critic = critic_agent
|
| 30 |
+
|
| 31 |
+
def score(
|
| 32 |
+
self,
|
| 33 |
+
new_script: str,
|
| 34 |
+
original_action: ArbitratorAction,
|
| 35 |
+
original_claim: CritiqueClaim,
|
| 36 |
+
region: str,
|
| 37 |
+
platform: str,
|
| 38 |
+
niche: str,
|
| 39 |
+
) -> DebateResolutionResult:
|
| 40 |
+
new_critique = self.critic.critique(new_script, region, platform, niche)
|
| 41 |
+
new_claims = new_critique.claims
|
| 42 |
+
|
| 43 |
+
orig_ts = _parse_seconds(original_claim.timestamp_range)
|
| 44 |
+
orig_class = original_claim.critique_class
|
| 45 |
+
|
| 46 |
+
matching = []
|
| 47 |
+
for c in new_claims:
|
| 48 |
+
if c.critique_class != orig_class:
|
| 49 |
+
continue
|
| 50 |
+
if orig_ts == -1.0:
|
| 51 |
+
matching.append(c)
|
| 52 |
+
else:
|
| 53 |
+
new_ts = _parse_seconds(c.timestamp_range)
|
| 54 |
+
if new_ts != -1.0 and abs(new_ts - orig_ts) <= 5:
|
| 55 |
+
matching.append(c)
|
| 56 |
+
|
| 57 |
+
if not matching:
|
| 58 |
+
return DebateResolutionResult(
|
| 59 |
+
score=1.0,
|
| 60 |
+
resolution_status="resolved",
|
| 61 |
+
original_claim_id=original_claim.claim_id,
|
| 62 |
+
original_claim_class=orig_class,
|
| 63 |
+
new_claims_count=len(new_claims),
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
worst = max(matching, key=lambda c: {"high": 2, "medium": 1, "low": 0}.get(c.severity, 1))
|
| 67 |
+
if worst.severity == "low":
|
| 68 |
+
return DebateResolutionResult(
|
| 69 |
+
score=0.5,
|
| 70 |
+
resolution_status="partially_resolved",
|
| 71 |
+
original_claim_id=original_claim.claim_id,
|
| 72 |
+
original_claim_class=orig_class,
|
| 73 |
+
new_claims_count=len(new_claims),
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
return DebateResolutionResult(
|
| 77 |
+
score=0.0,
|
| 78 |
+
resolution_status="persists",
|
| 79 |
+
original_claim_id=original_claim.claim_id,
|
| 80 |
+
original_claim_class=orig_class,
|
| 81 |
+
new_claims_count=len(new_claims),
|
| 82 |
+
)
|
viral_script_engine/rewards/r5_defender_preservation.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
from viral_script_engine.agents.defender import DefenderOutput
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class DefenderPreservationResult:
|
| 9 |
+
score: float
|
| 10 |
+
max_similarity: float
|
| 11 |
+
best_matching_sentence: str
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _sentence_split(text: str) -> List[str]:
|
| 15 |
+
import re
|
| 16 |
+
sentences = re.split(r"(?<=[.!?])\s+", text.strip())
|
| 17 |
+
return [s for s in sentences if s.strip()]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DefenderPreservationReward:
|
| 21 |
+
_model = None
|
| 22 |
+
_cache: dict = {}
|
| 23 |
+
|
| 24 |
+
def _get_model(self):
|
| 25 |
+
if DefenderPreservationReward._model is None:
|
| 26 |
+
from sentence_transformers import SentenceTransformer
|
| 27 |
+
DefenderPreservationReward._model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 28 |
+
return DefenderPreservationReward._model
|
| 29 |
+
|
| 30 |
+
def _embed(self, text: str):
|
| 31 |
+
import hashlib
|
| 32 |
+
key = hashlib.sha256(text.encode()).hexdigest()
|
| 33 |
+
if key not in DefenderPreservationReward._cache:
|
| 34 |
+
DefenderPreservationReward._cache[key] = self._get_model().encode(
|
| 35 |
+
text, convert_to_tensor=True
|
| 36 |
+
)
|
| 37 |
+
return DefenderPreservationReward._cache[key]
|
| 38 |
+
|
| 39 |
+
def score(self, defender_output: DefenderOutput, rewritten_script: str) -> DefenderPreservationResult:
|
| 40 |
+
from sentence_transformers.util import cos_sim
|
| 41 |
+
|
| 42 |
+
quote_emb = self._embed(defender_output.core_strength_quote)
|
| 43 |
+
sentences = _sentence_split(rewritten_script)
|
| 44 |
+
|
| 45 |
+
if not sentences:
|
| 46 |
+
return DefenderPreservationResult(score=0.0, max_similarity=0.0, best_matching_sentence="")
|
| 47 |
+
|
| 48 |
+
sims = [(float(cos_sim(quote_emb, self._embed(s))[0][0]), s) for s in sentences]
|
| 49 |
+
max_sim, best_sent = max(sims, key=lambda x: x[0])
|
| 50 |
+
|
| 51 |
+
if max_sim >= 0.85:
|
| 52 |
+
final_score = 1.0
|
| 53 |
+
elif max_sim >= 0.65:
|
| 54 |
+
final_score = max_sim
|
| 55 |
+
else:
|
| 56 |
+
final_score = 0.0
|
| 57 |
+
|
| 58 |
+
return DefenderPreservationResult(
|
| 59 |
+
score=final_score,
|
| 60 |
+
max_similarity=max_sim,
|
| 61 |
+
best_matching_sentence=best_sent,
|
| 62 |
+
)
|
viral_script_engine/rewards/reward_aggregator.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
import logging
|
| 2 |
-
from typing import List
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from viral_script_engine.environment.actions import ActionType
|
| 5 |
from viral_script_engine.environment.observations import RewardComponents
|
|
@@ -11,6 +13,19 @@ _COMPONENT_FIELDS = [
|
|
| 11 |
"r4_debate_resolution", "r5_defender_preservation",
|
| 12 |
]
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
class RewardAggregator:
|
| 16 |
def compute(
|
|
@@ -18,20 +33,31 @@ class RewardAggregator:
|
|
| 18 |
components: RewardComponents,
|
| 19 |
episode_start_components: RewardComponents,
|
| 20 |
action_history: List[ActionType],
|
| 21 |
-
|
|
|
|
|
|
|
| 22 |
components.compute_total()
|
|
|
|
| 23 |
|
| 24 |
-
# Anti-gaming rule 1: catastrophic drop (>0.2 drop in any component)
|
| 25 |
for field in _COMPONENT_FIELDS:
|
| 26 |
curr = getattr(components, field)
|
| 27 |
start = getattr(episode_start_components, field)
|
| 28 |
-
if curr is not None and start is not None and curr < start -
|
| 29 |
logger.warning("Catastrophic drop in %s: %.3f -> %.3f", field, start, curr)
|
| 30 |
components.total = 0.0
|
| 31 |
components.anti_gaming_penalty = start - curr
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
# Anti-gaming rule 2: action diversity (last 3 same ActionType)
|
| 35 |
penalty = 0.0
|
| 36 |
if len(action_history) >= 3 and len(set(action_history[-3:])) == 1:
|
| 37 |
penalty = 0.15
|
|
@@ -39,4 +65,15 @@ class RewardAggregator:
|
|
| 39 |
|
| 40 |
components.anti_gaming_penalty = penalty
|
| 41 |
components.total = max(0.0, min(1.0, components.total - penalty))
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import logging
|
| 2 |
+
from typing import List, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
|
| 6 |
from viral_script_engine.environment.actions import ActionType
|
| 7 |
from viral_script_engine.environment.observations import RewardComponents
|
|
|
|
| 13 |
"r4_debate_resolution", "r5_defender_preservation",
|
| 14 |
]
|
| 15 |
|
| 16 |
+
_DROP_THRESHOLD = 0.25
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AntiGamingLog(BaseModel):
|
| 20 |
+
episode_id: str
|
| 21 |
+
step_num: int
|
| 22 |
+
triggered: bool
|
| 23 |
+
rule_triggered: Optional[str] = None
|
| 24 |
+
component_that_dropped: Optional[str] = None
|
| 25 |
+
penalty_applied: float
|
| 26 |
+
pre_penalty_total: float
|
| 27 |
+
post_penalty_total: float
|
| 28 |
+
|
| 29 |
|
| 30 |
class RewardAggregator:
|
| 31 |
def compute(
|
|
|
|
| 33 |
components: RewardComponents,
|
| 34 |
episode_start_components: RewardComponents,
|
| 35 |
action_history: List[ActionType],
|
| 36 |
+
episode_id: str = "",
|
| 37 |
+
step_num: int = 0,
|
| 38 |
+
) -> Tuple[RewardComponents, AntiGamingLog]:
|
| 39 |
components.compute_total()
|
| 40 |
+
pre_penalty_total = components.total
|
| 41 |
|
|
|
|
| 42 |
for field in _COMPONENT_FIELDS:
|
| 43 |
curr = getattr(components, field)
|
| 44 |
start = getattr(episode_start_components, field)
|
| 45 |
+
if curr is not None and start is not None and curr < start - _DROP_THRESHOLD:
|
| 46 |
logger.warning("Catastrophic drop in %s: %.3f -> %.3f", field, start, curr)
|
| 47 |
components.total = 0.0
|
| 48 |
components.anti_gaming_penalty = start - curr
|
| 49 |
+
log = AntiGamingLog(
|
| 50 |
+
episode_id=episode_id,
|
| 51 |
+
step_num=step_num,
|
| 52 |
+
triggered=True,
|
| 53 |
+
rule_triggered="catastrophic_drop",
|
| 54 |
+
component_that_dropped=field,
|
| 55 |
+
penalty_applied=components.anti_gaming_penalty,
|
| 56 |
+
pre_penalty_total=pre_penalty_total,
|
| 57 |
+
post_penalty_total=0.0,
|
| 58 |
+
)
|
| 59 |
+
return components, log
|
| 60 |
|
|
|
|
| 61 |
penalty = 0.0
|
| 62 |
if len(action_history) >= 3 and len(set(action_history[-3:])) == 1:
|
| 63 |
penalty = 0.15
|
|
|
|
| 65 |
|
| 66 |
components.anti_gaming_penalty = penalty
|
| 67 |
components.total = max(0.0, min(1.0, components.total - penalty))
|
| 68 |
+
|
| 69 |
+
log = AntiGamingLog(
|
| 70 |
+
episode_id=episode_id,
|
| 71 |
+
step_num=step_num,
|
| 72 |
+
triggered=penalty > 0,
|
| 73 |
+
rule_triggered="action_repetition" if penalty > 0 else None,
|
| 74 |
+
component_that_dropped=None,
|
| 75 |
+
penalty_applied=penalty,
|
| 76 |
+
pre_penalty_total=pre_penalty_total,
|
| 77 |
+
post_penalty_total=components.total,
|
| 78 |
+
)
|
| 79 |
+
return components, log
|
viral_script_engine/scripts/run_baseline.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import json
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
from rich.console import Console
|
| 9 |
+
from rich.table import Table
|
| 10 |
+
from rich import box
|
| 11 |
+
|
| 12 |
+
load_dotenv()
|
| 13 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 14 |
+
|
| 15 |
+
from viral_script_engine.agents.baseline_arbitrator import BaselineArbitratorAgent
|
| 16 |
+
from viral_script_engine.environment.env import ViralScriptEnv
|
| 17 |
+
|
| 18 |
+
console = Console()
|
| 19 |
+
BASE_DIR = Path(__file__).parent.parent
|
| 20 |
+
LOGS_DIR = BASE_DIR / "logs"
|
| 21 |
+
LOGS_DIR.mkdir(exist_ok=True)
|
| 22 |
+
|
| 23 |
+
_SCHEDULE = (
|
| 24 |
+
[(i, "easy") for i in range(1, 9)]
|
| 25 |
+
+ [(i, "medium") for i in range(9, 17)]
|
| 26 |
+
+ [(i, "hard") for i in range(17, 21)]
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
_REWARD_KEYS = ["r1_hook_strength", "r2_coherence", "r3_cultural_alignment",
|
| 30 |
+
"r4_debate_resolution", "r5_defender_preservation"]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _make_env(difficulty: str) -> ViralScriptEnv:
|
| 34 |
+
return ViralScriptEnv(
|
| 35 |
+
scripts_path=str(BASE_DIR / "data" / "test_scripts" / "scripts.json"),
|
| 36 |
+
cultural_kb_path=str(BASE_DIR / "data" / "cultural_kb.json"),
|
| 37 |
+
max_steps=5,
|
| 38 |
+
difficulty=difficulty,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def run_episode(ep_num: int, difficulty: str, agent: BaselineArbitratorAgent) -> dict:
|
| 43 |
+
env = _make_env(difficulty)
|
| 44 |
+
obs, _ = env.reset()
|
| 45 |
+
|
| 46 |
+
episode_id = obs["episode_id"]
|
| 47 |
+
script_id = "unknown"
|
| 48 |
+
state = env.state()
|
| 49 |
+
original_script = state.get("original_script", "")
|
| 50 |
+
|
| 51 |
+
steps_log = []
|
| 52 |
+
total_reward = 0.0
|
| 53 |
+
|
| 54 |
+
for _ in range(env.max_steps):
|
| 55 |
+
action = agent.act(obs)
|
| 56 |
+
obs, reward, terminated, truncated, info = env.step(action)
|
| 57 |
+
rc = info["reward_components"]
|
| 58 |
+
anti_log = info.get("anti_gaming_log", {})
|
| 59 |
+
|
| 60 |
+
step_entry = {
|
| 61 |
+
"r1": rc.get("r1_hook_strength"),
|
| 62 |
+
"r2": rc.get("r2_coherence"),
|
| 63 |
+
"r3": rc.get("r3_cultural_alignment"),
|
| 64 |
+
"r4": rc.get("r4_debate_resolution"),
|
| 65 |
+
"r5": rc.get("r5_defender_preservation"),
|
| 66 |
+
"total": reward,
|
| 67 |
+
"anti_gaming_triggered": anti_log.get("triggered", False),
|
| 68 |
+
"penalty": anti_log.get("penalty_applied", 0.0),
|
| 69 |
+
}
|
| 70 |
+
steps_log.append(step_entry)
|
| 71 |
+
total_reward = reward
|
| 72 |
+
|
| 73 |
+
if terminated or truncated:
|
| 74 |
+
break
|
| 75 |
+
|
| 76 |
+
final_state = env.state()
|
| 77 |
+
final_script = final_state.get("current_script", "")
|
| 78 |
+
|
| 79 |
+
return {
|
| 80 |
+
"episode_num": ep_num,
|
| 81 |
+
"episode_id": episode_id,
|
| 82 |
+
"difficulty": difficulty,
|
| 83 |
+
"script_id": script_id,
|
| 84 |
+
"steps": steps_log,
|
| 85 |
+
"total_reward": total_reward,
|
| 86 |
+
"anti_gaming_logs": final_state.get("anti_gaming_logs", []),
|
| 87 |
+
"original_script": original_script,
|
| 88 |
+
"final_script": final_script,
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def main():
|
| 93 |
+
agent = BaselineArbitratorAgent()
|
| 94 |
+
all_episodes = []
|
| 95 |
+
|
| 96 |
+
for ep_num, difficulty in _SCHEDULE:
|
| 97 |
+
console.print(f"[dim]Episode {ep_num:02d}/20 ({difficulty})...[/dim]")
|
| 98 |
+
try:
|
| 99 |
+
result = run_episode(ep_num, difficulty, agent)
|
| 100 |
+
all_episodes.append(result)
|
| 101 |
+
console.print(
|
| 102 |
+
f" -> total_reward={result['total_reward']:.3f} "
|
| 103 |
+
f"steps={len(result['steps'])}"
|
| 104 |
+
)
|
| 105 |
+
except Exception as e:
|
| 106 |
+
console.print(f" [red]ERROR episode {ep_num}: {e}[/red]")
|
| 107 |
+
all_episodes.append({
|
| 108 |
+
"episode_num": ep_num,
|
| 109 |
+
"episode_id": "",
|
| 110 |
+
"difficulty": difficulty,
|
| 111 |
+
"script_id": "error",
|
| 112 |
+
"steps": [],
|
| 113 |
+
"total_reward": 0.0,
|
| 114 |
+
"anti_gaming_logs": [],
|
| 115 |
+
"original_script": "",
|
| 116 |
+
"final_script": "",
|
| 117 |
+
"error": str(e),
|
| 118 |
+
})
|
| 119 |
+
|
| 120 |
+
results_path = LOGS_DIR / "baseline_results.json"
|
| 121 |
+
with open(results_path, "w", encoding="utf-8") as f:
|
| 122 |
+
json.dump(all_episodes, f, indent=2, default=str)
|
| 123 |
+
|
| 124 |
+
_save_plots(all_episodes)
|
| 125 |
+
_print_summary(all_episodes)
|
| 126 |
+
|
| 127 |
+
mean_total = float(np.mean([e["total_reward"] for e in all_episodes]))
|
| 128 |
+
console.print(
|
| 129 |
+
f"\n[bold green]PHASE 2 GATE: PASS β Baseline curves saved. "
|
| 130 |
+
f"Pre-training mean total reward: {mean_total:.2f}[/bold green]"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def _collect_reward_series(episodes: list, key: str):
|
| 135 |
+
series = []
|
| 136 |
+
for ep in episodes:
|
| 137 |
+
vals = [s.get(key) for s in ep.get("steps", []) if s.get(key) is not None]
|
| 138 |
+
series.append(vals[-1] if vals else 0.0)
|
| 139 |
+
return series
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _save_plots(episodes: list):
|
| 143 |
+
import matplotlib
|
| 144 |
+
matplotlib.use("Agg")
|
| 145 |
+
import matplotlib.pyplot as plt
|
| 146 |
+
|
| 147 |
+
labels = {
|
| 148 |
+
"r1": "R1 Hook Strength",
|
| 149 |
+
"r2": "R2 Coherence",
|
| 150 |
+
"r3": "R3 Cultural Alignment",
|
| 151 |
+
"r4": "R4 Debate Resolution",
|
| 152 |
+
"r5": "R5 Defender Preservation",
|
| 153 |
+
"total": "Total Reward",
|
| 154 |
+
}
|
| 155 |
+
keys = list(labels.keys())
|
| 156 |
+
ep_nums = [e["episode_num"] for e in episodes]
|
| 157 |
+
|
| 158 |
+
fig, axes = plt.subplots(2, 3, figsize=(14, 8), dpi=150)
|
| 159 |
+
fig.suptitle(
|
| 160 |
+
"Baseline (Untrained) Arbitrator β Pre-Training Reward Curves",
|
| 161 |
+
fontsize=13,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
for idx, key in enumerate(keys):
|
| 165 |
+
ax = axes[idx // 3][idx % 3]
|
| 166 |
+
series = _collect_reward_series(episodes, key) if key != "total" else [e["total_reward"] for e in episodes]
|
| 167 |
+
ax.plot(ep_nums, series, marker="o", linewidth=1.5, markersize=4)
|
| 168 |
+
ax.set_title(labels[key], fontsize=10)
|
| 169 |
+
ax.set_xlabel("Episode", fontsize=8)
|
| 170 |
+
ax.set_ylabel("Reward", fontsize=8)
|
| 171 |
+
ax.set_ylim(0, 1)
|
| 172 |
+
ax.set_xlim(min(ep_nums) - 0.5, max(ep_nums) + 0.5)
|
| 173 |
+
ax.tick_params(labelsize=7)
|
| 174 |
+
ax.grid(True, alpha=0.3)
|
| 175 |
+
|
| 176 |
+
plt.tight_layout()
|
| 177 |
+
plot_path = LOGS_DIR / "baseline_reward_curves.png"
|
| 178 |
+
plt.savefig(str(plot_path), dpi=150)
|
| 179 |
+
plt.close()
|
| 180 |
+
console.print(f"[dim]Curves saved -> {plot_path}[/dim]")
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _print_summary(episodes: list):
|
| 184 |
+
table = Table(title="Baseline Results β Mean +/- Std (20 episodes)", box=box.SIMPLE_HEAD)
|
| 185 |
+
table.add_column("Reward", style="cyan", min_width=28)
|
| 186 |
+
table.add_column("Mean", min_width=8)
|
| 187 |
+
table.add_column("Std", min_width=8)
|
| 188 |
+
table.add_column("Min", min_width=8)
|
| 189 |
+
table.add_column("Max", min_width=8)
|
| 190 |
+
|
| 191 |
+
label_map = {
|
| 192 |
+
"r1": "R1 Hook Strength",
|
| 193 |
+
"r2": "R2 Coherence",
|
| 194 |
+
"r3": "R3 Cultural Alignment",
|
| 195 |
+
"r4": "R4 Debate Resolution",
|
| 196 |
+
"r5": "R5 Defender Preservation",
|
| 197 |
+
"total": "Total Reward",
|
| 198 |
+
}
|
| 199 |
+
for key, label in label_map.items():
|
| 200 |
+
if key == "total":
|
| 201 |
+
vals = [e["total_reward"] for e in episodes]
|
| 202 |
+
else:
|
| 203 |
+
vals = _collect_reward_series(episodes, key)
|
| 204 |
+
arr = np.array(vals, dtype=float)
|
| 205 |
+
table.add_row(
|
| 206 |
+
label,
|
| 207 |
+
f"{arr.mean():.3f}",
|
| 208 |
+
f"{arr.std():.3f}",
|
| 209 |
+
f"{arr.min():.3f}",
|
| 210 |
+
f"{arr.max():.3f}",
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
console.print(table)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
if __name__ == "__main__":
|
| 217 |
+
main()
|
viral_script_engine/tests/test_environment.py
CHANGED
|
@@ -49,6 +49,10 @@ def env():
|
|
| 49 |
with (
|
| 50 |
patch("viral_script_engine.environment.env.CriticAgent") as mock_critic_cls,
|
| 51 |
patch("viral_script_engine.environment.env.RewriterAgent") as mock_rewriter_cls,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
):
|
| 53 |
mock_critic = MagicMock()
|
| 54 |
mock_critic.critique.return_value = make_mock_critique()
|
|
@@ -58,6 +62,39 @@ def env():
|
|
| 58 |
mock_rewriter.rewrite.side_effect = make_mock_rewrite
|
| 59 |
mock_rewriter_cls.return_value = mock_rewriter
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
from viral_script_engine.environment.env import ViralScriptEnv
|
| 62 |
yield ViralScriptEnv(scripts_path=SCRIPTS_PATH, max_steps=5, difficulty="easy")
|
| 63 |
|
|
|
|
| 49 |
with (
|
| 50 |
patch("viral_script_engine.environment.env.CriticAgent") as mock_critic_cls,
|
| 51 |
patch("viral_script_engine.environment.env.RewriterAgent") as mock_rewriter_cls,
|
| 52 |
+
patch("viral_script_engine.environment.env.DefenderAgent") as mock_defender_cls,
|
| 53 |
+
patch("viral_script_engine.environment.env.CulturalAlignmentReward") as mock_r3_cls,
|
| 54 |
+
patch("viral_script_engine.environment.env.DebateResolutionReward") as mock_r4_cls,
|
| 55 |
+
patch("viral_script_engine.environment.env.DefenderPreservationReward") as mock_r5_cls,
|
| 56 |
):
|
| 57 |
mock_critic = MagicMock()
|
| 58 |
mock_critic.critique.return_value = make_mock_critique()
|
|
|
|
| 62 |
mock_rewriter.rewrite.side_effect = make_mock_rewrite
|
| 63 |
mock_rewriter_cls.return_value = mock_rewriter
|
| 64 |
|
| 65 |
+
mock_defender = MagicMock()
|
| 66 |
+
from viral_script_engine.agents.defender import DefenderOutput
|
| 67 |
+
mock_defender.defend.return_value = DefenderOutput(
|
| 68 |
+
core_strength="strong hook",
|
| 69 |
+
core_strength_quote="test quote",
|
| 70 |
+
defense_argument="preserve it",
|
| 71 |
+
flagged_critic_claims=[],
|
| 72 |
+
regional_voice_elements=[],
|
| 73 |
+
)
|
| 74 |
+
mock_defender_cls.return_value = mock_defender
|
| 75 |
+
|
| 76 |
+
mock_r3 = MagicMock()
|
| 77 |
+
mock_r3.score.return_value = MagicMock(score=0.6)
|
| 78 |
+
mock_r3_cls.return_value = mock_r3
|
| 79 |
+
|
| 80 |
+
mock_r4 = MagicMock()
|
| 81 |
+
from viral_script_engine.rewards.r4_debate_resolution import DebateResolutionResult
|
| 82 |
+
mock_r4.score.return_value = DebateResolutionResult(
|
| 83 |
+
score=0.8,
|
| 84 |
+
resolution_status="resolved",
|
| 85 |
+
original_claim_id="C1",
|
| 86 |
+
original_claim_class="hook_weakness",
|
| 87 |
+
new_claims_count=2,
|
| 88 |
+
)
|
| 89 |
+
mock_r4_cls.return_value = mock_r4
|
| 90 |
+
|
| 91 |
+
mock_r5 = MagicMock()
|
| 92 |
+
from viral_script_engine.rewards.r5_defender_preservation import DefenderPreservationResult
|
| 93 |
+
mock_r5.score.return_value = DefenderPreservationResult(
|
| 94 |
+
score=0.9, max_similarity=0.9, best_matching_sentence="test quote"
|
| 95 |
+
)
|
| 96 |
+
mock_r5_cls.return_value = mock_r5
|
| 97 |
+
|
| 98 |
from viral_script_engine.environment.env import ViralScriptEnv
|
| 99 |
yield ViralScriptEnv(scripts_path=SCRIPTS_PATH, max_steps=5, difficulty="easy")
|
| 100 |
|
viral_script_engine/tests/test_phase2.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import pytest
|
| 3 |
+
from unittest.mock import MagicMock, patch
|
| 4 |
+
|
| 5 |
+
from viral_script_engine.agents.defender import DefenderAgent, DefenderOutput, DefenderParseError
|
| 6 |
+
from viral_script_engine.agents.critic import CritiqueClaim
|
| 7 |
+
from viral_script_engine.environment.actions import ActionType, ArbitratorAction
|
| 8 |
+
from viral_script_engine.rewards.r3_cultural_alignment import CulturalAlignmentReward
|
| 9 |
+
from viral_script_engine.rewards.r4_debate_resolution import DebateResolutionReward, DebateResolutionResult
|
| 10 |
+
from viral_script_engine.rewards.r5_defender_preservation import DefenderPreservationReward
|
| 11 |
+
from viral_script_engine.rewards.reward_aggregator import RewardAggregator, AntiGamingLog
|
| 12 |
+
from viral_script_engine.environment.observations import RewardComponents
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# βββ fixtures ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 16 |
+
|
| 17 |
+
MOCK_DEFENDER_RESPONSE = json.dumps({
|
| 18 |
+
"core_strength": "Relatable hook about saving money",
|
| 19 |
+
"core_strength_quote": "Let me tell you a secret",
|
| 20 |
+
"defense_argument": "This creates immediate viewer curiosity and should not be changed.",
|
| 21 |
+
"flagged_critic_claims": ["C2"],
|
| 22 |
+
"regional_voice_elements": ["yaar", "ek dum solid"],
|
| 23 |
+
})
|
| 24 |
+
|
| 25 |
+
MOCK_CRITIQUE_CLAIMS = [
|
| 26 |
+
CritiqueClaim(
|
| 27 |
+
claim_id="C1",
|
| 28 |
+
critique_class="hook_weakness",
|
| 29 |
+
claim_text="Weak hook.",
|
| 30 |
+
timestamp_range="0:00-0:03",
|
| 31 |
+
evidence="Let me tell you a secret",
|
| 32 |
+
is_falsifiable=True,
|
| 33 |
+
severity="high",
|
| 34 |
+
),
|
| 35 |
+
CritiqueClaim(
|
| 36 |
+
claim_id="C2",
|
| 37 |
+
critique_class="cta_buried",
|
| 38 |
+
claim_text="CTA at end.",
|
| 39 |
+
timestamp_range="0:45-0:50",
|
| 40 |
+
evidence="Like and save this video",
|
| 41 |
+
is_falsifiable=True,
|
| 42 |
+
severity="medium",
|
| 43 |
+
),
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@pytest.fixture
|
| 48 |
+
def mock_defender_llm(monkeypatch):
|
| 49 |
+
monkeypatch.setattr(
|
| 50 |
+
"viral_script_engine.agents.llm_backend.LLMBackend.generate",
|
| 51 |
+
lambda self, sys_prompt, usr_prompt, **kw: MOCK_DEFENDER_RESPONSE,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# βββ Step 1: DefenderAgent ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 56 |
+
|
| 57 |
+
def test_defender_parses_output(mock_defender_llm):
|
| 58 |
+
agent = DefenderAgent()
|
| 59 |
+
result = agent.defend(
|
| 60 |
+
script="Let me tell you a secret about saving money. yaar, ek dum solid plan.",
|
| 61 |
+
critic_claims=MOCK_CRITIQUE_CLAIMS,
|
| 62 |
+
region="mumbai_gen_z",
|
| 63 |
+
platform="instagram",
|
| 64 |
+
)
|
| 65 |
+
assert isinstance(result, DefenderOutput)
|
| 66 |
+
assert result.core_strength_quote == "Let me tell you a secret"
|
| 67 |
+
assert "C2" in result.flagged_critic_claims
|
| 68 |
+
assert "yaar" in result.regional_voice_elements
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def test_defender_retries_on_invalid_json(monkeypatch):
|
| 72 |
+
call_count = {"n": 0}
|
| 73 |
+
|
| 74 |
+
def fake_generate(self, sys_prompt, usr_prompt, **kw):
|
| 75 |
+
call_count["n"] += 1
|
| 76 |
+
if call_count["n"] == 1:
|
| 77 |
+
return "NOT JSON AT ALL"
|
| 78 |
+
return MOCK_DEFENDER_RESPONSE
|
| 79 |
+
|
| 80 |
+
monkeypatch.setattr(
|
| 81 |
+
"viral_script_engine.agents.llm_backend.LLMBackend.generate",
|
| 82 |
+
fake_generate,
|
| 83 |
+
)
|
| 84 |
+
agent = DefenderAgent()
|
| 85 |
+
result = agent.defend("script", MOCK_CRITIQUE_CLAIMS, "mumbai_gen_z", "instagram")
|
| 86 |
+
assert isinstance(result, DefenderOutput)
|
| 87 |
+
assert call_count["n"] == 2
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def test_defender_raises_after_two_failures(monkeypatch):
|
| 91 |
+
monkeypatch.setattr(
|
| 92 |
+
"viral_script_engine.agents.llm_backend.LLMBackend.generate",
|
| 93 |
+
lambda self, sys_prompt, usr_prompt, **kw: "BAD JSON",
|
| 94 |
+
)
|
| 95 |
+
agent = DefenderAgent()
|
| 96 |
+
with pytest.raises(DefenderParseError):
|
| 97 |
+
agent.defend("script", MOCK_CRITIQUE_CLAIMS, "mumbai_gen_z", "instagram")
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# βββ Step 2: R3 CulturalAlignmentReward ββββββββββββββββββββββββββββββββββββββ
|
| 101 |
+
|
| 102 |
+
@pytest.fixture
|
| 103 |
+
def r3(tmp_path):
|
| 104 |
+
kb = {
|
| 105 |
+
"mumbai_gen_z": {
|
| 106 |
+
"valid_refs": ["Bandra", "CSMT", "local train", "Swiggy", "IPL"],
|
| 107 |
+
"correct_idioms": ["ek dum solid", "kya scene hai", "full on"],
|
| 108 |
+
"invalid_signals": ["trunk call", "VHS", "walkman"],
|
| 109 |
+
"anachronistic_signals": [],
|
| 110 |
+
},
|
| 111 |
+
"tier2_hindi_belt": {
|
| 112 |
+
"valid_refs": ["kirana store", "sabzi mandi", "jugaad", "panchayat", "mela"],
|
| 113 |
+
"correct_idioms": ["bilkul sahi", "arey bhai", "seedha baat"],
|
| 114 |
+
"invalid_signals": ["SaaS", "venture capital", "coworking space"],
|
| 115 |
+
"anachronistic_signals": [],
|
| 116 |
+
},
|
| 117 |
+
}
|
| 118 |
+
kb_path = tmp_path / "test_kb.json"
|
| 119 |
+
kb_path.write_text(json.dumps(kb), encoding="utf-8")
|
| 120 |
+
return CulturalAlignmentReward(knowledge_base_path=str(kb_path))
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def test_r3_scores_regional_script(r3):
|
| 124 |
+
script = "Take the local train to Bandra. IPL is on at night. ek dum solid plan yaar."
|
| 125 |
+
result = r3.score(script, "mumbai_gen_z")
|
| 126 |
+
assert result.score > 0.0
|
| 127 |
+
assert "local train" in result.valid_refs_found or "Bandra" in result.valid_refs_found
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def test_r3_scores_non_regional_script_lower(r3):
|
| 131 |
+
script = "Buy on Amazon. Use your credit card. Free delivery available nationwide."
|
| 132 |
+
regional = r3.score(
|
| 133 |
+
"Take local train to Bandra. IPL is on. ek dum solid.", "mumbai_gen_z"
|
| 134 |
+
)
|
| 135 |
+
non_regional = r3.score(script, "mumbai_gen_z")
|
| 136 |
+
assert regional.score >= non_regional.score
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def test_r3_penalises_invalid_signals(r3):
|
| 140 |
+
script = "This is like an old VHS walkman trunk call era."
|
| 141 |
+
result = r3.score(script, "mumbai_gen_z")
|
| 142 |
+
assert result.score == 0.0
|
| 143 |
+
assert len(result.invalid_signals_found) > 0
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def test_r3_neutral_for_unknown_region(r3):
|
| 147 |
+
result = r3.score("any script", "unknown_region_xyz")
|
| 148 |
+
assert result.score == 0.5
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def test_r3_tier2_valid(r3):
|
| 152 |
+
script = "Went to kirana store, met at sabzi mandi, pure jugaad. bilkul sahi plan."
|
| 153 |
+
result = r3.score(script, "tier2_hindi_belt")
|
| 154 |
+
assert result.score > 0.0
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def test_r3_tier2_penalises_metro_jargon(r3):
|
| 158 |
+
script = "We raised SaaS venture capital at a coworking space."
|
| 159 |
+
result = r3.score(script, "tier2_hindi_belt")
|
| 160 |
+
assert len(result.invalid_signals_found) > 0
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
# βββ Step 3: R4 DebateResolutionReward βββββββββββββββββββββββββββββββββββββββ
|
| 164 |
+
|
| 165 |
+
def _make_critique_output(claims):
|
| 166 |
+
from viral_script_engine.agents.critic import CritiqueOutput
|
| 167 |
+
return CritiqueOutput(claims=claims, overall_severity="medium", raw_response="")
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def _make_claim(claim_id, critique_class, timestamp_range, severity="high"):
|
| 171 |
+
return CritiqueClaim(
|
| 172 |
+
claim_id=claim_id,
|
| 173 |
+
critique_class=critique_class,
|
| 174 |
+
claim_text="test claim",
|
| 175 |
+
timestamp_range=timestamp_range,
|
| 176 |
+
evidence="evidence text",
|
| 177 |
+
is_falsifiable=True,
|
| 178 |
+
severity=severity,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def _make_action():
|
| 183 |
+
return ArbitratorAction(
|
| 184 |
+
action_type=ActionType.HOOK_REWRITE,
|
| 185 |
+
target_section="hook",
|
| 186 |
+
instruction="fix hook",
|
| 187 |
+
critique_claim_id="C1",
|
| 188 |
+
reasoning="test",
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def test_r4_resolved_when_no_matching_claim():
|
| 193 |
+
mock_critic = MagicMock()
|
| 194 |
+
mock_critic.critique.return_value = _make_critique_output([
|
| 195 |
+
_make_claim("C1", "cta_buried", "0:45-0:50"),
|
| 196 |
+
])
|
| 197 |
+
r4 = DebateResolutionReward(critic_agent=mock_critic)
|
| 198 |
+
original_claim = _make_claim("C1", "hook_weakness", "0:00-0:03", "high")
|
| 199 |
+
result = r4.score("new script", _make_action(), original_claim,
|
| 200 |
+
"mumbai_gen_z", "instagram", "finance")
|
| 201 |
+
assert result.score == 1.0
|
| 202 |
+
assert result.resolution_status == "resolved"
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def test_r4_partially_resolved_when_severity_drops():
|
| 206 |
+
mock_critic = MagicMock()
|
| 207 |
+
mock_critic.critique.return_value = _make_critique_output([
|
| 208 |
+
_make_claim("C1", "hook_weakness", "0:01-0:04", "low"),
|
| 209 |
+
])
|
| 210 |
+
r4 = DebateResolutionReward(critic_agent=mock_critic)
|
| 211 |
+
original_claim = _make_claim("C1", "hook_weakness", "0:00-0:03", "high")
|
| 212 |
+
result = r4.score("new script", _make_action(), original_claim,
|
| 213 |
+
"mumbai_gen_z", "instagram", "finance")
|
| 214 |
+
assert result.score == 0.5
|
| 215 |
+
assert result.resolution_status == "partially_resolved"
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def test_r4_persists_when_same_severity():
|
| 219 |
+
mock_critic = MagicMock()
|
| 220 |
+
mock_critic.critique.return_value = _make_critique_output([
|
| 221 |
+
_make_claim("C1", "hook_weakness", "0:01-0:03", "high"),
|
| 222 |
+
])
|
| 223 |
+
r4 = DebateResolutionReward(critic_agent=mock_critic)
|
| 224 |
+
original_claim = _make_claim("C1", "hook_weakness", "0:00-0:03", "high")
|
| 225 |
+
result = r4.score("new script", _make_action(), original_claim,
|
| 226 |
+
"mumbai_gen_z", "instagram", "finance")
|
| 227 |
+
assert result.score == 0.0
|
| 228 |
+
assert result.resolution_status == "persists"
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# βββ Step 4: R5 DefenderPreservationReward βββββββββββββββββββββββββββββββββββ
|
| 232 |
+
|
| 233 |
+
def _make_defender_output(quote: str) -> DefenderOutput:
|
| 234 |
+
return DefenderOutput(
|
| 235 |
+
core_strength="Strong opening",
|
| 236 |
+
core_strength_quote=quote,
|
| 237 |
+
defense_argument="Should be preserved.",
|
| 238 |
+
flagged_critic_claims=[],
|
| 239 |
+
regional_voice_elements=[],
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def test_r5_high_score_when_quote_present():
|
| 244 |
+
r5 = DefenderPreservationReward()
|
| 245 |
+
quote = "Let me tell you a secret about saving money every month."
|
| 246 |
+
script = "Let me tell you a secret about saving money every month. This is the key insight."
|
| 247 |
+
defender_out = _make_defender_output(quote)
|
| 248 |
+
result = r5.score(defender_out, script)
|
| 249 |
+
assert result.score >= 0.85
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def test_r5_zero_score_when_quote_absent():
|
| 253 |
+
r5 = DefenderPreservationReward()
|
| 254 |
+
quote = "Completely different text that shares nothing with rewrite."
|
| 255 |
+
script = "Today we discuss quantum physics and neutron stars in distant galaxies."
|
| 256 |
+
defender_out = _make_defender_output(quote)
|
| 257 |
+
result = r5.score(defender_out, script)
|
| 258 |
+
assert result.score < 0.65
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# βββ Step 5/6: AntiGamingLog and RewardAggregator ββββββββββββββββββββββββββββ
|
| 262 |
+
|
| 263 |
+
def _make_components(**kwargs) -> RewardComponents:
|
| 264 |
+
defaults = dict(
|
| 265 |
+
r1_hook_strength=0.7, r2_coherence=0.7,
|
| 266 |
+
r3_cultural_alignment=0.7,
|
| 267 |
+
r4_debate_resolution=None,
|
| 268 |
+
r5_defender_preservation=None,
|
| 269 |
+
)
|
| 270 |
+
defaults.update(kwargs)
|
| 271 |
+
rc = RewardComponents(**defaults)
|
| 272 |
+
rc.compute_total()
|
| 273 |
+
return rc
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def test_anti_gaming_catastrophic_drop_zeroes_reward():
|
| 277 |
+
aggregator = RewardAggregator()
|
| 278 |
+
start = _make_components(r2_coherence=0.8)
|
| 279 |
+
current = _make_components(r2_coherence=0.4)
|
| 280 |
+
|
| 281 |
+
result, log = aggregator.compute(current, start, [], episode_id="ep1", step_num=1)
|
| 282 |
+
|
| 283 |
+
assert result.total == 0.0
|
| 284 |
+
assert log.triggered is True
|
| 285 |
+
assert log.rule_triggered == "catastrophic_drop"
|
| 286 |
+
assert log.component_that_dropped == "r2_coherence"
|
| 287 |
+
assert log.post_penalty_total == 0.0
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def test_anti_gaming_diversity_penalty_fires_on_3x_same():
|
| 291 |
+
aggregator = RewardAggregator()
|
| 292 |
+
start = _make_components()
|
| 293 |
+
current = _make_components()
|
| 294 |
+
history = [ActionType.HOOK_REWRITE] * 3
|
| 295 |
+
|
| 296 |
+
result, log = aggregator.compute(current, start, history, episode_id="ep2", step_num=2)
|
| 297 |
+
|
| 298 |
+
assert log.triggered is True
|
| 299 |
+
assert log.rule_triggered == "action_repetition"
|
| 300 |
+
assert log.penalty_applied == 0.15
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def test_anti_gaming_log_not_triggered_clean():
|
| 304 |
+
aggregator = RewardAggregator()
|
| 305 |
+
start = _make_components()
|
| 306 |
+
current = _make_components()
|
| 307 |
+
history = [ActionType.HOOK_REWRITE, ActionType.CTA_PLACEMENT, ActionType.SECTION_REORDER]
|
| 308 |
+
|
| 309 |
+
result, log = aggregator.compute(current, start, history, episode_id="ep3", step_num=1)
|
| 310 |
+
|
| 311 |
+
assert log.triggered is False
|
| 312 |
+
assert log.rule_triggered is None
|
| 313 |
+
assert log.penalty_applied == 0.0
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def test_anti_gaming_log_fields_populated():
|
| 317 |
+
aggregator = RewardAggregator()
|
| 318 |
+
start = _make_components(r1_hook_strength=0.9)
|
| 319 |
+
current = _make_components(r1_hook_strength=0.5)
|
| 320 |
+
|
| 321 |
+
_, log = aggregator.compute(current, start, [], episode_id="myep", step_num=3)
|
| 322 |
+
|
| 323 |
+
assert log.episode_id == "myep"
|
| 324 |
+
assert log.step_num == 3
|
| 325 |
+
assert isinstance(log, AntiGamingLog)
|
viral_script_engine/tests/test_rewards.py
CHANGED
|
@@ -93,7 +93,7 @@ def test_aggregator_catastrophic_drop(aggregator):
|
|
| 93 |
start = RewardComponents(r1_hook_strength=0.8, r2_coherence=0.7)
|
| 94 |
start.compute_total()
|
| 95 |
current = RewardComponents(r1_hook_strength=0.3, r2_coherence=0.7)
|
| 96 |
-
result = aggregator.compute(current, start, [ActionType.HOOK_REWRITE])
|
| 97 |
assert result.total == 0.0
|
| 98 |
|
| 99 |
|
|
@@ -102,7 +102,7 @@ def test_aggregator_diversity_penalty(aggregator):
|
|
| 102 |
start.compute_total()
|
| 103 |
current = RewardComponents(r1_hook_strength=0.7, r2_coherence=0.7)
|
| 104 |
history = [ActionType.HOOK_REWRITE, ActionType.HOOK_REWRITE, ActionType.HOOK_REWRITE]
|
| 105 |
-
result = aggregator.compute(current, start, history)
|
| 106 |
assert result.anti_gaming_penalty == 0.15
|
| 107 |
assert result.total < 0.7
|
| 108 |
|
|
@@ -112,6 +112,6 @@ def test_aggregator_no_penalty(aggregator):
|
|
| 112 |
start.compute_total()
|
| 113 |
current = RewardComponents(r1_hook_strength=0.7, r2_coherence=0.7)
|
| 114 |
history = [ActionType.HOOK_REWRITE, ActionType.CTA_PLACEMENT, ActionType.SECTION_REORDER]
|
| 115 |
-
result = aggregator.compute(current, start, history)
|
| 116 |
assert result.anti_gaming_penalty == 0.0
|
| 117 |
assert result.total > 0
|
|
|
|
| 93 |
start = RewardComponents(r1_hook_strength=0.8, r2_coherence=0.7)
|
| 94 |
start.compute_total()
|
| 95 |
current = RewardComponents(r1_hook_strength=0.3, r2_coherence=0.7)
|
| 96 |
+
result, log = aggregator.compute(current, start, [ActionType.HOOK_REWRITE])
|
| 97 |
assert result.total == 0.0
|
| 98 |
|
| 99 |
|
|
|
|
| 102 |
start.compute_total()
|
| 103 |
current = RewardComponents(r1_hook_strength=0.7, r2_coherence=0.7)
|
| 104 |
history = [ActionType.HOOK_REWRITE, ActionType.HOOK_REWRITE, ActionType.HOOK_REWRITE]
|
| 105 |
+
result, log = aggregator.compute(current, start, history)
|
| 106 |
assert result.anti_gaming_penalty == 0.15
|
| 107 |
assert result.total < 0.7
|
| 108 |
|
|
|
|
| 112 |
start.compute_total()
|
| 113 |
current = RewardComponents(r1_hook_strength=0.7, r2_coherence=0.7)
|
| 114 |
history = [ActionType.HOOK_REWRITE, ActionType.CTA_PLACEMENT, ActionType.SECTION_REORDER]
|
| 115 |
+
result, log = aggregator.compute(current, start, history)
|
| 116 |
assert result.anti_gaming_penalty == 0.0
|
| 117 |
assert result.total > 0
|