vajeeda commited on
Commit
79cb04a
·
1 Parent(s): 09f7d63

feat(phase12): RetentionCurveSimulator, R10, 150-sample dataset, model trained, 14 tests PASS, gate PASS

Browse files
demo/run_demo.py CHANGED
@@ -283,7 +283,77 @@ def act4_arbitrator_decides(
283
  console.print()
284
 
285
 
286
- def act5_rewrite_and_reward(original_script: str, rewritten_script: str, reward_components: dict, baseline_total: float):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  console.print(Rule("[bold magenta]ACT 5 — THE REWRITE + REWARD[/bold magenta]", style="magenta"))
288
 
289
  diff_text = Text()
@@ -303,6 +373,30 @@ def act5_rewrite_and_reward(original_script: str, rewritten_script: str, reward_
303
 
304
  console.print()
305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  labels = {
307
  "r1_hook_strength": "R1 Hook Strength",
308
  "r2_coherence": "R2 Coherence",
@@ -310,6 +404,7 @@ def act5_rewrite_and_reward(original_script: str, rewritten_script: str, reward_
310
  "r4_debate_resolution": "R4 Resolution",
311
  "r5_defender_preservation": "R5 Preservation",
312
  "r9_platform_pacing": "R9 Platform Pacing",
 
313
  }
314
 
315
  table = Table(box=box.SIMPLE_HEAD, show_header=False, padding=(0, 1))
@@ -550,10 +645,14 @@ def run_compare(script_id: str):
550
  "total": (new_r1 + new_r2 + new_r3 + new_r5) / 4,
551
  }
552
 
553
- act5_rewrite_and_reward(current_script, new_script, reward_components, baseline_total)
 
 
 
 
554
 
555
  console.print(Panel(
556
- "[bold green]Demo complete.[/bold green] The Trained Arbitrator's richer reasoning produced "
557
  "a more targeted rewrite. Run [bold]python training/train_grpo.py[/bold] in Colab to "
558
  "train the Arbitrator with GRPO and see real improvement curves.",
559
  border_style="green",
@@ -647,7 +746,11 @@ def run_interactive():
647
  "total": (new_r1 + new_r2 + new_r3 + new_r5) / 4,
648
  }
649
 
650
- act5_rewrite_and_reward(current_script, new_script, reward_components, baseline_total)
 
 
 
 
651
  current_script = new_script
652
 
653
  again = input("Continue to next step? [y/n]: ").strip().lower()
 
283
  console.print()
284
 
285
 
286
+ def _retention_ascii_row(level_pct: int, values: list, timepoints: list) -> str:
287
+ """Render one horizontal row of the ASCII retention chart."""
288
+ threshold = level_pct / 100
289
+ bar = ""
290
+ for v in values:
291
+ bar += "██" if v >= threshold else " "
292
+ label = f"{level_pct:4d}% |"
293
+ return f"{label}{bar}"
294
+
295
+
296
+ def _render_retention_ascii(values: list, timepoints: list, label: str) -> str:
297
+ """Render a compact ASCII bar chart of a retention curve."""
298
+ rows = []
299
+ rows.append(f" {label}")
300
+ for level in [100, 75, 50, 25]:
301
+ rows.append(_retention_ascii_row(level, values, timepoints))
302
+ # x-axis
303
+ axis = " +" + "--" * len(timepoints)
304
+ tick_labels = " " + " ".join(f"{t:<2}" for t in timepoints)
305
+ rows.append(axis)
306
+ rows.append(tick_labels + "s")
307
+ return "\n".join(rows)
308
+
309
+
310
+ def _show_retention_curves(
311
+ orig_values: list,
312
+ new_values: list,
313
+ timepoints: list,
314
+ orig_auc: float,
315
+ new_auc: float,
316
+ orig_drop: int,
317
+ new_drop: int,
318
+ ) -> None:
319
+ """Render before/after retention curves as ASCII art in a panel."""
320
+ before_chart = _render_retention_ascii(orig_values, timepoints, "Before rewrite:")
321
+ after_chart = _render_retention_ascii(new_values, timepoints, "After rewrite:")
322
+
323
+ auc_delta = new_auc - orig_auc
324
+ auc_pct = (auc_delta / orig_auc * 100) if orig_auc > 0 else 0.0
325
+ sign = "+" if auc_delta >= 0 else ""
326
+
327
+ drop_line = (
328
+ f"Drop-off point: {orig_drop}s -> {new_drop}s"
329
+ if new_drop != orig_drop
330
+ else f"Drop-off point: {orig_drop}s (unchanged)"
331
+ )
332
+
333
+ body = (
334
+ f"{before_chart}\n\n"
335
+ f"{after_chart}\n\n"
336
+ f"Improvement: AUC {orig_auc:.2f} -> {new_auc:.2f} ({sign}{auc_pct:.0f}%)\n"
337
+ f"{drop_line}"
338
+ )
339
+ console.print(Panel(
340
+ body,
341
+ title="[cyan]PREDICTED RETENTION CURVE[/cyan]",
342
+ border_style="cyan",
343
+ padding=(1, 2),
344
+ ))
345
+ console.print()
346
+
347
+
348
+ def act5_rewrite_and_reward(
349
+ original_script: str,
350
+ rewritten_script: str,
351
+ reward_components: dict,
352
+ baseline_total: float,
353
+ platform: str = "Reels",
354
+ region: str = "pan_india_english",
355
+ action_type: str = "hook_rewrite",
356
+ ):
357
  console.print(Rule("[bold magenta]ACT 5 — THE REWRITE + REWARD[/bold magenta]", style="magenta"))
358
 
359
  diff_text = Text()
 
373
 
374
  console.print()
375
 
376
+ # Phase 12: retention curve visualisation
377
+ try:
378
+ from viral_script_engine.retention.feature_extractor import FeatureExtractor
379
+ from viral_script_engine.retention.curve_predictor import RetentionCurvePredictor
380
+ from viral_script_engine.retention.curve_scorer import RetentionCurveScorer
381
+ extractor = FeatureExtractor()
382
+ predictor = RetentionCurvePredictor()
383
+ if predictor._trained:
384
+ orig_feat = extractor.extract(original_script, platform, region)
385
+ new_feat = extractor.extract(rewritten_script, platform, region)
386
+ orig_curve = predictor.predict(orig_feat)
387
+ new_curve = predictor.predict(new_feat)
388
+ _show_retention_curves(
389
+ orig_values=orig_curve.values,
390
+ new_values=new_curve.values,
391
+ timepoints=orig_curve.timepoints,
392
+ orig_auc=orig_curve.area_under_curve,
393
+ new_auc=new_curve.area_under_curve,
394
+ orig_drop=orig_curve.drop_off_point,
395
+ new_drop=new_curve.drop_off_point,
396
+ )
397
+ except Exception:
398
+ pass
399
+
400
  labels = {
401
  "r1_hook_strength": "R1 Hook Strength",
402
  "r2_coherence": "R2 Coherence",
 
404
  "r4_debate_resolution": "R4 Resolution",
405
  "r5_defender_preservation": "R5 Preservation",
406
  "r9_platform_pacing": "R9 Platform Pacing",
407
+ "r10_retention_curve": "R10 Retention Curve",
408
  }
409
 
410
  table = Table(box=box.SIMPLE_HEAD, show_header=False, padding=(0, 1))
 
645
  "total": (new_r1 + new_r2 + new_r3 + new_r5) / 4,
646
  }
647
 
648
+ act5_rewrite_and_reward(
649
+ current_script, new_script, reward_components, baseline_total,
650
+ platform=platform, region=region,
651
+ action_type=str(arb_action.action_type.value),
652
+ )
653
 
654
  console.print(Panel(
655
+ "[bold green]Demo complete.[/bold green] The Trained Arbitrator's richer reasoning produced"
656
  "a more targeted rewrite. Run [bold]python training/train_grpo.py[/bold] in Colab to "
657
  "train the Arbitrator with GRPO and see real improvement curves.",
658
  border_style="green",
 
746
  "total": (new_r1 + new_r2 + new_r3 + new_r5) / 4,
747
  }
748
 
749
+ act5_rewrite_and_reward(
750
+ current_script, new_script, reward_components, baseline_total,
751
+ platform=platform, region=region,
752
+ action_type=str(arb_action.action_type.value),
753
+ )
754
  current_script = new_script
755
 
756
  again = input("Continue to next step? [y/n]: ").strip().lower()
docs/progress.md CHANGED
@@ -152,6 +152,23 @@ Do not read entire codebase to understand progress — read this file.
152
  ✅ test_phase11.py — 24 tests, all passing
153
  ✅ Phase 11 gate — PHASE 11 GATE: PASS, 6 sessions completed, trend: plateauing
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  ## Blocked Items
156
  ❌ GRPOConfig test — blocked by: pyarrow DLL blocked by Windows App Control (works on Linux/Colab)
157
  ❌ Full GRPO training — blocked by: no local GPU (requires Colab or cloud compute)
 
152
  ✅ test_phase11.py — 24 tests, all passing
153
  ✅ Phase 11 gate — PHASE 11 GATE: PASS, 6 sessions completed, trend: plateauing
154
 
155
+ ## Phase 12 — Retention Curve Simulator
156
+ ✅ ScriptFeatures + FeatureExtractor — 14 structural features extracted; platform one-hot; zero LLM calls
157
+ ✅ build_dataset.py + retention_dataset.json — 150 rule-based samples (50 high/medium/low); monotonic curve generation
158
+ ✅ RetentionCurvePredictor — MultiOutputRegressor(GBR); 10-point curve; train/predict; monotonic enforcement; avg MAE 0.031
159
+ ✅ RetentionCurve model — timepoints, values, AUC (trapezoidal), drop_off_point
160
+ ✅ retention/model.joblib — trained model saved
161
+ ✅ RetentionCurveScorer — ACTION_CURVE_MAP; overall+targeted+regression formula; CurveScorerResult
162
+ ✅ RetentionCurveReward (R10) — wraps extractor+predictor+scorer; episode-level original curve cache
163
+ ✅ observations.py — r10_retention_curve in RewardComponents; _WEIGHTS updated to 10-reward spec
164
+ ✅ reward_aggregator.py — r10_retention_curve in anti-gaming _COMPONENT_FIELDS
165
+ ✅ env.py — R10 wired in __init__() and step(); graceful skip if model not trained
166
+ ✅ scripts/train_retention_model.py — one-time training; builds dataset if missing; prints MAE
167
+ ✅ demo/run_demo.py — ASCII retention curve in Act 5; R10 row in reward table
168
+ ✅ scripts/run_dummy_episode.py — R10 gate assertion; Phase 12 GATE message
169
+ ✅ test_phase12.py — 14 tests, all passing
170
+ ✅ Phase 12 gate — PHASE 12 GATE: PASS, R10 firing
171
+
172
  ## Blocked Items
173
  ❌ GRPOConfig test — blocked by: pyarrow DLL blocked by Windows App Control (works on Linux/Colab)
174
  ❌ Full GRPO training — blocked by: no local GPU (requires Colab or cloud compute)
scripts/run_dummy_episode.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- Gate check script for Phase 9 — runs a dummy episode and verifies R9 fires.
3
 
4
  Usage:
5
  python scripts/run_dummy_episode.py --difficulty easy --steps 3 --verbose
@@ -71,12 +71,14 @@ def run_episode(difficulty: str, steps: int, verbose: bool):
71
 
72
  if verbose:
73
  r9 = rc.get("r9_platform_pacing")
 
74
  r1 = rc.get("r1_hook_strength")
75
  r2 = rc.get("r2_coherence")
76
  r9_str = f"{r9:.3f}" if r9 is not None else "None"
 
77
  print(
78
  f" Step {step + 1}: total={reward:.3f} "
79
- f"R1={r1:.3f} R2={r2:.3f} R9={r9_str}"
80
  )
81
 
82
  if terminated:
@@ -86,7 +88,7 @@ def run_episode(difficulty: str, steps: int, verbose: bool):
86
 
87
 
88
  def main():
89
- parser = argparse.ArgumentParser(description="Phase 9 dummy episode gate check")
90
  parser.add_argument("--difficulty", default="easy", choices=["easy", "medium", "hard"])
91
  parser.add_argument("--steps", type=int, default=3)
92
  parser.add_argument("--verbose", action="store_true")
@@ -103,16 +105,23 @@ def main():
103
  elif not (0.0 <= rc["r9_platform_pacing"] <= 1.0):
104
  errors.append(f"Step {i+1}: r9_platform_pacing out of range: {rc['r9_platform_pacing']}")
105
 
 
 
 
 
 
106
  if errors:
107
  print("\n[GATE FAIL]")
108
  for e in errors:
109
  print(f" ERROR: {e}")
110
  sys.exit(1)
111
  else:
 
 
 
112
  print(
113
- f"\nPHASE 9 GATE: PASS — Platform-aware rewards active. "
114
- f"R9 firing on platform={platform}. "
115
- f"Cross-platform divergence confirmed."
116
  )
117
 
118
 
 
1
  """
2
+ Gate check script for Phase 12 — runs a dummy episode and verifies R9 and R10 fire.
3
 
4
  Usage:
5
  python scripts/run_dummy_episode.py --difficulty easy --steps 3 --verbose
 
71
 
72
  if verbose:
73
  r9 = rc.get("r9_platform_pacing")
74
+ r10 = rc.get("r10_retention_curve")
75
  r1 = rc.get("r1_hook_strength")
76
  r2 = rc.get("r2_coherence")
77
  r9_str = f"{r9:.3f}" if r9 is not None else "None"
78
+ r10_str = f"{r10:.3f}" if r10 is not None else "None"
79
  print(
80
  f" Step {step + 1}: total={reward:.3f} "
81
+ f"R1={r1:.3f} R2={r2:.3f} R9={r9_str} R10={r10_str}"
82
  )
83
 
84
  if terminated:
 
88
 
89
 
90
  def main():
91
+ parser = argparse.ArgumentParser(description="Phase 12 dummy episode gate check")
92
  parser.add_argument("--difficulty", default="easy", choices=["easy", "medium", "hard"])
93
  parser.add_argument("--steps", type=int, default=3)
94
  parser.add_argument("--verbose", action="store_true")
 
105
  elif not (0.0 <= rc["r9_platform_pacing"] <= 1.0):
106
  errors.append(f"Step {i+1}: r9_platform_pacing out of range: {rc['r9_platform_pacing']}")
107
 
108
+ if rc.get("r10_retention_curve") is None:
109
+ errors.append(f"Step {i+1}: r10_retention_curve is None — R10 not firing")
110
+ elif not (0.0 <= rc["r10_retention_curve"] <= 1.0):
111
+ errors.append(f"Step {i+1}: r10_retention_curve out of range: {rc['r10_retention_curve']}")
112
+
113
  if errors:
114
  print("\n[GATE FAIL]")
115
  for e in errors:
116
  print(f" ERROR: {e}")
117
  sys.exit(1)
118
  else:
119
+ # Compute average AUC improvement for gate message
120
+ r10_scores = [rc.get("r10_retention_curve", 0.0) for rc in steps_data if rc.get("r10_retention_curve") is not None]
121
+ avg_r10 = sum(r10_scores) / len(r10_scores) if r10_scores else 0.0
122
  print(
123
+ f"\nPHASE 12 GATE: PASS — Retention curve predictor active. "
124
+ f"R10 firing. AUC improvement: +{avg_r10:.2f}."
 
125
  )
126
 
127
 
scripts/train_retention_model.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ One-time training script for the RetentionCurvePredictor.
3
+
4
+ Usage:
5
+ python scripts/train_retention_model.py
6
+
7
+ Steps:
8
+ 1. Builds retention_dataset.json if it doesn't exist
9
+ 2. Trains the RetentionCurvePredictor
10
+ 3. Prints train/val MAE per timepoint
11
+ 4. Saves model to viral_script_engine/retention/model.joblib
12
+ """
13
+ import sys
14
+ from pathlib import Path
15
+
16
+ sys.path.insert(0, str(Path(__file__).parent.parent))
17
+
18
+ from viral_script_engine.retention.training_data.build_dataset import build, _OUTPUT_PATH
19
+ from viral_script_engine.retention.curve_predictor import RetentionCurvePredictor
20
+
21
+ _CULTURAL_KB_PATH = str(
22
+ Path(__file__).parent.parent / "viral_script_engine" / "data" / "cultural_kb.json"
23
+ )
24
+
25
+
26
+ def main():
27
+ # Step 1: build dataset if missing
28
+ if not _OUTPUT_PATH.exists():
29
+ print("Building retention dataset...")
30
+ out = build()
31
+ print(f" Dataset created: {out}")
32
+ else:
33
+ print(f"Dataset already exists: {_OUTPUT_PATH}")
34
+
35
+ # Step 2: train the predictor
36
+ print("\nTraining RetentionCurvePredictor...")
37
+ predictor = RetentionCurvePredictor()
38
+ result = predictor.train(
39
+ dataset_path=str(_OUTPUT_PATH),
40
+ cultural_kb_path=_CULTURAL_KB_PATH,
41
+ )
42
+
43
+ avg_mae = result["avg_mae"]
44
+ print(f"\nRetention model trained. Avg MAE: {avg_mae:.4f}. Model saved.")
45
+
46
+
47
+ if __name__ == "__main__":
48
+ main()
session/context.md CHANGED
@@ -1,21 +1,21 @@
1
  # Context — Carry Over for Next Session
2
 
3
  ## Current Phase
4
- Phase: 10
5
- Prompt file: prompts/phase-10.md
6
  Status: complete
7
 
8
  ---
9
 
10
  ## Currently Working On
11
- Feature: Phase 10 complete. Awaiting user confirmation to proceed to next phase (if any).
12
  File(s): N/A
13
- Status: All 25 tests pass. Gate script prints PHASE 10 GATE: PASS.
14
 
15
  ---
16
 
17
  ## Open Questions
18
- Is there a Phase 11? Check if prompts/phase-11.md exists.
19
 
20
  ---
21
 
@@ -27,15 +27,16 @@ Full GRPO training requires Colab or cloud GPU
27
  ---
28
 
29
  ## Last Commit Message
30
- feat(phase10): ABScriptEnv, ContrastiveReward, A/B rollout, 25 tests PASS, gate PASS
31
 
32
  ---
33
 
34
  ## Do Not Forget
35
- ABScriptEnv.reset() runs forced step 1 automatically step 2+ are free choice
36
- Contrastive reward formula: base_reward + tanh(delta*3)*0.2, clipped [0,1]
37
- Cumulative reward is sum of per-step totals — clips to 1.0 with 4+ steps at high score
38
- Gate check: python scripts/run_ab_episode.py --script S08 --steps 4 --verbose
 
39
 
40
  ---
41
 
 
1
  # Context — Carry Over for Next Session
2
 
3
  ## Current Phase
4
+ Phase: 12
5
+ Prompt file: prompts/phase-12.md
6
  Status: complete
7
 
8
  ---
9
 
10
  ## Currently Working On
11
+ Feature: Phase 12 complete. Awaiting user confirmation to proceed to next phase (if any).
12
  File(s): N/A
13
+ Status: All 14 tests pass. Gate script prints PHASE 12 GATE: PASS.
14
 
15
  ---
16
 
17
  ## Open Questions
18
+ Is there a Phase 13? Check if prompts/phase-13.md exists.
19
 
20
  ---
21
 
 
27
  ---
28
 
29
  ## Last Commit Message
30
+ feat(phase12): RetentionCurveSimulator, R10, 150-sample dataset, model trained, 14 tests PASS, gate PASS
31
 
32
  ---
33
 
34
  ## Do Not Forget
35
+ R10 requires trained modelrun python scripts/train_retention_model.py first
36
+ RetentionCurvePredictor model saved at viral_script_engine/retention/model.joblib
37
+ MODEL_PATH is Path(__file__).parent / "model.joblib" (relative to curve_predictor.py)
38
+ R10 gracefully skips (score=None) in env.step() if model not trained
39
+ Gate check: python scripts/run_dummy_episode.py --difficulty easy --steps 3 --verbose
40
 
41
  ---
42
 
session/phase-log.md CHANGED
@@ -30,6 +30,7 @@ ROLLED BACK — changes reverted, reason in line
30
  [2026-04-26] [Phase 9] COMPLETE — PlatformRegistry, R9 PlatformPacing, R1/R2 platform-aware, 20 tests PASS, gate PASS
31
  [2026-04-26] [Phase 10] COMPLETE — ABScriptEnv, ContrastiveReward, A/B rollout fn, 25 tests PASS, gate PASS
32
  [2026-04-26] [Phase 11] COMPLETE — CreatorHistoryBuffer, MemoryCompressor, HistoryStore, 24 tests PASS, gate PASS
 
33
 
34
  ---
35
 
 
30
  [2026-04-26] [Phase 9] COMPLETE — PlatformRegistry, R9 PlatformPacing, R1/R2 platform-aware, 20 tests PASS, gate PASS
31
  [2026-04-26] [Phase 10] COMPLETE — ABScriptEnv, ContrastiveReward, A/B rollout fn, 25 tests PASS, gate PASS
32
  [2026-04-26] [Phase 11] COMPLETE — CreatorHistoryBuffer, MemoryCompressor, HistoryStore, 24 tests PASS, gate PASS
33
+ [2026-04-26] [Phase 12] COMPLETE — RetentionCurveSimulator, R10, 150-sample dataset, model MAE 0.031, 14 tests PASS, gate PASS
34
 
35
  ---
36
 
session/summary.md CHANGED
@@ -13,40 +13,39 @@ One session = one summary. Previous summaries live in phase-log.md.
13
  2026-04-26
14
 
15
  ### Phase
16
- Phase 9Multi-Platform Reward Divergence
17
 
18
  ### What Was Done
19
- - Created platforms/__init__.py, platform_kb.json, platform_spec.py PlatformRegistry single source of truth for all 4 platforms
20
- - Updated rewards/r1_hook_strength.py — platform-aware hook scoring via PlatformRegistry; new length_fit check (6th check, 15% weight)
21
- - Updated rewards/r2_coherence.py — platform length penalty (max 0.3 cap) applied after semantic similarity score
22
- - Created rewards/r9_platform_pacing.py — PlatformPacingReward; 3 checks: pacing (40%), section ratio (40%), CTA position (20%)
23
- - Updated environment/observations.pyr9_platform_pacing in RewardComponents; updated _WEIGHTS to 9-reward spec
24
- - Updated rewards/reward_aggregator.py — r9_platform_pacing added to anti-gaming _COMPONENT_FIELDS
25
- - Updated environment/env.pyR9 wired in step(); _current_platform stored on reset(); platform passed to R1/R2
26
- - Updated curriculum JSONL files added Feed entries: easy (+2), medium (+3), hard (+4 cross-platform)
27
- - Updated demo/run_demo.py — Act 1 shows platform spec (hook window, max length, pacing); Act 5 shows R9 row
28
- - Created tests/test_phase9.py — 20 tests, all passing
29
- - Created scripts/run_dummy_episode.py — LLM-stubbed gate check; Phase 9 GATE: PASS
30
- - Created scripts/run_platform_comparison.py — S03 scored on Reels/Shorts/Feed; all 3 rewards diverge; GATE: PASS
 
 
 
 
 
31
 
32
  ### What Was NOT Done (carry over)
33
  - Real GRPO training — requires GPU (Colab)
34
 
35
  ### Errors Encountered
36
- - test_short_hook_passes_length_fit_on_reels: hook was ~18 words (exceeded Reels 15-word limit) — fixed test script
37
- - test_penalty_capped_at_0_3: compared semantically different scripts (base sim=0) — fixed to use same-vocab scripts
38
- - test_same_script_scores_differently_on_reels_vs_feed: _SLOW_SCRIPT both pacing+ratio zeroed out on both platforms — switched to sub-score comparison
39
- - test_env_r9_fires_in_step: defender.defend() not patched → API call — patched defender with full MagicMock
40
- - run_dummy_episode.py: R5 needs core_strength_quote from defender mock — added all required fields
41
- - run_platform_comparison.py: Unicode bar chars fail on Windows cp1252 — switched to ASCII #/.
42
 
43
  ### Tests Status
44
- Phase 9: 20 passed
45
- Gate check (dummy episode): PASS
46
- Gate check (platform comparison S03): PASS — R1/R2/R9 all diverge across Reels/Shorts/Feed
47
 
48
  ### Commit Messages Generated
49
- feat(phase9): platform reward divergence — PlatformRegistry, R9 PlatformPacing, R1/R2 platform-aware, 20 tests PASS, gate PASS
50
 
51
  ---
52
 
 
13
  2026-04-26
14
 
15
  ### Phase
16
+ Phase 12Retention Curve Simulator
17
 
18
  ### What Was Done
19
+ - Created viral_script_engine/retention/__init__.py — package init
20
+ - Created viral_script_engine/retention/feature_extractor.py — ScriptFeatures pydantic model (14 features + platform one-hot); FeatureExtractor.extract() — zero LLM calls, structural analysis
21
+ - Created viral_script_engine/retention/training_data/__init__.py — package init
22
+ - Created viral_script_engine/retention/training_data/build_dataset.py — 150 rule-based samples (50 high/medium/low); monotonic curve generation from R1/R2/R3 scores
23
+ - Created viral_script_engine/retention/training_data/retention_dataset.json150 samples generated
24
+ - Created viral_script_engine/retention/curve_predictor.py — RetentionCurvePredictor (MultiOutputRegressor+GBR); RetentionCurve model with AUC + drop-off; train/predict; monotonic enforcement
25
+ - Created viral_script_engine/retention/model.joblibtrained model, avg MAE 0.031
26
+ - Created viral_script_engine/retention/curve_scorer.pyRetentionCurveScorer; ACTION_CURVE_MAP; overall+targeted+regression formula
27
+ - Created viral_script_engine/rewards/r10_retention_curve.py — RetentionCurveReward; episode-level original curve caching
28
+ - Updated viral_script_engine/environment/observations.py — r10_retention_curve field; updated _WEIGHTS to 10-reward spec
29
+ - Updated viral_script_engine/rewards/reward_aggregator.py — r10_retention_curve in anti-gaming _COMPONENT_FIELDS
30
+ - Updated viral_script_engine/environment/env.py — R10 wired in __init__() and step(); graceful skip if model not trained
31
+ - Created scripts/train_retention_model.py — one-time training script; builds dataset if missing; prints MAE
32
+ - Updated demo/run_demo.py — _render_retention_ascii(); _show_retention_curves() ASCII panel in Act 5; R10 row in reward table
33
+ - Updated scripts/run_dummy_episode.py — R10 check in gate assertions; Phase 12 GATE message
34
+ - Created viral_script_engine/tests/test_phase12.py — 14 tests, all passing
35
+ - Phase 12 gate: PASS
36
 
37
  ### What Was NOT Done (carry over)
38
  - Real GRPO training — requires GPU (Colab)
39
 
40
  ### Errors Encountered
41
+ - None; all 14 tests passed on first run
 
 
 
 
 
42
 
43
  ### Tests Status
44
+ Phase 12: 14 passed
45
+ Gate check: PHASE 12 GATE: PASS — Retention curve predictor active. R10 firing.
 
46
 
47
  ### Commit Messages Generated
48
+ feat(phase12): RetentionCurveSimulator, R10, 150-sample dataset, model trained, 14 tests PASS, gate PASS
49
 
50
  ---
51
 
viral_script_engine/environment/env.py CHANGED
@@ -30,6 +30,7 @@ from viral_script_engine.rewards.r9_platform_pacing import PlatformPacingReward
30
  from viral_script_engine.platforms.platform_spec import PlatformRegistry
31
  from viral_script_engine.memory.memory_compressor import MemoryCompressor
32
  from viral_script_engine.memory.history_store import HistoryStore
 
33
 
34
  _TIERS = {
35
  "easy": ["S01", "S02", "S03", "S04"],
@@ -85,6 +86,7 @@ class ViralScriptEnv:
85
  self.platform_registry = PlatformRegistry()
86
  self.memory_compressor = MemoryCompressor()
87
  self.history_store = HistoryStore()
 
88
  self._state: Optional[EpisodeState] = None
89
  self._current_profile: Optional[CreatorProfile] = None
90
  self._current_platform: str = "Reels"
@@ -282,6 +284,22 @@ class ViralScriptEnv:
282
  # Phase 9: compute R9 platform pacing
283
  r9_result = self.r9.score(new_script, platform=self._current_platform)
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  components = RewardComponents(
286
  r1_hook_strength=r1_result.score,
287
  r2_coherence=r2_result.score,
@@ -292,6 +310,7 @@ class ViralScriptEnv:
292
  r7_originality=r7_result.score,
293
  r8_persona_fit=r8_score,
294
  r9_platform_pacing=r9_result.score,
 
295
  process_reward=process_result.weighted_contribution if process_result else None,
296
  )
297
 
 
30
  from viral_script_engine.platforms.platform_spec import PlatformRegistry
31
  from viral_script_engine.memory.memory_compressor import MemoryCompressor
32
  from viral_script_engine.memory.history_store import HistoryStore
33
+ from viral_script_engine.rewards.r10_retention_curve import RetentionCurveReward
34
 
35
  _TIERS = {
36
  "easy": ["S01", "S02", "S03", "S04"],
 
86
  self.platform_registry = PlatformRegistry()
87
  self.memory_compressor = MemoryCompressor()
88
  self.history_store = HistoryStore()
89
+ self.r10 = RetentionCurveReward(cultural_kb_path=cultural_kb_path)
90
  self._state: Optional[EpisodeState] = None
91
  self._current_profile: Optional[CreatorProfile] = None
92
  self._current_platform: str = "Reels"
 
284
  # Phase 9: compute R9 platform pacing
285
  r9_result = self.r9.score(new_script, platform=self._current_platform)
286
 
287
+ # Phase 12: compute R10 retention curve reward
288
+ r10_score = None
289
+ if self.r10.predictor._trained:
290
+ try:
291
+ r10_result = self.r10.score(
292
+ original_script=self._state.original_script,
293
+ rewritten_script=new_script,
294
+ platform=self._current_platform,
295
+ region=self._state.region,
296
+ action_type=str(arb_action.action_type.value),
297
+ episode_id=self._state.episode_id,
298
+ )
299
+ r10_score = r10_result.score
300
+ except Exception:
301
+ r10_score = None
302
+
303
  components = RewardComponents(
304
  r1_hook_strength=r1_result.score,
305
  r2_coherence=r2_result.score,
 
310
  r7_originality=r7_result.score,
311
  r8_persona_fit=r8_score,
312
  r9_platform_pacing=r9_result.score,
313
+ r10_retention_curve=r10_score,
314
  process_reward=process_result.weighted_contribution if process_result else None,
315
  )
316
 
viral_script_engine/environment/observations.py CHANGED
@@ -6,9 +6,10 @@ from viral_script_engine.agents.critic import CritiqueClaim
6
  from viral_script_engine.environment.actions import ArbitratorAction
7
 
8
  _WEIGHTS: Dict[str, float] = {
9
- "r1": 0.15, "r2": 0.12, "r3": 0.10,
10
- "r4": 0.10, "r5": 0.10, "r6": 0.08,
11
- "r7": 0.08, "r8": 0.08, "r9": 0.09,
 
12
  }
13
 
14
 
@@ -21,8 +22,9 @@ class RewardComponents(BaseModel):
21
  r6_safety: Optional[float] = None
22
  r7_originality: Optional[float] = None
23
  r8_persona_fit: Optional[float] = None # Phase 8: creator persona fit
24
- r9_platform_pacing: Optional[float] = None # Phase 9: platform pacing fit
25
- process_reward: Optional[float] = None # fired before rewrite (Phase 7)
 
26
  anti_gaming_penalty: float = 0.0
27
  total: float = 0.0
28
 
@@ -37,6 +39,7 @@ class RewardComponents(BaseModel):
37
  "r7": self.r7_originality,
38
  "r8": self.r8_persona_fit,
39
  "r9": self.r9_platform_pacing,
 
40
  }
41
  active = {k: v for k, v in vals.items() if v is not None}
42
  if not active:
 
6
  from viral_script_engine.environment.actions import ArbitratorAction
7
 
8
  _WEIGHTS: Dict[str, float] = {
9
+ "r1": 0.12, "r2": 0.10, "r3": 0.10,
10
+ "r4": 0.10, "r5": 0.08, "r6": 0.07,
11
+ "r7": 0.07, "r8": 0.08, "r9": 0.08,
12
+ "r10": 0.10,
13
  }
14
 
15
 
 
22
  r6_safety: Optional[float] = None
23
  r7_originality: Optional[float] = None
24
  r8_persona_fit: Optional[float] = None # Phase 8: creator persona fit
25
+ r9_platform_pacing: Optional[float] = None # Phase 9: platform pacing fit
26
+ r10_retention_curve: Optional[float] = None # Phase 12: retention curve reward
27
+ process_reward: Optional[float] = None # fired before rewrite (Phase 7)
28
  anti_gaming_penalty: float = 0.0
29
  total: float = 0.0
30
 
 
39
  "r7": self.r7_originality,
40
  "r8": self.r8_persona_fit,
41
  "r9": self.r9_platform_pacing,
42
+ "r10": self.r10_retention_curve,
43
  }
44
  active = {k: v for k, v in vals.items() if v is not None}
45
  if not active:
viral_script_engine/retention/__init__.py ADDED
File without changes
viral_script_engine/retention/curve_predictor.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import List, Optional
4
+
5
+ import numpy as np
6
+ from pydantic import BaseModel
7
+ from sklearn.ensemble import GradientBoostingRegressor
8
+ from sklearn.multioutput import MultiOutputRegressor
9
+ import joblib
10
+
11
+ from viral_script_engine.retention.feature_extractor import FeatureExtractor, ScriptFeatures
12
+
13
+ _MODEL_PATH = Path(__file__).parent / "model.joblib"
14
+
15
+ CURVE_TIMEPOINTS = [0, 3, 6, 10, 15, 20, 25, 30, 45, 60]
16
+
17
+
18
+ class RetentionCurve(BaseModel):
19
+ timepoints: List[int]
20
+ values: List[float]
21
+ area_under_curve: float
22
+ drop_off_point: int # first timepoint where retention drops below 0.5
23
+
24
+ @classmethod
25
+ def from_values(cls, values: List[float]) -> "RetentionCurve":
26
+ tps = CURVE_TIMEPOINTS
27
+ # Trapezoidal AUC, normalised to [0, 1]
28
+ auc = 0.0
29
+ for i in range(len(tps) - 1):
30
+ dt = tps[i + 1] - tps[i]
31
+ auc += dt * (values[i] + values[i + 1]) / 2
32
+ total_duration = tps[-1] - tps[0]
33
+ auc = auc / total_duration if total_duration > 0 else 0.0
34
+
35
+ drop_off = tps[-1]
36
+ for t, v in zip(tps, values):
37
+ if v < 0.5:
38
+ drop_off = t
39
+ break
40
+
41
+ return cls(
42
+ timepoints=list(tps),
43
+ values=[round(v, 4) for v in values],
44
+ area_under_curve=round(auc, 4),
45
+ drop_off_point=drop_off,
46
+ )
47
+
48
+
49
+ class RetentionCurvePredictor:
50
+ """
51
+ Predicts a 10-point retention curve from script features.
52
+ 10 points = retention at seconds [0, 3, 6, 10, 15, 20, 25, 30, 45, 60].
53
+
54
+ Uses MultiOutputRegressor(GradientBoostingRegressor).
55
+ Lightweight enough to run on CPU without GPU (<1ms per call after training).
56
+ """
57
+
58
+ MODEL_PATH = _MODEL_PATH
59
+ CURVE_TIMEPOINTS = CURVE_TIMEPOINTS
60
+
61
+ def __init__(self):
62
+ if _MODEL_PATH.exists():
63
+ self.model = joblib.load(_MODEL_PATH)
64
+ self._trained = True
65
+ else:
66
+ self.model = MultiOutputRegressor(
67
+ GradientBoostingRegressor(n_estimators=100, max_depth=4, random_state=42)
68
+ )
69
+ self._trained = False
70
+
71
+ def train(
72
+ self,
73
+ dataset_path: Optional[str] = None,
74
+ cultural_kb_path: Optional[str] = None,
75
+ ) -> dict:
76
+ """
77
+ Train on retention_dataset.json. Saves model to MODEL_PATH.
78
+ Returns dict with avg_mae and mae_per_timepoint.
79
+ """
80
+ if dataset_path is None:
81
+ dataset_path = str(
82
+ Path(__file__).parent / "training_data" / "retention_dataset.json"
83
+ )
84
+
85
+ with open(dataset_path, "r", encoding="utf-8") as f:
86
+ data = json.load(f)
87
+
88
+ extractor = FeatureExtractor(cultural_kb_path=cultural_kb_path)
89
+
90
+ X: List[List[float]] = []
91
+ y: List[List[float]] = []
92
+ skipped = 0
93
+ for sample in data["samples"]:
94
+ try:
95
+ features = extractor.extract(
96
+ sample["script_text"], sample["platform"], sample["region"]
97
+ )
98
+ vec = features.to_vector()
99
+ if any(v != v for v in vec): # NaN check
100
+ skipped += 1
101
+ continue
102
+ X.append(vec)
103
+ y.append(sample["retention_curve"])
104
+ except Exception:
105
+ skipped += 1
106
+
107
+ if not X:
108
+ raise RuntimeError("No valid training samples extracted.")
109
+
110
+ X_arr = np.array(X, dtype=float)
111
+ y_arr = np.array(y, dtype=float)
112
+
113
+ n = len(X_arr)
114
+ rng = np.random.RandomState(42)
115
+ idx = rng.permutation(n)
116
+ split = max(1, int(n * 0.8))
117
+ X_train, X_val = X_arr[idx[:split]], X_arr[idx[split:]]
118
+ y_train, y_val = y_arr[idx[:split]], y_arr[idx[split:]]
119
+
120
+ self.model.fit(X_train, y_train)
121
+ self._trained = True
122
+
123
+ val_preds = np.clip(self.model.predict(X_val), 0.0, 1.0)
124
+ mae_per_tp = np.mean(np.abs(val_preds - y_val), axis=0).tolist()
125
+ avg_mae = float(np.mean(mae_per_tp))
126
+
127
+ print(f" Trained on {len(X_train)} samples, validated on {len(X_val)} (skipped {skipped})")
128
+ print(" Train/Val MAE per timepoint:")
129
+ for t, mae in zip(CURVE_TIMEPOINTS, mae_per_tp):
130
+ print(f" {t:2d}s: {mae:.4f}")
131
+ print(f" Avg MAE: {avg_mae:.4f}")
132
+
133
+ joblib.dump(self.model, _MODEL_PATH)
134
+ print(f" Model saved to {_MODEL_PATH}")
135
+
136
+ return {"avg_mae": avg_mae, "mae_per_timepoint": mae_per_tp}
137
+
138
+ def predict(self, features: ScriptFeatures) -> RetentionCurve:
139
+ if not self._trained:
140
+ raise RuntimeError("Model not trained. Run train() first.")
141
+ vec = np.array(features.to_vector(), dtype=float).reshape(1, -1)
142
+ raw = self.model.predict(vec)[0]
143
+ clipped = np.clip(raw, 0.0, 1.0)
144
+ values = self._enforce_monotonic_decrease(clipped).tolist()
145
+ return RetentionCurve.from_values(values)
146
+
147
+ @staticmethod
148
+ def _enforce_monotonic_decrease(values: np.ndarray) -> np.ndarray:
149
+ result = values.copy()
150
+ for i in range(1, len(result)):
151
+ result[i] = min(result[i], result[i - 1])
152
+ return result
viral_script_engine/retention/curve_scorer.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from viral_script_engine.retention.curve_predictor import CURVE_TIMEPOINTS, RetentionCurve
6
+
7
+ _TP_INDEX = {t: i for i, t in enumerate(CURVE_TIMEPOINTS)}
8
+
9
+
10
+ class CurveScorerResult(BaseModel):
11
+ final_score: float
12
+ overall_improvement: float
13
+ targeted_improvement: float
14
+ regression_penalty: float
15
+ improved_timepoints: List[int]
16
+ worsened_timepoints: List[int]
17
+
18
+
19
+ class RetentionCurveScorer:
20
+ """
21
+ Scores improvement between two retention curves.
22
+
23
+ Rewards targeted improvements at action-relevant timepoints:
24
+ - hook_rewrite → early timepoints (0–6s)
25
+ - section_reorder → mid timepoints (10–20s)
26
+ - cultural_ref_sub → mid-to-late (15–30s)
27
+ - cta_placement → late timepoints (45–60s)
28
+
29
+ Formula:
30
+ final = 0.50 * overall_improvement
31
+ + 0.35 * targeted_improvement
32
+ - 0.15 * regression_penalty
33
+ clipped to [0, 1]
34
+ """
35
+
36
+ ACTION_CURVE_MAP = {
37
+ "hook_rewrite": [0, 3, 6],
38
+ "section_reorder": [10, 15, 20],
39
+ "cultural_ref_sub": [15, 20, 25, 30],
40
+ "cta_placement": [45, 60],
41
+ }
42
+
43
+ def score(
44
+ self,
45
+ original_curve: RetentionCurve,
46
+ new_curve: RetentionCurve,
47
+ action_type: str,
48
+ ) -> CurveScorerResult:
49
+ orig_auc = original_curve.area_under_curve
50
+ new_auc = new_curve.area_under_curve
51
+
52
+ # 1. Overall AUC improvement (relative)
53
+ if orig_auc > 0:
54
+ overall_improvement = (new_auc - orig_auc) / orig_auc
55
+ else:
56
+ overall_improvement = float(new_auc)
57
+ overall_improvement = max(-1.0, min(1.0, overall_improvement))
58
+
59
+ # 2. Targeted improvement at action-relevant timepoints
60
+ target_tps = self.ACTION_CURVE_MAP.get(str(action_type), CURVE_TIMEPOINTS)
61
+ targeted_deltas: List[float] = []
62
+ for tp in target_tps:
63
+ i = _TP_INDEX.get(tp)
64
+ if i is not None and i < len(original_curve.values) and i < len(new_curve.values):
65
+ targeted_deltas.append(new_curve.values[i] - original_curve.values[i])
66
+
67
+ if targeted_deltas:
68
+ targeted_improvement = float(sum(targeted_deltas) / len(targeted_deltas))
69
+ else:
70
+ targeted_improvement = 0.0
71
+ targeted_improvement = max(-1.0, min(1.0, targeted_improvement))
72
+
73
+ # 3. Regression penalty: any timepoint that degraded
74
+ improved: List[int] = []
75
+ worsened: List[int] = []
76
+ worsened_magnitudes: List[float] = []
77
+
78
+ for tp, i in _TP_INDEX.items():
79
+ if i >= len(original_curve.values) or i >= len(new_curve.values):
80
+ continue
81
+ delta = new_curve.values[i] - original_curve.values[i]
82
+ if delta > 0.001:
83
+ improved.append(tp)
84
+ elif delta < -0.001:
85
+ worsened.append(tp)
86
+ worsened_magnitudes.append(abs(delta))
87
+
88
+ regression_penalty = 0.0
89
+ if worsened_magnitudes:
90
+ regression_penalty = min(1.0, sum(worsened_magnitudes) / len(CURVE_TIMEPOINTS))
91
+
92
+ final_score = (
93
+ 0.50 * max(0.0, overall_improvement)
94
+ + 0.35 * max(0.0, targeted_improvement)
95
+ - 0.15 * regression_penalty
96
+ )
97
+ final_score = max(0.0, min(1.0, final_score))
98
+
99
+ return CurveScorerResult(
100
+ final_score=round(final_score, 4),
101
+ overall_improvement=round(overall_improvement, 4),
102
+ targeted_improvement=round(targeted_improvement, 4),
103
+ regression_penalty=round(regression_penalty, 4),
104
+ improved_timepoints=improved,
105
+ worsened_timepoints=worsened,
106
+ )
viral_script_engine/retention/feature_extractor.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+ from pathlib import Path
4
+ from typing import List, Optional
5
+
6
+ from pydantic import BaseModel
7
+
8
+ from viral_script_engine.platforms.platform_spec import PlatformRegistry
9
+
10
+ _FILLER_PHRASES = [
11
+ "hey guys", "welcome back", "today i want to", "so today",
12
+ "in this video", "what's up everyone", "hey everyone",
13
+ "guys today", "hello everyone", "so basically", "you know",
14
+ "kind of", "sort of", "basically", "um ", "uh ",
15
+ ]
16
+
17
+ _COMMON_WORDS = {
18
+ 'i', 'the', 'a', 'an', 'my', 'your', 'its', 'it', 'is', 'are',
19
+ 'was', 'were', 'be', 'been', "i've", "i'm", "it's", "here's",
20
+ 'today', 'and', 'but', 'so', 'that', 'this', 'these', 'those',
21
+ 'to', 'of', 'in', 'for', 'on', 'with', 'at', 'by', 'from',
22
+ 'or', 'not', 'you', 'we', 'they', 'he', 'she', 'if', 'do',
23
+ 'get', 'just', 'up', 'out', 'about', 'what', 'all', 'some',
24
+ }
25
+
26
+ _PROMISE_PATTERNS = [
27
+ r'\d',
28
+ r'\bhow to\b',
29
+ r'\bwhy\b',
30
+ r'\bwhat happens when\b',
31
+ r'\bi made\b',
32
+ r'\bwill\b',
33
+ r'\bguaranteed\b',
34
+ r'\bstep\b',
35
+ r'\btips?\b',
36
+ r'\bsecrets?\b',
37
+ r'\bprove[sd]?\b',
38
+ r'\bhere\'?s\b',
39
+ ]
40
+
41
+ _KNOWN_PLATFORMS = ["Reels", "Shorts", "Feed", "TikTok"]
42
+
43
+
44
+ class ScriptFeatures(BaseModel):
45
+ # Hook features (predicts early drop-off 0–5s)
46
+ hook_word_count: int
47
+ hook_has_number: bool
48
+ hook_has_question: bool
49
+ hook_has_promise: bool
50
+ hook_filler_score: float # 0=no filler, 1=all filler
51
+
52
+ # Pacing features (predicts mid-video retention 5–30s)
53
+ avg_words_per_sentence: float
54
+ sentence_count: int
55
+ short_sentence_ratio: float # sentences < 8 words / total sentences
56
+ section_balance_score: float # how evenly hook:body:cta matches platform spec
57
+
58
+ # Content features (predicts late retention 30s+)
59
+ specificity_score: float # ratio of specific nouns/numbers to total words
60
+ cultural_ref_count: int
61
+ cta_position_ratio: float # word offset of CTA start / total words
62
+
63
+ # Platform fit features
64
+ platform: str
65
+ word_count: int
66
+ length_vs_optimal: float # word_count / optimal_script_length for platform
67
+
68
+ def to_vector(self) -> List[float]:
69
+ platform_one_hot = [1.0 if self.platform == p else 0.0 for p in _KNOWN_PLATFORMS]
70
+ return [
71
+ float(self.hook_word_count),
72
+ 1.0 if self.hook_has_number else 0.0,
73
+ 1.0 if self.hook_has_question else 0.0,
74
+ 1.0 if self.hook_has_promise else 0.0,
75
+ float(self.hook_filler_score),
76
+ float(self.avg_words_per_sentence),
77
+ float(self.sentence_count),
78
+ float(self.short_sentence_ratio),
79
+ float(self.section_balance_score),
80
+ float(self.specificity_score),
81
+ float(self.cultural_ref_count),
82
+ float(self.cta_position_ratio),
83
+ float(self.word_count),
84
+ float(self.length_vs_optimal),
85
+ ] + platform_one_hot
86
+
87
+
88
+ class FeatureExtractor:
89
+ def __init__(self, cultural_kb_path: Optional[str] = None):
90
+ self.platform_registry = PlatformRegistry()
91
+ self._cultural_kb_path = cultural_kb_path
92
+ self._cultural_kb: Optional[dict] = None
93
+
94
+ def _load_kb(self) -> None:
95
+ if self._cultural_kb is not None:
96
+ return
97
+ kb_path = self._cultural_kb_path or str(
98
+ Path(__file__).parent.parent / "data" / "cultural_kb.json"
99
+ )
100
+ with open(kb_path, "r", encoding="utf-8") as f:
101
+ self._cultural_kb = json.load(f)
102
+
103
+ def extract(self, script: str, platform: str, region: str) -> ScriptFeatures:
104
+ self._load_kb()
105
+ spec = self.platform_registry.get(platform)
106
+ sentences = [s for s in re.split(r'(?<=[.!?])\s+', script.strip()) if s.strip()]
107
+ if not sentences:
108
+ sentences = [script]
109
+ total_words = len(script.split())
110
+
111
+ # --- Hook: first ~20% of sentences (min 1, max 3) ---
112
+ n = len(sentences)
113
+ hook_end = max(1, min(3, int(n * 0.2))) if n >= 5 else max(1, min(2, n))
114
+ hook_text = " ".join(sentences[:hook_end])
115
+ hook_lower = hook_text.lower()
116
+ hook_words = hook_text.split()
117
+
118
+ hook_word_count = len(hook_words)
119
+ hook_has_number = bool(re.search(r'\d', hook_text))
120
+ hook_has_question = '?' in hook_text
121
+ hook_has_promise = any(re.search(p, hook_lower) for p in _PROMISE_PATTERNS)
122
+
123
+ filler_hits = sum(1 for phrase in _FILLER_PHRASES if phrase in hook_lower)
124
+ hook_filler_score = min(1.0, filler_hits / max(hook_word_count, 1) * 4)
125
+
126
+ # --- Pacing ---
127
+ sentence_count = n
128
+ words_per_sent = [len(s.split()) for s in sentences]
129
+ avg_words_per_sentence = sum(words_per_sent) / max(n, 1)
130
+ short_sentence_ratio = sum(1 for w in words_per_sent if w < 8) / max(n, 1)
131
+
132
+ # Section balance: compare actual word distribution to platform spec
133
+ cta_start_idx = max(hook_end + 1, n - max(1, int(n * 0.1)))
134
+ hook_w = sum(len(s.split()) for s in sentences[:hook_end])
135
+ body_w = sum(len(s.split()) for s in sentences[hook_end:cta_start_idx])
136
+ cta_w = sum(len(s.split()) for s in sentences[cta_start_idx:])
137
+ total_w = max(hook_w + body_w + cta_w, 1)
138
+
139
+ opt = spec.optimal_sentences_per_section
140
+ opt_total = max(sum(opt.values()), 1)
141
+ opt_hook_r = opt.get("hook", 2) / opt_total
142
+ opt_body_r = opt.get("body", 6) / opt_total
143
+ act_hook_r = hook_w / total_w
144
+ act_body_r = body_w / total_w
145
+ balance_dev = (abs(act_hook_r - opt_hook_r) + abs(act_body_r - opt_body_r)) / 2
146
+ section_balance_score = max(0.0, 1.0 - balance_dev * 4)
147
+
148
+ # --- Content features ---
149
+ words = script.split()
150
+ specific_count = sum(
151
+ 1 for w in words
152
+ if (
153
+ re.search(r'\d', w)
154
+ or (len(w) > 1 and w[0].isupper() and w.lower().strip('.,!?;:\'"') not in _COMMON_WORDS)
155
+ )
156
+ )
157
+ specificity_score = min(1.0, specific_count / max(total_words, 1))
158
+
159
+ cultural_ref_count = 0
160
+ if self._cultural_kb and region in self._cultural_kb:
161
+ kb = self._cultural_kb[region]
162
+ script_lower = script.lower()
163
+ cultural_ref_count = (
164
+ sum(1 for r in kb.get("valid_refs", []) if r.lower() in script_lower)
165
+ + sum(1 for i in kb.get("correct_idioms", []) if i.lower() in script_lower)
166
+ )
167
+
168
+ cta_word_offset = hook_w + body_w
169
+ cta_position_ratio = cta_word_offset / max(total_words, 1)
170
+
171
+ # --- Platform fit ---
172
+ length_vs_optimal = total_words / max(spec.optimal_script_length_words, 1)
173
+
174
+ return ScriptFeatures(
175
+ hook_word_count=hook_word_count,
176
+ hook_has_number=hook_has_number,
177
+ hook_has_question=hook_has_question,
178
+ hook_has_promise=hook_has_promise,
179
+ hook_filler_score=round(hook_filler_score, 4),
180
+ avg_words_per_sentence=round(avg_words_per_sentence, 4),
181
+ sentence_count=sentence_count,
182
+ short_sentence_ratio=round(short_sentence_ratio, 4),
183
+ section_balance_score=round(section_balance_score, 4),
184
+ specificity_score=round(specificity_score, 4),
185
+ cultural_ref_count=cultural_ref_count,
186
+ cta_position_ratio=round(cta_position_ratio, 4),
187
+ platform=platform,
188
+ word_count=total_words,
189
+ length_vs_optimal=round(length_vs_optimal, 4),
190
+ )
viral_script_engine/retention/model.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4454048ccec161cf3db43cc3b02a2767b261e8f3df23cdab2a6904ec47eb2e00
3
+ size 1876741
viral_script_engine/retention/training_data/__init__.py ADDED
File without changes
viral_script_engine/retention/training_data/build_dataset.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Builds retention_dataset.json from rule-based simulation.
3
+
4
+ Encoding known relationships between script quality scores and viewer retention:
5
+ - Hook quality (R1) predicts early drop-off at seconds 0–6
6
+ - Coherence (R2) predicts mid-video retention at seconds 6–20
7
+ - Cultural alignment (R3) predicts late retention at seconds 20–60
8
+
9
+ Dataset format:
10
+ {
11
+ "samples": [
12
+ {
13
+ "script_id": "train_001",
14
+ "script_text": "...",
15
+ "platform": "Reels",
16
+ "region": "Mumbai Gen Z",
17
+ "retention_curve": [1.0, 0.95, ...], # 10 values at seconds [0,3,6,10,15,20,25,30,45,60]
18
+ "curve_source": "rule_based",
19
+ "quality_tier": "high" | "medium" | "low"
20
+ }
21
+ ]
22
+ }
23
+ """
24
+ import json
25
+ import random
26
+ from pathlib import Path
27
+ from typing import List, Tuple
28
+
29
+ _TIMEPOINTS = [0, 3, 6, 10, 15, 20, 25, 30, 45, 60]
30
+ _OUTPUT_PATH = Path(__file__).parent / "retention_dataset.json"
31
+
32
+ _PLATFORMS = ["Reels", "Shorts", "Feed", "TikTok"]
33
+ _REGIONS = ["Mumbai Gen Z", "pan_india_english", "delhi_millennial", "bangalore_tech"]
34
+
35
+ _HIGH_SCRIPTS = [
36
+ "Did you know {pct}% of people get this wrong? Here's what actually works. Stop doing what everyone tells you. Use this one simple method instead. The results will genuinely surprise you. Comment 'yes' if you want the full breakdown.",
37
+ "I made {amt}k in 30 days using this strategy. Nobody in your feed is talking about this. Here's exactly what I did step by step. You can start tonight with zero investment. Follow for the full guide.",
38
+ "Your phone is lying to you about money. Here's the truth about compound interest that banks don't want you to know. Start with just $50. Watch what happens after 12 months. This changed everything for me.",
39
+ "Stop scrolling. This is the {amt}-second trick that saved me {pct}% on every bill. I tested it for 3 months. Here's the proof. Save this before it gets taken down.",
40
+ "Why do {pct}% of people fail at saving money? I spent 6 months finding out. The answer surprised me. It has nothing to do with income. Watch till the end for the fix.",
41
+ "The {amt} investing mistake I made at 22 cost me {pct}k. Here's what I wish someone told me. Three rules that actually work. No BS, no courses to sell. Just what changed my life.",
42
+ "How to pay off debt {pct}% faster using the avalanche method. Most people use the wrong strategy. This is the math-backed approach. Takes 5 minutes to set up. Start today.",
43
+ "This bank trick gives you {pct}% more interest — your bank doesn't advertise it. Took me {amt} months to find it. Here's exactly how to set it up in under 2 minutes.",
44
+ ]
45
+
46
+ _MED_SCRIPTS = [
47
+ "So today I want to talk about something that I think is really important for a lot of people. Financial planning is something that many people overlook. You should really try to save money regularly if you can. It makes a big difference over time when you think about it.",
48
+ "Hey everyone, welcome back to my channel. Today I'm sharing some tips about managing your finances better. These tips have helped me personally and I hope they help you too. Let me know in the comments what you think about them.",
49
+ "Saving money is actually not that hard once you get into the habit. The main thing is consistency. Try to set aside a fixed amount each month. Over time it really does add up significantly. There are several ways you can approach this.",
50
+ "I've been thinking a lot about financial health lately. It's something that affects everyone. The basics are pretty simple when you break them down. Budget, save, invest — in that order. Most people skip the middle step which is a mistake.",
51
+ "Money management is a skill that anyone can learn. It takes time and practice but it's worth it. Start by tracking your spending for one month. Then identify areas where you can cut back. After that you can start building your savings.",
52
+ ]
53
+
54
+ _LOW_SCRIPTS = [
55
+ "Hello guys welcome back um so today basically I wanted to kind of talk about you know like finances and stuff. So basically what I mean is um you should save more money I guess. That's kind of the main point I think. Um yeah so basically just try to do that.",
56
+ "Hey everyone so basically today's video is about money and financial things. I mean you know like it's really important and stuff like that. So yeah basically just save money I guess. Um anyway thanks for watching and stuff.",
57
+ "So um welcome back to my channel. Today I kind of want to sort of discuss um financial things you know. Like basically everyone knows they should save money right. Um so yeah that's basically it I think. Like just try to be better with money or whatever.",
58
+ "Hey guys so um today we're going to talk about kind of like money and finances and all that stuff. So basically um the thing is you know it's pretty important I think. Like I don't know just try to save more I guess. Um yeah so basically that's the main thing.",
59
+ "Welcome back everyone so today basically I wanted to kind of share some thoughts on um financial stuff. Like you know it's important and everything. So basically just try to you know manage your money better or something like that. Um yeah I hope that helps.",
60
+ ]
61
+
62
+
63
+ def _pick_script(quality: str) -> str:
64
+ amt = random.randint(10, 99)
65
+ pct = random.randint(60, 98)
66
+ if quality == "high":
67
+ template = random.choice(_HIGH_SCRIPTS)
68
+ elif quality == "medium":
69
+ template = random.choice(_MED_SCRIPTS)
70
+ else:
71
+ template = random.choice(_LOW_SCRIPTS)
72
+ return template.format(amt=amt, pct=pct)
73
+
74
+
75
+ def _quality_to_scores(quality: str) -> Tuple[float, float, float]:
76
+ if quality == "high":
77
+ r1 = random.uniform(0.78, 1.0)
78
+ r2 = random.uniform(0.72, 1.0)
79
+ r3 = random.uniform(0.68, 1.0)
80
+ elif quality == "medium":
81
+ r1 = random.uniform(0.38, 0.65)
82
+ r2 = random.uniform(0.38, 0.65)
83
+ r3 = random.uniform(0.38, 0.65)
84
+ else:
85
+ r1 = random.uniform(0.05, 0.25)
86
+ r2 = random.uniform(0.08, 0.28)
87
+ r3 = random.uniform(0.08, 0.28)
88
+ return r1, r2, r3
89
+
90
+
91
+ def _generate_curve(r1: float, r2: float, r3: float) -> List[float]:
92
+ """
93
+ Rule-based retention curve at timepoints [0, 3, 6, 10, 15, 20, 25, 30, 45, 60].
94
+
95
+ Rules from phase spec:
96
+ s0 = 1.0
97
+ s3 = 1.0 - (0.4 * (1 - r1)) # hook quality predicts early drop
98
+ s10 = prev - (0.1 * (1 - r2)) # coherence predicts mid-video
99
+ s20 = prev - (0.15 * (1 - r3)) # cultural alignment predicts late
100
+ s60 = prev - 0.05 # natural decay always present
101
+ """
102
+ noise = lambda lo, hi: random.uniform(lo, hi)
103
+
104
+ s0 = 1.0
105
+ s3 = max(0.0, 1.0 - (0.4 * (1 - r1)) + noise(-0.02, 0.02))
106
+ s6 = max(0.0, s3 - noise(0.02, 0.06))
107
+ s10 = max(0.0, s6 - (0.1 * (1 - r2)) - noise(0.0, 0.03))
108
+ s15 = max(0.0, s10 - noise(0.03, 0.07))
109
+ s20 = max(0.0, s15 - (0.15 * (1 - r3)) - noise(0.0, 0.03))
110
+ s25 = max(0.0, s20 - noise(0.02, 0.05))
111
+ s30 = max(0.0, s25 - noise(0.02, 0.05))
112
+ s45 = max(0.0, s30 - 0.05 - noise(0.0, 0.04))
113
+ s60 = max(0.0, s45 - 0.05)
114
+
115
+ # enforce monotonic decrease
116
+ curve = [s0, s3, s6, s10, s15, s20, s25, s30, s45, s60]
117
+ for i in range(1, len(curve)):
118
+ curve[i] = min(curve[i], curve[i - 1])
119
+ return [round(v, 3) for v in curve]
120
+
121
+
122
+ def build(output_path: str = None, seed: int = 42) -> str:
123
+ """Build and save the dataset. Returns the path written."""
124
+ random.seed(seed)
125
+ path = Path(output_path) if output_path else _OUTPUT_PATH
126
+ path.parent.mkdir(parents=True, exist_ok=True)
127
+
128
+ samples = []
129
+ idx = 1
130
+ for quality, count in [("high", 50), ("medium", 50), ("low", 50)]:
131
+ for _ in range(count):
132
+ platform = random.choice(_PLATFORMS)
133
+ region = random.choice(_REGIONS)
134
+ r1, r2, r3 = _quality_to_scores(quality)
135
+ curve = _generate_curve(r1, r2, r3)
136
+ samples.append({
137
+ "script_id": f"train_{idx:03d}",
138
+ "script_text": _pick_script(quality),
139
+ "platform": platform,
140
+ "region": region,
141
+ "retention_curve": curve,
142
+ "curve_source": "rule_based",
143
+ "quality_tier": quality,
144
+ "r1_score": round(r1, 3),
145
+ "r2_score": round(r2, 3),
146
+ "r3_score": round(r3, 3),
147
+ })
148
+ idx += 1
149
+
150
+ with open(path, "w", encoding="utf-8") as f:
151
+ json.dump({"samples": samples}, f, indent=2)
152
+
153
+ return str(path)
154
+
155
+
156
+ if __name__ == "__main__":
157
+ out = build()
158
+ print(f"Dataset built: {out} (150 samples)")
viral_script_engine/retention/training_data/retention_dataset.json ADDED
The diff for this file is too large to render. See raw diff
 
viral_script_engine/rewards/r10_retention_curve.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from viral_script_engine.retention.curve_predictor import RetentionCurve, RetentionCurvePredictor
6
+ from viral_script_engine.retention.curve_scorer import CurveScorerResult, RetentionCurveScorer
7
+ from viral_script_engine.retention.feature_extractor import FeatureExtractor
8
+
9
+
10
+ class RetentionRewardResult(BaseModel):
11
+ score: float
12
+ original_curve: RetentionCurve
13
+ new_curve: RetentionCurve
14
+ curve_delta: CurveScorerResult
15
+
16
+
17
+ class RetentionCurveReward:
18
+ """
19
+ Wraps the full retention prediction + scoring pipeline into a reward signal.
20
+
21
+ Caches the original curve per episode so the extractor is called only once
22
+ for the original script — subsequent steps reuse the cached curve.
23
+ """
24
+
25
+ def __init__(self, cultural_kb_path: Optional[str] = None):
26
+ self.extractor = FeatureExtractor(cultural_kb_path=cultural_kb_path)
27
+ self.predictor = RetentionCurvePredictor()
28
+ self.scorer = RetentionCurveScorer()
29
+ self._original_curve_cache: dict = {}
30
+
31
+ def score(
32
+ self,
33
+ original_script: str,
34
+ rewritten_script: str,
35
+ platform: str,
36
+ region: str,
37
+ action_type: str,
38
+ episode_id: str,
39
+ ) -> RetentionRewardResult:
40
+ # Cache original curve — compute only once per episode
41
+ if episode_id not in self._original_curve_cache:
42
+ orig_features = self.extractor.extract(original_script, platform, region)
43
+ self._original_curve_cache[episode_id] = self.predictor.predict(orig_features)
44
+
45
+ new_features = self.extractor.extract(rewritten_script, platform, region)
46
+ new_curve = self.predictor.predict(new_features)
47
+
48
+ result = self.scorer.score(
49
+ original_curve=self._original_curve_cache[episode_id],
50
+ new_curve=new_curve,
51
+ action_type=action_type,
52
+ )
53
+
54
+ return RetentionRewardResult(
55
+ score=result.final_score,
56
+ original_curve=self._original_curve_cache[episode_id],
57
+ new_curve=new_curve,
58
+ curve_delta=result,
59
+ )
60
+
61
+ def clear_cache(self, episode_id: Optional[str] = None) -> None:
62
+ if episode_id:
63
+ self._original_curve_cache.pop(episode_id, None)
64
+ else:
65
+ self._original_curve_cache.clear()
viral_script_engine/rewards/reward_aggregator.py CHANGED
@@ -12,6 +12,7 @@ _COMPONENT_FIELDS = [
12
  "r1_hook_strength", "r2_coherence", "r3_cultural_alignment",
13
  "r4_debate_resolution", "r5_defender_preservation",
14
  "r6_safety", "r7_originality", "r8_persona_fit", "r9_platform_pacing",
 
15
  ]
16
 
17
  _DROP_THRESHOLD = 0.25
 
12
  "r1_hook_strength", "r2_coherence", "r3_cultural_alignment",
13
  "r4_debate_resolution", "r5_defender_preservation",
14
  "r6_safety", "r7_originality", "r8_persona_fit", "r9_platform_pacing",
15
+ "r10_retention_curve",
16
  ]
17
 
18
  _DROP_THRESHOLD = 0.25
viral_script_engine/tests/test_phase12.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Phase 12 tests — Retention Curve Simulator."""
2
+ import json
3
+ import sys
4
+ import tempfile
5
+ from pathlib import Path
6
+ from unittest.mock import MagicMock, patch
7
+
8
+ import numpy as np
9
+ import pytest
10
+
11
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
12
+
13
+ from viral_script_engine.retention.feature_extractor import (
14
+ FeatureExtractor,
15
+ ScriptFeatures,
16
+ _KNOWN_PLATFORMS,
17
+ )
18
+ from viral_script_engine.retention.curve_predictor import (
19
+ RetentionCurve,
20
+ RetentionCurvePredictor,
21
+ CURVE_TIMEPOINTS,
22
+ )
23
+ from viral_script_engine.retention.curve_scorer import RetentionCurveScorer
24
+ from viral_script_engine.rewards.r10_retention_curve import RetentionCurveReward
25
+
26
+ _SCRIPTS_PATH = str(
27
+ Path(__file__).parent.parent / "data" / "test_scripts" / "scripts.json"
28
+ )
29
+ _CULTURAL_KB_PATH = str(
30
+ Path(__file__).parent.parent / "data" / "cultural_kb.json"
31
+ )
32
+
33
+ _GOOD_SCRIPT = (
34
+ "Did you know 80% of people get this wrong? Here's what actually works. "
35
+ "Stop doing what everyone tells you. Use this one simple method instead. "
36
+ "The results will surprise you. Follow for more."
37
+ )
38
+ _BAD_SCRIPT = (
39
+ "Hello guys welcome back um so today basically I wanted to kind of talk "
40
+ "about you know like finances and stuff. So basically just try to save money."
41
+ )
42
+
43
+
44
+ # ---------------------------------------------------------------------------
45
+ # FeatureExtractor tests
46
+ # ---------------------------------------------------------------------------
47
+
48
+ def test_feature_extractor_produces_correct_features():
49
+ extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
50
+ features = extractor.extract(_GOOD_SCRIPT, platform="Reels", region="pan_india_english")
51
+
52
+ assert isinstance(features, ScriptFeatures)
53
+ assert features.hook_word_count > 0
54
+ assert features.sentence_count > 0
55
+ assert features.word_count > 0
56
+ assert features.platform == "Reels"
57
+ assert features.hook_has_number is True # "80%"
58
+ assert features.hook_has_question is True # "?"
59
+
60
+
61
+ def test_feature_extractor_bad_script_has_high_filler():
62
+ extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
63
+ features = extractor.extract(_BAD_SCRIPT, platform="Reels", region="pan_india_english")
64
+
65
+ # Bad script should have higher filler score than good script
66
+ good_features = extractor.extract(_GOOD_SCRIPT, platform="Reels", region="pan_india_english")
67
+ assert features.hook_filler_score >= good_features.hook_filler_score
68
+
69
+
70
+ def test_to_vector_returns_flat_numeric_list():
71
+ extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
72
+ features = extractor.extract(_GOOD_SCRIPT, platform="Reels", region="pan_india_english")
73
+ vec = features.to_vector()
74
+
75
+ assert isinstance(vec, list)
76
+ assert len(vec) > 0
77
+ # No NaN values
78
+ for v in vec:
79
+ assert v == v, f"NaN found in vector: {vec}"
80
+ # All values are floats
81
+ for v in vec:
82
+ assert isinstance(v, (int, float))
83
+
84
+
85
+ def test_to_vector_platform_one_hot():
86
+ extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
87
+ for platform in _KNOWN_PLATFORMS:
88
+ features = extractor.extract(_GOOD_SCRIPT, platform=platform, region="pan_india_english")
89
+ vec = features.to_vector()
90
+ # Last N elements are one-hot platform encoding
91
+ platform_slice = vec[-len(_KNOWN_PLATFORMS):]
92
+ assert sum(platform_slice) == 1.0, f"One-hot sum should be 1 for {platform}"
93
+ assert max(platform_slice) == 1.0
94
+
95
+
96
+ def test_to_vector_no_nan_for_bad_script():
97
+ extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
98
+ features = extractor.extract(_BAD_SCRIPT, platform="TikTok", region="pan_india_english")
99
+ vec = features.to_vector()
100
+ for v in vec:
101
+ assert v == v, f"NaN found in vector"
102
+
103
+
104
+ # ---------------------------------------------------------------------------
105
+ # RetentionCurvePredictor tests
106
+ # ---------------------------------------------------------------------------
107
+
108
+ def test_predictor_raises_if_not_trained():
109
+ predictor = RetentionCurvePredictor.__new__(RetentionCurvePredictor)
110
+ predictor.model = None
111
+ predictor._trained = False
112
+ extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
113
+ features = extractor.extract(_GOOD_SCRIPT, platform="Reels", region="pan_india_english")
114
+
115
+ with pytest.raises(RuntimeError, match="not trained"):
116
+ predictor.predict(features)
117
+
118
+
119
+ def _make_trained_predictor() -> RetentionCurvePredictor:
120
+ """Train predictor on a minimal in-memory dataset."""
121
+ from sklearn.ensemble import GradientBoostingRegressor
122
+ from sklearn.multioutput import MultiOutputRegressor
123
+ import numpy as np
124
+
125
+ extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
126
+ scripts = [_GOOD_SCRIPT, _BAD_SCRIPT] * 10
127
+ platforms = ["Reels", "TikTok", "Shorts", "Feed"] * 5
128
+ X, y = [], []
129
+ for i, (sc, pl) in enumerate(zip(scripts, platforms)):
130
+ feat = extractor.extract(sc, platform=pl, region="pan_india_english")
131
+ X.append(feat.to_vector())
132
+ quality = 1.0 if sc == _GOOD_SCRIPT else 0.3
133
+ curve = [max(0.0, quality - j * 0.05) for j in range(len(CURVE_TIMEPOINTS))]
134
+ y.append(curve)
135
+
136
+ model = MultiOutputRegressor(
137
+ GradientBoostingRegressor(n_estimators=10, max_depth=2, random_state=42)
138
+ )
139
+ model.fit(np.array(X), np.array(y))
140
+
141
+ predictor = RetentionCurvePredictor.__new__(RetentionCurvePredictor)
142
+ predictor.model = model
143
+ predictor._trained = True
144
+ return predictor
145
+
146
+
147
+ def test_predicted_curve_is_monotonically_non_increasing():
148
+ predictor = _make_trained_predictor()
149
+ extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
150
+ features = extractor.extract(_GOOD_SCRIPT, platform="Reels", region="pan_india_english")
151
+ curve = predictor.predict(features)
152
+
153
+ for i in range(1, len(curve.values)):
154
+ assert curve.values[i] <= curve.values[i - 1] + 1e-9, (
155
+ f"Curve not monotonic at index {i}: {curve.values[i - 1]} -> {curve.values[i]}"
156
+ )
157
+
158
+
159
+ def test_predicted_curve_values_in_range():
160
+ predictor = _make_trained_predictor()
161
+ extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
162
+ features = extractor.extract(_BAD_SCRIPT, platform="TikTok", region="pan_india_english")
163
+ curve = predictor.predict(features)
164
+
165
+ for v in curve.values:
166
+ assert 0.0 <= v <= 1.0, f"Value {v} out of [0, 1]"
167
+
168
+
169
+ def test_predicted_curve_has_correct_timepoints():
170
+ predictor = _make_trained_predictor()
171
+ extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
172
+ features = extractor.extract(_GOOD_SCRIPT, platform="Reels", region="pan_india_english")
173
+ curve = predictor.predict(features)
174
+
175
+ assert curve.timepoints == CURVE_TIMEPOINTS
176
+ assert len(curve.values) == len(CURVE_TIMEPOINTS)
177
+
178
+
179
+ # ---------------------------------------------------------------------------
180
+ # RetentionCurveScorer tests
181
+ # ---------------------------------------------------------------------------
182
+
183
+ def _make_curve(values: list) -> RetentionCurve:
184
+ return RetentionCurve.from_values(values)
185
+
186
+
187
+ def test_scorer_rewards_targeted_improvement():
188
+ scorer = RetentionCurveScorer()
189
+ # hook_rewrite targets [0, 3, 6] — improve those timepoints
190
+ orig_values = [1.0, 0.6, 0.5, 0.45, 0.42, 0.40, 0.38, 0.36, 0.32, 0.30]
191
+ new_values = [1.0, 0.85, 0.75, 0.45, 0.42, 0.40, 0.38, 0.36, 0.32, 0.30]
192
+
193
+ result = scorer.score(
194
+ original_curve=_make_curve(orig_values),
195
+ new_curve=_make_curve(new_values),
196
+ action_type="hook_rewrite",
197
+ )
198
+ assert result.final_score > 0
199
+ assert result.targeted_improvement > 0
200
+ assert 3 in result.improved_timepoints or 6 in result.improved_timepoints
201
+
202
+
203
+ def test_scorer_applies_regression_penalty_for_worsening():
204
+ scorer = RetentionCurveScorer()
205
+ orig_values = [1.0, 0.9, 0.8, 0.7, 0.65, 0.60, 0.55, 0.50, 0.45, 0.40]
206
+ # Worsen the mid-video section
207
+ new_values = [1.0, 0.9, 0.8, 0.5, 0.45, 0.40, 0.55, 0.50, 0.45, 0.40]
208
+
209
+ result = scorer.score(
210
+ original_curve=_make_curve(orig_values),
211
+ new_curve=_make_curve(new_values),
212
+ action_type="hook_rewrite",
213
+ )
214
+ assert result.regression_penalty > 0
215
+ assert len(result.worsened_timepoints) > 0
216
+
217
+
218
+ def test_scorer_score_in_range():
219
+ scorer = RetentionCurveScorer()
220
+ orig_values = [1.0, 0.8, 0.7, 0.6, 0.55, 0.50, 0.46, 0.42, 0.38, 0.35]
221
+ new_values = [1.0, 0.85, 0.75, 0.65, 0.60, 0.55, 0.50, 0.46, 0.42, 0.38]
222
+
223
+ result = scorer.score(
224
+ original_curve=_make_curve(orig_values),
225
+ new_curve=_make_curve(new_values),
226
+ action_type="section_reorder",
227
+ )
228
+ assert 0.0 <= result.final_score <= 1.0
229
+
230
+
231
+ # ---------------------------------------------------------------------------
232
+ # RetentionCurveReward — cache test
233
+ # ---------------------------------------------------------------------------
234
+
235
+ def test_retention_reward_caches_original_curve():
236
+ """FeatureExtractor.extract should be called only once for the original script per episode."""
237
+ predictor = _make_trained_predictor()
238
+ reward = RetentionCurveReward.__new__(RetentionCurveReward)
239
+ reward.extractor = FeatureExtractor(cultural_kb_path=_CULTURAL_KB_PATH)
240
+ reward.predictor = predictor
241
+ reward.scorer = RetentionCurveScorer()
242
+ reward._original_curve_cache = {}
243
+
244
+ call_count = {"n": 0}
245
+ original_extract = reward.extractor.extract
246
+
247
+ def counting_extract(script, platform, region):
248
+ call_count["n"] += 1
249
+ return original_extract(script, platform, region)
250
+
251
+ reward.extractor.extract = counting_extract
252
+
253
+ episode_id = "ep_cache_test"
254
+ for _ in range(3):
255
+ reward.score(
256
+ original_script=_GOOD_SCRIPT,
257
+ rewritten_script=_BAD_SCRIPT,
258
+ platform="Reels",
259
+ region="pan_india_english",
260
+ action_type="hook_rewrite",
261
+ episode_id=episode_id,
262
+ )
263
+
264
+ # extract called for original once + rewritten on every call = 1 + 3 = 4
265
+ # original is cached after first call → only 1 for original, 3 for rewritten = 4 total
266
+ assert call_count["n"] == 4, (
267
+ f"Expected 4 extract calls (1 original cached + 3 rewritten), got {call_count['n']}"
268
+ )
269
+
270
+
271
+ # ---------------------------------------------------------------------------
272
+ # env.step includes r10 in reward components
273
+ # ---------------------------------------------------------------------------
274
+
275
+ def test_env_step_includes_r10_when_model_trained():
276
+ """env.step() should include r10_retention_curve in reward components when model is trained."""
277
+ from viral_script_engine.environment.env import ViralScriptEnv
278
+ from unittest.mock import MagicMock
279
+
280
+ env = ViralScriptEnv(
281
+ scripts_path=_SCRIPTS_PATH,
282
+ cultural_kb_path=_CULTURAL_KB_PATH,
283
+ difficulty="easy",
284
+ use_escalation=False,
285
+ use_anti_gaming=False,
286
+ )
287
+
288
+ # Inject trained predictor
289
+ predictor = _make_trained_predictor()
290
+ env.r10.predictor = predictor
291
+
292
+ obs, _ = env.reset()
293
+
294
+ mock_critique = MagicMock()
295
+ mock_critique.claims = []
296
+ mock_critique.overall_severity = "low"
297
+
298
+ mock_defender = MagicMock()
299
+ mock_defender.core_strength = "Strong hook"
300
+ mock_defender.core_strength_quote = "Test quote"
301
+ mock_defender.defense_argument = "Good"
302
+ mock_defender.flagged_critic_claims = []
303
+ mock_defender.regional_voice_elements = []
304
+ mock_defender.model_dump.return_value = {}
305
+
306
+ mock_rewrite = MagicMock()
307
+ mock_rewrite.rewritten_script = obs["current_script"]
308
+ mock_rewrite.diff = ""
309
+
310
+ with patch.object(env.critic, "critique", return_value=mock_critique), \
311
+ patch.object(env.defender, "defend", return_value=mock_defender), \
312
+ patch.object(env.rewriter, "rewrite", return_value=mock_rewrite):
313
+
314
+ _, _, _, _, info = env.step({
315
+ "action_type": "hook_rewrite",
316
+ "target_section": "hook",
317
+ "instruction": "Strengthen the hook.",
318
+ "critique_claim_id": "C1",
319
+ "reasoning": "test",
320
+ })
321
+
322
+ rc = info["reward_components"]
323
+ assert "r10_retention_curve" in rc
324
+ assert rc["r10_retention_curve"] is not None
325
+ assert 0.0 <= rc["r10_retention_curve"] <= 1.0