File size: 13,924 Bytes
fd88516
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
"""

Multi-armed bandit learning system for optimizing generation policies

Learns which retrieval weights and generation parameters work best for each persona/content_type

"""

import numpy as np
import random
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
from datetime import datetime, timedelta
from sqlmodel import Session, select

from models import Script, AutoScore, PolicyWeights, Rating
from db import get_session

@dataclass
class BanditArm:
    """Represents one configuration of parameters to test"""
    name: str
    semantic_weight: float
    bm25_weight: float
    quality_weight: float  
    freshness_weight: float
    temp_low: float
    temp_mid: float
    temp_high: float
    
    def __post_init__(self):
        # Ensure weights sum to 1.0
        total = self.semantic_weight + self.bm25_weight + self.quality_weight + self.freshness_weight
        if total != 1.0:
            self.semantic_weight /= total
            self.bm25_weight /= total
            self.quality_weight /= total
            self.freshness_weight /= total

class PolicyBandit:
    """Multi-armed bandit for learning optimal generation policies"""
    
    def __init__(self, epsilon: float = 0.15, decay_rate: float = 0.99):
        self.epsilon = epsilon  # Exploration rate
        self.decay_rate = decay_rate  # Epsilon decay over time
        self.min_epsilon = 0.05
        
        # Define arms (different parameter configurations)
        self.arms = [
            # Current default
            BanditArm("balanced", 0.45, 0.25, 0.20, 0.10, 0.4, 0.7, 0.95),
            
            # Semantic-heavy (focus on meaning)
            BanditArm("semantic_heavy", 0.60, 0.15, 0.15, 0.10, 0.4, 0.7, 0.95),
            
            # Quality-focused (use only best examples)
            BanditArm("quality_focused", 0.35, 0.20, 0.35, 0.10, 0.3, 0.6, 0.85),
            
            # Fresh-focused (prioritize recent trends)
            BanditArm("fresh_focused", 0.40, 0.20, 0.15, 0.25, 0.5, 0.8, 1.0),
            
            # Conservative (lower temperatures)
            BanditArm("conservative", 0.45, 0.25, 0.20, 0.10, 0.3, 0.5, 0.7),
            
            # Creative (higher temperatures)  
            BanditArm("creative", 0.45, 0.25, 0.20, 0.10, 0.6, 0.9, 1.2),
            
            # Text-match heavy (traditional keyword matching)
            BanditArm("text_heavy", 0.25, 0.45, 0.20, 0.10, 0.4, 0.7, 0.95)
        ]
        
        # Initialize arm statistics
        self.arm_counts = {arm.name: 0 for arm in self.arms}
        self.arm_rewards = {arm.name: 0.0 for arm in self.arms}
    
    def select_arm(self, persona: str, content_type: str) -> BanditArm:
        """Select arm using epsilon-greedy with UCB bias"""
        
        # Load existing policy weights to initialize arm stats
        self._load_arm_stats(persona, content_type)
        
        # Decay epsilon over time
        current_epsilon = max(self.min_epsilon, self.epsilon * (self.decay_rate ** sum(self.arm_counts.values())))
        
        if random.random() < current_epsilon:
            # Explore: random arm
            selected_arm = random.choice(self.arms)
            print(f"🔄 Exploring with {selected_arm.name} policy (ε={current_epsilon:.3f})")
        else:
            # Exploit: best arm with UCB confidence bounds
            selected_arm = self._select_best_arm_ucb()
            print(f"⭐ Exploiting with {selected_arm.name} policy")
        
        return selected_arm
    
    def _select_best_arm_ucb(self) -> BanditArm:
        """Select arm using Upper Confidence Bound"""
        total_counts = sum(self.arm_counts.values())
        if total_counts == 0:
            return self.arms[0]  # Default to first arm
        
        best_arm = None
        best_score = float('-inf')
        
        for arm in self.arms:
            count = self.arm_counts[arm.name]
            if count == 0:
                return arm  # Always try unplayed arms first
            
            # UCB score = average reward + confidence interval
            avg_reward = self.arm_rewards[arm.name] / count
            confidence = np.sqrt(2 * np.log(total_counts) / count)
            ucb_score = avg_reward + confidence
            
            if ucb_score > best_score:
                best_score = ucb_score
                best_arm = arm
        
        return best_arm or self.arms[0]
    
    def _load_arm_stats(self, persona: str, content_type: str):
        """Load historical performance for this persona/content_type"""
        with get_session() as ses:
            policy = ses.exec(
                select(PolicyWeights).where(
                    PolicyWeights.persona == persona,
                    PolicyWeights.content_type == content_type
                )
            ).first()
            
            if policy:
                # Find matching arm and update stats
                for arm in self.arms:
                    if self._arm_matches_policy(arm, policy):
                        self.arm_counts[arm.name] = policy.total_generations
                        self.arm_rewards[arm.name] = policy.success_rate * policy.total_generations
                        break
    
    def _arm_matches_policy(self, arm: BanditArm, policy: PolicyWeights, tolerance: float = 0.05) -> bool:
        """Check if an arm matches the stored policy within tolerance"""
        return (
            abs(arm.semantic_weight - policy.semantic_weight) < tolerance and
            abs(arm.bm25_weight - policy.bm25_weight) < tolerance and
            abs(arm.quality_weight - policy.quality_weight) < tolerance and
            abs(arm.freshness_weight - policy.freshness_weight) < tolerance
        )
    
    def update_reward(self, 

                     arm: BanditArm, 

                     reward: float, 

                     persona: str, 

                     content_type: str,

                     script_id: int):
        """Update arm performance with new reward signal"""
        
        # Update in-memory stats
        self.arm_counts[arm.name] += 1
        self.arm_rewards[arm.name] += reward
        
        # Update database policy
        self._update_policy_weights(arm, reward, persona, content_type)
        
        print(f"📈 Updated {arm.name}: reward={reward:.3f}, avg={self.arm_rewards[arm.name]/self.arm_counts[arm.name]:.3f}")
    
    def _update_policy_weights(self, 

                             arm: BanditArm, 

                             reward: float, 

                             persona: str, 

                             content_type: str):
        """Update policy weights in database"""
        with get_session() as ses:
            policy = ses.exec(
                select(PolicyWeights).where(
                    PolicyWeights.persona == persona,
                    PolicyWeights.content_type == content_type
                )
            ).first()
            
            if not policy:
                # Create new policy
                policy = PolicyWeights(
                    persona=persona,
                    content_type=content_type,
                    semantic_weight=arm.semantic_weight,
                    bm25_weight=arm.bm25_weight,
                    quality_weight=arm.quality_weight,
                    freshness_weight=arm.freshness_weight,
                    temp_low=arm.temp_low,
                    temp_mid=arm.temp_mid,
                    temp_high=arm.temp_high,
                    total_generations=1,
                    success_rate=reward
                )
            else:
                # Update existing policy with exponential moving average
                alpha = 0.1  # Learning rate
                policy.success_rate = (1 - alpha) * policy.success_rate + alpha * reward
                policy.total_generations += 1
                
                # If this arm is performing well, shift weights toward it
                if reward > policy.success_rate:
                    shift = 0.05  # Small shift toward better performing arm
                    policy.semantic_weight = (1 - shift) * policy.semantic_weight + shift * arm.semantic_weight
                    policy.bm25_weight = (1 - shift) * policy.bm25_weight + shift * arm.bm25_weight
                    policy.quality_weight = (1 - shift) * policy.quality_weight + shift * arm.quality_weight
                    policy.freshness_weight = (1 - shift) * policy.freshness_weight + shift * arm.freshness_weight
                    
                    policy.temp_low = (1 - shift) * policy.temp_low + shift * arm.temp_low
                    policy.temp_mid = (1 - shift) * policy.temp_mid + shift * arm.temp_mid
                    policy.temp_high = (1 - shift) * policy.temp_high + shift * arm.temp_high
            
            policy.updated_at = datetime.utcnow()
            ses.add(policy)
            ses.commit()
    
    def calculate_reward(self, script_id: int) -> float:
        """

        Calculate reward signal from script performance

        Combines auto-scores and human ratings when available

        """
        reward_components = []
        
        with get_session() as ses:
            # Get auto-score
            auto_score = ses.exec(
                select(AutoScore).where(AutoScore.script_id == script_id)
            ).first()
            
            if auto_score and auto_score.confidence > 0.5:
                # Weighted composite of auto-scores
                auto_reward = (
                    0.35 * auto_score.overall +
                    0.20 * auto_score.hook +
                    0.15 * auto_score.originality +
                    0.15 * auto_score.style_fit +
                    0.15 * auto_score.safety
                ) / 5.0  # Normalize to 0-1
                
                reward_components.append(('auto', auto_reward, auto_score.confidence))
            
            # Get human ratings
            script = ses.get(Script, script_id)
            if script and script.ratings_count > 0:
                human_reward = script.score_overall / 5.0  # Normalize to 0-1
                confidence = min(1.0, script.ratings_count / 3.0)  # More ratings = higher confidence
                reward_components.append(('human', human_reward, confidence))
        
        if not reward_components:
            return 0.5  # Neutral reward if no scores available
        
        # Weighted average of reward components by confidence
        total_weight = sum(confidence for _, _, confidence in reward_components)
        weighted_reward = sum(
            reward * confidence for _, reward, confidence in reward_components
        ) / total_weight
        
        return weighted_reward

