File size: 13,468 Bytes
fd88516
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
"""
Integration layer between the existing system and new RAG capabilities
Shows how to plug the enhanced system into the current workflow
"""

from typing import List, Dict, Any, Optional
import json
from sqlmodel import Session
from datetime import datetime

from models import Script, Embedding, AutoScore, PolicyWeights
from db import get_session, init_db
from deepseek_client import chat, get_api_key
from rag_retrieval import RAGRetriever
from auto_scorer import AutoScorer, ScriptReranker
from bandit_learner import PolicyLearner

class EnhancedScriptGenerator:
    """
    Enhanced version of script generation with RAG + policy learning
    Drop-in replacement for the existing generate_scripts function
    """
    
    def __init__(self):
        self.retriever = RAGRetriever()
        self.scorer = AutoScorer()
        self.reranker = ScriptReranker()
        self.policy_learner = PolicyLearner()
        
        # Verify we have API key
        if not get_api_key():
            raise ValueError("DeepSeek API key not found!")
    
    def generate_scripts_enhanced(self,
                                persona: str,
                                boundaries: str, 
                                content_type: str,
                                tone: str,
                                manual_refs: List[str] = None,
                                n: int = 6) -> List[Dict]:
        """
        Enhanced script generation with:
        1. RAG-based reference selection
        2. Policy-optimized parameters  
        3. Auto-scoring and reranking
        4. Online learning feedback
        """
        
        print(f"πŸ€– Enhanced generation: {persona} Γ— {content_type} Γ— {n} scripts")
        
        # Step 1: Get optimized policy for this persona/content_type
        policy_arm = self.policy_learner.get_optimized_policy(persona, content_type)
        
        # Step 2: Build dynamic few-shot pack using RAG
        query_context = f"{persona} {content_type} {tone}"
        few_shot_pack = self.retriever.build_dynamic_few_shot_pack(
            persona=persona,
            content_type=content_type, 
            query_context=query_context
        )
        
        # Step 3: Combine RAG refs with manual refs
        rag_refs = (
            few_shot_pack.get('best_hooks', []) +
            few_shot_pack.get('best_beats', []) +
            few_shot_pack.get('best_captions', [])
        )
        all_refs = (manual_refs or []) + rag_refs
        
        print(f"πŸ“š Using {len(rag_refs)} RAG refs + {len(manual_refs or [])} manual refs")
        
        # Step 4: Enhanced generation with policy-optimized parameters
        drafts = self._generate_with_policy(
            persona=persona,
            boundaries=boundaries,
            content_type=content_type, 
            tone=tone,
            refs=all_refs,
            policy_arm=policy_arm,
            n=n,
            few_shot_pack=few_shot_pack
        )
        
        # Step 5: Anti-copying detection and cleanup
        print(f"πŸ›‘οΈ Checking for similarity to reference content...")
        
        # Extract reference texts for copying detection
        reference_texts = rag_refs
        cleaned_drafts = []
        
        for draft in drafts:
            # Check for copying
            detection_results = self.retriever.detect_copying(
                generated_content=draft,
                reference_texts=reference_texts,
                similarity_threshold=0.92
            )
            
            if detection_results['is_copying']:
                print(f"⚠️ Anti-copy triggered for draft: {draft.get('title', 'Untitled')[:30]}")
                print(f"   Max similarity: {detection_results['max_similarity']:.3f}")
                
                # Auto-rewrite similar content
                cleaned_draft = self.retriever.auto_rewrite_similar_content(
                    generated_content=draft,
                    detection_results=detection_results
                )
                cleaned_drafts.append(cleaned_draft)
            else:
                cleaned_drafts.append(draft)
        
        # Step 6: Auto-score all generated drafts
        script_ids = self._save_drafts_to_db(cleaned_drafts, persona, content_type, tone)
        auto_scores = [self.scorer.score_and_store(sid) for sid in script_ids]
        
        print(f"πŸ“Š Auto-scored {len(auto_scores)} drafts")
        
        # Step 7: Rerank by composite score
        ranked_script_ids = self.reranker.rerank_scripts(script_ids)
        
        # Step 8: Policy learning feedback
        self.policy_learner.learn_from_generation_batch(
            persona=persona,
            content_type=content_type,
            generated_script_ids=script_ids,
            selected_arm=policy_arm
        )
        
        # Return drafts in ranked order with scores
        return self._format_enhanced_results(ranked_script_ids, cleaned_drafts)
    
    def _generate_with_policy(self,
                            persona: str,
                            boundaries: str,
                            content_type: str,
                            tone: str,
                            refs: List[str],
                            policy_arm: Any,  # BanditArm
                            n: int,
                            few_shot_pack: Dict) -> List[Dict]:
        """Generate scripts using policy-optimized parameters"""
        
        # Enhanced system prompt with few-shot pack context
        system = f"""You write Instagram-compliant, suggestive-but-not-explicit Reels briefs.
        
STYLE CONTEXT: {few_shot_pack.get('style_card', '')}

BEST PATTERNS TO EMULATE:
Hooks: {json.dumps(few_shot_pack.get('best_hooks', []))}  
Beats: {json.dumps(few_shot_pack.get('best_beats', []))}
Captions: {json.dumps(few_shot_pack.get('best_captions', []))}

AVOID THESE PATTERNS: {json.dumps(few_shot_pack.get('negative_patterns', []))}

Use tight hooks, concrete visual beats, clear CTAs. Avoid explicit sexual terms.
Return ONLY JSON: an array of length {n}, each with {{title,hook,beats,voiceover,caption,hashtags,cta}}.
"""
        
        user = f"""
Persona: {persona}
Boundaries: {boundaries}
Content type: {content_type} | Tone: {tone}
Constraints: {json.dumps(few_shot_pack.get('constraints', {}))}

Reference snippets (inspire, don't copy):
{chr(10).join(f"- {r}" for r in refs[:8])}  # Limit to top 8 refs

Generate {n} unique variations. JSON array ONLY.
"""
        
        # Generate with multiple temperatures (policy-optimized)
        variants = []
        temps = [policy_arm.temp_low, policy_arm.temp_mid, policy_arm.temp_high]
        scripts_per_temp = max(1, n // len(temps))
        
        for i, temp in enumerate(temps):
            batch_size = scripts_per_temp
            if i == len(temps) - 1:  # Last batch gets remainder
                batch_size = n - len(variants)
            
            if batch_size <= 0:
                break
                
            try:
                out = chat([
                    {"role": "system", "content": system},
                    {"role": "user", "content": user.replace(f"Generate {n}", f"Generate {batch_size}")}
                ], temperature=temp)
                
                # Extract JSON
                start = out.find("[")
                end = out.rfind("]")
                if start >= 0 and end > start:
                    batch_variants = json.loads(out[start:end+1])
                    variants.extend(batch_variants[:batch_size])
                    print(f"✨ Generated {len(batch_variants)} scripts at temp={temp}")
                    
            except Exception as e:
                print(f"❌ Generation failed at temp={temp}: {e}")
        
        return variants[:n]  # Ensure we don't exceed requested count
    
    def _save_drafts_to_db(self, 
                          drafts: List[Dict], 
                          persona: str, 
                          content_type: str, 
                          tone: str) -> List[int]:
        """Save generated drafts to database and return script IDs"""
        
        script_ids = []
        
        with get_session() as ses:
            for draft in drafts:
                try:
                    # Calculate basic compliance
                    from compliance import score_script, blob_from
                    content_blob = blob_from(draft)
                    compliance_level, _ = score_script(content_blob)
                    
                    script = Script(
                        creator=persona,
                        content_type=content_type,
                        tone=tone,
                        title=draft.get("title", "Generated Script"),
                        hook=draft.get("hook", ""),
                        beats=draft.get("beats", []),
                        voiceover=draft.get("voiceover", ""),
                        caption=draft.get("caption", ""),
                        hashtags=draft.get("hashtags", []),
                        cta=draft.get("cta", ""),
                        compliance=compliance_level,
                        source="ai"
                    )
                    
                    ses.add(script)
                    ses.commit()
                    ses.refresh(script)
                    
                    script_ids.append(script.id)
                    
                    # Generate embeddings for new script
                    embeddings = self.retriever.generate_embeddings(script)
                    for embedding in embeddings:
                        ses.add(embedding)
                    
                except Exception as e:
                    print(f"❌ Failed to save draft: {e}")
                    continue
            
            ses.commit()
        
        return script_ids
    
    def _format_enhanced_results(self, 
                               ranked_script_ids: List[tuple], 
                               original_drafts: List[Dict]) -> List[Dict]:
        """Format results with ranking and score information"""
        
        # Create a lookup for original drafts by content
        draft_lookup = {}
        for i, draft in enumerate(original_drafts):
            key = draft.get("title", "") + draft.get("hook", "")
            draft_lookup[key] = draft
        
        results = []
        
        with get_session() as ses:
            for script_id, composite_score in ranked_script_ids:
                script = ses.get(Script, script_id)
                if script:
                    # Convert back to the expected format
                    result = {
                        "title": script.title,
                        "hook": script.hook,
                        "beats": script.beats,
                        "voiceover": script.voiceover,
                        "caption": script.caption,
                        "hashtags": script.hashtags,
                        "cta": script.cta,
                        # Enhanced metadata
                        "_enhanced_score": round(composite_score, 3),
                        "_script_id": script_id,
                        "_compliance": script.compliance
                    }
                    results.append(result)
        
        return results

# Backward compatibility wrapper
def generate_scripts_rag(persona: str,
                        boundaries: str,
                        content_type: str,
                        tone: str,
                        refs: List[str],
                        n: int = 6) -> List[Dict]:
    """
    Drop-in replacement for existing generate_scripts function
    Uses enhanced RAG system while maintaining API compatibility
    """
    generator = EnhancedScriptGenerator()
    return generator.generate_scripts_enhanced(
        persona=persona,
        boundaries=boundaries,
        content_type=content_type,
        tone=tone,
        manual_refs=refs,
        n=n
    )

def setup_rag_system():
    """One-time setup to initialize the RAG system"""
    print("πŸ”§ Setting up RAG system...")
    
    # Initialize database with new tables
    init_db()
    print("βœ… Database initialized")
    
    # Generate embeddings for existing scripts
    from rag_retrieval import index_all_scripts
    index_all_scripts()
    print("βœ… Existing scripts indexed")
    
    # Auto-score recent scripts
    scorer = AutoScorer()
    recent_scores = scorer.batch_score_recent(hours=24*7)  # Last week
    print(f"βœ… Auto-scored {len(recent_scores)} recent scripts")
    
    print("πŸŽ‰ RAG system setup complete!")

if __name__ == "__main__":
    # Demo the enhanced system
    setup_rag_system()
    
    # Test generation
    generator = EnhancedScriptGenerator()
    results = generator.generate_scripts_enhanced(
        persona="Anya",
        boundaries="Instagram-safe; suggestive but not explicit",
        content_type="thirst-trap",
        tone="playful, flirty",
        manual_refs=["Just a quick workout session", "Getting ready for the day"],
        n=3
    )
    
    print(f"\n🎬 Generated {len(results)} enhanced scripts:")
    for i, script in enumerate(results, 1):
        score = script.get('_enhanced_score', 0)
        compliance = script.get('_compliance', 'unknown')
        print(f"{i}. {script['title']} (score: {score}, compliance: {compliance})")
        print(f"   Hook: {script['hook'][:60]}...")