scriptwriter / rag_retrieval.py
kreemyyyy's picture
Upload 13 files
fd88516 verified
"""
Enhanced RAG retrieval system for AI Script Studio
Extends the existing hybrid reference system with semantic search and policy learning
"""
import numpy as np
import math
from typing import List, Dict, Tuple, Optional
from sentence_transformers import SentenceTransformer
from sqlmodel import Session, select
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import json
from datetime import datetime, timedelta
from models import Script, Embedding, AutoScore, PolicyWeights, StyleCard
from db import get_session
class RAGRetriever:
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
"""Initialize with lightweight but effective embedding model"""
self.encoder = SentenceTransformer(model_name)
self.tfidf = TfidfVectorizer(max_features=1000, stop_words='english')
def generate_embeddings(self, script: Script) -> List[Embedding]:
"""Generate embeddings for different parts of a script"""
parts = {
'full': self._get_full_text(script),
'hook': script.hook or '',
'beats': ' '.join(script.beats or []),
'caption': script.caption or ''
}
embeddings = []
for part, text in parts.items():
if text.strip(): # Only embed non-empty parts
vector = self.encoder.encode(text).tolist()
meta = {
'creator': script.creator,
'content_type': script.content_type,
'tone': script.tone,
'quality_score': script.score_overall or 0.0,
'compliance': script.compliance
}
embeddings.append(Embedding(
script_id=script.id,
part=part,
vector=vector,
meta=meta
))
return embeddings
def _get_full_text(self, script: Script) -> str:
"""Combine all script parts into full text"""
parts = [
script.title,
script.hook or '',
' '.join(script.beats or []),
script.voiceover or '',
script.caption or '',
script.cta or ''
]
return ' '.join(p for p in parts if p.strip())
def hybrid_retrieve(self,
query_text: str,
persona: str,
content_type: str,
k: int = 6,
global_quality_mean: float = 4.2,
shrinkage_alpha: float = 10.0,
freshness_tau_days: float = 28.0) -> List[Dict]:
"""
Production-grade hybrid retrieval with proper score normalization:
- Semantic similarity (cosine normalized to [0,1])
- BM25/TF-IDF similarity (min-max normalized per query)
- Quality scores (Bayesian shrinkage)
- Freshness boost (exponential decay)
- Policy-learned weights
"""
# Get policy weights for this persona/content_type
weights = self._get_policy_weights(persona, content_type)
with get_session() as ses:
# Get all relevant scripts
scripts = list(ses.exec(
select(Script).where(
Script.creator == persona,
Script.content_type == content_type,
Script.is_reference == True,
Script.compliance != "fail"
)
))
if not scripts:
return []
# Get embeddings for semantic similarity
embeddings = list(ses.exec(
select(Embedding).join(Script, Embedding.script_id == Script.id).where(
Embedding.part == 'full',
Script.creator == persona,
Script.content_type == content_type,
Script.is_reference == True,
Script.compliance != "fail"
)
))
# Pre-calculate all raw scores for normalization
raw_scores = []
query_embedding = self.encoder.encode(query_text)
now = datetime.utcnow()
for script in scripts:
# Find matching embedding
script_embedding = next(
(e for e in embeddings if e.script_id == script.id),
None
)
# 1. Raw semantic similarity (cosine returns [-1,1])
if script_embedding:
raw_cosine = cosine_similarity(
[query_embedding],
[script_embedding.vector]
)[0][0]
else:
raw_cosine = -1.0 # Worst case for missing embeddings
# 2. Raw BM25/TF-IDF similarity
script_text = self._get_full_text(script)
raw_bm25 = self._calculate_tfidf_similarity(query_text, script_text)
raw_scores.append({
'script': script,
'raw_cosine': raw_cosine,
'raw_bm25': raw_bm25
})
# Normalize BM25 scores (min-max normalization across this query's candidates)
bm25_scores = [s['raw_bm25'] for s in raw_scores]
min_bm25 = min(bm25_scores)
max_bm25 = max(bm25_scores)
bm25_range = max_bm25 - min_bm25 + 1e-9 # Avoid division by zero
# Calculate final normalized scores
results = []
for raw_score in raw_scores:
script = raw_score['script']
scores = {}
# 1. Semantic similarity: normalize cosine [-1,1] → [0,1]
scores['semantic'] = (raw_score['raw_cosine'] + 1.0) / 2.0
# 2. BM25: min-max normalize within this query's candidate set
scores['bm25'] = (raw_score['raw_bm25'] - min_bm25) / bm25_range
# 3. Quality: Bayesian shrinkage toward global mean
n_ratings = script.ratings_count or 0
local_quality = script.score_overall or global_quality_mean
# Shrinkage: blend local mean with global mean based on sample size
shrunk_quality = (
(n_ratings / (n_ratings + shrinkage_alpha)) * local_quality +
(shrinkage_alpha / (n_ratings + shrinkage_alpha)) * global_quality_mean
)
# Normalize to [0,1] (assuming 1-5 rating scale)
scores['quality'] = max(0.0, min(1.0, (shrunk_quality - 1) / 4))
# 4. Freshness: exponential decay (smoother than linear)
days_old = max(0, (now - script.created_at).days)
scores['freshness'] = math.exp(-days_old / freshness_tau_days)
# Combined score using policy weights
combined_score = (
weights.semantic_weight * scores['semantic'] +
weights.bm25_weight * scores['bm25'] +
weights.quality_weight * scores['quality'] +
weights.freshness_weight * scores['freshness']
)
results.append({
'script': script,
'score': combined_score,
'component_scores': scores,
# Debug info
'_debug': {
'n_ratings': n_ratings,
'raw_quality': local_quality,
'shrunk_quality': shrunk_quality,
'days_old': days_old
}
})
# Sort by combined score and return top k
results.sort(key=lambda x: x['score'], reverse=True)
return results[:k]
def _calculate_tfidf_similarity(self, query: str, doc: str) -> float:
"""Calculate TF-IDF similarity between query and document"""
try:
tfidf_matrix = self.tfidf.fit_transform([query, doc])
similarity = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]
return float(similarity)
except:
return 0.0
def _get_policy_weights(self, persona: str, content_type: str) -> PolicyWeights:
"""Get learned policy weights or create defaults"""
with get_session() as ses:
weights = ses.exec(
select(PolicyWeights).where(
PolicyWeights.persona == persona,
PolicyWeights.content_type == content_type
)
).first()
if not weights:
# Create default weights
weights = PolicyWeights(
persona=persona,
content_type=content_type
)
ses.add(weights)
ses.commit()
ses.refresh(weights)
return weights
def build_dynamic_few_shot_pack(self,
persona: str,
content_type: str,
query_context: str = "") -> Dict:
"""Build dynamic few-shot examples pack optimized for this request"""
# Get best references via hybrid retrieval
references = self.hybrid_retrieve(
query_text=query_context or f"{persona} {content_type}",
persona=persona,
content_type=content_type,
k=6
)
if not references:
return {"style_card": "", "examples": [], "constraints": {}}
# Extract best examples by type
best_hooks = []
best_beats = []
best_captions = []
for ref in references[:4]: # Use top 4 references
script = ref['script']
if script.hook and len(best_hooks) < 2:
best_hooks.append(script.hook)
if script.beats and len(best_beats) < 1:
best_beats.extend(script.beats[:2]) # First 2 beats
if script.caption and len(best_captions) < 1:
best_captions.append(script.caption)
# Get or create style card
style_card = self._get_style_card(persona, content_type)
return {
"style_card": f"Persona: {persona} | Content: {content_type}",
"best_hooks": best_hooks[:2],
"best_beats": best_beats[:3],
"best_captions": best_captions[:1],
"constraints": {
"max_length": "15-25 seconds",
"compliance": "Instagram-safe",
"tone": references[0]['script'].tone if references else "playful"
},
"negative_patterns": style_card.negative_patterns if style_card else []
}
def _get_style_card(self, persona: str, content_type: str) -> Optional[StyleCard]:
"""Get existing style card or return None"""
with get_session() as ses:
return ses.exec(
select(StyleCard).where(
StyleCard.persona == persona,
StyleCard.content_type == content_type
)
).first()
def detect_copying(self,
generated_content: Dict,
reference_texts: List[str],
similarity_threshold: float = 0.92) -> Dict:
"""
Detect if generated content is too similar to reference material.
Returns detection results with flagged content and similarity scores.
Args:
generated_content: Dict with keys like 'hook', 'caption', 'beats', etc.
reference_texts: List of reference text snippets to compare against
similarity_threshold: Cosine similarity threshold (0.92 recommended)
Returns:
Dict with detection results and recommendations
"""
detection_results = {
'is_copying': False,
'flagged_fields': [],
'max_similarity': 0.0,
'rewrite_recommendations': []
}
if not reference_texts:
return detection_results
# Encode all reference texts
reference_embeddings = self.encoder.encode(reference_texts)
# Fields to check for copying
fields_to_check = ['hook', 'caption', 'cta']
for field in fields_to_check:
if field in generated_content and generated_content[field]:
generated_text = str(generated_content[field])
# Skip very short texts (less than 10 characters)
if len(generated_text.strip()) < 10:
continue
# Encode generated text
generated_embedding = self.encoder.encode([generated_text])
# Calculate similarity to all reference texts
similarities = cosine_similarity(generated_embedding, reference_embeddings)[0]
max_sim = float(np.max(similarities))
# Update overall max similarity
detection_results['max_similarity'] = max(detection_results['max_similarity'], max_sim)
# Check if similarity exceeds threshold
if max_sim >= similarity_threshold:
detection_results['is_copying'] = True
detection_results['flagged_fields'].append({
'field': field,
'text': generated_text,
'similarity': max_sim,
'similar_reference': reference_texts[int(np.argmax(similarities))]
})
# Generate rewrite recommendation
if max_sim >= 0.95:
urgency = "CRITICAL"
action = "Completely rewrite this content"
elif max_sim >= 0.92:
urgency = "HIGH"
action = "Significantly rephrase this content"
else:
urgency = "MEDIUM"
action = "Minor rewording may be needed"
detection_results['rewrite_recommendations'].append({
'field': field,
'urgency': urgency,
'action': action,
'original': generated_text
})
return detection_results
def auto_rewrite_similar_content(self,
generated_content: Dict,
detection_results: Dict,
rewrite_instruction: str = "Rewrite to be more original while keeping the same intent") -> Dict:
"""
Automatically rewrite content that's too similar to references.
Args:
generated_content: The original generated content
detection_results: Results from detect_copying()
rewrite_instruction: Instructions for how to rewrite
Returns:
Rewritten content dict
"""
if not detection_results['is_copying']:
return generated_content
rewritten_content = generated_content.copy()
for flag in detection_results['flagged_fields']:
field = flag['field']
original_text = flag['text']
# Simple rewrite strategy: add instruction to modify the text
# In a production system, you'd call the LLM to rewrite
rewrite_prompt = f"""
Original: {original_text}
This text is too similar to existing reference material.
Please rewrite it to be more original while keeping the same intent and tone.
Make it clearly different from the reference but equally engaging.
Rewritten version:
"""
# For now, add a flag that this needs rewriting
# In production, you'd call your LLM API here
rewritten_content[field] = f"[NEEDS_REWRITE] {original_text}"
# Log the issue
print(f"🚨 Anti-copy detection: {field} flagged (similarity: {flag['similarity']:.3f})")
print(f" Original: {original_text[:60]}...")
print(f" Similar to: {flag['similar_reference'][:60]}...")
return rewritten_content
def index_all_scripts():
"""Utility function to generate embeddings for all existing scripts"""
retriever = RAGRetriever()
with get_session() as ses:
scripts = list(ses.exec(select(Script)))
for script in scripts:
# Check if embeddings already exist
existing = ses.exec(
select(Embedding).where(Embedding.script_id == script.id)
).first()
if not existing:
embeddings = retriever.generate_embeddings(script)
for embedding in embeddings:
ses.add(embedding)
print(f"Generated embeddings for script {script.id}")
ses.commit()
print(f"Indexing complete! Processed {len(scripts)} scripts.")
if __name__ == "__main__":
# Run this to index your existing scripts
index_all_scripts()