class PolicyLearner:
    """Main interface for policy learning"""
    
    def __init__(self):
        self.bandit = PolicyBandit()
    
    def learn_from_generation_batch(self, 

                                  persona: str,

                                  content_type: str,

                                  generated_script_ids: List[int],

                                  selected_arm: BanditArm):
        """Learn from a batch of generated scripts"""
        
        if not generated_script_ids:
            return
        
        # Calculate average reward from the batch
        rewards = [self.bandit.calculate_reward(sid) for sid in generated_script_ids]
        avg_reward = sum(rewards) / len(rewards)
        
        # Update bandit with average performance
        self.bandit.update_reward(
            selected_arm, 
            avg_reward, 
            persona, 
            content_type,
            generated_script_ids[0]  # Representative script ID
        )
        
        print(f"🧠 Policy learning: {persona}/{content_type}{avg_reward:.3f} reward")
    
    def get_optimized_policy(self, persona: str, content_type: str) -> BanditArm:
        """Get the current best policy for this persona/content_type"""
        return self.bandit.select_arm(persona, content_type)
    
    def run_learning_cycle(self):
        """Run a learning cycle on recent generations"""
        print("🔄 Starting policy learning cycle...")
        
        # Find recent AI-generated scripts by persona/content_type
        cutoff = datetime.utcnow() - timedelta(hours=24)
        
        with get_session() as ses:
            recent_scripts = list(ses.exec(
                select(Script).where(
                    Script.created_at >= cutoff,
                    Script.source == "ai"
                )
            ))
        
        # Group by persona/content_type
        groups = {}
        for script in recent_scripts:
            key = (script.creator, script.content_type)
            if key not in groups:
                groups[key] = []
            groups[key].append(script.id)
        
        # Learn from each group
        for (persona, content_type), script_ids in groups.items():
            if len(script_ids) >= 3:  # Need minimum batch size
                # For now, assume they used the balanced policy
                # In practice, you'd track which policy was used for each generation
                balanced_arm = next(arm for arm in self.bandit.arms if arm.name == "balanced")
                self.learn_from_generation_batch(persona, content_type, script_ids, balanced_arm)

def run_policy_learning():
    """Main entry point for policy learning"""
    learner = PolicyLearner()
    learner.run_learning_cycle()

if __name__ == "__main__":
    run_policy_learning()