Spaces:
Sleeping
Sleeping
| """ | |
| Multi-armed bandit learning system for optimizing generation policies | |
| Learns which retrieval weights and generation parameters work best for each persona/content_type | |
| """ | |
| import numpy as np | |
| import random | |
| from typing import Dict, List, Tuple, Optional | |
| from dataclasses import dataclass | |
| from datetime import datetime, timedelta | |
| from sqlmodel import Session, select | |
| from models import Script, AutoScore, PolicyWeights, Rating | |
| from db import get_session | |
| class BanditArm: | |
| """Represents one configuration of parameters to test""" | |
| name: str | |
| semantic_weight: float | |
| bm25_weight: float | |
| quality_weight: float | |
| freshness_weight: float | |
| temp_low: float | |
| temp_mid: float | |
| temp_high: float | |
| def __post_init__(self): | |
| # Ensure weights sum to 1.0 | |
| total = self.semantic_weight + self.bm25_weight + self.quality_weight + self.freshness_weight | |
| if total != 1.0: | |
| self.semantic_weight /= total | |
| self.bm25_weight /= total | |
| self.quality_weight /= total | |
| self.freshness_weight /= total | |
| class PolicyBandit: | |
| """Multi-armed bandit for learning optimal generation policies""" | |
| def __init__(self, epsilon: float = 0.15, decay_rate: float = 0.99): | |
| self.epsilon = epsilon # Exploration rate | |
| self.decay_rate = decay_rate # Epsilon decay over time | |
| self.min_epsilon = 0.05 | |
| # Define arms (different parameter configurations) | |
| self.arms = [ | |
| # Current default | |
| BanditArm("balanced", 0.45, 0.25, 0.20, 0.10, 0.4, 0.7, 0.95), | |
| # Semantic-heavy (focus on meaning) | |
| BanditArm("semantic_heavy", 0.60, 0.15, 0.15, 0.10, 0.4, 0.7, 0.95), | |
| # Quality-focused (use only best examples) | |
| BanditArm("quality_focused", 0.35, 0.20, 0.35, 0.10, 0.3, 0.6, 0.85), | |
| # Fresh-focused (prioritize recent trends) | |
| BanditArm("fresh_focused", 0.40, 0.20, 0.15, 0.25, 0.5, 0.8, 1.0), | |
| # Conservative (lower temperatures) | |
| BanditArm("conservative", 0.45, 0.25, 0.20, 0.10, 0.3, 0.5, 0.7), | |
| # Creative (higher temperatures) | |
| BanditArm("creative", 0.45, 0.25, 0.20, 0.10, 0.6, 0.9, 1.2), | |
| # Text-match heavy (traditional keyword matching) | |
| BanditArm("text_heavy", 0.25, 0.45, 0.20, 0.10, 0.4, 0.7, 0.95) | |
| ] | |
| # Initialize arm statistics | |
| self.arm_counts = {arm.name: 0 for arm in self.arms} | |
| self.arm_rewards = {arm.name: 0.0 for arm in self.arms} | |
| def select_arm(self, persona: str, content_type: str) -> BanditArm: | |
| """Select arm using epsilon-greedy with UCB bias""" | |
| # Load existing policy weights to initialize arm stats | |
| self._load_arm_stats(persona, content_type) | |
| # Decay epsilon over time | |
| current_epsilon = max(self.min_epsilon, self.epsilon * (self.decay_rate ** sum(self.arm_counts.values()))) | |
| if random.random() < current_epsilon: | |
| # Explore: random arm | |
| selected_arm = random.choice(self.arms) | |
| print(f"🔄 Exploring with {selected_arm.name} policy (ε={current_epsilon:.3f})") | |
| else: | |
| # Exploit: best arm with UCB confidence bounds | |
| selected_arm = self._select_best_arm_ucb() | |
| print(f"⭐ Exploiting with {selected_arm.name} policy") | |
| return selected_arm | |
| def _select_best_arm_ucb(self) -> BanditArm: | |
| """Select arm using Upper Confidence Bound""" | |
| total_counts = sum(self.arm_counts.values()) | |
| if total_counts == 0: | |
| return self.arms[0] # Default to first arm | |
| best_arm = None | |
| best_score = float('-inf') | |
| for arm in self.arms: | |
| count = self.arm_counts[arm.name] | |
| if count == 0: | |
| return arm # Always try unplayed arms first | |
| # UCB score = average reward + confidence interval | |
| avg_reward = self.arm_rewards[arm.name] / count | |
| confidence = np.sqrt(2 * np.log(total_counts) / count) | |
| ucb_score = avg_reward + confidence | |
| if ucb_score > best_score: | |
| best_score = ucb_score | |
| best_arm = arm | |
| return best_arm or self.arms[0] | |
| def _load_arm_stats(self, persona: str, content_type: str): | |
| """Load historical performance for this persona/content_type""" | |
| with get_session() as ses: | |
| policy = ses.exec( | |
| select(PolicyWeights).where( | |
| PolicyWeights.persona == persona, | |
| PolicyWeights.content_type == content_type | |
| ) | |
| ).first() | |
| if policy: | |
| # Find matching arm and update stats | |
| for arm in self.arms: | |
| if self._arm_matches_policy(arm, policy): | |
| self.arm_counts[arm.name] = policy.total_generations | |
| self.arm_rewards[arm.name] = policy.success_rate * policy.total_generations | |
| break | |
| def _arm_matches_policy(self, arm: BanditArm, policy: PolicyWeights, tolerance: float = 0.05) -> bool: | |
| """Check if an arm matches the stored policy within tolerance""" | |
| return ( | |
| abs(arm.semantic_weight - policy.semantic_weight) < tolerance and | |
| abs(arm.bm25_weight - policy.bm25_weight) < tolerance and | |
| abs(arm.quality_weight - policy.quality_weight) < tolerance and | |
| abs(arm.freshness_weight - policy.freshness_weight) < tolerance | |
| ) | |
| def update_reward(self, | |
| arm: BanditArm, | |
| reward: float, | |
| persona: str, | |
| content_type: str, | |
| script_id: int): | |
| """Update arm performance with new reward signal""" | |
| # Update in-memory stats | |
| self.arm_counts[arm.name] += 1 | |
| self.arm_rewards[arm.name] += reward | |
| # Update database policy | |
| self._update_policy_weights(arm, reward, persona, content_type) | |
| print(f"📈 Updated {arm.name}: reward={reward:.3f}, avg={self.arm_rewards[arm.name]/self.arm_counts[arm.name]:.3f}") | |
| def _update_policy_weights(self, | |
| arm: BanditArm, | |
| reward: float, | |
| persona: str, | |
| content_type: str): | |
| """Update policy weights in database""" | |
| with get_session() as ses: | |
| policy = ses.exec( | |
| select(PolicyWeights).where( | |
| PolicyWeights.persona == persona, | |
| PolicyWeights.content_type == content_type | |
| ) | |
| ).first() | |
| if not policy: | |
| # Create new policy | |
| policy = PolicyWeights( | |
| persona=persona, | |
| content_type=content_type, | |
| semantic_weight=arm.semantic_weight, | |
| bm25_weight=arm.bm25_weight, | |
| quality_weight=arm.quality_weight, | |
| freshness_weight=arm.freshness_weight, | |
| temp_low=arm.temp_low, | |
| temp_mid=arm.temp_mid, | |
| temp_high=arm.temp_high, | |
| total_generations=1, | |
| success_rate=reward | |
| ) | |
| else: | |
| # Update existing policy with exponential moving average | |
| alpha = 0.1 # Learning rate | |
| policy.success_rate = (1 - alpha) * policy.success_rate + alpha * reward | |
| policy.total_generations += 1 | |
| # If this arm is performing well, shift weights toward it | |
| if reward > policy.success_rate: | |
| shift = 0.05 # Small shift toward better performing arm | |
| policy.semantic_weight = (1 - shift) * policy.semantic_weight + shift * arm.semantic_weight | |
| policy.bm25_weight = (1 - shift) * policy.bm25_weight + shift * arm.bm25_weight | |
| policy.quality_weight = (1 - shift) * policy.quality_weight + shift * arm.quality_weight | |
| policy.freshness_weight = (1 - shift) * policy.freshness_weight + shift * arm.freshness_weight | |
| policy.temp_low = (1 - shift) * policy.temp_low + shift * arm.temp_low | |
| policy.temp_mid = (1 - shift) * policy.temp_mid + shift * arm.temp_mid | |
| policy.temp_high = (1 - shift) * policy.temp_high + shift * arm.temp_high | |
| policy.updated_at = datetime.utcnow() | |
| ses.add(policy) | |
| ses.commit() | |
| def calculate_reward(self, script_id: int) -> float: | |
| """ | |
| Calculate reward signal from script performance | |
| Combines auto-scores and human ratings when available | |
| """ | |
| reward_components = [] | |
| with get_session() as ses: | |
| # Get auto-score | |
| auto_score = ses.exec( | |
| select(AutoScore).where(AutoScore.script_id == script_id) | |
| ).first() | |
| if auto_score and auto_score.confidence > 0.5: | |
| # Weighted composite of auto-scores | |
| auto_reward = ( | |
| 0.35 * auto_score.overall + | |
| 0.20 * auto_score.hook + | |
| 0.15 * auto_score.originality + | |
| 0.15 * auto_score.style_fit + | |
| 0.15 * auto_score.safety | |
| ) / 5.0 # Normalize to 0-1 | |
| reward_components.append(('auto', auto_reward, auto_score.confidence)) | |
| # Get human ratings | |
| script = ses.get(Script, script_id) | |
| if script and script.ratings_count > 0: | |
| human_reward = script.score_overall / 5.0 # Normalize to 0-1 | |
| confidence = min(1.0, script.ratings_count / 3.0) # More ratings = higher confidence | |
| reward_components.append(('human', human_reward, confidence)) | |
| if not reward_components: | |
| return 0.5 # Neutral reward if no scores available | |
| # Weighted average of reward components by confidence | |
| total_weight = sum(confidence for _, _, confidence in reward_components) | |
| weighted_reward = sum( | |
| reward * confidence for _, reward, confidence in reward_components | |
| ) / total_weight | |
| return weighted_reward | |
| class PolicyLearner: | |
| """Main interface for policy learning""" | |
| def __init__(self): | |
| self.bandit = PolicyBandit() | |
| def learn_from_generation_batch(self, | |
| persona: str, | |
| content_type: str, | |
| generated_script_ids: List[int], | |
| selected_arm: BanditArm): | |
| """Learn from a batch of generated scripts""" | |
| if not generated_script_ids: | |
| return | |
| # Calculate average reward from the batch | |
| rewards = [self.bandit.calculate_reward(sid) for sid in generated_script_ids] | |
| avg_reward = sum(rewards) / len(rewards) | |
| # Update bandit with average performance | |
| self.bandit.update_reward( | |
| selected_arm, | |
| avg_reward, | |
| persona, | |
| content_type, | |
| generated_script_ids[0] # Representative script ID | |
| ) | |
| print(f"🧠 Policy learning: {persona}/{content_type} → {avg_reward:.3f} reward") | |
| def get_optimized_policy(self, persona: str, content_type: str) -> BanditArm: | |
| """Get the current best policy for this persona/content_type""" | |
| return self.bandit.select_arm(persona, content_type) | |
| def run_learning_cycle(self): | |
| """Run a learning cycle on recent generations""" | |
| print("🔄 Starting policy learning cycle...") | |
| # Find recent AI-generated scripts by persona/content_type | |
| cutoff = datetime.utcnow() - timedelta(hours=24) | |
| with get_session() as ses: | |
| recent_scripts = list(ses.exec( | |
| select(Script).where( | |
| Script.created_at >= cutoff, | |
| Script.source == "ai" | |
| ) | |
| )) | |
| # Group by persona/content_type | |
| groups = {} | |
| for script in recent_scripts: | |
| key = (script.creator, script.content_type) | |
| if key not in groups: | |
| groups[key] = [] | |
| groups[key].append(script.id) | |
| # Learn from each group | |
| for (persona, content_type), script_ids in groups.items(): | |
| if len(script_ids) >= 3: # Need minimum batch size | |
| # For now, assume they used the balanced policy | |
| # In practice, you'd track which policy was used for each generation | |
| balanced_arm = next(arm for arm in self.bandit.arms if arm.name == "balanced") | |
| self.learn_from_generation_batch(persona, content_type, script_ids, balanced_arm) | |
| def run_policy_learning(): | |
| """Main entry point for policy learning""" | |
| learner = PolicyLearner() | |
| learner.run_learning_cycle() | |
| if __name__ == "__main__": | |
| run_policy_learning() | |