Spaces:
Sleeping
Sleeping
| """ | |
| Integration layer between the existing system and new RAG capabilities | |
| Shows how to plug the enhanced system into the current workflow | |
| """ | |
| from typing import List, Dict, Any, Optional | |
| import json | |
| from sqlmodel import Session | |
| from datetime import datetime | |
| from models import Script, Embedding, AutoScore, PolicyWeights | |
| from db import get_session, init_db | |
| from deepseek_client import chat, get_api_key | |
| from rag_retrieval import RAGRetriever | |
| from auto_scorer import AutoScorer, ScriptReranker | |
| from bandit_learner import PolicyLearner | |
| class EnhancedScriptGenerator: | |
| """ | |
| Enhanced version of script generation with RAG + policy learning | |
| Drop-in replacement for the existing generate_scripts function | |
| """ | |
| def __init__(self): | |
| self.retriever = RAGRetriever() | |
| self.scorer = AutoScorer() | |
| self.reranker = ScriptReranker() | |
| self.policy_learner = PolicyLearner() | |
| # Verify we have API key | |
| if not get_api_key(): | |
| raise ValueError("DeepSeek API key not found!") | |
| def generate_scripts_enhanced(self, | |
| persona: str, | |
| boundaries: str, | |
| content_type: str, | |
| tone: str, | |
| manual_refs: List[str] = None, | |
| n: int = 6) -> List[Dict]: | |
| """ | |
| Enhanced script generation with: | |
| 1. RAG-based reference selection | |
| 2. Policy-optimized parameters | |
| 3. Auto-scoring and reranking | |
| 4. Online learning feedback | |
| """ | |
| print(f"π€ Enhanced generation: {persona} Γ {content_type} Γ {n} scripts") | |
| # Step 1: Get optimized policy for this persona/content_type | |
| policy_arm = self.policy_learner.get_optimized_policy(persona, content_type) | |
| # Step 2: Build dynamic few-shot pack using RAG | |
| query_context = f"{persona} {content_type} {tone}" | |
| few_shot_pack = self.retriever.build_dynamic_few_shot_pack( | |
| persona=persona, | |
| content_type=content_type, | |
| query_context=query_context | |
| ) | |
| # Step 3: Combine RAG refs with manual refs | |
| rag_refs = ( | |
| few_shot_pack.get('best_hooks', []) + | |
| few_shot_pack.get('best_beats', []) + | |
| few_shot_pack.get('best_captions', []) | |
| ) | |
| all_refs = (manual_refs or []) + rag_refs | |
| print(f"π Using {len(rag_refs)} RAG refs + {len(manual_refs or [])} manual refs") | |
| # Step 4: Enhanced generation with policy-optimized parameters | |
| drafts = self._generate_with_policy( | |
| persona=persona, | |
| boundaries=boundaries, | |
| content_type=content_type, | |
| tone=tone, | |
| refs=all_refs, | |
| policy_arm=policy_arm, | |
| n=n, | |
| few_shot_pack=few_shot_pack | |
| ) | |
| # Step 5: Anti-copying detection and cleanup | |
| print(f"π‘οΈ Checking for similarity to reference content...") | |
| # Extract reference texts for copying detection | |
| reference_texts = rag_refs | |
| cleaned_drafts = [] | |
| for draft in drafts: | |
| # Check for copying | |
| detection_results = self.retriever.detect_copying( | |
| generated_content=draft, | |
| reference_texts=reference_texts, | |
| similarity_threshold=0.92 | |
| ) | |
| if detection_results['is_copying']: | |
| print(f"β οΈ Anti-copy triggered for draft: {draft.get('title', 'Untitled')[:30]}") | |
| print(f" Max similarity: {detection_results['max_similarity']:.3f}") | |
| # Auto-rewrite similar content | |
| cleaned_draft = self.retriever.auto_rewrite_similar_content( | |
| generated_content=draft, | |
| detection_results=detection_results | |
| ) | |
| cleaned_drafts.append(cleaned_draft) | |
| else: | |
| cleaned_drafts.append(draft) | |
| # Step 6: Auto-score all generated drafts | |
| script_ids = self._save_drafts_to_db(cleaned_drafts, persona, content_type, tone) | |
| auto_scores = [self.scorer.score_and_store(sid) for sid in script_ids] | |
| print(f"π Auto-scored {len(auto_scores)} drafts") | |
| # Step 7: Rerank by composite score | |
| ranked_script_ids = self.reranker.rerank_scripts(script_ids) | |
| # Step 8: Policy learning feedback | |
| self.policy_learner.learn_from_generation_batch( | |
| persona=persona, | |
| content_type=content_type, | |
| generated_script_ids=script_ids, | |
| selected_arm=policy_arm | |
| ) | |
| # Return drafts in ranked order with scores | |
| return self._format_enhanced_results(ranked_script_ids, cleaned_drafts) | |
| def _generate_with_policy(self, | |
| persona: str, | |
| boundaries: str, | |
| content_type: str, | |
| tone: str, | |
| refs: List[str], | |
| policy_arm: Any, # BanditArm | |
| n: int, | |
| few_shot_pack: Dict) -> List[Dict]: | |
| """Generate scripts using policy-optimized parameters""" | |
| # Enhanced system prompt with few-shot pack context | |
| system = f"""You write Instagram-compliant, suggestive-but-not-explicit Reels briefs. | |
| STYLE CONTEXT: {few_shot_pack.get('style_card', '')} | |
| BEST PATTERNS TO EMULATE: | |
| Hooks: {json.dumps(few_shot_pack.get('best_hooks', []))} | |
| Beats: {json.dumps(few_shot_pack.get('best_beats', []))} | |
| Captions: {json.dumps(few_shot_pack.get('best_captions', []))} | |
| AVOID THESE PATTERNS: {json.dumps(few_shot_pack.get('negative_patterns', []))} | |
| Use tight hooks, concrete visual beats, clear CTAs. Avoid explicit sexual terms. | |
| Return ONLY JSON: an array of length {n}, each with {{title,hook,beats,voiceover,caption,hashtags,cta}}. | |
| """ | |
| user = f""" | |
| Persona: {persona} | |
| Boundaries: {boundaries} | |
| Content type: {content_type} | Tone: {tone} | |
| Constraints: {json.dumps(few_shot_pack.get('constraints', {}))} | |
| Reference snippets (inspire, don't copy): | |
| {chr(10).join(f"- {r}" for r in refs[:8])} # Limit to top 8 refs | |
| Generate {n} unique variations. JSON array ONLY. | |
| """ | |
| # Generate with multiple temperatures (policy-optimized) | |
| variants = [] | |
| temps = [policy_arm.temp_low, policy_arm.temp_mid, policy_arm.temp_high] | |
| scripts_per_temp = max(1, n // len(temps)) | |
| for i, temp in enumerate(temps): | |
| batch_size = scripts_per_temp | |
| if i == len(temps) - 1: # Last batch gets remainder | |
| batch_size = n - len(variants) | |
| if batch_size <= 0: | |
| break | |
| try: | |
| out = chat([ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user.replace(f"Generate {n}", f"Generate {batch_size}")} | |
| ], temperature=temp) | |
| # Extract JSON | |
| start = out.find("[") | |
| end = out.rfind("]") | |
| if start >= 0 and end > start: | |
| batch_variants = json.loads(out[start:end+1]) | |
| variants.extend(batch_variants[:batch_size]) | |
| print(f"β¨ Generated {len(batch_variants)} scripts at temp={temp}") | |
| except Exception as e: | |
| print(f"β Generation failed at temp={temp}: {e}") | |
| return variants[:n] # Ensure we don't exceed requested count | |
| def _save_drafts_to_db(self, | |
| drafts: List[Dict], | |
| persona: str, | |
| content_type: str, | |
| tone: str) -> List[int]: | |
| """Save generated drafts to database and return script IDs""" | |
| script_ids = [] | |
| with get_session() as ses: | |
| for draft in drafts: | |
| try: | |
| # Calculate basic compliance | |
| from compliance import score_script, blob_from | |
| content_blob = blob_from(draft) | |
| compliance_level, _ = score_script(content_blob) | |
| script = Script( | |
| creator=persona, | |
| content_type=content_type, | |
| tone=tone, | |
| title=draft.get("title", "Generated Script"), | |
| hook=draft.get("hook", ""), | |
| beats=draft.get("beats", []), | |
| voiceover=draft.get("voiceover", ""), | |
| caption=draft.get("caption", ""), | |
| hashtags=draft.get("hashtags", []), | |
| cta=draft.get("cta", ""), | |
| compliance=compliance_level, | |
| source="ai" | |
| ) | |
| ses.add(script) | |
| ses.commit() | |
| ses.refresh(script) | |
| script_ids.append(script.id) | |
| # Generate embeddings for new script | |
| embeddings = self.retriever.generate_embeddings(script) | |
| for embedding in embeddings: | |
| ses.add(embedding) | |
| except Exception as e: | |
| print(f"β Failed to save draft: {e}") | |
| continue | |
| ses.commit() | |
| return script_ids | |
| def _format_enhanced_results(self, | |
| ranked_script_ids: List[tuple], | |
| original_drafts: List[Dict]) -> List[Dict]: | |
| """Format results with ranking and score information""" | |
| # Create a lookup for original drafts by content | |
| draft_lookup = {} | |
| for i, draft in enumerate(original_drafts): | |
| key = draft.get("title", "") + draft.get("hook", "") | |
| draft_lookup[key] = draft | |
| results = [] | |
| with get_session() as ses: | |
| for script_id, composite_score in ranked_script_ids: | |
| script = ses.get(Script, script_id) | |
| if script: | |
| # Convert back to the expected format | |
| result = { | |
| "title": script.title, | |
| "hook": script.hook, | |
| "beats": script.beats, | |
| "voiceover": script.voiceover, | |
| "caption": script.caption, | |
| "hashtags": script.hashtags, | |
| "cta": script.cta, | |
| # Enhanced metadata | |
| "_enhanced_score": round(composite_score, 3), | |
| "_script_id": script_id, | |
| "_compliance": script.compliance | |
| } | |
| results.append(result) | |
| return results | |
| # Backward compatibility wrapper | |
| def generate_scripts_rag(persona: str, | |
| boundaries: str, | |
| content_type: str, | |
| tone: str, | |
| refs: List[str], | |
| n: int = 6) -> List[Dict]: | |
| """ | |
| Drop-in replacement for existing generate_scripts function | |
| Uses enhanced RAG system while maintaining API compatibility | |
| """ | |
| generator = EnhancedScriptGenerator() | |
| return generator.generate_scripts_enhanced( | |
| persona=persona, | |
| boundaries=boundaries, | |
| content_type=content_type, | |
| tone=tone, | |
| manual_refs=refs, | |
| n=n | |
| ) | |
| def setup_rag_system(): | |
| """One-time setup to initialize the RAG system""" | |
| print("π§ Setting up RAG system...") | |
| # Initialize database with new tables | |
| init_db() | |
| print("β Database initialized") | |
| # Generate embeddings for existing scripts | |
| from rag_retrieval import index_all_scripts | |
| index_all_scripts() | |
| print("β Existing scripts indexed") | |
| # Auto-score recent scripts | |
| scorer = AutoScorer() | |
| recent_scores = scorer.batch_score_recent(hours=24*7) # Last week | |
| print(f"β Auto-scored {len(recent_scores)} recent scripts") | |
| print("π RAG system setup complete!") | |
| if __name__ == "__main__": | |
| # Demo the enhanced system | |
| setup_rag_system() | |
| # Test generation | |
| generator = EnhancedScriptGenerator() | |
| results = generator.generate_scripts_enhanced( | |
| persona="Anya", | |
| boundaries="Instagram-safe; suggestive but not explicit", | |
| content_type="thirst-trap", | |
| tone="playful, flirty", | |
| manual_refs=["Just a quick workout session", "Getting ready for the day"], | |
| n=3 | |
| ) | |
| print(f"\n㪠Generated {len(results)} enhanced scripts:") | |
| for i, script in enumerate(results, 1): | |
| score = script.get('_enhanced_score', 0) | |
| compliance = script.get('_compliance', 'unknown') | |
| print(f"{i}. {script['title']} (score: {score}, compliance: {compliance})") | |
| print(f" Hook: {script['hook'][:60]}...") | |