Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced RAG retrieval system for AI Script Studio | |
| Extends the existing hybrid reference system with semantic search and policy learning | |
| """ | |
| import numpy as np | |
| import math | |
| from typing import List, Dict, Tuple, Optional | |
| from sentence_transformers import SentenceTransformer | |
| from sqlmodel import Session, select | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import json | |
| from datetime import datetime, timedelta | |
| from models import Script, Embedding, AutoScore, PolicyWeights, StyleCard | |
| from db import get_session | |
| class RAGRetriever: | |
| def __init__(self, model_name: str = "all-MiniLM-L6-v2"): | |
| """Initialize with lightweight but effective embedding model""" | |
| self.encoder = SentenceTransformer(model_name) | |
| self.tfidf = TfidfVectorizer(max_features=1000, stop_words='english') | |
| def generate_embeddings(self, script: Script) -> List[Embedding]: | |
| """Generate embeddings for different parts of a script""" | |
| parts = { | |
| 'full': self._get_full_text(script), | |
| 'hook': script.hook or '', | |
| 'beats': ' '.join(script.beats or []), | |
| 'caption': script.caption or '' | |
| } | |
| embeddings = [] | |
| for part, text in parts.items(): | |
| if text.strip(): # Only embed non-empty parts | |
| vector = self.encoder.encode(text).tolist() | |
| meta = { | |
| 'creator': script.creator, | |
| 'content_type': script.content_type, | |
| 'tone': script.tone, | |
| 'quality_score': script.score_overall or 0.0, | |
| 'compliance': script.compliance | |
| } | |
| embeddings.append(Embedding( | |
| script_id=script.id, | |
| part=part, | |
| vector=vector, | |
| meta=meta | |
| )) | |
| return embeddings | |
| def _get_full_text(self, script: Script) -> str: | |
| """Combine all script parts into full text""" | |
| parts = [ | |
| script.title, | |
| script.hook or '', | |
| ' '.join(script.beats or []), | |
| script.voiceover or '', | |
| script.caption or '', | |
| script.cta or '' | |
| ] | |
| return ' '.join(p for p in parts if p.strip()) | |
| def hybrid_retrieve(self, | |
| query_text: str, | |
| persona: str, | |
| content_type: str, | |
| k: int = 6, | |
| global_quality_mean: float = 4.2, | |
| shrinkage_alpha: float = 10.0, | |
| freshness_tau_days: float = 28.0) -> List[Dict]: | |
| """ | |
| Production-grade hybrid retrieval with proper score normalization: | |
| - Semantic similarity (cosine normalized to [0,1]) | |
| - BM25/TF-IDF similarity (min-max normalized per query) | |
| - Quality scores (Bayesian shrinkage) | |
| - Freshness boost (exponential decay) | |
| - Policy-learned weights | |
| """ | |
| # Get policy weights for this persona/content_type | |
| weights = self._get_policy_weights(persona, content_type) | |
| with get_session() as ses: | |
| # Get all relevant scripts | |
| scripts = list(ses.exec( | |
| select(Script).where( | |
| Script.creator == persona, | |
| Script.content_type == content_type, | |
| Script.is_reference == True, | |
| Script.compliance != "fail" | |
| ) | |
| )) | |
| if not scripts: | |
| return [] | |
| # Get embeddings for semantic similarity | |
| embeddings = list(ses.exec( | |
| select(Embedding).join(Script, Embedding.script_id == Script.id).where( | |
| Embedding.part == 'full', | |
| Script.creator == persona, | |
| Script.content_type == content_type, | |
| Script.is_reference == True, | |
| Script.compliance != "fail" | |
| ) | |
| )) | |
| # Pre-calculate all raw scores for normalization | |
| raw_scores = [] | |
| query_embedding = self.encoder.encode(query_text) | |
| now = datetime.utcnow() | |
| for script in scripts: | |
| # Find matching embedding | |
| script_embedding = next( | |
| (e for e in embeddings if e.script_id == script.id), | |
| None | |
| ) | |
| # 1. Raw semantic similarity (cosine returns [-1,1]) | |
| if script_embedding: | |
| raw_cosine = cosine_similarity( | |
| [query_embedding], | |
| [script_embedding.vector] | |
| )[0][0] | |
| else: | |
| raw_cosine = -1.0 # Worst case for missing embeddings | |
| # 2. Raw BM25/TF-IDF similarity | |
| script_text = self._get_full_text(script) | |
| raw_bm25 = self._calculate_tfidf_similarity(query_text, script_text) | |
| raw_scores.append({ | |
| 'script': script, | |
| 'raw_cosine': raw_cosine, | |
| 'raw_bm25': raw_bm25 | |
| }) | |
| # Normalize BM25 scores (min-max normalization across this query's candidates) | |
| bm25_scores = [s['raw_bm25'] for s in raw_scores] | |
| min_bm25 = min(bm25_scores) | |
| max_bm25 = max(bm25_scores) | |
| bm25_range = max_bm25 - min_bm25 + 1e-9 # Avoid division by zero | |
| # Calculate final normalized scores | |
| results = [] | |
| for raw_score in raw_scores: | |
| script = raw_score['script'] | |
| scores = {} | |
| # 1. Semantic similarity: normalize cosine [-1,1] → [0,1] | |
| scores['semantic'] = (raw_score['raw_cosine'] + 1.0) / 2.0 | |
| # 2. BM25: min-max normalize within this query's candidate set | |
| scores['bm25'] = (raw_score['raw_bm25'] - min_bm25) / bm25_range | |
| # 3. Quality: Bayesian shrinkage toward global mean | |
| n_ratings = script.ratings_count or 0 | |
| local_quality = script.score_overall or global_quality_mean | |
| # Shrinkage: blend local mean with global mean based on sample size | |
| shrunk_quality = ( | |
| (n_ratings / (n_ratings + shrinkage_alpha)) * local_quality + | |
| (shrinkage_alpha / (n_ratings + shrinkage_alpha)) * global_quality_mean | |
| ) | |
| # Normalize to [0,1] (assuming 1-5 rating scale) | |
| scores['quality'] = max(0.0, min(1.0, (shrunk_quality - 1) / 4)) | |
| # 4. Freshness: exponential decay (smoother than linear) | |
| days_old = max(0, (now - script.created_at).days) | |
| scores['freshness'] = math.exp(-days_old / freshness_tau_days) | |
| # Combined score using policy weights | |
| combined_score = ( | |
| weights.semantic_weight * scores['semantic'] + | |
| weights.bm25_weight * scores['bm25'] + | |
| weights.quality_weight * scores['quality'] + | |
| weights.freshness_weight * scores['freshness'] | |
| ) | |
| results.append({ | |
| 'script': script, | |
| 'score': combined_score, | |
| 'component_scores': scores, | |
| # Debug info | |
| '_debug': { | |
| 'n_ratings': n_ratings, | |
| 'raw_quality': local_quality, | |
| 'shrunk_quality': shrunk_quality, | |
| 'days_old': days_old | |
| } | |
| }) | |
| # Sort by combined score and return top k | |
| results.sort(key=lambda x: x['score'], reverse=True) | |
| return results[:k] | |
| def _calculate_tfidf_similarity(self, query: str, doc: str) -> float: | |
| """Calculate TF-IDF similarity between query and document""" | |
| try: | |
| tfidf_matrix = self.tfidf.fit_transform([query, doc]) | |
| similarity = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0] | |
| return float(similarity) | |
| except: | |
| return 0.0 | |
| def _get_policy_weights(self, persona: str, content_type: str) -> PolicyWeights: | |
| """Get learned policy weights or create defaults""" | |
| with get_session() as ses: | |
| weights = ses.exec( | |
| select(PolicyWeights).where( | |
| PolicyWeights.persona == persona, | |
| PolicyWeights.content_type == content_type | |
| ) | |
| ).first() | |
| if not weights: | |
| # Create default weights | |
| weights = PolicyWeights( | |
| persona=persona, | |
| content_type=content_type | |
| ) | |
| ses.add(weights) | |
| ses.commit() | |
| ses.refresh(weights) | |
| return weights | |
| def build_dynamic_few_shot_pack(self, | |
| persona: str, | |
| content_type: str, | |
| query_context: str = "") -> Dict: | |
| """Build dynamic few-shot examples pack optimized for this request""" | |
| # Get best references via hybrid retrieval | |
| references = self.hybrid_retrieve( | |
| query_text=query_context or f"{persona} {content_type}", | |
| persona=persona, | |
| content_type=content_type, | |
| k=6 | |
| ) | |
| if not references: | |
| return {"style_card": "", "examples": [], "constraints": {}} | |
| # Extract best examples by type | |
| best_hooks = [] | |
| best_beats = [] | |
| best_captions = [] | |
| for ref in references[:4]: # Use top 4 references | |
| script = ref['script'] | |
| if script.hook and len(best_hooks) < 2: | |
| best_hooks.append(script.hook) | |
| if script.beats and len(best_beats) < 1: | |
| best_beats.extend(script.beats[:2]) # First 2 beats | |
| if script.caption and len(best_captions) < 1: | |
| best_captions.append(script.caption) | |
| # Get or create style card | |
| style_card = self._get_style_card(persona, content_type) | |
| return { | |
| "style_card": f"Persona: {persona} | Content: {content_type}", | |
| "best_hooks": best_hooks[:2], | |
| "best_beats": best_beats[:3], | |
| "best_captions": best_captions[:1], | |
| "constraints": { | |
| "max_length": "15-25 seconds", | |
| "compliance": "Instagram-safe", | |
| "tone": references[0]['script'].tone if references else "playful" | |
| }, | |
| "negative_patterns": style_card.negative_patterns if style_card else [] | |
| } | |
| def _get_style_card(self, persona: str, content_type: str) -> Optional[StyleCard]: | |
| """Get existing style card or return None""" | |
| with get_session() as ses: | |
| return ses.exec( | |
| select(StyleCard).where( | |
| StyleCard.persona == persona, | |
| StyleCard.content_type == content_type | |
| ) | |
| ).first() | |
| def detect_copying(self, | |
| generated_content: Dict, | |
| reference_texts: List[str], | |
| similarity_threshold: float = 0.92) -> Dict: | |
| """ | |
| Detect if generated content is too similar to reference material. | |
| Returns detection results with flagged content and similarity scores. | |
| Args: | |
| generated_content: Dict with keys like 'hook', 'caption', 'beats', etc. | |
| reference_texts: List of reference text snippets to compare against | |
| similarity_threshold: Cosine similarity threshold (0.92 recommended) | |
| Returns: | |
| Dict with detection results and recommendations | |
| """ | |
| detection_results = { | |
| 'is_copying': False, | |
| 'flagged_fields': [], | |
| 'max_similarity': 0.0, | |
| 'rewrite_recommendations': [] | |
| } | |
| if not reference_texts: | |
| return detection_results | |
| # Encode all reference texts | |
| reference_embeddings = self.encoder.encode(reference_texts) | |
| # Fields to check for copying | |
| fields_to_check = ['hook', 'caption', 'cta'] | |
| for field in fields_to_check: | |
| if field in generated_content and generated_content[field]: | |
| generated_text = str(generated_content[field]) | |
| # Skip very short texts (less than 10 characters) | |
| if len(generated_text.strip()) < 10: | |
| continue | |
| # Encode generated text | |
| generated_embedding = self.encoder.encode([generated_text]) | |
| # Calculate similarity to all reference texts | |
| similarities = cosine_similarity(generated_embedding, reference_embeddings)[0] | |
| max_sim = float(np.max(similarities)) | |
| # Update overall max similarity | |
| detection_results['max_similarity'] = max(detection_results['max_similarity'], max_sim) | |
| # Check if similarity exceeds threshold | |
| if max_sim >= similarity_threshold: | |
| detection_results['is_copying'] = True | |
| detection_results['flagged_fields'].append({ | |
| 'field': field, | |
| 'text': generated_text, | |
| 'similarity': max_sim, | |
| 'similar_reference': reference_texts[int(np.argmax(similarities))] | |
| }) | |
| # Generate rewrite recommendation | |
| if max_sim >= 0.95: | |
| urgency = "CRITICAL" | |
| action = "Completely rewrite this content" | |
| elif max_sim >= 0.92: | |
| urgency = "HIGH" | |
| action = "Significantly rephrase this content" | |
| else: | |
| urgency = "MEDIUM" | |
| action = "Minor rewording may be needed" | |
| detection_results['rewrite_recommendations'].append({ | |
| 'field': field, | |
| 'urgency': urgency, | |
| 'action': action, | |
| 'original': generated_text | |
| }) | |
| return detection_results | |
| def auto_rewrite_similar_content(self, | |
| generated_content: Dict, | |
| detection_results: Dict, | |
| rewrite_instruction: str = "Rewrite to be more original while keeping the same intent") -> Dict: | |
| """ | |
| Automatically rewrite content that's too similar to references. | |
| Args: | |
| generated_content: The original generated content | |
| detection_results: Results from detect_copying() | |
| rewrite_instruction: Instructions for how to rewrite | |
| Returns: | |
| Rewritten content dict | |
| """ | |
| if not detection_results['is_copying']: | |
| return generated_content | |
| rewritten_content = generated_content.copy() | |
| for flag in detection_results['flagged_fields']: | |
| field = flag['field'] | |
| original_text = flag['text'] | |
| # Simple rewrite strategy: add instruction to modify the text | |
| # In a production system, you'd call the LLM to rewrite | |
| rewrite_prompt = f""" | |
| Original: {original_text} | |
| This text is too similar to existing reference material. | |
| Please rewrite it to be more original while keeping the same intent and tone. | |
| Make it clearly different from the reference but equally engaging. | |
| Rewritten version: | |
| """ | |
| # For now, add a flag that this needs rewriting | |
| # In production, you'd call your LLM API here | |
| rewritten_content[field] = f"[NEEDS_REWRITE] {original_text}" | |
| # Log the issue | |
| print(f"🚨 Anti-copy detection: {field} flagged (similarity: {flag['similarity']:.3f})") | |
| print(f" Original: {original_text[:60]}...") | |
| print(f" Similar to: {flag['similar_reference'][:60]}...") | |
| return rewritten_content | |
| def index_all_scripts(): | |
| """Utility function to generate embeddings for all existing scripts""" | |
| retriever = RAGRetriever() | |
| with get_session() as ses: | |
| scripts = list(ses.exec(select(Script))) | |
| for script in scripts: | |
| # Check if embeddings already exist | |
| existing = ses.exec( | |
| select(Embedding).where(Embedding.script_id == script.id) | |
| ).first() | |
| if not existing: | |
| embeddings = retriever.generate_embeddings(script) | |
| for embedding in embeddings: | |
| ses.add(embedding) | |
| print(f"Generated embeddings for script {script.id}") | |
| ses.commit() | |
| print(f"Indexing complete! Processed {len(scripts)} scripts.") | |
| if __name__ == "__main__": | |
| # Run this to index your existing scripts | |
| index_all_scripts() | |