Spaces:
Sleeping
Sleeping
feat(phase12): RetentionCurveSimulator, R10, 150-sample dataset, model trained, 14 tests PASS, gate PASS
Browse files- demo/run_demo.py +107 -4
- docs/progress.md +17 -0
- scripts/run_dummy_episode.py +15 -6
- scripts/train_retention_model.py +48 -0
- session/context.md +11 -10
- session/phase-log.md +1 -0
- session/summary.md +22 -23
- viral_script_engine/environment/env.py +19 -0
- viral_script_engine/environment/observations.py +8 -5
- viral_script_engine/retention/__init__.py +0 -0
- viral_script_engine/retention/curve_predictor.py +152 -0
- viral_script_engine/retention/curve_scorer.py +106 -0
- viral_script_engine/retention/feature_extractor.py +190 -0
- viral_script_engine/retention/model.joblib +3 -0
- viral_script_engine/retention/training_data/__init__.py +0 -0
- viral_script_engine/retention/training_data/build_dataset.py +158 -0
- viral_script_engine/retention/training_data/retention_dataset.json +0 -0
- viral_script_engine/rewards/r10_retention_curve.py +65 -0
- viral_script_engine/rewards/reward_aggregator.py +1 -0
- viral_script_engine/tests/test_phase12.py +325 -0
demo/run_demo.py
CHANGED
|
@@ -283,7 +283,77 @@ def act4_arbitrator_decides(
|
|
| 283 |
console.print()
|
| 284 |
|
| 285 |
|
| 286 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
console.print(Rule("[bold magenta]ACT 5 — THE REWRITE + REWARD[/bold magenta]", style="magenta"))
|
| 288 |
|
| 289 |
diff_text = Text()
|
|
@@ -303,6 +373,30 @@ def act5_rewrite_and_reward(original_script: str, rewritten_script: str, reward_
|
|
| 303 |
|
| 304 |
console.print()
|
| 305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
labels = {
|
| 307 |
"r1_hook_strength": "R1 Hook Strength",
|
| 308 |
"r2_coherence": "R2 Coherence",
|
|
@@ -310,6 +404,7 @@ def act5_rewrite_and_reward(original_script: str, rewritten_script: str, reward_
|
|
| 310 |
"r4_debate_resolution": "R4 Resolution",
|
| 311 |
"r5_defender_preservation": "R5 Preservation",
|
| 312 |
"r9_platform_pacing": "R9 Platform Pacing",
|
|
|
|
| 313 |
}
|
| 314 |
|
| 315 |
table = Table(box=box.SIMPLE_HEAD, show_header=False, padding=(0, 1))
|
|
@@ -550,10 +645,14 @@ def run_compare(script_id: str):
|
|
| 550 |
"total": (new_r1 + new_r2 + new_r3 + new_r5) / 4,
|
| 551 |
}
|
| 552 |
|
| 553 |
-
act5_rewrite_and_reward(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
|
| 555 |
console.print(Panel(
|
| 556 |
-
"[bold green]Demo complete.[/bold green] The Trained Arbitrator's richer reasoning produced
|
| 557 |
"a more targeted rewrite. Run [bold]python training/train_grpo.py[/bold] in Colab to "
|
| 558 |
"train the Arbitrator with GRPO and see real improvement curves.",
|
| 559 |
border_style="green",
|
|
@@ -647,7 +746,11 @@ def run_interactive():
|
|
| 647 |
"total": (new_r1 + new_r2 + new_r3 + new_r5) / 4,
|
| 648 |
}
|
| 649 |
|
| 650 |
-
act5_rewrite_and_reward(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 651 |
current_script = new_script
|
| 652 |
|
| 653 |
again = input("Continue to next step? [y/n]: ").strip().lower()
|
|
|
|
| 283 |
console.print()
|
| 284 |
|
| 285 |
|
| 286 |
+
def _retention_ascii_row(level_pct: int, values: list, timepoints: list) -> str:
|
| 287 |
+
"""Render one horizontal row of the ASCII retention chart."""
|
| 288 |
+
threshold = level_pct / 100
|
| 289 |
+
bar = ""
|
| 290 |
+
for v in values:
|
| 291 |
+
bar += "██" if v >= threshold else " "
|
| 292 |
+
label = f"{level_pct:4d}% |"
|
| 293 |
+
return f"{label}{bar}"
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def _render_retention_ascii(values: list, timepoints: list, label: str) -> str:
|
| 297 |
+
"""Render a compact ASCII bar chart of a retention curve."""
|
| 298 |
+
rows = []
|
| 299 |
+
rows.append(f" {label}")
|
| 300 |
+
for level in [100, 75, 50, 25]:
|
| 301 |
+
rows.append(_retention_ascii_row(level, values, timepoints))
|
| 302 |
+
# x-axis
|
| 303 |
+
axis = " +" + "--" * len(timepoints)
|
| 304 |
+
tick_labels = " " + " ".join(f"{t:<2}" for t in timepoints)
|
| 305 |
+
rows.append(axis)
|
| 306 |
+
rows.append(tick_labels + "s")
|
| 307 |
+
return "\n".join(rows)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def _show_retention_curves(
|
| 311 |
+
orig_values: list,
|
| 312 |
+
new_values: list,
|
| 313 |
+
timepoints: list,
|
| 314 |
+
orig_auc: float,
|
| 315 |
+
new_auc: float,
|
| 316 |
+
orig_drop: int,
|
| 317 |
+
new_drop: int,
|
| 318 |
+
) -> None:
|
| 319 |
+
"""Render before/after retention curves as ASCII art in a panel."""
|
| 320 |
+
before_chart = _render_retention_ascii(orig_values, timepoints, "Before rewrite:")
|
| 321 |
+
after_chart = _render_retention_ascii(new_values, timepoints, "After rewrite:")
|
| 322 |
+
|
| 323 |
+
auc_delta = new_auc - orig_auc
|
| 324 |
+
auc_pct = (auc_delta / orig_auc * 100) if orig_auc > 0 else 0.0
|
| 325 |
+
sign = "+" if auc_delta >= 0 else ""
|
| 326 |
+
|
| 327 |
+
drop_line = (
|
| 328 |
+
f"Drop-off point: {orig_drop}s -> {new_drop}s"
|
| 329 |
+
if new_drop != orig_drop
|
| 330 |
+
else f"Drop-off point: {orig_drop}s (unchanged)"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
body = (
|
| 334 |
+
f"{before_chart}\n\n"
|
| 335 |
+
f"{after_chart}\n\n"
|
| 336 |
+
f"Improvement: AUC {orig_auc:.2f} -> {new_auc:.2f} ({sign}{auc_pct:.0f}%)\n"
|
| 337 |
+
f"{drop_line}"
|
| 338 |
+
)
|
| 339 |
+
console.print(Panel(
|
| 340 |
+
body,
|
| 341 |
+
title="[cyan]PREDICTED RETENTION CURVE[/cyan]",
|
| 342 |
+
border_style="cyan",
|
| 343 |
+
padding=(1, 2),
|
| 344 |
+
))
|
| 345 |
+
console.print()
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def act5_rewrite_and_reward(
|
| 349 |
+
original_script: str,
|
| 350 |
+
rewritten_script: str,
|
| 351 |
+
reward_components: dict,
|
| 352 |
+
baseline_total: float,
|
| 353 |
+
platform: str = "Reels",
|
| 354 |
+
region: str = "pan_india_english",
|
| 355 |
+
action_type: str = "hook_rewrite",
|
| 356 |
+
):
|
| 357 |
console.print(Rule("[bold magenta]ACT 5 — THE REWRITE + REWARD[/bold magenta]", style="magenta"))
|
| 358 |
|
| 359 |
diff_text = Text()
|
|
|
|
| 373 |
|
| 374 |
console.print()
|
| 375 |
|
| 376 |
+
# Phase 12: retention curve visualisation
|
| 377 |
+
try:
|
| 378 |
+
from viral_script_engine.retention.feature_extractor import FeatureExtractor
|
| 379 |
+
from viral_script_engine.retention.curve_predictor import RetentionCurvePredictor
|
| 380 |
+
from viral_script_engine.retention.curve_scorer import RetentionCurveScorer
|
| 381 |
+
extractor = FeatureExtractor()
|
| 382 |
+
predictor = RetentionCurvePredictor()
|
| 383 |
+
if predictor._trained:
|
| 384 |
+
orig_feat = extractor.extract(original_script, platform, region)
|
| 385 |
+
new_feat = extractor.extract(rewritten_script, platform, region)
|
| 386 |
+
orig_curve = predictor.predict(orig_feat)
|
| 387 |
+
new_curve = predictor.predict(new_feat)
|
| 388 |
+
_show_retention_curves(
|
| 389 |
+
orig_values=orig_curve.values,
|
| 390 |
+
new_values=new_curve.values,
|
| 391 |
+
timepoints=orig_curve.timepoints,
|
| 392 |
+
orig_auc=orig_curve.area_under_curve,
|
| 393 |
+
new_auc=new_curve.area_under_curve,
|
| 394 |
+
orig_drop=orig_curve.drop_off_point,
|
| 395 |
+
new_drop=new_curve.drop_off_point,
|
| 396 |
+
)
|
| 397 |
+
except Exception:
|
| 398 |
+
pass
|
| 399 |
+
|
| 400 |
labels = {
|
| 401 |
"r1_hook_strength": "R1 Hook Strength",
|
| 402 |
"r2_coherence": "R2 Coherence",
|
|
|
|
| 404 |
"r4_debate_resolution": "R4 Resolution",
|
| 405 |
"r5_defender_preservation": "R5 Preservation",
|
| 406 |
"r9_platform_pacing": "R9 Platform Pacing",
|
| 407 |
+
"r10_retention_curve": "R10 Retention Curve",
|
| 408 |
}
|
| 409 |
|
| 410 |
table = Table(box=box.SIMPLE_HEAD, show_header=False, padding=(0, 1))
|
|
|
|
| 645 |
"total": (new_r1 + new_r2 + new_r3 + new_r5) / 4,
|
| 646 |
}
|
| 647 |
|
| 648 |
+
act5_rewrite_and_reward(
|
| 649 |
+
current_script, new_script, reward_components, baseline_total,
|
| 650 |
+
platform=platform, region=region,
|
| 651 |
+
action_type=str(arb_action.action_type.value),
|
| 652 |
+
)
|
| 653 |
|
| 654 |
console.print(Panel(
|
| 655 |
+
"[bold green]Demo complete.[/bold green] The Trained Arbitrator's richer reasoning produced"
|
| 656 |
"a more targeted rewrite. Run [bold]python training/train_grpo.py[/bold] in Colab to "
|
| 657 |
"train the Arbitrator with GRPO and see real improvement curves.",
|
| 658 |
border_style="green",
|
|
|
|
| 746 |
"total": (new_r1 + new_r2 + new_r3 + new_r5) / 4,
|
| 747 |
}
|
| 748 |
|
| 749 |
+
act5_rewrite_and_reward(
|
| 750 |
+
current_script, new_script, reward_components, baseline_total,
|
| 751 |
+
platform=platform, region=region,
|
| 752 |
+
action_type=str(arb_action.action_type.value),
|
| 753 |
+
)
|
| 754 |
current_script = new_script
|
| 755 |
|
| 756 |
again = input("Continue to next step? [y/n]: ").strip().lower()
|
docs/progress.md
CHANGED
|
@@ -152,6 +152,23 @@ Do not read entire codebase to understand progress — read this file.
|
|
| 152 |
✅ test_phase11.py — 24 tests, all passing
|
| 153 |
✅ Phase 11 gate — PHASE 11 GATE: PASS, 6 sessions completed, trend: plateauing
|
| 154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
## Blocked Items
|
| 156 |
❌ GRPOConfig test — blocked by: pyarrow DLL blocked by Windows App Control (works on Linux/Colab)
|
| 157 |
❌ Full GRPO training — blocked by: no local GPU (requires Colab or cloud compute)
|
|
|
|
| 152 |
✅ test_phase11.py — 24 tests, all passing
|
| 153 |
✅ Phase 11 gate — PHASE 11 GATE: PASS, 6 sessions completed, trend: plateauing
|
| 154 |
|
| 155 |
+
## Phase 12 — Retention Curve Simulator
|
| 156 |
+
✅ ScriptFeatures + FeatureExtractor — 14 structural features extracted; platform one-hot; zero LLM calls
|
| 157 |
+
✅ build_dataset.py + retention_dataset.json — 150 rule-based samples (50 high/medium/low); monotonic curve generation
|
| 158 |
+
✅ RetentionCurvePredictor — MultiOutputRegressor(GBR); 10-point curve; train/predict; monotonic enforcement; avg MAE 0.031
|
| 159 |
+
✅ RetentionCurve model — timepoints, values, AUC (trapezoidal), drop_off_point
|
| 160 |
+
✅ retention/model.joblib — trained model saved
|
| 161 |
+
✅ RetentionCurveScorer — ACTION_CURVE_MAP; overall+targeted+regression formula; CurveScorerResult
|
| 162 |
+
✅ RetentionCurveReward (R10) — wraps extractor+predictor+scorer; episode-level original curve cache
|
| 163 |
+
✅ observations.py — r10_retention_curve in RewardComponents; _WEIGHTS updated to 10-reward spec
|
| 164 |
+
✅ reward_aggregator.py — r10_retention_curve in anti-gaming _COMPONENT_FIELDS
|
| 165 |
+
✅ env.py — R10 wired in __init__() and step(); graceful skip if model not trained
|
| 166 |
+
✅ scripts/train_retention_model.py — one-time training; builds dataset if missing; prints MAE
|
| 167 |
+
✅ demo/run_demo.py — ASCII retention curve in Act 5; R10 row in reward table
|
| 168 |
+
✅ scripts/run_dummy_episode.py — R10 gate assertion; Phase 12 GATE message
|
| 169 |
+
✅ test_phase12.py — 14 tests, all passing
|
| 170 |
+
✅ Phase 12 gate — PHASE 12 GATE: PASS, R10 firing
|
| 171 |
+
|
| 172 |
## Blocked Items
|
| 173 |
❌ GRPOConfig test — blocked by: pyarrow DLL blocked by Windows App Control (works on Linux/Colab)
|
| 174 |
❌ Full GRPO training — blocked by: no local GPU (requires Colab or cloud compute)
|
scripts/run_dummy_episode.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
"""
|
| 2 |
-
Gate check script for Phase
|
| 3 |
|
| 4 |
Usage:
|
| 5 |
python scripts/run_dummy_episode.py --difficulty easy --steps 3 --verbose
|
|
@@ -71,12 +71,14 @@ def run_episode(difficulty: str, steps: int, verbose: bool):
|
|
| 71 |
|
| 72 |
if verbose:
|
| 73 |
r9 = rc.get("r9_platform_pacing")
|
|
|
|
| 74 |
r1 = rc.get("r1_hook_strength")
|
| 75 |
r2 = rc.get("r2_coherence")
|
| 76 |
r9_str = f"{r9:.3f}" if r9 is not None else "None"
|
|
|
|
| 77 |
print(
|
| 78 |
f" Step {step + 1}: total={reward:.3f} "
|
| 79 |
-
f"R1={r1:.3f} R2={r2:.3f} R9={r9_str}"
|
| 80 |
)
|
| 81 |
|
| 82 |
if terminated:
|
|
@@ -86,7 +88,7 @@ def run_episode(difficulty: str, steps: int, verbose: bool):
|
|
| 86 |
|
| 87 |
|
| 88 |
def main():
|
| 89 |
-
parser = argparse.ArgumentParser(description="Phase
|
| 90 |
parser.add_argument("--difficulty", default="easy", choices=["easy", "medium", "hard"])
|
| 91 |
parser.add_argument("--steps", type=int, default=3)
|
| 92 |
parser.add_argument("--verbose", action="store_true")
|
|
@@ -103,16 +105,23 @@ def main():
|
|
| 103 |
elif not (0.0 <= rc["r9_platform_pacing"] <= 1.0):
|
| 104 |
errors.append(f"Step {i+1}: r9_platform_pacing out of range: {rc['r9_platform_pacing']}")
|
| 105 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
if errors:
|
| 107 |
print("\n[GATE FAIL]")
|
| 108 |
for e in errors:
|
| 109 |
print(f" ERROR: {e}")
|
| 110 |
sys.exit(1)
|
| 111 |
else:
|
|
|
|
|
|
|
|
|
|
| 112 |
print(
|
| 113 |
-
f"\nPHASE
|
| 114 |
-
f"
|
| 115 |
-
f"Cross-platform divergence confirmed."
|
| 116 |
)
|
| 117 |
|
| 118 |
|
|
|
|
| 1 |
"""
|
| 2 |
+
Gate check script for Phase 12 — runs a dummy episode and verifies R9 and R10 fire.
|
| 3 |
|
| 4 |
Usage:
|
| 5 |
python scripts/run_dummy_episode.py --difficulty easy --steps 3 --verbose
|
|
|
|
| 71 |
|
| 72 |
if verbose:
|
| 73 |
r9 = rc.get("r9_platform_pacing")
|
| 74 |
+
r10 = rc.get("r10_retention_curve")
|
| 75 |
r1 = rc.get("r1_hook_strength")
|
| 76 |
r2 = rc.get("r2_coherence")
|
| 77 |
r9_str = f"{r9:.3f}" if r9 is not None else "None"
|
| 78 |
+
r10_str = f"{r10:.3f}" if r10 is not None else "None"
|
| 79 |
print(
|
| 80 |
f" Step {step + 1}: total={reward:.3f} "
|
| 81 |
+
f"R1={r1:.3f} R2={r2:.3f} R9={r9_str} R10={r10_str}"
|
| 82 |
)
|
| 83 |
|
| 84 |
if terminated:
|
|
|
|
| 88 |
|
| 89 |
|
| 90 |
def main():
|
| 91 |
+
parser = argparse.ArgumentParser(description="Phase 12 dummy episode gate check")
|
| 92 |
parser.add_argument("--difficulty", default="easy", choices=["easy", "medium", "hard"])
|
| 93 |
parser.add_argument("--steps", type=int, default=3)
|
| 94 |
parser.add_argument("--verbose", action="store_true")
|
|
|
|
| 105 |
elif not (0.0 <= rc["r9_platform_pacing"] <= 1.0):
|
| 106 |
errors.append(f"Step {i+1}: r9_platform_pacing out of range: {rc['r9_platform_pacing']}")
|
| 107 |
|
| 108 |
+
if rc.get("r10_retention_curve") is None:
|
| 109 |
+
errors.append(f"Step {i+1}: r10_retention_curve is None — R10 not firing")
|
| 110 |
+
elif not (0.0 <= rc["r10_retention_curve"] <= 1.0):
|
| 111 |
+
errors.append(f"Step {i+1}: r10_retention_curve out of range: {rc['r10_retention_curve']}")
|
| 112 |
+
|
| 113 |
if errors:
|
| 114 |
print("\n[GATE FAIL]")
|
| 115 |
for e in errors:
|
| 116 |
print(f" ERROR: {e}")
|
| 117 |
sys.exit(1)
|
| 118 |
else:
|
| 119 |
+
# Compute average AUC improvement for gate message
|
| 120 |
+
r10_scores = [rc.get("r10_retention_curve", 0.0) for rc in steps_data if rc.get("r10_retention_curve") is not None]
|
| 121 |
+
avg_r10 = sum(r10_scores) / len(r10_scores) if r10_scores else 0.0
|
| 122 |
print(
|
| 123 |
+
f"\nPHASE 12 GATE: PASS — Retention curve predictor active. "
|
| 124 |
+
f"R10 firing. AUC improvement: +{avg_r10:.2f}."
|
|
|
|
| 125 |
)
|
| 126 |
|
| 127 |
|
scripts/train_retention_model.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
One-time training script for the RetentionCurvePredictor.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python scripts/train_retention_model.py
|
| 6 |
+
|
| 7 |
+
Steps:
|
| 8 |
+
1. Builds retention_dataset.json if it doesn't exist
|
| 9 |
+
2. Trains the RetentionCurvePredictor
|
| 10 |
+
3. Prints train/val MAE per timepoint
|
| 11 |
+
4. Saves model to viral_script_engine/retention/model.joblib
|
| 12 |
+
"""
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 17 |
+
|
| 18 |
+
from viral_script_engine.retention.training_data.build_dataset import build, _OUTPUT_PATH
|
| 19 |
+
from viral_script_engine.retention.curve_predictor import RetentionCurvePredictor
|
| 20 |
+
|
| 21 |
+
_CULTURAL_KB_PATH = str(
|
| 22 |
+
Path(__file__).parent.parent / "viral_script_engine" / "data" / "cultural_kb.json"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main():
|
| 27 |
+
# Step 1: build dataset if missing
|
| 28 |
+
if not _OUTPUT_PATH.exists():
|
| 29 |
+
print("Building retention dataset...")
|
| 30 |
+
out = build()
|
| 31 |
+
print(f" Dataset created: {out}")
|
| 32 |
+
else:
|
| 33 |
+
print(f"Dataset already exists: {_OUTPUT_PATH}")
|
| 34 |
+
|
| 35 |
+
# Step 2: train the predictor
|
| 36 |
+
print("\nTraining RetentionCurvePredictor...")
|
| 37 |
+
predictor = RetentionCurvePredictor()
|
| 38 |
+
result = predictor.train(
|
| 39 |
+
dataset_path=str(_OUTPUT_PATH),
|
| 40 |
+
cultural_kb_path=_CULTURAL_KB_PATH,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
avg_mae = result["avg_mae"]
|
| 44 |
+
print(f"\nRetention model trained. Avg MAE: {avg_mae:.4f}. Model saved.")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
if __name__ == "__main__":
|
| 48 |
+
main()
|
session/context.md
CHANGED
|
@@ -1,21 +1,21 @@
|
|
| 1 |
# Context — Carry Over for Next Session
|
| 2 |
|
| 3 |
## Current Phase
|
| 4 |
-
Phase:
|
| 5 |
-
Prompt file: prompts/phase-
|
| 6 |
Status: complete
|
| 7 |
|
| 8 |
---
|
| 9 |
|
| 10 |
## Currently Working On
|
| 11 |
-
Feature: Phase
|
| 12 |
File(s): N/A
|
| 13 |
-
Status: All
|
| 14 |
|
| 15 |
---
|
| 16 |
|
| 17 |
## Open Questions
|
| 18 |
-
Is there a Phase
|
| 19 |
|
| 20 |
---
|
| 21 |
|
|
@@ -27,15 +27,16 @@ Full GRPO training requires Colab or cloud GPU
|
|
| 27 |
---
|
| 28 |
|
| 29 |
## Last Commit Message
|
| 30 |
-
feat(
|
| 31 |
|
| 32 |
---
|
| 33 |
|
| 34 |
## Do Not Forget
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
| 39 |
|
| 40 |
---
|
| 41 |
|
|
|
|
| 1 |
# Context — Carry Over for Next Session
|
| 2 |
|
| 3 |
## Current Phase
|
| 4 |
+
Phase: 12
|
| 5 |
+
Prompt file: prompts/phase-12.md
|
| 6 |
Status: complete
|
| 7 |
|
| 8 |
---
|
| 9 |
|
| 10 |
## Currently Working On
|
| 11 |
+
Feature: Phase 12 complete. Awaiting user confirmation to proceed to next phase (if any).
|
| 12 |
File(s): N/A
|
| 13 |
+
Status: All 14 tests pass. Gate script prints PHASE 12 GATE: PASS.
|
| 14 |
|
| 15 |
---
|
| 16 |
|
| 17 |
## Open Questions
|
| 18 |
+
Is there a Phase 13? Check if prompts/phase-13.md exists.
|
| 19 |
|
| 20 |
---
|
| 21 |
|
|
|
|
| 27 |
---
|
| 28 |
|
| 29 |
## Last Commit Message
|
| 30 |
+
feat(phase12): RetentionCurveSimulator, R10, 150-sample dataset, model trained, 14 tests PASS, gate PASS
|
| 31 |
|
| 32 |
---
|
| 33 |
|
| 34 |
## Do Not Forget
|
| 35 |
+
R10 requires trained model — run python scripts/train_retention_model.py first
|
| 36 |
+
RetentionCurvePredictor model saved at viral_script_engine/retention/model.joblib
|
| 37 |
+
MODEL_PATH is Path(__file__).parent / "model.joblib" (relative to curve_predictor.py)
|
| 38 |
+
R10 gracefully skips (score=None) in env.step() if model not trained
|
| 39 |
+
Gate check: python scripts/run_dummy_episode.py --difficulty easy --steps 3 --verbose
|
| 40 |
|
| 41 |
---
|
| 42 |
|
session/phase-log.md
CHANGED
|
@@ -30,6 +30,7 @@ ROLLED BACK — changes reverted, reason in line
|
|
| 30 |
[2026-04-26] [Phase 9] COMPLETE — PlatformRegistry, R9 PlatformPacing, R1/R2 platform-aware, 20 tests PASS, gate PASS
|
| 31 |
[2026-04-26] [Phase 10] COMPLETE — ABScriptEnv, ContrastiveReward, A/B rollout fn, 25 tests PASS, gate PASS
|
| 32 |
[2026-04-26] [Phase 11] COMPLETE — CreatorHistoryBuffer, MemoryCompressor, HistoryStore, 24 tests PASS, gate PASS
|
|
|
|
| 33 |
|
| 34 |
---
|
| 35 |
|
|
|
|
| 30 |
[2026-04-26] [Phase 9] COMPLETE — PlatformRegistry, R9 PlatformPacing, R1/R2 platform-aware, 20 tests PASS, gate PASS
|
| 31 |
[2026-04-26] [Phase 10] COMPLETE — ABScriptEnv, ContrastiveReward, A/B rollout fn, 25 tests PASS, gate PASS
|
| 32 |
[2026-04-26] [Phase 11] COMPLETE — CreatorHistoryBuffer, MemoryCompressor, HistoryStore, 24 tests PASS, gate PASS
|
| 33 |
+
[2026-04-26] [Phase 12] COMPLETE — RetentionCurveSimulator, R10, 150-sample dataset, model MAE 0.031, 14 tests PASS, gate PASS
|
| 34 |
|
| 35 |
---
|
| 36 |
|
session/summary.md
CHANGED
|
@@ -13,40 +13,39 @@ One session = one summary. Previous summaries live in phase-log.md.
|
|
| 13 |
2026-04-26
|
| 14 |
|
| 15 |
### Phase
|
| 16 |
-
Phase
|
| 17 |
|
| 18 |
### What Was Done
|
| 19 |
-
- Created
|
| 20 |
-
-
|
| 21 |
-
-
|
| 22 |
-
- Created
|
| 23 |
-
-
|
| 24 |
-
-
|
| 25 |
-
-
|
| 26 |
-
-
|
| 27 |
-
-
|
| 28 |
-
-
|
| 29 |
-
-
|
| 30 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
### What Was NOT Done (carry over)
|
| 33 |
- Real GRPO training — requires GPU (Colab)
|
| 34 |
|
| 35 |
### Errors Encountered
|
| 36 |
-
-
|
| 37 |
-
- test_penalty_capped_at_0_3: compared semantically different scripts (base sim=0) — fixed to use same-vocab scripts
|
| 38 |
-
- test_same_script_scores_differently_on_reels_vs_feed: _SLOW_SCRIPT both pacing+ratio zeroed out on both platforms — switched to sub-score comparison
|
| 39 |
-
- test_env_r9_fires_in_step: defender.defend() not patched → API call — patched defender with full MagicMock
|
| 40 |
-
- run_dummy_episode.py: R5 needs core_strength_quote from defender mock — added all required fields
|
| 41 |
-
- run_platform_comparison.py: Unicode bar chars fail on Windows cp1252 — switched to ASCII #/.
|
| 42 |
|
| 43 |
### Tests Status
|
| 44 |
-
Phase
|
| 45 |
-
Gate check
|
| 46 |
-
Gate check (platform comparison S03): PASS — R1/R2/R9 all diverge across Reels/Shorts/Feed
|
| 47 |
|
| 48 |
### Commit Messages Generated
|
| 49 |
-
feat(
|
| 50 |
|
| 51 |
---
|
| 52 |
|
|
|
|
| 13 |
2026-04-26
|
| 14 |
|
| 15 |
### Phase
|
| 16 |
+
Phase 12 — Retention Curve Simulator
|
| 17 |
|
| 18 |
### What Was Done
|
| 19 |
+
- Created viral_script_engine/retention/__init__.py — package init
|
| 20 |
+
- Created viral_script_engine/retention/feature_extractor.py — ScriptFeatures pydantic model (14 features + platform one-hot); FeatureExtractor.extract() — zero LLM calls, structural analysis
|
| 21 |
+
- Created viral_script_engine/retention/training_data/__init__.py — package init
|
| 22 |
+
- Created viral_script_engine/retention/training_data/build_dataset.py — 150 rule-based samples (50 high/medium/low); monotonic curve generation from R1/R2/R3 scores
|
| 23 |
+
- Created viral_script_engine/retention/training_data/retention_dataset.json — 150 samples generated
|
| 24 |
+
- Created viral_script_engine/retention/curve_predictor.py — RetentionCurvePredictor (MultiOutputRegressor+GBR); RetentionCurve model with AUC + drop-off; train/predict; monotonic enforcement
|
| 25 |
+
- Created viral_script_engine/retention/model.joblib — trained model, avg MAE 0.031
|
| 26 |
+
- Created viral_script_engine/retention/curve_scorer.py — RetentionCurveScorer; ACTION_CURVE_MAP; overall+targeted+regression formula
|
| 27 |
+
- Created viral_script_engine/rewards/r10_retention_curve.py — RetentionCurveReward; episode-level original curve caching
|
| 28 |
+
- Updated viral_script_engine/environment/observations.py — r10_retention_curve field; updated _WEIGHTS to 10-reward spec
|
| 29 |
+
- Updated viral_script_engine/rewards/reward_aggregator.py — r10_retention_curve in anti-gaming _COMPONENT_FIELDS
|
| 30 |
+
- Updated viral_script_engine/environment/env.py — R10 wired in __init__() and step(); graceful skip if model not trained
|
| 31 |
+
- Created scripts/train_retention_model.py — one-time training script; builds dataset if missing; prints MAE
|
| 32 |
+
- Updated demo/run_demo.py — _render_retention_ascii(); _show_retention_curves() ASCII panel in Act 5; R10 row in reward table
|
| 33 |
+
- Updated scripts/run_dummy_episode.py — R10 check in gate assertions; Phase 12 GATE message
|
| 34 |
+
- Created viral_script_engine/tests/test_phase12.py — 14 tests, all passing
|
| 35 |
+
- Phase 12 gate: PASS
|
| 36 |
|
| 37 |
### What Was NOT Done (carry over)
|
| 38 |
- Real GRPO training — requires GPU (Colab)
|
| 39 |
|
| 40 |
### Errors Encountered
|
| 41 |
+
- None; all 14 tests passed on first run
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
### Tests Status
|
| 44 |
+
Phase 12: 14 passed
|
| 45 |
+
Gate check: PHASE 12 GATE: PASS — Retention curve predictor active. R10 firing.
|
|
|
|
| 46 |
|
| 47 |
### Commit Messages Generated
|
| 48 |
+
feat(phase12): RetentionCurveSimulator, R10, 150-sample dataset, model trained, 14 tests PASS, gate PASS
|
| 49 |
|
| 50 |
---
|
| 51 |
|
viral_script_engine/environment/env.py
CHANGED
|
@@ -30,6 +30,7 @@ from viral_script_engine.rewards.r9_platform_pacing import PlatformPacingReward
|
|
| 30 |
from viral_script_engine.platforms.platform_spec import PlatformRegistry
|
| 31 |
from viral_script_engine.memory.memory_compressor import MemoryCompressor
|
| 32 |
from viral_script_engine.memory.history_store import HistoryStore
|
|
|
|
| 33 |
|
| 34 |
_TIERS = {
|
| 35 |
"easy": ["S01", "S02", "S03", "S04"],
|
|
@@ -85,6 +86,7 @@ class ViralScriptEnv:
|
|
| 85 |
self.platform_registry = PlatformRegistry()
|
| 86 |
self.memory_compressor = MemoryCompressor()
|
| 87 |
self.history_store = HistoryStore()
|
|
|
|
| 88 |
self._state: Optional[EpisodeState] = None
|
| 89 |
self._current_profile: Optional[CreatorProfile] = None
|
| 90 |
self._current_platform: str = "Reels"
|
|
@@ -282,6 +284,22 @@ class ViralScriptEnv:
|
|
| 282 |
# Phase 9: compute R9 platform pacing
|
| 283 |
r9_result = self.r9.score(new_script, platform=self._current_platform)
|
| 284 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
components = RewardComponents(
|
| 286 |
r1_hook_strength=r1_result.score,
|
| 287 |
r2_coherence=r2_result.score,
|
|
@@ -292,6 +310,7 @@ class ViralScriptEnv:
|
|
| 292 |
r7_originality=r7_result.score,
|
| 293 |
r8_persona_fit=r8_score,
|
| 294 |
r9_platform_pacing=r9_result.score,
|
|
|
|
| 295 |
process_reward=process_result.weighted_contribution if process_result else None,
|
| 296 |
)
|
| 297 |
|
|
|
|
| 30 |
from viral_script_engine.platforms.platform_spec import PlatformRegistry
|
| 31 |
from viral_script_engine.memory.memory_compressor import MemoryCompressor
|
| 32 |
from viral_script_engine.memory.history_store import HistoryStore
|
| 33 |
+
from viral_script_engine.rewards.r10_retention_curve import RetentionCurveReward
|
| 34 |
|
| 35 |
_TIERS = {
|
| 36 |
"easy": ["S01", "S02", "S03", "S04"],
|
|
|
|
| 86 |
self.platform_registry = PlatformRegistry()
|
| 87 |
self.memory_compressor = MemoryCompressor()
|
| 88 |
self.history_store = HistoryStore()
|
| 89 |
+
self.r10 = RetentionCurveReward(cultural_kb_path=cultural_kb_path)
|
| 90 |
self._state: Optional[EpisodeState] = None
|
| 91 |
self._current_profile: Optional[CreatorProfile] = None
|
| 92 |
self._current_platform: str = "Reels"
|
|
|
|
| 284 |
# Phase 9: compute R9 platform pacing
|
| 285 |
r9_result = self.r9.score(new_script, platform=self._current_platform)
|
| 286 |
|
| 287 |
+
# Phase 12: compute R10 retention curve reward
|
| 288 |
+
r10_score = None
|
| 289 |
+
if self.r10.predictor._trained:
|
| 290 |
+
try:
|
| 291 |
+
r10_result = self.r10.score(
|
| 292 |
+
original_script=self._state.original_script,
|
| 293 |
+
rewritten_script=new_script,
|
| 294 |
+
platform=self._current_platform,
|
| 295 |
+
region=self._state.region,
|
| 296 |
+
action_type=str(arb_action.action_type.value),
|
| 297 |
+
episode_id=self._state.episode_id,
|
| 298 |
+
)
|
| 299 |
+
r10_score = r10_result.score
|
| 300 |
+
except Exception:
|
| 301 |
+
r10_score = None
|
| 302 |
+
|
| 303 |
components = RewardComponents(
|
| 304 |
r1_hook_strength=r1_result.score,
|
| 305 |
r2_coherence=r2_result.score,
|
|
|
|
| 310 |
r7_originality=r7_result.score,
|
| 311 |
r8_persona_fit=r8_score,
|
| 312 |
r9_platform_pacing=r9_result.score,
|
| 313 |
+
r10_retention_curve=r10_score,
|
| 314 |
process_reward=process_result.weighted_contribution if process_result else None,
|
| 315 |
)
|
| 316 |
|
viral_script_engine/environment/observations.py
CHANGED
|
@@ -6,9 +6,10 @@ from viral_script_engine.agents.critic import CritiqueClaim
|
|
| 6 |
from viral_script_engine.environment.actions import ArbitratorAction
|
| 7 |
|
| 8 |
_WEIGHTS: Dict[str, float] = {
|
| 9 |
-
"r1": 0.
|
| 10 |
-
"r4": 0.10, "r5": 0.
|
| 11 |
-
"r7": 0.
|
|
|
|
| 12 |
}
|
| 13 |
|
| 14 |
|
|
@@ -21,8 +22,9 @@ class RewardComponents(BaseModel):
|
|
| 21 |
r6_safety: Optional[float] = None
|
| 22 |
r7_originality: Optional[float] = None
|
| 23 |
r8_persona_fit: Optional[float] = None # Phase 8: creator persona fit
|
| 24 |
-
r9_platform_pacing: Optional[float] = None
|
| 25 |
-
|
|
|
|
| 26 |
anti_gaming_penalty: float = 0.0
|
| 27 |
total: float = 0.0
|
| 28 |
|
|
@@ -37,6 +39,7 @@ class RewardComponents(BaseModel):
|
|
| 37 |
"r7": self.r7_originality,
|
| 38 |
"r8": self.r8_persona_fit,
|
| 39 |
"r9": self.r9_platform_pacing,
|
|
|
|
| 40 |
}
|
| 41 |
active = {k: v for k, v in vals.items() if v is not None}
|
| 42 |
if not active:
|
|
|
|
| 6 |
from viral_script_engine.environment.actions import ArbitratorAction
|
| 7 |
|
| 8 |
_WEIGHTS: Dict[str, float] = {
|
| 9 |
+
"r1": 0.12, "r2": 0.10, "r3": 0.10,
|
| 10 |
+
"r4": 0.10, "r5": 0.08, "r6": 0.07,
|
| 11 |
+
"r7": 0.07, "r8": 0.08, "r9": 0.08,
|
| 12 |
+
"r10": 0.10,
|
| 13 |
}
|
| 14 |
|
| 15 |
|
|
|
|
| 22 |
r6_safety: Optional[float] = None
|
| 23 |
r7_originality: Optional[float] = None
|
| 24 |
r8_persona_fit: Optional[float] = None # Phase 8: creator persona fit
|
| 25 |
+
r9_platform_pacing: Optional[float] = None # Phase 9: platform pacing fit
|
| 26 |
+
r10_retention_curve: Optional[float] = None # Phase 12: retention curve reward
|
| 27 |
+
process_reward: Optional[float] = None # fired before rewrite (Phase 7)
|
| 28 |
anti_gaming_penalty: float = 0.0
|
| 29 |
total: float = 0.0
|
| 30 |
|
|
|
|
| 39 |
"r7": self.r7_originality,
|
| 40 |
"r8": self.r8_persona_fit,
|
| 41 |
"r9": self.r9_platform_pacing,
|
| 42 |
+
"r10": self.r10_retention_curve,
|
| 43 |
}
|
| 44 |
active = {k: v for k, v in vals.items() if v is not None}
|
| 45 |
if not active:
|
viral_script_engine/retention/__init__.py
ADDED
|
File without changes
|
viral_script_engine/retention/curve_predictor.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
from sklearn.ensemble import GradientBoostingRegressor
|
| 8 |
+
from sklearn.multioutput import MultiOutputRegressor
|
| 9 |
+
import joblib
|
| 10 |
+
|
| 11 |
+
from viral_script_engine.retention.feature_extractor import FeatureExtractor, ScriptFeatures
|
| 12 |
+
|
| 13 |
+
_MODEL_PATH = Path(__file__).parent / "model.joblib"
|
| 14 |
+
|
| 15 |
+
CURVE_TIMEPOINTS = [0, 3, 6, 10, 15, 20, 25, 30, 45, 60]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RetentionCurve(BaseModel):
|
| 19 |
+
timepoints: List[int]
|
| 20 |
+
values: List[float]
|
| 21 |
+
area_under_curve: float
|
| 22 |
+
drop_off_point: int # first timepoint where retention drops below 0.5
|
| 23 |
+
|
| 24 |
+
@classmethod
|
| 25 |
+
def from_values(cls, values: List[float]) -> "RetentionCurve":
|
| 26 |
+
tps = CURVE_TIMEPOINTS
|
| 27 |
+
# Trapezoidal AUC, normalised to [0, 1]
|
| 28 |
+
auc = 0.0
|
| 29 |
+
for i in range(len(tps) - 1):
|
| 30 |
+
dt = tps[i + 1] - tps[i]
|
| 31 |
+
auc += dt * (values[i] + values[i + 1]) / 2
|
| 32 |
+
total_duration = tps[-1] - tps[0]
|
| 33 |
+
auc = auc / total_duration if total_duration > 0 else 0.0
|
| 34 |
+
|
| 35 |
+
drop_off = tps[-1]
|
| 36 |
+
for t, v in zip(tps, values):
|
| 37 |
+
if v < 0.5:
|
| 38 |
+
drop_off = t
|
| 39 |
+
break
|
| 40 |
+
|
| 41 |
+
return cls(
|
| 42 |
+
timepoints=list(tps),
|
| 43 |
+
values=[round(v, 4) for v in values],
|
| 44 |
+
area_under_curve=round(auc, 4),
|
| 45 |
+
drop_off_point=drop_off,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class RetentionCurvePredictor:
|
| 50 |
+
"""
|
| 51 |
+
Predicts a 10-point retention curve from script features.
|
| 52 |
+
10 points = retention at seconds [0, 3, 6, 10, 15, 20, 25, 30, 45, 60].
|
| 53 |
+
|
| 54 |
+
Uses MultiOutputRegressor(GradientBoostingRegressor).
|
| 55 |
+
Lightweight enough to run on CPU without GPU (<1ms per call after training).
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
MODEL_PATH = _MODEL_PATH
|
| 59 |
+
CURVE_TIMEPOINTS = CURVE_TIMEPOINTS
|
| 60 |
+
|
| 61 |
+
def __init__(self):
|
| 62 |
+
if _MODEL_PATH.exists():
|
| 63 |
+
self.model = joblib.load(_MODEL_PATH)
|
| 64 |
+
self._trained = True
|
| 65 |
+
else:
|
| 66 |
+
self.model = MultiOutputRegressor(
|
| 67 |
+
GradientBoostingRegressor(n_estimators=100, max_depth=4, random_state=42)
|
| 68 |
+
)
|
| 69 |
+
self._trained = False
|
| 70 |
+
|
| 71 |
+
def train(
|
| 72 |
+
self,
|
| 73 |
+
dataset_path: Optional[str] = None,
|
| 74 |
+
cultural_kb_path: Optional[str] = None,
|
| 75 |
+
) -> dict:
|
| 76 |
+
"""
|
| 77 |
+
Train on retention_dataset.json. Saves model to MODEL_PATH.
|
| 78 |
+
Returns dict with avg_mae and mae_per_timepoint.
|
| 79 |
+
"""
|
| 80 |
+
if dataset_path is None:
|
| 81 |
+
dataset_path = str(
|
| 82 |
+
Path(__file__).parent / "training_data" / "retention_dataset.json"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
with open(dataset_path, "r", encoding="utf-8") as f:
|
| 86 |
+
data = json.load(f)
|
| 87 |
+
|
| 88 |
+
extractor = FeatureExtractor(cultural_kb_path=cultural_kb_path)
|
| 89 |
+
|
| 90 |
+
X: List[List[float]] = []
|
| 91 |
+
y: List[List[float]] = []
|
| 92 |
+
skipped = 0
|
| 93 |
+
for sample in data["samples"]:
|
| 94 |
+
try:
|
| 95 |
+
features = extractor.extract(
|
| 96 |
+
sample["script_text"], sample["platform"], sample["region"]
|
| 97 |
+
)
|
| 98 |
+
vec = features.to_vector()
|
| 99 |
+
if any(v != v for v in vec): # NaN check
|
| 100 |
+
skipped += 1
|
| 101 |
+
continue
|
| 102 |
+
X.append(vec)
|
| 103 |
+
y.append(sample["retention_curve"])
|
| 104 |
+
except Exception:
|
| 105 |
+
skipped += 1
|
| 106 |
+
|
| 107 |
+
if not X:
|
| 108 |
+
raise RuntimeError("No valid training samples extracted.")
|
| 109 |
+
|
| 110 |
+
X_arr = np.array(X, dtype=float)
|
| 111 |
+
y_arr = np.array(y, dtype=float)
|
| 112 |
+
|
| 113 |
+
n = len(X_arr)
|
| 114 |
+
rng = np.random.RandomState(42)
|
| 115 |
+
idx = rng.permutation(n)
|
| 116 |
+
split = max(1, int(n * 0.8))
|
| 117 |
+
X_train, X_val = X_arr[idx[:split]], X_arr[idx[split:]]
|
| 118 |
+
y_train, y_val = y_arr[idx[:split]], y_arr[idx[split:]]
|
| 119 |
+
|
| 120 |
+
self.model.fit(X_train, y_train)
|
| 121 |
+
self._trained = True
|
| 122 |
+
|
| 123 |
+
val_preds = np.clip(self.model.predict(X_val), 0.0, 1.0)
|
| 124 |
+
mae_per_tp = np.mean(np.abs(val_preds - y_val), axis=0).tolist()
|
| 125 |
+
avg_mae = float(np.mean(mae_per_tp))
|
| 126 |
+
|
| 127 |
+
print(f" Trained on {len(X_train)} samples, validated on {len(X_val)} (skipped {skipped})")
|
| 128 |
+
print(" Train/Val MAE per timepoint:")
|
| 129 |
+
for t, mae in zip(CURVE_TIMEPOINTS, mae_per_tp):
|
| 130 |
+
print(f" {t:2d}s: {mae:.4f}")
|
| 131 |
+
print(f" Avg MAE: {avg_mae:.4f}")
|
| 132 |
+
|
| 133 |
+
joblib.dump(self.model, _MODEL_PATH)
|
| 134 |
+
print(f" Model saved to {_MODEL_PATH}")
|
| 135 |
+
|
| 136 |
+
return {"avg_mae": avg_mae, "mae_per_timepoint": mae_per_tp}
|
| 137 |
+
|
| 138 |
+
def predict(self, features: ScriptFeatures) -> RetentionCurve:
|
| 139 |
+
if not self._trained:
|
| 140 |
+
raise RuntimeError("Model not trained. Run train() first.")
|
| 141 |
+
vec = np.array(features.to_vector(), dtype=float).reshape(1, -1)
|
| 142 |
+
raw = self.model.predict(vec)[0]
|
| 143 |
+
clipped = np.clip(raw, 0.0, 1.0)
|
| 144 |
+
values = self._enforce_monotonic_decrease(clipped).tolist()
|
| 145 |
+
return RetentionCurve.from_values(values)
|
| 146 |
+
|
| 147 |
+
@staticmethod
|
| 148 |
+
def _enforce_monotonic_decrease(values: np.ndarray) -> np.ndarray:
|
| 149 |
+
result = values.copy()
|
| 150 |
+
for i in range(1, len(result)):
|
| 151 |
+
result[i] = min(result[i], result[i - 1])
|
| 152 |
+
return result
|
viral_script_engine/retention/curve_scorer.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
|
| 5 |
+
from viral_script_engine.retention.curve_predictor import CURVE_TIMEPOINTS, RetentionCurve
|
| 6 |
+
|
| 7 |
+
_TP_INDEX = {t: i for i, t in enumerate(CURVE_TIMEPOINTS)}
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class CurveScorerResult(BaseModel):
|
| 11 |
+
final_score: float
|
| 12 |
+
overall_improvement: float
|
| 13 |
+
targeted_improvement: float
|
| 14 |
+
regression_penalty: float
|
| 15 |
+
improved_timepoints: List[int]
|
| 16 |
+
worsened_timepoints: List[int]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class RetentionCurveScorer:
|
| 20 |
+
"""
|
| 21 |
+
Scores improvement between two retention curves.
|
| 22 |
+
|
| 23 |
+
Rewards targeted improvements at action-relevant timepoints:
|
| 24 |
+
- hook_rewrite → early timepoints (0–6s)
|
| 25 |
+
- section_reorder → mid timepoints (10–20s)
|
| 26 |
+
- cultural_ref_sub → mid-to-late (15–30s)
|
| 27 |
+
- cta_placement → late timepoints (45–60s)
|
| 28 |
+
|
| 29 |
+
Formula:
|
| 30 |
+
final = 0.50 * overall_improvement
|
| 31 |
+
+ 0.35 * targeted_improvement
|
| 32 |
+
- 0.15 * regression_penalty
|
| 33 |
+
clipped to [0, 1]
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
ACTION_CURVE_MAP = {
|
| 37 |
+
"hook_rewrite": [0, 3, 6],
|
| 38 |
+
"section_reorder": [10, 15, 20],
|
| 39 |
+
"cultural_ref_sub": [15, 20, 25, 30],
|
| 40 |
+
"cta_placement": [45, 60],
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
def score(
|
| 44 |
+
self,
|
| 45 |
+
original_curve: RetentionCurve,
|
| 46 |
+
new_curve: RetentionCurve,
|
| 47 |
+
action_type: str,
|
| 48 |
+
) -> CurveScorerResult:
|
| 49 |
+
orig_auc = original_curve.area_under_curve
|
| 50 |
+
new_auc = new_curve.area_under_curve
|
| 51 |
+
|
| 52 |
+
# 1. Overall AUC improvement (relative)
|
| 53 |
+
if orig_auc > 0:
|
| 54 |
+
overall_improvement = (new_auc - orig_auc) / orig_auc
|
| 55 |
+
else:
|
| 56 |
+
overall_improvement = float(new_auc)
|
| 57 |
+
overall_improvement = max(-1.0, min(1.0, overall_improvement))
|
| 58 |
+
|
| 59 |
+
# 2. Targeted improvement at action-relevant timepoints
|
| 60 |
+
target_tps = self.ACTION_CURVE_MAP.get(str(action_type), CURVE_TIMEPOINTS)
|
| 61 |
+
targeted_deltas: List[float] = []
|
| 62 |
+
for tp in target_tps:
|
| 63 |
+
i = _TP_INDEX.get(tp)
|
| 64 |
+
if i is not None and i < len(original_curve.values) and i < len(new_curve.values):
|
| 65 |
+
targeted_deltas.append(new_curve.values[i] - original_curve.values[i])
|
| 66 |
+
|
| 67 |
+
if targeted_deltas:
|
| 68 |
+
targeted_improvement = float(sum(targeted_deltas) / len(targeted_deltas))
|
| 69 |
+
else:
|
| 70 |
+
targeted_improvement = 0.0
|
| 71 |
+
targeted_improvement = max(-1.0, min(1.0, targeted_improvement))
|
| 72 |
+
|
| 73 |
+
# 3. Regression penalty: any timepoint that degraded
|
| 74 |
+
improved: List[int] = []
|
| 75 |
+
worsened: List[int] = []
|
| 76 |
+
worsened_magnitudes: List[float] = []
|
| 77 |
+
|
| 78 |
+
for tp, i in _TP_INDEX.items():
|
| 79 |
+
if i >= len(original_curve.values) or i >= len(new_curve.values):
|
| 80 |
+
continue
|
| 81 |
+
delta = new_curve.values[i] - original_curve.values[i]
|
| 82 |
+
if delta > 0.001:
|
| 83 |
+
improved.append(tp)
|
| 84 |
+
elif delta < -0.001:
|
| 85 |
+
worsened.append(tp)
|
| 86 |
+
worsened_magnitudes.append(abs(delta))
|
| 87 |
+
|
| 88 |
+
regression_penalty = 0.0
|
| 89 |
+
if worsened_magnitudes:
|
| 90 |
+
regression_penalty = min(1.0, sum(worsened_magnitudes) / len(CURVE_TIMEPOINTS))
|
| 91 |
+
|
| 92 |
+
final_score = (
|
| 93 |
+
0.50 * max(0.0, overall_improvement)
|
| 94 |
+
+ 0.35 * max(0.0, targeted_improvement)
|
| 95 |
+
- 0.15 * regression_penalty
|
| 96 |
+
)
|
| 97 |
+
final_score = max(0.0, min(1.0, final_score))
|
| 98 |
+
|
| 99 |
+
return CurveScorerResult(
|
| 100 |
+
final_score=round(final_score, 4),
|
| 101 |
+
overall_improvement=round(overall_improvement, 4),
|
| 102 |
+
targeted_improvement=round(targeted_improvement, 4),
|
| 103 |
+
regression_penalty=round(regression_penalty, 4),
|
| 104 |
+
improved_timepoints=improved,
|
| 105 |
+
worsened_timepoints=worsened,
|
| 106 |
+
)
|
viral_script_engine/retention/feature_extractor.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
|
| 8 |
+
from viral_script_engine.platforms.platform_spec import PlatformRegistry
|
| 9 |
+
|
| 10 |
+
_FILLER_PHRASES = [
|
| 11 |
+
"hey guys", "welcome back", "today i want to", "so today",
|
| 12 |
+
"in this video", "what's up everyone", "hey everyone",
|
| 13 |
+
"guys today", "hello everyone", "so basically", "you know",
|
| 14 |
+
"kind of", "sort of", "basically", "um ", "uh ",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
_COMMON_WORDS = {
|
| 18 |
+
'i', 'the', 'a', 'an', 'my', 'your', 'its', 'it', 'is', 'are',
|
| 19 |
+
'was', 'were', 'be', 'been', "i've", "i'm", "it's", "here's",
|
| 20 |
+
'today', 'and', 'but', 'so', 'that', 'this', 'these', 'those',
|
| 21 |
+
'to', 'of', 'in', 'for', 'on', 'with', 'at', 'by', 'from',
|
| 22 |
+
'or', 'not', 'you', 'we', 'they', 'he', 'she', 'if', 'do',
|
| 23 |
+
'get', 'just', 'up', 'out', 'about', 'what', 'all', 'some',
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
_PROMISE_PATTERNS = [
|
| 27 |
+
r'\d',
|
| 28 |
+
r'\bhow to\b',
|
| 29 |
+
r'\bwhy\b',
|
| 30 |
+
r'\bwhat happens when\b',
|
| 31 |
+
r'\bi made\b',
|
| 32 |
+
r'\bwill\b',
|
| 33 |
+
r'\bguaranteed\b',
|
| 34 |
+
r'\bstep\b',
|
| 35 |
+
r'\btips?\b',
|
| 36 |
+
r'\bsecrets?\b',
|
| 37 |
+
r'\bprove[sd]?\b',
|
| 38 |
+
r'\bhere\'?s\b',
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
_KNOWN_PLATFORMS = ["Reels", "Shorts", "Feed", "TikTok"]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class ScriptFeatures(BaseModel):
|
| 45 |
+
# Hook features (predicts early drop-off 0–5s)
|
| 46 |
+
hook_word_count: int
|
| 47 |
+
hook_has_number: bool
|
| 48 |
+
hook_has_question: bool
|
| 49 |
+
hook_has_promise: bool
|
| 50 |
+
hook_filler_score: float # 0=no filler, 1=all filler
|
| 51 |
+
|
| 52 |
+
# Pacing features (predicts mid-video retention 5–30s)
|
| 53 |
+
avg_words_per_sentence: float
|
| 54 |
+
sentence_count: int
|
| 55 |
+
short_sentence_ratio: float # sentences < 8 words / total sentences
|
| 56 |
+
section_balance_score: float # how evenly hook:body:cta matches platform spec
|
| 57 |
+
|
| 58 |
+
# Content features (predicts late retention 30s+)
|
| 59 |
+
specificity_score: float # ratio of specific nouns/numbers to total words
|
| 60 |
+
cultural_ref_count: int
|
| 61 |
+
cta_position_ratio: float # word offset of CTA start / total words
|
| 62 |
+
|
| 63 |
+
# Platform fit features
|
| 64 |
+
platform: str
|
| 65 |
+
word_count: int
|
| 66 |
+
length_vs_optimal: float # word_count / optimal_script_length for platform
|
| 67 |
+
|
| 68 |
+
def to_vector(self) -> List[float]:
|
| 69 |
+
platform_one_hot = [1.0 if self.platform == p else 0.0 for p in _KNOWN_PLATFORMS]
|
| 70 |
+
return [
|
| 71 |
+
float(self.hook_word_count),
|
| 72 |
+
1.0 if self.hook_has_number else 0.0,
|
| 73 |
+
1.0 if self.hook_has_question else 0.0,
|
| 74 |
+
1.0 if self.hook_has_promise else 0.0,
|
| 75 |
+
float(self.hook_filler_score),
|
| 76 |
+
float(self.avg_words_per_sentence),
|
| 77 |
+
float(self.sentence_count),
|
| 78 |
+
float(self.short_sentence_ratio),
|
| 79 |
+
float(self.section_balance_score),
|
| 80 |
+
float(self.specificity_score),
|
| 81 |
+
float(self.cultural_ref_count),
|
| 82 |
+
float(self.cta_position_ratio),
|
| 83 |
+
float(self.word_count),
|
| 84 |
+
float(self.length_vs_optimal),
|
| 85 |
+
] + platform_one_hot
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class FeatureExtractor:
|
| 89 |
+
def __init__(self, cultural_kb_path: Optional[str] = None):
|
| 90 |
+
self.platform_registry = PlatformRegistry()
|
| 91 |
+
self._cultural_kb_path = cultural_kb_path
|
| 92 |
+
self._cultural_kb: Optional[dict] = None
|
| 93 |
+
|
| 94 |
+
def _load_kb(self) -> None:
|
| 95 |
+
if self._cultural_kb is not None:
|
| 96 |
+
return
|
| 97 |
+
kb_path = self._cultural_kb_path or str(
|
| 98 |
+
Path(__file__).parent.parent / "data" / "cultural_kb.json"
|
| 99 |
+
)
|
| 100 |
+
with open(kb_path, "r", encoding="utf-8") as f:
|
| 101 |
+
self._cultural_kb = json.load(f)
|
| 102 |
+
|
| 103 |
+
def extract(self, script: str, platform: str, region: str) -> ScriptFeatures:
|
| 104 |
+
self._load_kb()
|
| 105 |
+
spec = self.platform_registry.get(platform)
|
| 106 |
+
sentences = [s for s in re.split(r'(?<=[.!?])\s+', script.strip()) if s.strip()]
|
| 107 |
+
if not sentences:
|
| 108 |
+
sentences = [script]
|
| 109 |
+
total_words = len(script.split())
|
| 110 |
+
|
| 111 |
+
# --- Hook: first ~20% of sentences (min 1, max 3) ---
|
| 112 |
+
n = len(sentences)
|
| 113 |
+
hook_end = max(1, min(3, int(n * 0.2))) if n >= 5 else max(1, min(2, n))
|
| 114 |
+
hook_text = " ".join(sentences[:hook_end])
|
| 115 |
+
hook_lower = hook_text.lower()
|
| 116 |
+
hook_words = hook_text.split()
|
| 117 |
+
|
| 118 |
+
hook_word_count = len(hook_words)
|
| 119 |
+
hook_has_number = bool(re.search(r'\d', hook_text))
|
| 120 |
+
hook_has_question = '?' in hook_text
|
| 121 |
+
hook_has_promise = any(re.search(p, hook_lower) for p in _PROMISE_PATTERNS)
|
| 122 |
+
|
| 123 |
+
filler_hits = sum(1 for phrase in _FILLER_PHRASES if phrase in hook_lower)
|
| 124 |
+
hook_filler_score = min(1.0, filler_hits / max(hook_word_count, 1) * 4)
|
| 125 |
+
|
| 126 |
+
# --- Pacing ---
|
| 127 |
+
sentence_count = n
|
| 128 |
+
words_per_sent = [len(s.split()) for s in sentences]
|
| 129 |
+
avg_words_per_sentence = sum(words_per_sent) / max(n, 1)
|
| 130 |
+
short_sentence_ratio = sum(1 for w in words_per_sent if w < 8) / max(n, 1)
|
| 131 |
+
|
| 132 |
+
# Section balance: compare actual word distribution to platform spec
|
| 133 |
+
cta_start_idx = max(hook_end + 1, n - max(1, int(n * 0.1)))
|
| 134 |
+
hook_w = sum(len(s.split()) for s in sentences[:hook_end])
|
| 135 |
+
body_w = sum(len(s.split()) for s in sentences[hook_end:cta_start_idx])
|
| 136 |
+
cta_w = sum(len(s.split()) for s in sentences[cta_start_idx:])
|
| 137 |
+
total_w = max(hook_w + body_w + cta_w, 1)
|
| 138 |
+
|
| 139 |
+
opt = spec.optimal_sentences_per_section
|
| 140 |
+
opt_total = max(sum(opt.values()), 1)
|
| 141 |
+
opt_hook_r = opt.get("hook", 2) / opt_total
|
| 142 |
+
opt_body_r = opt.get("body", 6) / opt_total
|
| 143 |
+
act_hook_r = hook_w / total_w
|
| 144 |
+
act_body_r = body_w / total_w
|
| 145 |
+
balance_dev = (abs(act_hook_r - opt_hook_r) + abs(act_body_r - opt_body_r)) / 2
|
| 146 |
+
section_balance_score = max(0.0, 1.0 - balance_dev * 4)
|
| 147 |
+
|
| 148 |
+
# --- Content features ---
|
| 149 |
+
words = script.split()
|
| 150 |
+
specific_count = sum(
|
| 151 |
+
1 for w in words
|
| 152 |
+
if (
|
| 153 |
+
re.search(r'\d', w)
|
| 154 |
+
or (len(w) > 1 and w[0].isupper() and w.lower().strip('.,!?;:\'"') not in _COMMON_WORDS)
|
| 155 |
+
)
|
| 156 |
+
)
|
| 157 |
+
specificity_score = min(1.0, specific_count / max(total_words, 1))
|
| 158 |
+
|
| 159 |
+
cultural_ref_count = 0
|
| 160 |
+
if self._cultural_kb and region in self._cultural_kb:
|
| 161 |
+
kb = self._cultural_kb[region]
|
| 162 |
+
script_lower = script.lower()
|
| 163 |
+
cultural_ref_count = (
|
| 164 |
+
sum(1 for r in kb.get("valid_refs", []) if r.lower() in script_lower)
|
| 165 |
+
+ sum(1 for i in kb.get("correct_idioms", []) if i.lower() in script_lower)
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
cta_word_offset = hook_w + body_w
|
| 169 |
+
cta_position_ratio = cta_word_offset / max(total_words, 1)
|
| 170 |
+
|
| 171 |
+
# --- Platform fit ---
|
| 172 |
+
length_vs_optimal = total_words / max(spec.optimal_script_length_words, 1)
|
| 173 |
+
|
| 174 |
+
return ScriptFeatures(
|
| 175 |
+
hook_word_count=hook_word_count,
|
| 176 |
+
hook_has_number=hook_has_number,
|
| 177 |
+
hook_has_question=hook_has_question,
|
| 178 |
+
hook_has_promise=hook_has_promise,
|
| 179 |
+
hook_filler_score=round(hook_filler_score, 4),
|
| 180 |
+
avg_words_per_sentence=round(avg_words_per_sentence, 4),
|
| 181 |
+
sentence_count=sentence_count,
|
| 182 |
+
short_sentence_ratio=round(short_sentence_ratio, 4),
|
| 183 |
+
section_balance_score=round(section_balance_score, 4),
|
| 184 |
+
specificity_score=round(specificity_score, 4),
|
| 185 |
+
cultural_ref_count=cultural_ref_count,
|
| 186 |
+
cta_position_ratio=round(cta_position_ratio, 4),
|
| 187 |
+
platform=platform,
|
| 188 |
+
word_count=total_words,
|
| 189 |
+
length_vs_optimal=round(length_vs_optimal, 4),
|
| 190 |
+
)
|
viral_script_engine/retention/model.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4454048ccec161cf3db43cc3b02a2767b261e8f3df23cdab2a6904ec47eb2e00
|
| 3 |
+
size 1876741
|
viral_script_engine/retention/training_data/__init__.py
ADDED
|
File without changes
|
viral_script_engine/retention/training_data/build_dataset.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Builds retention_dataset.json from rule-based simulation.
|
| 3 |
+
|
| 4 |
+
Encoding known relationships between script quality scores and viewer retention:
|
| 5 |
+
- Hook quality (R1) predicts early drop-off at seconds 0–6
|
| 6 |
+
- Coherence (R2) predicts mid-video retention at seconds 6–20
|
| 7 |
+
- Cultural alignment (R3) predicts late retention at seconds 20–60
|
| 8 |
+
|
| 9 |
+
Dataset format:
|
| 10 |
+
{
|
| 11 |
+
"samples": [
|
| 12 |
+
{
|
| 13 |
+
"script_id": "train_001",
|
| 14 |
+
"script_text": "...",
|
| 15 |
+
"platform": "Reels",
|
| 16 |
+
"region": "Mumbai Gen Z",
|
| 17 |
+
"retention_curve": [1.0, 0.95, ...], # 10 values at seconds [0,3,6,10,15,20,25,30,45,60]
|
| 18 |
+
"curve_source": "rule_based",
|
| 19 |
+
"quality_tier": "high" | "medium" | "low"
|
| 20 |
+
}
|
| 21 |
+
]
|
| 22 |
+
}
|
| 23 |
+
"""
|
| 24 |
+
import json
|
| 25 |
+
import random
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import List, Tuple
|
| 28 |
+
|
| 29 |
+
_TIMEPOINTS = [0, 3, 6, 10, 15, 20, 25, 30, 45, 60]
|
| 30 |
+
_OUTPUT_PATH = Path(__file__).parent / "retention_dataset.json"
|
| 31 |
+
|
| 32 |
+
_PLATFORMS = ["Reels", "Shorts", "Feed", "TikTok"]
|
| 33 |
+
_REGIONS = ["Mumbai Gen Z", "pan_india_english", "delhi_millennial", "bangalore_tech"]
|
| 34 |
+
|
| 35 |
+
_HIGH_SCRIPTS = [
|
| 36 |
+
"Did you know {pct}% of people get this wrong? Here's what actually works. Stop doing what everyone tells you. Use this one simple method instead. The results will genuinely surprise you. Comment 'yes' if you want the full breakdown.",
|
| 37 |
+
"I made {amt}k in 30 days using this strategy. Nobody in your feed is talking about this. Here's exactly what I did step by step. You can start tonight with zero investment. Follow for the full guide.",
|
| 38 |
+
"Your phone is lying to you about money. Here's the truth about compound interest that banks don't want you to know. Start with just $50. Watch what happens after 12 months. This changed everything for me.",
|
| 39 |
+
"Stop scrolling. This is the {amt}-second trick that saved me {pct}% on every bill. I tested it for 3 months. Here's the proof. Save this before it gets taken down.",
|
| 40 |
+
"Why do {pct}% of people fail at saving money? I spent 6 months finding out. The answer surprised me. It has nothing to do with income. Watch till the end for the fix.",
|
| 41 |
+
"The {amt} investing mistake I made at 22 cost me {pct}k. Here's what I wish someone told me. Three rules that actually work. No BS, no courses to sell. Just what changed my life.",
|
| 42 |
+
"How to pay off debt {pct}% faster using the avalanche method. Most people use the wrong strategy. This is the math-backed approach. Takes 5 minutes to set up. Start today.",
|
| 43 |
+
"This bank trick gives you {pct}% more interest — your bank doesn't advertise it. Took me {amt} months to find it. Here's exactly how to set it up in under 2 minutes.",
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
_MED_SCRIPTS = [
|
| 47 |
+
"So today I want to talk about something that I think is really important for a lot of people. Financial planning is something that many people overlook. You should really try to save money regularly if you can. It makes a big difference over time when you think about it.",
|
| 48 |
+
"Hey everyone, welcome back to my channel. Today I'm sharing some tips about managing your finances better. These tips have helped me personally and I hope they help you too. Let me know in the comments what you think about them.",
|
| 49 |
+
"Saving money is actually not that hard once you get into the habit. The main thing is consistency. Try to set aside a fixed amount each month. Over time it really does add up significantly. There are several ways you can approach this.",
|
| 50 |
+
"I've been thinking a lot about financial health lately. It's something that affects everyone. The basics are pretty simple when you break them down. Budget, save, invest — in that order. Most people skip the middle step which is a mistake.",
|
| 51 |
+
"Money management is a skill that anyone can learn. It takes time and practice but it's worth it. Start by tracking your spending for one month. Then identify areas where you can cut back. After that you can start building your savings.",
|
| 52 |
+
]
|
| 53 |
+
|
| 54 |
+
_LOW_SCRIPTS = [
|
| 55 |
+
"Hello guys welcome back um so today basically I wanted to kind of talk about you know like finances and stuff. So basically what I mean is um you should save more money I guess. That's kind of the main point I think. Um yeah so basically just try to do that.",
|
| 56 |
+
"Hey everyone so basically today's video is about money and financial things. I mean you know like it's really important and stuff like that. So yeah basically just save money I guess. Um anyway thanks for watching and stuff.",
|
| 57 |
+
"So um welcome back to my channel. Today I kind of want to sort of discuss um financial things you know. Like basically everyone knows they should save money right. Um so yeah that's basically it I think. Like just try to be better with money or whatever.",
|
| 58 |
+
"Hey guys so um today we're going to talk about kind of like money and finances and all that stuff. So basically um the thing is you know it's pretty important I think. Like I don't know just try to save more I guess. Um yeah so basically that's the main thing.",
|
| 59 |
+
"Welcome back everyone so today basically I wanted to kind of share some thoughts on um financial stuff. Like you know it's important and everything. So basically just try to you know manage your money better or something like that. Um yeah I hope that helps.",
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _pick_script(quality: str) -> str:
|
| 64 |
+
amt = random.randint(10, 99)
|
| 65 |
+
pct = random.randint(60, 98)
|
| 66 |
+
if quality == "high":
|
| 67 |
+
template = random.choice(_HIGH_SCRIPTS)
|
| 68 |
+
elif quality == "medium":
|
| 69 |
+
template = random.choice(_MED_SCRIPTS)
|
| 70 |
+
else:
|
| 71 |
+
template = random.choice(_LOW_SCRIPTS)
|
| 72 |
+
return template.format(amt=amt, pct=pct)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _quality_to_scores(quality: str) -> Tuple[float, float, float]:
|
| 76 |
+
if quality == "high":
|
| 77 |
+
r1 = random.uniform(0.78, 1.0)
|
| 78 |
+
r2 = random.uniform(0.72, 1.0)
|
| 79 |
+
r3 = random.uniform(0.68, 1.0)
|
| 80 |
+
elif quality == "medium":
|
| 81 |
+
r1 = random.uniform(0.38, 0.65)
|
| 82 |
+
r2 = random.uniform(0.38, 0.65)
|
| 83 |
+
r3 = random.uniform(0.38, 0.65)
|
| 84 |
+
else:
|
| 85 |
+
r1 = random.uniform(0.05, 0.25)
|
| 86 |
+
r2 = random.uniform(0.08, 0.28)
|
| 87 |
+
r3 = random.uniform(0.08, 0.28)
|
| 88 |
+
return r1, r2, r3
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def _generate_curve(r1: float, r2: float, r3: float) -> List[float]:
|
| 92 |
+
"""
|
| 93 |
+
Rule-based retention curve at timepoints [0, 3, 6, 10, 15, 20, 25, 30, 45, 60].
|
| 94 |
+
|
| 95 |
+
Rules from phase spec:
|
| 96 |
+
s0 = 1.0
|
| 97 |
+
s3 = 1.0 - (0.4 * (1 - r1)) # hook quality predicts early drop
|
| 98 |
+
s10 = prev - (0.1 * (1 - r2)) # coherence predicts mid-video
|
| 99 |
+
s20 = prev - (0.15 * (1 - r3)) # cultural alignment predicts late
|
| 100 |
+
s60 = prev - 0.05 # natural decay always present
|
| 101 |
+
"""
|
| 102 |
+
noise = lambda lo, hi: random.uniform(lo, hi)
|
| 103 |
+
|
| 104 |
+
s0 = 1.0
|
| 105 |
+
s3 = max(0.0, 1.0 - (0.4 * (1 - r1)) + noise(-0.02, 0.02))
|
| 106 |
+
s6 = max(0.0, s3 - noise(0.02, 0.06))
|
| 107 |
+
s10 = max(0.0, s6 - (0.1 * (1 - r2)) - noise(0.0, 0.03))
|
| 108 |
+
s15 = max(0.0, s10 - noise(0.03, 0.07))
|
| 109 |
+
s20 = max(0.0, s15 - (0.15 * (1 - r3)) - noise(0.0, 0.03))
|
| 110 |
+
s25 = max(0.0, s20 - noise(0.02, 0.05))
|
| 111 |
+
s30 = max(0.0, s25 - noise(0.02, 0.05))
|
| 112 |
+
s45 = max(0.0, s30 - 0.05 - noise(0.0, 0.04))
|
| 113 |
+
s60 = max(0.0, s45 - 0.05)
|
| 114 |
+
|
| 115 |
+
# enforce monotonic decrease
|
| 116 |
+
curve = [s0, s3, s6, s10, s15, s20, s25, s30, s45, s60]
|
| 117 |
+
for i in range(1, len(curve)):
|
| 118 |
+
curve[i] = min(curve[i], curve[i - 1])
|
| 119 |
+
return [round(v, 3) for v in curve]
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def build(output_path: str = None, seed: int = 42) -> str:
|
| 123 |
+
"""Build and save the dataset. Returns the path written."""
|
| 124 |
+
random.seed(seed)
|
| 125 |
+
path = Path(output_path) if output_path else _OUTPUT_PATH
|
| 126 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 127 |
+
|
| 128 |
+
samples = []
|
| 129 |
+
idx = 1
|
| 130 |
+
for quality, count in [("high", 50), ("medium", 50), ("low", 50)]:
|
| 131 |
+
for _ in range(count):
|
| 132 |
+
platform = random.choice(_PLATFORMS)
|
| 133 |
+
region = random.choice(_REGIONS)
|
| 134 |
+
r1, r2, r3 = _quality_to_scores(quality)
|
| 135 |
+
curve = _generate_curve(r1, r2, r3)
|
| 136 |
+
samples.append({
|
| 137 |
+
"script_id": f"train_{idx:03d}",
|
| 138 |
+
"script_text": _pick_script(quality),
|
| 139 |
+
"platform": platform,
|
| 140 |
+
"region": region,
|
| 141 |
+
"retention_curve": curve,
|
| 142 |
+
"curve_source": "rule_based",
|
| 143 |
+
"quality_tier": quality,
|
| 144 |
+
"r1_score": round(r1, 3),
|
| 145 |
+
"r2_score": round(r2, 3),
|
| 146 |
+
"r3_score": round(r3, 3),
|
| 147 |
+
})
|
| 148 |
+
idx += 1
|
| 149 |
+
|
| 150 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 151 |
+
json.dump({"samples": samples}, f, indent=2)
|
| 152 |
+
|
| 153 |
+
return str(path)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
if __name__ == "__main__":
|
| 157 |
+
out = build()
|
| 158 |
+
print(f"Dataset built: {out} (150 samples)")
|
viral_script_engine/retention/training_data/retention_dataset.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
viral_script_engine/rewards/r10_retention_curve.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
|
| 5 |
+
from viral_script_engine.retention.curve_predictor import RetentionCurve, RetentionCurvePredictor
|
| 6 |
+
from viral_script_engine.retention.curve_scorer import CurveScorerResult, RetentionCurveScorer
|
| 7 |
+
from viral_script_engine.retention.feature_extractor import FeatureExtractor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class RetentionRewardResult(BaseModel):
|
| 11 |
+
score: float
|
| 12 |
+
original_curve: RetentionCurve
|
| 13 |
+
new_curve: RetentionCurve
|
| 14 |
+
curve_delta: CurveScorerResult
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class RetentionCurveReward:
|
| 18 |
+
"""
|
| 19 |
+
Wraps the full retention prediction + scoring pipeline into a reward signal.
|
| 20 |
+
|
| 21 |
+
Caches the original curve per episode so the extractor is called only once
|
| 22 |
+
for the original script — subsequent steps reuse the cached curve.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, cultural_kb_path: Optional[str] = None):
|
| 26 |
+
self.extractor = FeatureExtractor(cultural_kb_path=cultural_kb_path)
|
| 27 |
+
self.predictor = RetentionCurvePredictor()
|
| 28 |
+
self.scorer = RetentionCurveScorer()
|
| 29 |
+
self._original_curve_cache: dict = {}
|
| 30 |
+
|
| 31 |
+
def score(
|
| 32 |
+
self,
|
| 33 |
+
original_script: str,
|
| 34 |
+
rewritten_script: str,
|
| 35 |
+
platform: str,
|
| 36 |
+
region: str,
|
| 37 |
+
action_type: str,
|
| 38 |
+
episode_id: str,
|
| 39 |
+
) -> RetentionRewardResult:
|
| 40 |
+
# Cache original curve — compute only once per episode
|
| 41 |
+
if episode_id not in self._original_curve_cache:
|
| 42 |
+
orig_features = self.extractor.extract(original_script, platform, region)
|
| 43 |
+
self._original_curve_cache[episode_id] = self.predictor.predict(orig_features)
|
| 44 |
+
|
| 45 |
+
new_features = self.extractor.extract(rewritten_script, platform, region)
|
| 46 |
+
new_curve = self.predictor.predict(new_features)
|
| 47 |
+
|
| 48 |
+
result = self.scorer.score(
|
| 49 |
+
original_curve=self._original_curve_cache[episode_id],
|
| 50 |
+
new_curve=new_curve,
|
| 51 |
+
action_type=action_type,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return RetentionRewardResult(
|
| 55 |
+
score=result.final_score,
|
| 56 |
+
original_curve=self._original_curve_cache[episode_id],
|
| 57 |
+
new_curve=new_curve,
|
| 58 |
+
curve_delta=result,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def clear_cache(self, episode_id: Optional[str] = None) -> None:
|
| 62 |
+
if episode_id:
|
| 63 |
+
self._original_curve_cache.pop(episode_id, None)
|
| 64 |
+
else:
|
| 65 |
+
self._original_curve_cache.clear()
|
viral_script_engine/rewards/reward_aggregator.py
CHANGED
|
@@ -12,6 +12,7 @@ _COMPONENT_FIELDS = [
|
|
| 12 |
"r1_hook_strength", "r2_coherence", "r3_cultural_alignment",
|
| 13 |
"r4_debate_resolution", "r5_defender_preservation",
|
| 14 |
"r6_safety", "r7_originality", "r8_persona_fit", "r9_platform_pacing",
|
|
|
|
| 15 |
]
|
| 16 |
|
| 17 |
_DROP_THRESHOLD = 0.25
|
|
|
|
| 12 |
"r1_hook_strength", "r2_coherence", "r3_cultural_alignment",
|
| 13 |
"r4_debate_resolution", "r5_defender_preservation",
|
| 14 |
"r6_safety", "r7_originality", "r8_persona_fit", "r9_platform_pacing",
|
| 15 |
+
"r10_retention_curve",
|
| 16 |
]
|
| 17 |
|
| 18 |
_DROP_THRESHOLD = 0.25
|
viral_script_engine/tests/test_phase12.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Phase 12 tests — Retention Curve Simulator."""
|
| 2 |
+
import json
|
| 3 |
+
import sys
|
| 4 |
+
import tempfile
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from unittest.mock import MagicMock, patch
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pytest
|
| 10 |
+
|
| 11 |
+
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 12 |
+
|
| 13 |
+
from viral_script_engine.retention.feature_extractor import (
|
| 14 |
+
FeatureExtractor,
|
| 15 |
+
ScriptFeatures,
|
| 16 |
+
_KNOWN_PLATFORMS,
|
| 17 |
+
)
|
| 18 |
+
from viral_script_engine.retention.curve_predictor import (
|
| 19 |
+
RetentionCurve,
|
| 20 |
+
RetentionCurvePredictor,
|
| 21 |
+
CURVE_TIMEPOINTS,
|
| 22 |
+
)
|
| 23 |
+
from viral_script_engine.retention.curve_scorer import RetentionCurveScorer
|
| 24 |
+
from viral_script_engine.rewards.r10_retention_curve import RetentionCurveReward
|
| 25 |
+
|
| 26 |
+
_SCRIPTS_PATH = str(
|
| 27 |
+
Path(__file__).parent.parent / "data" / "test_scripts" / "scripts.json"
|
| 28 |
+
)
|
| 29 |
+
_CULTURAL_KB_PATH = str(
|
| 30 |
+
Path(__file__).parent.parent / "data" / "cultural_kb.json"
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
_GOOD_SCRIPT = (
|
| 34 |
+
"Did you know 80% of people get this wrong? Here's what actually works. "
|
| 35 |
+
"Stop doing what everyone tells you. Use this one simple method instead. "
|
| 36 |
+
"The results will surprise you. Follow for more."
|
| 37 |
+
)
|
| 38 |
+
_BAD_SCRIPT = (
|
| 39 |
+
"Hello guys welcome back um so today basically I wanted to kind of talk "
|
| 40 |
+
"about you know like finances and stuff. So basically just try to save money."
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
# FeatureExtractor tests
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
def test_feature_extractor_produces_correct_features():
|
| 49 |
+
extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
|
| 50 |
+
features = extractor.extract(_GOOD_SCRIPT, platform="Reels", region="pan_india_english")
|
| 51 |
+
|
| 52 |
+
assert isinstance(features, ScriptFeatures)
|
| 53 |
+
assert features.hook_word_count > 0
|
| 54 |
+
assert features.sentence_count > 0
|
| 55 |
+
assert features.word_count > 0
|
| 56 |
+
assert features.platform == "Reels"
|
| 57 |
+
assert features.hook_has_number is True # "80%"
|
| 58 |
+
assert features.hook_has_question is True # "?"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def test_feature_extractor_bad_script_has_high_filler():
|
| 62 |
+
extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
|
| 63 |
+
features = extractor.extract(_BAD_SCRIPT, platform="Reels", region="pan_india_english")
|
| 64 |
+
|
| 65 |
+
# Bad script should have higher filler score than good script
|
| 66 |
+
good_features = extractor.extract(_GOOD_SCRIPT, platform="Reels", region="pan_india_english")
|
| 67 |
+
assert features.hook_filler_score >= good_features.hook_filler_score
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def test_to_vector_returns_flat_numeric_list():
|
| 71 |
+
extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
|
| 72 |
+
features = extractor.extract(_GOOD_SCRIPT, platform="Reels", region="pan_india_english")
|
| 73 |
+
vec = features.to_vector()
|
| 74 |
+
|
| 75 |
+
assert isinstance(vec, list)
|
| 76 |
+
assert len(vec) > 0
|
| 77 |
+
# No NaN values
|
| 78 |
+
for v in vec:
|
| 79 |
+
assert v == v, f"NaN found in vector: {vec}"
|
| 80 |
+
# All values are floats
|
| 81 |
+
for v in vec:
|
| 82 |
+
assert isinstance(v, (int, float))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def test_to_vector_platform_one_hot():
|
| 86 |
+
extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
|
| 87 |
+
for platform in _KNOWN_PLATFORMS:
|
| 88 |
+
features = extractor.extract(_GOOD_SCRIPT, platform=platform, region="pan_india_english")
|
| 89 |
+
vec = features.to_vector()
|
| 90 |
+
# Last N elements are one-hot platform encoding
|
| 91 |
+
platform_slice = vec[-len(_KNOWN_PLATFORMS):]
|
| 92 |
+
assert sum(platform_slice) == 1.0, f"One-hot sum should be 1 for {platform}"
|
| 93 |
+
assert max(platform_slice) == 1.0
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def test_to_vector_no_nan_for_bad_script():
|
| 97 |
+
extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
|
| 98 |
+
features = extractor.extract(_BAD_SCRIPT, platform="TikTok", region="pan_india_english")
|
| 99 |
+
vec = features.to_vector()
|
| 100 |
+
for v in vec:
|
| 101 |
+
assert v == v, f"NaN found in vector"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# ---------------------------------------------------------------------------
|
| 105 |
+
# RetentionCurvePredictor tests
|
| 106 |
+
# ---------------------------------------------------------------------------
|
| 107 |
+
|
| 108 |
+
def test_predictor_raises_if_not_trained():
|
| 109 |
+
predictor = RetentionCurvePredictor.__new__(RetentionCurvePredictor)
|
| 110 |
+
predictor.model = None
|
| 111 |
+
predictor._trained = False
|
| 112 |
+
extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
|
| 113 |
+
features = extractor.extract(_GOOD_SCRIPT, platform="Reels", region="pan_india_english")
|
| 114 |
+
|
| 115 |
+
with pytest.raises(RuntimeError, match="not trained"):
|
| 116 |
+
predictor.predict(features)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def _make_trained_predictor() -> RetentionCurvePredictor:
|
| 120 |
+
"""Train predictor on a minimal in-memory dataset."""
|
| 121 |
+
from sklearn.ensemble import GradientBoostingRegressor
|
| 122 |
+
from sklearn.multioutput import MultiOutputRegressor
|
| 123 |
+
import numpy as np
|
| 124 |
+
|
| 125 |
+
extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
|
| 126 |
+
scripts = [_GOOD_SCRIPT, _BAD_SCRIPT] * 10
|
| 127 |
+
platforms = ["Reels", "TikTok", "Shorts", "Feed"] * 5
|
| 128 |
+
X, y = [], []
|
| 129 |
+
for i, (sc, pl) in enumerate(zip(scripts, platforms)):
|
| 130 |
+
feat = extractor.extract(sc, platform=pl, region="pan_india_english")
|
| 131 |
+
X.append(feat.to_vector())
|
| 132 |
+
quality = 1.0 if sc == _GOOD_SCRIPT else 0.3
|
| 133 |
+
curve = [max(0.0, quality - j * 0.05) for j in range(len(CURVE_TIMEPOINTS))]
|
| 134 |
+
y.append(curve)
|
| 135 |
+
|
| 136 |
+
model = MultiOutputRegressor(
|
| 137 |
+
GradientBoostingRegressor(n_estimators=10, max_depth=2, random_state=42)
|
| 138 |
+
)
|
| 139 |
+
model.fit(np.array(X), np.array(y))
|
| 140 |
+
|
| 141 |
+
predictor = RetentionCurvePredictor.__new__(RetentionCurvePredictor)
|
| 142 |
+
predictor.model = model
|
| 143 |
+
predictor._trained = True
|
| 144 |
+
return predictor
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def test_predicted_curve_is_monotonically_non_increasing():
|
| 148 |
+
predictor = _make_trained_predictor()
|
| 149 |
+
extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
|
| 150 |
+
features = extractor.extract(_GOOD_SCRIPT, platform="Reels", region="pan_india_english")
|
| 151 |
+
curve = predictor.predict(features)
|
| 152 |
+
|
| 153 |
+
for i in range(1, len(curve.values)):
|
| 154 |
+
assert curve.values[i] <= curve.values[i - 1] + 1e-9, (
|
| 155 |
+
f"Curve not monotonic at index {i}: {curve.values[i - 1]} -> {curve.values[i]}"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def test_predicted_curve_values_in_range():
|
| 160 |
+
predictor = _make_trained_predictor()
|
| 161 |
+
extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
|
| 162 |
+
features = extractor.extract(_BAD_SCRIPT, platform="TikTok", region="pan_india_english")
|
| 163 |
+
curve = predictor.predict(features)
|
| 164 |
+
|
| 165 |
+
for v in curve.values:
|
| 166 |
+
assert 0.0 <= v <= 1.0, f"Value {v} out of [0, 1]"
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def test_predicted_curve_has_correct_timepoints():
|
| 170 |
+
predictor = _make_trained_predictor()
|
| 171 |
+
extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
|
| 172 |
+
features = extractor.extract(_GOOD_SCRIPT, platform="Reels", region="pan_india_english")
|
| 173 |
+
curve = predictor.predict(features)
|
| 174 |
+
|
| 175 |
+
assert curve.timepoints == CURVE_TIMEPOINTS
|
| 176 |
+
assert len(curve.values) == len(CURVE_TIMEPOINTS)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# ---------------------------------------------------------------------------
|
| 180 |
+
# RetentionCurveScorer tests
|
| 181 |
+
# ---------------------------------------------------------------------------
|
| 182 |
+
|
| 183 |
+
def _make_curve(values: list) -> RetentionCurve:
|
| 184 |
+
return RetentionCurve.from_values(values)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def test_scorer_rewards_targeted_improvement():
|
| 188 |
+
scorer = RetentionCurveScorer()
|
| 189 |
+
# hook_rewrite targets [0, 3, 6] — improve those timepoints
|
| 190 |
+
orig_values = [1.0, 0.6, 0.5, 0.45, 0.42, 0.40, 0.38, 0.36, 0.32, 0.30]
|
| 191 |
+
new_values = [1.0, 0.85, 0.75, 0.45, 0.42, 0.40, 0.38, 0.36, 0.32, 0.30]
|
| 192 |
+
|
| 193 |
+
result = scorer.score(
|
| 194 |
+
original_curve=_make_curve(orig_values),
|
| 195 |
+
new_curve=_make_curve(new_values),
|
| 196 |
+
action_type="hook_rewrite",
|
| 197 |
+
)
|
| 198 |
+
assert result.final_score > 0
|
| 199 |
+
assert result.targeted_improvement > 0
|
| 200 |
+
assert 3 in result.improved_timepoints or 6 in result.improved_timepoints
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
def test_scorer_applies_regression_penalty_for_worsening():
|
| 204 |
+
scorer = RetentionCurveScorer()
|
| 205 |
+
orig_values = [1.0, 0.9, 0.8, 0.7, 0.65, 0.60, 0.55, 0.50, 0.45, 0.40]
|
| 206 |
+
# Worsen the mid-video section
|
| 207 |
+
new_values = [1.0, 0.9, 0.8, 0.5, 0.45, 0.40, 0.55, 0.50, 0.45, 0.40]
|
| 208 |
+
|
| 209 |
+
result = scorer.score(
|
| 210 |
+
original_curve=_make_curve(orig_values),
|
| 211 |
+
new_curve=_make_curve(new_values),
|
| 212 |
+
action_type="hook_rewrite",
|
| 213 |
+
)
|
| 214 |
+
assert result.regression_penalty > 0
|
| 215 |
+
assert len(result.worsened_timepoints) > 0
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def test_scorer_score_in_range():
|
| 219 |
+
scorer = RetentionCurveScorer()
|
| 220 |
+
orig_values = [1.0, 0.8, 0.7, 0.6, 0.55, 0.50, 0.46, 0.42, 0.38, 0.35]
|
| 221 |
+
new_values = [1.0, 0.85, 0.75, 0.65, 0.60, 0.55, 0.50, 0.46, 0.42, 0.38]
|
| 222 |
+
|
| 223 |
+
result = scorer.score(
|
| 224 |
+
original_curve=_make_curve(orig_values),
|
| 225 |
+
new_curve=_make_curve(new_values),
|
| 226 |
+
action_type="section_reorder",
|
| 227 |
+
)
|
| 228 |
+
assert 0.0 <= result.final_score <= 1.0
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# ---------------------------------------------------------------------------
|
| 232 |
+
# RetentionCurveReward — cache test
|
| 233 |
+
# ---------------------------------------------------------------------------
|
| 234 |
+
|
| 235 |
+
def test_retention_reward_caches_original_curve():
|
| 236 |
+
"""FeatureExtractor.extract should be called only once for the original script per episode."""
|
| 237 |
+
predictor = _make_trained_predictor()
|
| 238 |
+
reward = RetentionCurveReward.__new__(RetentionCurveReward)
|
| 239 |
+
reward.extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
|
| 240 |
+
reward.predictor = predictor
|
| 241 |
+
reward.scorer = RetentionCurveScorer()
|
| 242 |
+
reward._original_curve_cache = {}
|
| 243 |
+
|
| 244 |
+
call_count = {"n": 0}
|
| 245 |
+
original_extract = reward.extractor.extract
|
| 246 |
+
|
| 247 |
+
def counting_extract(script, platform, region):
|
| 248 |
+
call_count["n"] += 1
|
| 249 |
+
return original_extract(script, platform, region)
|
| 250 |
+
|
| 251 |
+
reward.extractor.extract = counting_extract
|
| 252 |
+
|
| 253 |
+
episode_id = "ep_cache_test"
|
| 254 |
+
for _ in range(3):
|
| 255 |
+
reward.score(
|
| 256 |
+
original_script=_GOOD_SCRIPT,
|
| 257 |
+
rewritten_script=_BAD_SCRIPT,
|
| 258 |
+
platform="Reels",
|
| 259 |
+
region="pan_india_english",
|
| 260 |
+
action_type="hook_rewrite",
|
| 261 |
+
episode_id=episode_id,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
# extract called for original once + rewritten on every call = 1 + 3 = 4
|
| 265 |
+
# original is cached after first call → only 1 for original, 3 for rewritten = 4 total
|
| 266 |
+
assert call_count["n"] == 4, (
|
| 267 |
+
f"Expected 4 extract calls (1 original cached + 3 rewritten), got {call_count['n']}"
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# ---------------------------------------------------------------------------
|
| 272 |
+
# env.step includes r10 in reward components
|
| 273 |
+
# ---------------------------------------------------------------------------
|
| 274 |
+
|
| 275 |
+
def test_env_step_includes_r10_when_model_trained():
|
| 276 |
+
"""env.step() should include r10_retention_curve in reward components when model is trained."""
|
| 277 |
+
from viral_script_engine.environment.env import ViralScriptEnv
|
| 278 |
+
from unittest.mock import MagicMock
|
| 279 |
+
|
| 280 |
+
env = ViralScriptEnv(
|
| 281 |
+
scripts_path=_SCRIPTS_PATH,
|
| 282 |
+
cultural_kb_path=_CULTURAL_KB_PATH,
|
| 283 |
+
difficulty="easy",
|
| 284 |
+
use_escalation=False,
|
| 285 |
+
use_anti_gaming=False,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# Inject trained predictor
|
| 289 |
+
predictor = _make_trained_predictor()
|
| 290 |
+
env.r10.predictor = predictor
|
| 291 |
+
|
| 292 |
+
obs, _ = env.reset()
|
| 293 |
+
|
| 294 |
+
mock_critique = MagicMock()
|
| 295 |
+
mock_critique.claims = []
|
| 296 |
+
mock_critique.overall_severity = "low"
|
| 297 |
+
|
| 298 |
+
mock_defender = MagicMock()
|
| 299 |
+
mock_defender.core_strength = "Strong hook"
|
| 300 |
+
mock_defender.core_strength_quote = "Test quote"
|
| 301 |
+
mock_defender.defense_argument = "Good"
|
| 302 |
+
mock_defender.flagged_critic_claims = []
|
| 303 |
+
mock_defender.regional_voice_elements = []
|
| 304 |
+
mock_defender.model_dump.return_value = {}
|
| 305 |
+
|
| 306 |
+
mock_rewrite = MagicMock()
|
| 307 |
+
mock_rewrite.rewritten_script = obs["current_script"]
|
| 308 |
+
mock_rewrite.diff = ""
|
| 309 |
+
|
| 310 |
+
with patch.object(env.critic, "critique", return_value=mock_critique), \
|
| 311 |
+
patch.object(env.defender, "defend", return_value=mock_defender), \
|
| 312 |
+
patch.object(env.rewriter, "rewrite", return_value=mock_rewrite):
|
| 313 |
+
|
| 314 |
+
_, _, _, _, info = env.step({
|
| 315 |
+
"action_type": "hook_rewrite",
|
| 316 |
+
"target_section": "hook",
|
| 317 |
+
"instruction": "Strengthen the hook.",
|
| 318 |
+
"critique_claim_id": "C1",
|
| 319 |
+
"reasoning": "test",
|
| 320 |
+
})
|
| 321 |
+
|
| 322 |
+
rc = info["reward_components"]
|
| 323 |
+
assert "r10_retention_curve" in rc
|
| 324 |
+
assert rc["r10_retention_curve"] is not None
|
| 325 |
+
assert 0.0 <= rc["r10_retention_curve"] <= 1.0
|