Spaces:
Sleeping
Sleeping
File size: 13,468 Bytes
fd88516 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 |
"""
Integration layer between the existing system and new RAG capabilities
Shows how to plug the enhanced system into the current workflow
"""
from typing import List, Dict, Any, Optional
import json
from sqlmodel import Session
from datetime import datetime
from models import Script, Embedding, AutoScore, PolicyWeights
from db import get_session, init_db
from deepseek_client import chat, get_api_key
from rag_retrieval import RAGRetriever
from auto_scorer import AutoScorer, ScriptReranker
from bandit_learner import PolicyLearner
class EnhancedScriptGenerator:
"""
Enhanced version of script generation with RAG + policy learning
Drop-in replacement for the existing generate_scripts function
"""
def __init__(self):
self.retriever = RAGRetriever()
self.scorer = AutoScorer()
self.reranker = ScriptReranker()
self.policy_learner = PolicyLearner()
# Verify we have API key
if not get_api_key():
raise ValueError("DeepSeek API key not found!")
def generate_scripts_enhanced(self,
persona: str,
boundaries: str,
content_type: str,
tone: str,
manual_refs: List[str] = None,
n: int = 6) -> List[Dict]:
"""
Enhanced script generation with:
1. RAG-based reference selection
2. Policy-optimized parameters
3. Auto-scoring and reranking
4. Online learning feedback
"""
print(f"π€ Enhanced generation: {persona} Γ {content_type} Γ {n} scripts")
# Step 1: Get optimized policy for this persona/content_type
policy_arm = self.policy_learner.get_optimized_policy(persona, content_type)
# Step 2: Build dynamic few-shot pack using RAG
query_context = f"{persona} {content_type} {tone}"
few_shot_pack = self.retriever.build_dynamic_few_shot_pack(
persona=persona,
content_type=content_type,
query_context=query_context
)
# Step 3: Combine RAG refs with manual refs
rag_refs = (
few_shot_pack.get('best_hooks', []) +
few_shot_pack.get('best_beats', []) +
few_shot_pack.get('best_captions', [])
)
all_refs = (manual_refs or []) + rag_refs
print(f"π Using {len(rag_refs)} RAG refs + {len(manual_refs or [])} manual refs")
# Step 4: Enhanced generation with policy-optimized parameters
drafts = self._generate_with_policy(
persona=persona,
boundaries=boundaries,
content_type=content_type,
tone=tone,
refs=all_refs,
policy_arm=policy_arm,
n=n,
few_shot_pack=few_shot_pack
)
# Step 5: Anti-copying detection and cleanup
print(f"π‘οΈ Checking for similarity to reference content...")
# Extract reference texts for copying detection
reference_texts = rag_refs
cleaned_drafts = []
for draft in drafts:
# Check for copying
detection_results = self.retriever.detect_copying(
generated_content=draft,
reference_texts=reference_texts,
similarity_threshold=0.92
)
if detection_results['is_copying']:
print(f"β οΈ Anti-copy triggered for draft: {draft.get('title', 'Untitled')[:30]}")
print(f" Max similarity: {detection_results['max_similarity']:.3f}")
# Auto-rewrite similar content
cleaned_draft = self.retriever.auto_rewrite_similar_content(
generated_content=draft,
detection_results=detection_results
)
cleaned_drafts.append(cleaned_draft)
else:
cleaned_drafts.append(draft)
# Step 6: Auto-score all generated drafts
script_ids = self._save_drafts_to_db(cleaned_drafts, persona, content_type, tone)
auto_scores = [self.scorer.score_and_store(sid) for sid in script_ids]
print(f"π Auto-scored {len(auto_scores)} drafts")
# Step 7: Rerank by composite score
ranked_script_ids = self.reranker.rerank_scripts(script_ids)
# Step 8: Policy learning feedback
self.policy_learner.learn_from_generation_batch(
persona=persona,
content_type=content_type,
generated_script_ids=script_ids,
selected_arm=policy_arm
)
# Return drafts in ranked order with scores
return self._format_enhanced_results(ranked_script_ids, cleaned_drafts)
def _generate_with_policy(self,
persona: str,
boundaries: str,
content_type: str,
tone: str,
refs: List[str],
policy_arm: Any, # BanditArm
n: int,
few_shot_pack: Dict) -> List[Dict]:
"""Generate scripts using policy-optimized parameters"""
# Enhanced system prompt with few-shot pack context
system = f"""You write Instagram-compliant, suggestive-but-not-explicit Reels briefs.
STYLE CONTEXT: {few_shot_pack.get('style_card', '')}
BEST PATTERNS TO EMULATE:
Hooks: {json.dumps(few_shot_pack.get('best_hooks', []))}
Beats: {json.dumps(few_shot_pack.get('best_beats', []))}
Captions: {json.dumps(few_shot_pack.get('best_captions', []))}
AVOID THESE PATTERNS: {json.dumps(few_shot_pack.get('negative_patterns', []))}
Use tight hooks, concrete visual beats, clear CTAs. Avoid explicit sexual terms.
Return ONLY JSON: an array of length {n}, each with {{title,hook,beats,voiceover,caption,hashtags,cta}}.
"""
user = f"""
Persona: {persona}
Boundaries: {boundaries}
Content type: {content_type} | Tone: {tone}
Constraints: {json.dumps(few_shot_pack.get('constraints', {}))}
Reference snippets (inspire, don't copy):
{chr(10).join(f"- {r}" for r in refs[:8])} # Limit to top 8 refs
Generate {n} unique variations. JSON array ONLY.
"""
# Generate with multiple temperatures (policy-optimized)
variants = []
temps = [policy_arm.temp_low, policy_arm.temp_mid, policy_arm.temp_high]
scripts_per_temp = max(1, n // len(temps))
for i, temp in enumerate(temps):
batch_size = scripts_per_temp
if i == len(temps) - 1: # Last batch gets remainder
batch_size = n - len(variants)
if batch_size <= 0:
break
try:
out = chat([
{"role": "system", "content": system},
{"role": "user", "content": user.replace(f"Generate {n}", f"Generate {batch_size}")}
], temperature=temp)
# Extract JSON
start = out.find("[")
end = out.rfind("]")
if start >= 0 and end > start:
batch_variants = json.loads(out[start:end+1])
variants.extend(batch_variants[:batch_size])
print(f"β¨ Generated {len(batch_variants)} scripts at temp={temp}")
except Exception as e:
print(f"β Generation failed at temp={temp}: {e}")
return variants[:n] # Ensure we don't exceed requested count
def _save_drafts_to_db(self,
drafts: List[Dict],
persona: str,
content_type: str,
tone: str) -> List[int]:
"""Save generated drafts to database and return script IDs"""
script_ids = []
with get_session() as ses:
for draft in drafts:
try:
# Calculate basic compliance
from compliance import score_script, blob_from
content_blob = blob_from(draft)
compliance_level, _ = score_script(content_blob)
script = Script(
creator=persona,
content_type=content_type,
tone=tone,
title=draft.get("title", "Generated Script"),
hook=draft.get("hook", ""),
beats=draft.get("beats", []),
voiceover=draft.get("voiceover", ""),
caption=draft.get("caption", ""),
hashtags=draft.get("hashtags", []),
cta=draft.get("cta", ""),
compliance=compliance_level,
source="ai"
)
ses.add(script)
ses.commit()
ses.refresh(script)
script_ids.append(script.id)
# Generate embeddings for new script
embeddings = self.retriever.generate_embeddings(script)
for embedding in embeddings:
ses.add(embedding)
except Exception as e:
print(f"β Failed to save draft: {e}")
continue
ses.commit()
return script_ids
def _format_enhanced_results(self,
ranked_script_ids: List[tuple],
original_drafts: List[Dict]) -> List[Dict]:
"""Format results with ranking and score information"""
# Create a lookup for original drafts by content
draft_lookup = {}
for i, draft in enumerate(original_drafts):
key = draft.get("title", "") + draft.get("hook", "")
draft_lookup[key] = draft
results = []
with get_session() as ses:
for script_id, composite_score in ranked_script_ids:
script = ses.get(Script, script_id)
if script:
# Convert back to the expected format
result = {
"title": script.title,
"hook": script.hook,
"beats": script.beats,
"voiceover": script.voiceover,
"caption": script.caption,
"hashtags": script.hashtags,
"cta": script.cta,
# Enhanced metadata
"_enhanced_score": round(composite_score, 3),
"_script_id": script_id,
"_compliance": script.compliance
}
results.append(result)
return results
# Backward compatibility wrapper
def generate_scripts_rag(persona: str,
boundaries: str,
content_type: str,
tone: str,
refs: List[str],
n: int = 6) -> List[Dict]:
"""
Drop-in replacement for existing generate_scripts function
Uses enhanced RAG system while maintaining API compatibility
"""
generator = EnhancedScriptGenerator()
return generator.generate_scripts_enhanced(
persona=persona,
boundaries=boundaries,
content_type=content_type,
tone=tone,
manual_refs=refs,
n=n
)
def setup_rag_system():
"""One-time setup to initialize the RAG system"""
print("π§ Setting up RAG system...")
# Initialize database with new tables
init_db()
print("β
Database initialized")
# Generate embeddings for existing scripts
from rag_retrieval import index_all_scripts
index_all_scripts()
print("β
Existing scripts indexed")
# Auto-score recent scripts
scorer = AutoScorer()
recent_scores = scorer.batch_score_recent(hours=24*7) # Last week
print(f"β
Auto-scored {len(recent_scores)} recent scripts")
print("π RAG system setup complete!")
if __name__ == "__main__":
# Demo the enhanced system
setup_rag_system()
# Test generation
generator = EnhancedScriptGenerator()
results = generator.generate_scripts_enhanced(
persona="Anya",
boundaries="Instagram-safe; suggestive but not explicit",
content_type="thirst-trap",
tone="playful, flirty",
manual_refs=["Just a quick workout session", "Getting ready for the day"],
n=3
)
print(f"\n㪠Generated {len(results)} enhanced scripts:")
for i, script in enumerate(results, 1):
score = script.get('_enhanced_score', 0)
compliance = script.get('_compliance', 'unknown')
print(f"{i}. {script['title']} (score: {score}, compliance: {compliance})")
print(f" Hook: {script['hook'][:60]}...")
|