kreemyyyy commited on
Commit
fd88516
·
verified ·
1 Parent(s): 0c3a5df

Upload 13 files

Browse files
Files changed (13) hide show
  1. .gitignore +34 -0
  2. README.md +62 -20
  3. app.py +844 -0
  4. auto_scorer.py +240 -0
  5. bandit_learner.py +330 -0
  6. compliance.py +26 -0
  7. db.py +248 -0
  8. deepseek_client.py +59 -0
  9. models.py +103 -0
  10. packages.txt +1 -0
  11. rag_integration.py +350 -0
  12. rag_retrieval.py +444 -0
  13. requirements.txt +16 -3
.gitignore ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment and secrets
2
+ .env
3
+ .streamlit/secrets.toml
4
+ secrets.toml
5
+
6
+ # Python cache
7
+ __pycache__/
8
+ *.pyc
9
+ *.pyo
10
+ *.pyd
11
+ .Python
12
+ *.so
13
+
14
+ # Database files
15
+ *.db
16
+ *.sqlite
17
+ *.sqlite3
18
+
19
+ # IDE files
20
+ .vscode/
21
+ .idea/
22
+ *.swp
23
+ *.swo
24
+
25
+ # OS files
26
+ .DS_Store
27
+ Thumbs.db
28
+
29
+ # Logs
30
+ *.log
31
+
32
+ # Temporary files
33
+ *.tmp
34
+ *.temp
README.md CHANGED
@@ -1,20 +1,62 @@
1
- ---
2
- title: Scriptwriter
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: Streamlit template space
12
- license: mit
13
- ---
14
-
15
- # Welcome to Streamlit!
16
-
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
18
-
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: AI Script Studio
3
+ emoji: 🎬
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: streamlit
7
+ sdk_version: 1.37.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Generate Instagram-ready scripts with AI-powered RAG system
12
+ ---
13
+
14
+ # 🎬 AI Script Studio
15
+
16
+ Generate Instagram-ready scripts with AI using advanced RAG (Retrieval-Augmented Generation) system.
17
+
18
+ ## Features
19
+
20
+ - 🤖 **AI-Powered Generation**: Uses DeepSeek API for high-quality script generation
21
+ - 🧠 **RAG System**: Retrieval-Augmented Generation with semantic search
22
+ - 📊 **Multi-Armed Bandit Learning**: Self-improving generation policies
23
+ - 🎯 **Auto-Scoring**: LLM-based quality assessment
24
+ - 📈 **Rating System**: Human feedback integration with learning
25
+ - 🎨 **Multiple Personas**: Support for different creator styles
26
+ - 📝 **Content Types**: Various Instagram content formats
27
+
28
+ ## How It Works
29
+
30
+ 1. **Reference Retrieval**: Uses semantic search to find relevant examples
31
+ 2. **Policy Learning**: Multi-armed bandit optimizes generation parameters
32
+ 3. **AI Generation**: Creates scripts using retrieved references
33
+ 4. **Auto-Scoring**: LLM judges quality across 5 dimensions
34
+ 5. **Learning Loop**: System improves based on feedback
35
+
36
+ ## Usage
37
+
38
+ 1. Select your creator persona
39
+ 2. Choose content type and tone
40
+ 3. Add reference examples (optional)
41
+ 4. Generate scripts with AI
42
+ 5. Rate and provide feedback
43
+ 6. System learns and improves
44
+
45
+ ## Technical Stack
46
+
47
+ - **Frontend**: Streamlit
48
+ - **AI**: DeepSeek API
49
+ - **RAG**: Sentence Transformers + FAISS
50
+ - **Database**: SQLite with SQLModel
51
+ - **Learning**: Multi-armed bandit algorithms
52
+ - **Scoring**: LLM-based evaluation
53
+
54
+ ## Setup
55
+
56
+ 1. Add your DeepSeek API key to the secrets
57
+ 2. The app will automatically initialize the database
58
+ 3. Start generating scripts!
59
+
60
+ ## API Key
61
+
62
+ Get your free API key at: https://platform.deepseek.com/api_keys
app.py ADDED
@@ -0,0 +1,844 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, streamlit as st
2
+ from dotenv import load_dotenv
3
+ from sqlmodel import select
4
+ from db import init_db, get_session, add_rating
5
+ from models import Script, Revision
6
+ from deepseek_client import generate_scripts, revise_for, selective_rewrite
7
+ # Lazy import for RAG system to improve startup time
8
+ # from rag_integration import generate_scripts_rag
9
+ from compliance import blob_from, score_script
10
+ import time
11
+
12
+ # Configure page - MUST be first Streamlit command
13
+ st.set_page_config(
14
+ page_title="🎬 AI Script Studio",
15
+ layout="wide",
16
+ initial_sidebar_state="expanded"
17
+ )
18
+
19
+ def script_to_json_dict(script):
20
+ """Convert script to JSON-serializable dictionary"""
21
+ data = script.model_dump()
22
+ # Remove datetime fields that cause JSON serialization issues
23
+ data.pop('created_at', None)
24
+ data.pop('updated_at', None)
25
+ return data
26
+
27
+ # Load environment - works both locally and on Hugging Face Spaces
28
+ load_dotenv()
29
+
30
+ # Initialize database with error handling for cloud deployment
31
+ try:
32
+ init_db()
33
+ st.sidebar.write("✅ Database initialized successfully")
34
+ except Exception as e:
35
+ st.sidebar.write(f"⚠️ Database init warning: {str(e)}")
36
+ # Continue anyway - some features may be limited
37
+
38
+ # Check for API key in Streamlit secrets or environment
39
+ api_key = st.secrets.get("DEEPSEEK_API_KEY") if hasattr(st, 'secrets') and "DEEPSEEK_API_KEY" in st.secrets else os.getenv("DEEPSEEK_API_KEY")
40
+
41
+ # DEBUG INFO - remove after fixing
42
+ if hasattr(st, 'secrets'):
43
+ st.sidebar.write("🔍 DEBUG: Secrets available")
44
+ if "DEEPSEEK_API_KEY" in st.secrets:
45
+ st.sidebar.write("✅ DEEPSEEK_API_KEY found in secrets")
46
+ st.sidebar.write(f"🔑 Key length: {len(st.secrets['DEEPSEEK_API_KEY'])}")
47
+ st.sidebar.write(f"🔑 Key starts with: {st.secrets['DEEPSEEK_API_KEY'][:10]}...")
48
+ else:
49
+ st.sidebar.write("❌ DEEPSEEK_API_KEY NOT in secrets")
50
+ st.sidebar.write(f"Available secrets: {list(st.secrets.keys())}")
51
+ else:
52
+ st.sidebar.write("❌ No secrets available")
53
+
54
+ if not api_key:
55
+ st.error("🔑 **DeepSeek API Key Required**")
56
+ st.markdown("""
57
+ **For Local Development:**
58
+ - Create a `.env` file and add: `DEEPSEEK_API_KEY=your_key_here`
59
+
60
+ **For Streamlit Cloud:**
61
+ - Go to your app settings → Secrets
62
+ - Add: `DEEPSEEK_API_KEY = "your_key_here"`
63
+
64
+ Get your free API key at: https://platform.deepseek.com/api_keys
65
+ """)
66
+ st.stop()
67
+ else:
68
+ st.sidebar.write("✅ API key loaded successfully")
69
+
70
+
71
+ # Custom CSS for better styling
72
+ st.markdown("""
73
+ <style>
74
+ .main-header {
75
+ text-align: center;
76
+ padding: 1rem;
77
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
78
+ color: white;
79
+ border-radius: 10px;
80
+ margin-bottom: 2rem;
81
+ }
82
+ .step-container {
83
+ border: 2px solid #e1e1e1;
84
+ border-radius: 10px;
85
+ padding: 1rem;
86
+ margin-bottom: 1rem;
87
+ background-color: #f8f9fa;
88
+ }
89
+ .draft-card {
90
+ border: 1px solid #ddd;
91
+ border-radius: 8px;
92
+ padding: 0.8rem;
93
+ margin-bottom: 0.5rem;
94
+ background: white;
95
+ transition: all 0.2s ease;
96
+ }
97
+ .draft-card:hover {
98
+ box-shadow: 0 2px 8px rgba(0,0,0,0.1);
99
+ border-color: #667eea;
100
+ }
101
+ .success-box {
102
+ background-color: #d4edda;
103
+ border: 1px solid #c3e6cb;
104
+ border-radius: 5px;
105
+ padding: 1rem;
106
+ margin: 1rem 0;
107
+ }
108
+ </style>
109
+ """, unsafe_allow_html=True)
110
+
111
+ # Header
112
+ st.markdown("""
113
+ <div class="main-header">
114
+ <h1>🎬 AI Script Studio</h1>
115
+ <p>Generate Instagram-ready scripts with AI • Powered by DeepSeek</p>
116
+ </div>
117
+ """, unsafe_allow_html=True)
118
+
119
+ # Initialize session state
120
+ if 'generation_step' not in st.session_state:
121
+ st.session_state.generation_step = 'setup'
122
+ if 'generated_count' not in st.session_state:
123
+ st.session_state.generated_count = 0
124
+
125
+ # Sidebar - Generation Controls
126
+ with st.sidebar:
127
+ st.header("🎯 Script Generation")
128
+
129
+ # Step 1: Basic Settings
130
+ with st.expander("📝 Step 1: Basic Settings", expanded=True):
131
+ # Dynamic creator dropdown (pull from database + defaults)
132
+ with get_session() as ses:
133
+ db_creators = list(ses.exec(select(Script.creator).distinct()))
134
+ db_creator_names = [c for c in db_creators if c]
135
+
136
+ default_creators = ["Creator A", "Emily", "Anya", "Ava Cherrry", "Ava Xreyess", "FitBryceAdams", "RealCarlyJane", "Sophie Rain", "Zoe AloneAtHome"]
137
+ all_creators = list(set(default_creators + db_creator_names))
138
+ creator_options = sorted(all_creators)
139
+ creator = st.selectbox(
140
+ "Creator Name",
141
+ creator_options,
142
+ help="Choose from existing creators or your imported scripts"
143
+ )
144
+
145
+ # Expanded content types
146
+ content_type = st.selectbox(
147
+ "Content Type",
148
+ ["thirst-trap", "skit", "reaction-prank", "talking-style", "lifestyle", "fake-podcast", "dance-trend", "voice-tease-asmr"],
149
+ help="Choose the type of content you want to create"
150
+ )
151
+
152
+ # Multi-select tones
153
+ tone_options = ["naughty", "playful", "suggestive", "funny", "flirty", "bratty", "teasing", "intimate", "witty", "comedic", "confident", "wholesome", "asmr-voice"]
154
+ selected_tones = st.multiselect(
155
+ "Tone/Vibe (select multiple)",
156
+ tone_options,
157
+ default=["playful"],
158
+ help="Choose one or more tones - scripts often blend 2-3 vibes"
159
+ )
160
+ tone = ", ".join(selected_tones) if selected_tones else "playful"
161
+
162
+ n = st.slider(
163
+ "Number of drafts",
164
+ min_value=1,
165
+ max_value=20,
166
+ value=6,
167
+ help="How many script variations to generate"
168
+ )
169
+
170
+ # Step 2: Persona & Style
171
+ with st.expander("👤 Step 2: Persona & Style", expanded=True):
172
+ # Persona presets
173
+ persona_presets = {
174
+ "Girl-next-door": "girl-next-door; playful; witty; approachable",
175
+ "Bratty tease": "bratty; teasing; demanding; playful attitude",
176
+ "Dominant/In control": "confident; in control; commanding; assertive",
177
+ "Innocent but suggestive": "innocent; sweet; accidentally suggestive; naive charm",
178
+ "Party girl": "outgoing; fun; social; party vibes; energetic",
179
+ "Gym fitspo": "fitness focused; motivational; athletic; body confident",
180
+ "ASMR/Voice fetish": "soft spoken; intimate; soothing; sensual voice",
181
+ "Girlfriend experience": "loving; intimate; caring; relationship vibes",
182
+ "Funny meme-style": "comedic; meme references; internet culture; quirky",
183
+ "Candid/Lifestyle": "authentic; relatable; everyday life; natural"
184
+ }
185
+
186
+ col1, col2 = st.columns([0.6, 0.4])
187
+ with col1:
188
+ persona_preset = st.selectbox(
189
+ "Persona Preset",
190
+ ["Custom"] + list(persona_presets.keys()),
191
+ help="Choose a preset or use custom"
192
+ )
193
+
194
+ with col2:
195
+ if persona_preset != "Custom":
196
+ if st.button("📋 Use Preset", use_container_width=True):
197
+ st.session_state.persona_text = persona_presets[persona_preset]
198
+
199
+ persona = st.text_area(
200
+ "Persona Description",
201
+ value=st.session_state.get('persona_text', "girl-next-door; playful; witty"),
202
+ help="Describe the character/personality for the scripts"
203
+ )
204
+
205
+ # Compliance/Boundaries presets
206
+ boundary_presets = {
207
+ "Safe IG mode": "No explicit words; no sexual acts; suggestive only; no banned IG terms; keep it flirty but clean",
208
+ "Spicy mode": "Innuendos allowed; suggestive language OK; no explicit acts; can be naughty but not graphic",
209
+ "Brand-safe": "No swearing; no sex references; just flirty and fun; wholesome with hint of tease",
210
+ "Mild NSFW": "Moaning sounds OK; wet references allowed; squirt innuendo OK; suggestive but not explicit",
211
+ "Platform optimized": "Avoid flagged keywords; use creative euphemisms; suggestive storytelling style"
212
+ }
213
+
214
+ col1, col2 = st.columns([0.6, 0.4])
215
+ with col1:
216
+ boundary_preset = st.selectbox(
217
+ "Compliance Preset",
218
+ ["Custom"] + list(boundary_presets.keys()),
219
+ help="Choose platform-appropriate safety rules"
220
+ )
221
+
222
+ with col2:
223
+ if boundary_preset != "Custom":
224
+ if st.button("🛡️ Use Preset", use_container_width=True):
225
+ st.session_state.boundaries_text = boundary_presets[boundary_preset]
226
+
227
+ boundaries = st.text_area(
228
+ "Content Boundaries",
229
+ value=st.session_state.get('boundaries_text', "No explicit words; no solicitation; no age refs"),
230
+ help="What should the AI avoid? Set your safety guidelines here"
231
+ )
232
+
233
+ # Step 3: Advanced Options
234
+ with st.expander("⚡ Step 3: Advanced Options", expanded=False):
235
+ col1, col2 = st.columns(2)
236
+
237
+ with col1:
238
+ # Hook style
239
+ hook_style = st.selectbox(
240
+ "Hook Style",
241
+ ["Auto", "Question", "Confession", "Contrarian", "PSA", "Tease", "Command", "Shock"],
242
+ help="How should the hook start?"
243
+ )
244
+
245
+ # Length
246
+ length = st.selectbox(
247
+ "Target Length",
248
+ ["Auto", "Short (5-7s)", "Medium (8-12s)", "Longer (13-20s)"],
249
+ help="How long should the script be?"
250
+ )
251
+
252
+ # Risk level
253
+ risk_level = st.slider(
254
+ "Risk Level",
255
+ min_value=1,
256
+ max_value=5,
257
+ value=3,
258
+ help="1=Safe, 3=Suggestive, 5=Spicy"
259
+ )
260
+
261
+ with col2:
262
+ # Retention gimmick
263
+ retention = st.selectbox(
264
+ "Retention Hook",
265
+ ["Auto", "Twist ending", "Shock reveal", "Naughty payoff", "Innocent→dirty flip", "Cliffhanger"],
266
+ help="How to keep viewers watching?"
267
+ )
268
+
269
+ # Shot type
270
+ shot_type = st.selectbox(
271
+ "Shot Type",
272
+ ["Auto", "POV", "Selfie cam", "Tripod", "Over-the-shoulder", "Mirror shot"],
273
+ help="Camera angle/perspective"
274
+ )
275
+
276
+ # Wardrobe
277
+ wardrobe = st.selectbox(
278
+ "Wardrobe/Setting",
279
+ ["Auto", "Gym fit", "Bikini", "Bed outfit", "Towel", "Dress", "Casual", "Kitchen", "Car"],
280
+ help="Setting or outfit context"
281
+ )
282
+
283
+ # Step 4: Optional References
284
+ with st.expander("📚 Step 4: Extra References (Optional)", expanded=False):
285
+ st.info("💡 The AI automatically uses your database references, but you can add more here")
286
+ refs_text = st.text_area(
287
+ "Additional Reference Lines",
288
+ value="",
289
+ height=100,
290
+ help="Add extra inspiration lines (one per line)"
291
+ )
292
+
293
+ # Generation Button
294
+ st.markdown("---")
295
+
296
+ # Show reference count
297
+ from db import get_hybrid_refs
298
+
299
+ # Map new content types to existing database types for compatibility
300
+ content_type_mapping = {
301
+ "thirst-trap": "talking_style / thirst_trap",
302
+ "skit": "comedy",
303
+ "reaction-prank": "prank",
304
+ "talking-style": "talking_style",
305
+ "lifestyle": "lifestyle",
306
+ "fake-podcast": "fake-podcast",
307
+ "dance-trend": "trend-adaptation",
308
+ "voice-tease-asmr": "talking_style"
309
+ }
310
+
311
+ mapped_content_type = content_type_mapping.get(content_type, content_type)
312
+ ref_count = len(get_hybrid_refs(creator, mapped_content_type, k=6))
313
+
314
+ st.info(f"🤖 AI will use {ref_count} database references + your extras")
315
+
316
+ generate_button = st.button(
317
+ "🚀 Generate Scripts",
318
+ type="primary",
319
+ use_container_width=True
320
+ )
321
+
322
+ # Generation Process
323
+ if generate_button:
324
+ with st.spinner("🧠 AI is creating your scripts..."):
325
+ try:
326
+ # Get manual refs from text area
327
+ manual_refs = [x.strip() for x in refs_text.split("\n") if x.strip()]
328
+
329
+ # Get automatic refs from selected creator scripts in database using content type mapping
330
+ auto_refs = get_hybrid_refs(creator, mapped_content_type, k=6)
331
+
332
+ # Combine both
333
+ all_refs = manual_refs + auto_refs
334
+
335
+ # Progress indicator
336
+ progress_bar = st.progress(0)
337
+ status_text = st.empty()
338
+
339
+ status_text.text("🔍 Analyzing references...")
340
+ progress_bar.progress(25)
341
+ time.sleep(0.5)
342
+
343
+ status_text.text("🧠 RAG system selecting optimal references...")
344
+ progress_bar.progress(40)
345
+ time.sleep(0.3)
346
+
347
+ status_text.text("✨ Generating enhanced content with AI learning...")
348
+ progress_bar.progress(60)
349
+
350
+ # Build enhanced prompt from advanced options
351
+ advanced_prompt = ""
352
+ if hook_style != "Auto":
353
+ advanced_prompt += f"Hook style: {hook_style}. "
354
+ if length != "Auto":
355
+ advanced_prompt += f"Target length: {length}. "
356
+ if retention != "Auto":
357
+ advanced_prompt += f"Retention strategy: {retention}. "
358
+ if shot_type != "Auto":
359
+ advanced_prompt += f"Shot type: {shot_type}. "
360
+ if wardrobe != "Auto":
361
+ advanced_prompt += f"Setting/wardrobe: {wardrobe}. "
362
+ if risk_level != 3:
363
+ risk_desc = {1: "very safe", 2: "mild", 3: "suggestive", 4: "spicy", 5: "very spicy"}
364
+ advanced_prompt += f"Risk level: {risk_desc[risk_level]}. "
365
+
366
+ # Enhance boundaries with advanced prompt
367
+ enhanced_boundaries = boundaries
368
+ if advanced_prompt:
369
+ enhanced_boundaries += f"\n\nADVANCED GUIDANCE: {advanced_prompt}"
370
+
371
+ # Generate scripts with enhanced RAG system (lazy import)
372
+ try:
373
+ from rag_integration import generate_scripts_rag
374
+ drafts = generate_scripts_rag(persona, enhanced_boundaries, content_type, tone, all_refs, n=n)
375
+ except ImportError as e:
376
+ st.warning(f"RAG system not available: {e}. Using fallback generation.")
377
+ # Fallback to simple generation
378
+ drafts = generate_scripts(enhanced_boundaries, n)
379
+
380
+ progress_bar.progress(75)
381
+ status_text.text("💾 Saving to database...")
382
+
383
+ # Save to database
384
+ with get_session() as ses:
385
+ for d in drafts:
386
+ lvl, _ = score_script(" ".join([d.get("title",""), d.get("hook",""), *d.get("beats",[]), d.get("voiceover",""), d.get("caption",""), d.get("cta","")]))
387
+ s = Script(
388
+ creator=creator, content_type=content_type, tone=tone,
389
+ title=d["title"], hook=d["hook"], beats=d["beats"],
390
+ voiceover=d["voiceover"], caption=d["caption"],
391
+ hashtags=d.get("hashtags",[]), cta=d.get("cta",""),
392
+ compliance=lvl, source="ai"
393
+ )
394
+ ses.add(s)
395
+ ses.commit()
396
+
397
+ progress_bar.progress(100)
398
+ status_text.text("")
399
+ progress_bar.empty()
400
+
401
+ st.session_state.generated_count += len(drafts)
402
+ st.success(f"🎉 Generated {len(drafts)} scripts successfully!")
403
+
404
+ # Show which refs were used and advanced options
405
+ col1, col2 = st.columns(2)
406
+ with col1:
407
+ if auto_refs:
408
+ st.markdown("**🤖 Hybrid refs used this run:**")
409
+ for line in auto_refs[:3]: # Show first 3
410
+ st.write(f"• {line}")
411
+
412
+ with col2:
413
+ if advanced_prompt:
414
+ st.markdown("**⚡ Advanced options applied:**")
415
+ st.write(f"• {advanced_prompt[:100]}...")
416
+ st.write(f"**📊 Settings:** {tone} • {content_type}")
417
+
418
+ st.balloons()
419
+
420
+ # Auto-refresh to show new drafts
421
+ time.sleep(1)
422
+ st.rerun()
423
+
424
+ except Exception as e:
425
+ st.error(f"❌ Generation failed: {str(e)}")
426
+ st.write("💡 Try adjusting your parameters or check your API key")
427
+
428
+ # Quick Actions
429
+ st.markdown("---")
430
+ st.subheader("⚡ Quick Actions")
431
+
432
+ col1, col2 = st.columns(2)
433
+ with col1:
434
+ if st.button("🔄 Refresh", use_container_width=True):
435
+ st.rerun()
436
+ with col2:
437
+ if st.button("🗑️ Clear All", use_container_width=True, help="Delete all your generated scripts"):
438
+ if st.session_state.get('confirm_clear'):
439
+ with get_session() as ses:
440
+ scripts_to_delete = list(ses.exec(select(Script).where(Script.creator == creator, Script.source == "ai")))
441
+ for script in scripts_to_delete:
442
+ ses.delete(script)
443
+ ses.commit()
444
+ st.success("🗑️ All drafts cleared!")
445
+ st.session_state.confirm_clear = False
446
+ st.rerun()
447
+ else:
448
+ st.session_state.confirm_clear = True
449
+ st.warning("Click again to confirm deletion!")
450
+
451
+ # Main Area
452
+ tab1, tab2, tab3 = st.tabs(["📝 Draft Review", "🎯 Filters", "📊 Analytics"])
453
+
454
+ with tab1:
455
+ # Load drafts
456
+ with get_session() as ses:
457
+ q = select(Script).where(Script.creator == creator, Script.source == "ai")
458
+ all_drafts = list(ses.exec(q))
459
+
460
+ if not all_drafts:
461
+ st.markdown("""
462
+ <div style="text-align: center; padding: 3rem;">
463
+ <h3>🎬 Ready to Create Amazing Scripts?</h3>
464
+ <p style="font-size: 1.2rem; color: #666;">
465
+ 👈 Use the sidebar to generate your first batch of AI scripts<br>
466
+ 🤖 The AI will learn from successful examples in the database<br>
467
+ ✨ Then review, edit, and perfect your scripts here
468
+ </p>
469
+ </div>
470
+ """, unsafe_allow_html=True)
471
+
472
+ if st.session_state.generated_count > 0:
473
+ st.info(f"🎉 You've generated {st.session_state.generated_count} scripts so far! Use filters to find them.")
474
+ else:
475
+ # Draft management
476
+ col1, col2 = st.columns([0.4, 0.6], gap="large")
477
+
478
+ with col1:
479
+ st.subheader(f"📋 Your Drafts ({len(all_drafts)})")
480
+
481
+ # Quick filters
482
+ filter_col1, filter_col2 = st.columns(2)
483
+ with filter_col1:
484
+ compliance_filter = st.selectbox(
485
+ "Compliance",
486
+ ["All", "PASS", "WARN", "FAIL"],
487
+ key="compliance_filter"
488
+ )
489
+ with filter_col2:
490
+ sort_by = st.selectbox(
491
+ "Sort by",
492
+ ["Newest", "Oldest", "Title"],
493
+ key="sort_filter"
494
+ )
495
+
496
+ # Apply filters
497
+ filtered_drafts = all_drafts
498
+ if compliance_filter != "All":
499
+ filtered_drafts = [d for d in filtered_drafts if d.compliance.upper() == compliance_filter]
500
+
501
+ # Apply sorting
502
+ if sort_by == "Newest":
503
+ filtered_drafts.sort(key=lambda x: x.created_at, reverse=True)
504
+ elif sort_by == "Oldest":
505
+ filtered_drafts.sort(key=lambda x: x.created_at)
506
+ else: # Title
507
+ filtered_drafts.sort(key=lambda x: x.title)
508
+
509
+ # Draft cards
510
+ selected_id = st.session_state.get("selected_id")
511
+
512
+ for draft in filtered_drafts:
513
+ # Compliance color coding
514
+ compliance_color = {
515
+ "pass": "🟢",
516
+ "warn": "🟡",
517
+ "fail": "🔴"
518
+ }.get(draft.compliance, "⚪")
519
+
520
+ # Create card
521
+ with st.container(border=True):
522
+ if st.button(
523
+ f"{compliance_color} {draft.title}",
524
+ key=f"select-{draft.id}",
525
+ use_container_width=True
526
+ ):
527
+ st.session_state["selected_id"] = draft.id
528
+ selected_id = draft.id
529
+
530
+ st.caption(f"🎭 {draft.tone} • 📅 {draft.created_at.strftime('%m/%d %H:%M')}")
531
+
532
+ # Preview hook
533
+ if draft.hook:
534
+ st.markdown(f"*{draft.hook[:80]}{'...' if len(draft.hook) > 80 else ''}*")
535
+
536
+ with col2:
537
+ st.subheader("✏️ Script Editor")
538
+
539
+ if not filtered_drafts:
540
+ st.info("No drafts match your filters. Try adjusting the filter settings.")
541
+ else:
542
+ # Auto-select first draft if none selected
543
+ if not selected_id or selected_id not in [d.id for d in filtered_drafts]:
544
+ selected_id = filtered_drafts[0].id
545
+ st.session_state["selected_id"] = selected_id
546
+
547
+ # Get current draft
548
+ current = next((x for x in filtered_drafts if x.id == selected_id), filtered_drafts[0])
549
+
550
+ # Editor tabs
551
+ edit_tab1, edit_tab2, edit_tab3 = st.tabs(["📝 Edit", "🛠️ AI Tools", "📜 History"])
552
+
553
+ with edit_tab1:
554
+ # Main editing fields
555
+ with st.form("edit_script"):
556
+ title = st.text_input("Title", value=current.title)
557
+ hook = st.text_area("Hook", value=current.hook or "", height=80)
558
+ beats_text = st.text_area("Beats (one per line)", value="\n".join(current.beats or []), height=120)
559
+ voiceover = st.text_area("Voiceover", value=current.voiceover or "", height=80)
560
+ caption = st.text_area("Caption", value=current.caption or "", height=100)
561
+ # Clean up hashtags display - remove commas, show as space-separated
562
+ current_hashtags = current.hashtags or []
563
+ hashtags_display = " ".join(current_hashtags) if current_hashtags else ""
564
+ hashtags = st.text_input("Hashtags (space separated)", value=hashtags_display, help="Enter hashtags like: #gym #fitness #workout")
565
+ cta = st.text_input("Call to Action", value=current.cta or "")
566
+
567
+ # Submit button
568
+ if st.form_submit_button("💾 Save Changes", type="primary", use_container_width=True):
569
+ with get_session() as ses:
570
+ dbs = ses.get(Script, current.id)
571
+ dbs.title = title
572
+ dbs.hook = hook
573
+ dbs.beats = [x.strip() for x in beats_text.split("\n") if x.strip()]
574
+ dbs.voiceover = voiceover
575
+ dbs.caption = caption
576
+ # Parse hashtags from space-separated input
577
+ dbs.hashtags = [x.strip() for x in hashtags.split() if x.strip()]
578
+ dbs.cta = cta
579
+
580
+ # Update compliance
581
+ lvl, _ = score_script(blob_from(dbs.model_dump()))
582
+ dbs.compliance = lvl
583
+
584
+ ses.add(dbs)
585
+ ses.commit()
586
+
587
+ st.success("✅ Script saved successfully!")
588
+ time.sleep(1)
589
+ st.rerun()
590
+
591
+ # Rating widget
592
+ st.markdown("### Rate this script (feeds future generations)")
593
+
594
+ # Show current ratings if any
595
+ if current.ratings_count > 0:
596
+ st.info(f"📊 Current ratings ({current.ratings_count} ratings): Overall: {current.score_overall:.1f}/5.0, Hook: {current.score_hook:.1f}/5.0, Originality: {current.score_originality:.1f}/5.0")
597
+
598
+ with st.form("rate_script"):
599
+ colA, colB, colC, colD, colE = st.columns(5)
600
+ overall = colA.slider("Overall", 1.0, 5.0, 4.0, 0.5)
601
+ hook_s = colB.slider("Hook clarity", 1.0, 5.0, 4.0, 0.5)
602
+ orig_s = colC.slider("Originality", 1.0, 5.0, 4.0, 0.5)
603
+ fit_s = colD.slider("Style fit", 1.0, 5.0, 4.0, 0.5)
604
+ safe_s = colE.slider("Safety", 1.0, 5.0, 4.0, 0.5)
605
+ notes = st.text_input("Notes (optional)")
606
+
607
+ if st.form_submit_button("💫 Save rating", type="secondary", use_container_width=True):
608
+ add_rating(
609
+ script_id=current.id,
610
+ overall=overall, hook=hook_s, originality=orig_s,
611
+ style_fit=fit_s, safety=safe_s, notes=notes, rater="human"
612
+ )
613
+ st.success("Rating saved. Future generations will weigh this higher.")
614
+ time.sleep(1)
615
+ st.rerun()
616
+
617
+ with edit_tab2:
618
+ st.write("🤖 **AI-Powered Improvements**")
619
+
620
+ # Quick AI actions
621
+ col1, col2 = st.columns(2)
622
+
623
+ with col1:
624
+ if st.button("🛡️ Make Safer", use_container_width=True):
625
+ with st.spinner("Making content safer..."):
626
+ revised = revise_for("be Instagram-compliant and safer", script_to_json_dict(current), "Remove risky phrases; keep intent and beat order.")
627
+ with get_session() as ses:
628
+ dbs = ses.get(Script, current.id)
629
+ before = dbs.caption
630
+ dbs.caption = revised.get("caption", dbs.caption)
631
+ lvl, _ = score_script(blob_from(revised))
632
+ dbs.compliance = lvl
633
+ ses.add(dbs)
634
+ ses.commit()
635
+ ses.add(Revision(script_id=dbs.id, label="Auto safer", field="caption", before=before, after=dbs.caption))
636
+ ses.commit()
637
+ st.success("✅ Content made safer!")
638
+ st.rerun()
639
+
640
+ if st.button("✨ More Playful", use_container_width=True):
641
+ with st.spinner("Adding playful vibes..."):
642
+ revised = revise_for("be more playful (keep safe)", script_to_json_dict(current), "Increase playful tone without adding risk.")
643
+ with get_session() as ses:
644
+ dbs = ses.get(Script, current.id)
645
+ before = dbs.hook
646
+ dbs.hook = revised.get("hook", dbs.hook)
647
+ ses.add(dbs)
648
+ ses.commit()
649
+ ses.add(Revision(script_id=dbs.id, label="More playful", field="hook", before=before, after=dbs.hook))
650
+ ses.commit()
651
+ st.success("✨ Added playful energy!")
652
+ st.rerun()
653
+
654
+ with col2:
655
+ if st.button("✂️ Shorter Hook", use_container_width=True):
656
+ with st.spinner("Tightening hook..."):
657
+ revised = revise_for("shorten the hook to <= 8 words", script_to_json_dict(current), "Shorten only the hook, keep intent.")
658
+ with get_session() as ses:
659
+ dbs = ses.get(Script, current.id)
660
+ before = dbs.hook
661
+ dbs.hook = revised.get("hook", dbs.hook)
662
+ ses.add(dbs)
663
+ ses.commit()
664
+ ses.add(Revision(script_id=dbs.id, label="Shorter hook", field="hook", before=before, after=dbs.hook))
665
+ ses.commit()
666
+ st.success("✂️ Hook tightened!")
667
+ st.rerun()
668
+
669
+ if st.button("🇬🇧 Localize (UK)", use_container_width=True):
670
+ with st.spinner("Localizing content..."):
671
+ revised = revise_for("localize to UK English", script_to_json_dict(current), "Adjust spelling/phrasing to UK without changing content.")
672
+ with get_session() as ses:
673
+ dbs = ses.get(Script, current.id)
674
+ before = dbs.caption
675
+ dbs.caption = revised.get("caption", dbs.caption)
676
+ ses.add(dbs)
677
+ ses.commit()
678
+ ses.add(Revision(script_id=dbs.id, label="Localize UK", field="caption", before=before, after=dbs.caption))
679
+ ses.commit()
680
+ st.success("🇬🇧 Localized to UK!")
681
+ st.rerun()
682
+
683
+ # Custom rewrite section
684
+ st.markdown("---")
685
+ st.write("🎯 **Custom Rewrite**")
686
+
687
+ with st.form("custom_rewrite"):
688
+ rewrite_col1, rewrite_col2 = st.columns([0.6, 0.4])
689
+
690
+ with rewrite_col1:
691
+ field = st.selectbox("Field to Edit", ["title","hook","voiceover","caption","cta","beats"])
692
+ snippet = st.text_input("Exact text you want to change")
693
+
694
+ with rewrite_col2:
695
+ prompt = st.text_input("How to rewrite it")
696
+
697
+ if st.form_submit_button("🪄 Rewrite", use_container_width=True):
698
+ if snippet and prompt:
699
+ with st.spinner("AI is rewriting..."):
700
+ draft = script_to_json_dict(current)
701
+ revised = selective_rewrite(draft, field, snippet, prompt)
702
+ with get_session() as ses:
703
+ dbs = ses.get(Script, current.id)
704
+ before = getattr(dbs, field)
705
+ setattr(dbs, field, revised.get(field, before))
706
+ lvl, _ = score_script(blob_from(dbs.model_dump()))
707
+ dbs.compliance = lvl
708
+ ses.add(dbs)
709
+ ses.commit()
710
+ ses.add(Revision(script_id=dbs.id, label="Custom rewrite", field=field, before=str(before), after=str(getattr(dbs, field))))
711
+ ses.commit()
712
+ st.success("🪄 Rewrite complete!")
713
+ st.rerun()
714
+ else:
715
+ st.error("Please fill in both the text and rewrite instructions")
716
+
717
+ with edit_tab3:
718
+ st.write("📜 **Revision History**")
719
+
720
+ with get_session() as ses:
721
+ revisions = list(ses.exec(
722
+ select(Revision).where(Revision.script_id==current.id).order_by(Revision.created_at.desc())
723
+ ))
724
+
725
+ if not revisions:
726
+ st.info("No revisions yet. Make some changes to see the history!")
727
+ else:
728
+ for rev in revisions:
729
+ with st.expander(f"🔄 {rev.label} • {rev.field} • {rev.created_at.strftime('%m/%d %H:%M')}"):
730
+ col1, col2 = st.columns(2)
731
+ with col1:
732
+ st.write("**Before:**")
733
+ st.code(rev.before)
734
+ with col2:
735
+ st.write("**After:**")
736
+ st.code(rev.after)
737
+
738
+ with tab2:
739
+ st.subheader("🎯 Advanced Filters & Search")
740
+
741
+ # Advanced filtering interface
742
+ filter_col1, filter_col2, filter_col3 = st.columns(3)
743
+
744
+ with filter_col1:
745
+ creator_filter = st.selectbox("Creator", ["All"] + ["Creator A", "Emily"])
746
+ content_filter = st.selectbox("Content Type", ["All"] + ["thirst-trap", "lifestyle", "comedy", "prank", "fake-podcast", "trend-adaptation"])
747
+
748
+ with filter_col2:
749
+ compliance_filter_adv = st.selectbox("Compliance Status", ["All", "PASS", "WARN", "FAIL"])
750
+ source_filter = st.selectbox("Source", ["All", "AI Generated", "Imported", "Manual"])
751
+
752
+ with filter_col3:
753
+ date_filter = st.selectbox("Date Range", ["All Time", "Today", "This Week", "This Month"])
754
+ search_text = st.text_input("🔍 Search in titles/content")
755
+
756
+ # Apply advanced filters and show results
757
+ with get_session() as ses:
758
+ query = select(Script)
759
+
760
+ # Apply filters
761
+ if creator_filter != "All":
762
+ query = query.where(Script.creator == creator_filter)
763
+ if content_filter != "All":
764
+ query = query.where(Script.content_type == content_filter)
765
+ if compliance_filter_adv != "All":
766
+ query = query.where(Script.compliance == compliance_filter_adv.lower())
767
+
768
+ filtered_results = list(ses.exec(query))
769
+
770
+ # Search in text
771
+ if search_text:
772
+ filtered_results = [
773
+ r for r in filtered_results
774
+ if search_text.lower() in r.title.lower() or
775
+ search_text.lower() in (r.hook or "").lower() or
776
+ search_text.lower() in (r.caption or "").lower()
777
+ ]
778
+
779
+ st.write(f"**Found {len(filtered_results)} scripts**")
780
+
781
+ # Display filtered results
782
+ if filtered_results:
783
+ for script in filtered_results[:10]: # Show first 10
784
+ with st.expander(f"{script.compliance.upper()} • {script.title} • {script.creator}"):
785
+ st.write(f"**Hook:** {script.hook}")
786
+ st.write(f"**Type:** {script.content_type} • **Tone:** {script.tone}")
787
+ st.write(f"**Created:** {script.created_at.strftime('%Y-%m-%d %H:%M')}")
788
+
789
+ with tab3:
790
+ st.subheader("📊 Script Analytics")
791
+
792
+ # Get all scripts for analytics
793
+ with get_session() as ses:
794
+ all_scripts = list(ses.exec(select(Script)))
795
+
796
+ if all_scripts:
797
+ # Create metrics
798
+ col1, col2, col3, col4 = st.columns(4)
799
+
800
+ with col1:
801
+ st.metric("Total Scripts", len(all_scripts))
802
+
803
+ with col2:
804
+ ai_generated = len([s for s in all_scripts if s.source == "ai"])
805
+ st.metric("AI Generated", ai_generated)
806
+
807
+ with col3:
808
+ passed_compliance = len([s for s in all_scripts if s.compliance == "pass"])
809
+ st.metric("Compliance PASS", passed_compliance)
810
+
811
+ with col4:
812
+ unique_creators = len(set(s.creator for s in all_scripts))
813
+ st.metric("Creators", unique_creators)
814
+
815
+ # Charts and insights
816
+ st.markdown("### 📈 Content Insights")
817
+
818
+ # Compliance distribution
819
+ compliance_counts = {}
820
+ for script in all_scripts:
821
+ compliance_counts[script.compliance] = compliance_counts.get(script.compliance, 0) + 1
822
+
823
+ if compliance_counts:
824
+ st.bar_chart(compliance_counts)
825
+
826
+ # Content type distribution
827
+ type_counts = {}
828
+ for script in all_scripts:
829
+ type_counts[script.content_type] = type_counts.get(script.content_type, 0) + 1
830
+
831
+ if type_counts:
832
+ st.bar_chart(type_counts)
833
+
834
+ else:
835
+ st.info("📊 Generate some scripts to see analytics!")
836
+
837
+ # Footer
838
+ st.markdown("---")
839
+ st.markdown("""
840
+ <div style="text-align: center; color: #666; padding: 1rem;">
841
+ 🎬 AI Script Studio • Built with Streamlit & DeepSeek AI<br>
842
+ 💡 Tip: Generate scripts in batches, then refine with AI tools for best results
843
+ </div>
844
+ """, unsafe_allow_html=True)
auto_scorer.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Auto-scoring system using LLM judges for script quality assessment
3
+ Integrates with existing DeepSeek client
4
+ """
5
+
6
+ import json
7
+ from typing import Dict, List, Tuple
8
+ from sqlmodel import Session, select
9
+ from datetime import datetime, timedelta
10
+
11
+ from models import Script, AutoScore, PolicyWeights
12
+ from db import get_session
13
+ from deepseek_client import chat
14
+
15
+ class AutoScorer:
16
+ def __init__(self, confidence_threshold: float = 0.7):
17
+ self.confidence_threshold = confidence_threshold
18
+
19
+ def score_script(self, script_data: Dict) -> Dict[str, float]:
20
+ """
21
+ Score a script using LLM judge across 5 dimensions
22
+ Returns scores and confidence level
23
+ """
24
+
25
+ system_prompt = """You are an expert Instagram content analyst. Score this script on 5 dimensions (1-5 scale):
26
+
27
+ 1. OVERALL: General quality and effectiveness (1=poor, 5=excellent)
28
+ 2. HOOK: How compelling is the opening (1=boring, 5=irresistible)
29
+ 3. ORIGINALITY: How unique/creative (1=generic, 5=highly original)
30
+ 4. STYLE_FIT: How well it matches the persona (1=off-brand, 5=perfect fit)
31
+ 5. SAFETY: Instagram compliance (1=risky, 5=completely safe)
32
+
33
+ Return ONLY a JSON object with: {"overall": X, "hook": X, "originality": X, "style_fit": X, "safety": X, "confidence": X, "reasoning": "brief explanation"}
34
+
35
+ Be consistent and objective. Confidence should be 0.1-1.0 based on how certain you are."""
36
+
37
+ user_prompt = f"""
38
+ Script to score:
39
+ Title: {script_data.get('title', '')}
40
+ Hook: {script_data.get('hook', '')}
41
+ Beats: {script_data.get('beats', [])}
42
+ Caption: {script_data.get('caption', '')}
43
+ Persona: {script_data.get('creator', '')}
44
+ Content Type: {script_data.get('content_type', '')}
45
+ Tone: {script_data.get('tone', '')}
46
+
47
+ Score this script now."""
48
+
49
+ try:
50
+ response = chat([
51
+ {"role": "system", "content": system_prompt},
52
+ {"role": "user", "content": user_prompt}
53
+ ], temperature=0.3) # Low temperature for consistent scoring
54
+
55
+ # Extract JSON from response
56
+ start = response.find("{")
57
+ end = response.rfind("}") + 1
58
+
59
+ if start >= 0 and end > start:
60
+ scores = json.loads(response[start:end])
61
+
62
+ # Validate scores are in range
63
+ required_keys = ['overall', 'hook', 'originality', 'style_fit', 'safety']
64
+ for key in required_keys:
65
+ if key not in scores or not (1 <= scores[key] <= 5):
66
+ raise ValueError(f"Invalid score for {key}")
67
+
68
+ # Ensure confidence is present and valid
69
+ if 'confidence' not in scores or not (0.1 <= scores['confidence'] <= 1.0):
70
+ scores['confidence'] = 0.7 # Default confidence
71
+
72
+ return scores
73
+ else:
74
+ raise ValueError("No valid JSON found in response")
75
+
76
+ except Exception as e:
77
+ print(f"Auto-scoring failed: {e}")
78
+ # Return neutral scores with low confidence
79
+ return {
80
+ 'overall': 3.0,
81
+ 'hook': 3.0,
82
+ 'originality': 3.0,
83
+ 'style_fit': 3.0,
84
+ 'safety': 3.0,
85
+ 'confidence': 0.3,
86
+ 'reasoning': f"Scoring failed: {str(e)}"
87
+ }
88
+
89
+ def score_and_store(self, script_id: int) -> AutoScore:
90
+ """Score a script and store in database"""
91
+ with get_session() as ses:
92
+ script = ses.get(Script, script_id)
93
+ if not script:
94
+ raise ValueError(f"Script {script_id} not found")
95
+
96
+ # Prepare script data for scoring
97
+ script_data = {
98
+ 'title': script.title,
99
+ 'hook': script.hook,
100
+ 'beats': script.beats,
101
+ 'caption': script.caption,
102
+ 'creator': script.creator,
103
+ 'content_type': script.content_type,
104
+ 'tone': script.tone
105
+ }
106
+
107
+ # Get scores
108
+ scores = self.score_script(script_data)
109
+
110
+ # Store auto-score
111
+ auto_score = AutoScore(
112
+ script_id=script_id,
113
+ overall=scores['overall'],
114
+ hook=scores['hook'],
115
+ originality=scores['originality'],
116
+ style_fit=scores['style_fit'],
117
+ safety=scores['safety'],
118
+ confidence=scores['confidence'],
119
+ notes=scores.get('reasoning', '')
120
+ )
121
+
122
+ ses.add(auto_score)
123
+ ses.commit()
124
+ ses.refresh(auto_score)
125
+
126
+ return auto_score
127
+
128
+ def batch_score_recent(self, hours: int = 24) -> List[AutoScore]:
129
+ """Score all recently generated scripts that haven't been auto-scored"""
130
+ cutoff = datetime.utcnow() - timedelta(hours=hours)
131
+
132
+ with get_session() as ses:
133
+ # Find scripts without auto-scores
134
+ recent_scripts = ses.exec(
135
+ select(Script).where(
136
+ Script.created_at >= cutoff,
137
+ Script.source == "ai" # Only score AI-generated scripts
138
+ )
139
+ ).all()
140
+
141
+ # Filter out already scored
142
+ unscored = []
143
+ for script in recent_scripts:
144
+ existing_score = ses.exec(
145
+ select(AutoScore).where(AutoScore.script_id == script.id)
146
+ ).first()
147
+ if not existing_score:
148
+ unscored.append(script)
149
+
150
+ print(f"Auto-scoring {len(unscored)} recent scripts...")
151
+
152
+ results = []
153
+ for script in unscored:
154
+ try:
155
+ auto_score = self.score_and_store(script.id)
156
+ results.append(auto_score)
157
+ print(f"Scored script {script.id}: {auto_score.overall:.1f}/5.0")
158
+ except Exception as e:
159
+ print(f"Failed to score script {script.id}: {e}")
160
+
161
+ return results
162
+
163
+ class ScriptReranker:
164
+ """Rerank generated scripts using composite scoring"""
165
+
166
+ def __init__(self, weights: Dict[str, float] = None):
167
+ self.weights = weights or {
168
+ 'overall': 0.35,
169
+ 'hook': 0.20,
170
+ 'originality': 0.15,
171
+ 'style_fit': 0.15,
172
+ 'safety': 0.15
173
+ }
174
+
175
+ def rerank_scripts(self, script_ids: List[int]) -> List[Tuple[int, float]]:
176
+ """
177
+ Rerank scripts by composite score
178
+ Returns list of (script_id, composite_score) sorted by score descending
179
+ """
180
+
181
+ results = []
182
+
183
+ with get_session() as ses:
184
+ for script_id in script_ids:
185
+ # Try to get auto-score first
186
+ auto_score = ses.exec(
187
+ select(AutoScore).where(AutoScore.script_id == script_id)
188
+ ).first()
189
+
190
+ if auto_score and auto_score.confidence >= 0.5:
191
+ # Use auto-scores
192
+ composite = (
193
+ self.weights['overall'] * auto_score.overall +
194
+ self.weights['hook'] * auto_score.hook +
195
+ self.weights['originality'] * auto_score.originality +
196
+ self.weights['style_fit'] * auto_score.style_fit +
197
+ self.weights['safety'] * auto_score.safety
198
+ )
199
+ else:
200
+ # Fall back to human ratings if available
201
+ script = ses.get(Script, script_id)
202
+ if script and script.ratings_count > 0:
203
+ composite = (
204
+ self.weights['overall'] * (script.score_overall or 3.0) +
205
+ self.weights['hook'] * (script.score_hook or 3.0) +
206
+ self.weights['originality'] * (script.score_originality or 3.0) +
207
+ self.weights['style_fit'] * (script.score_style_fit or 3.0) +
208
+ self.weights['safety'] * (script.score_safety or 3.0)
209
+ )
210
+ else:
211
+ # Default neutral score
212
+ composite = 3.0
213
+
214
+ results.append((script_id, composite))
215
+
216
+ # Sort by composite score descending
217
+ results.sort(key=lambda x: x[1], reverse=True)
218
+ return results
219
+
220
+ def get_best_script(self, script_ids: List[int]) -> int:
221
+ """Get the ID of the highest-scoring script"""
222
+ ranked = self.rerank_scripts(script_ids)
223
+ return ranked[0][0] if ranked else script_ids[0]
224
+
225
+ def auto_score_pipeline():
226
+ """Main pipeline to auto-score recent scripts"""
227
+ scorer = AutoScorer()
228
+
229
+ # Score recent scripts
230
+ new_scores = scorer.batch_score_recent(hours=24)
231
+
232
+ if new_scores:
233
+ print(f"\n📊 Auto-scoring Results ({len(new_scores)} scripts):")
234
+ for score in new_scores:
235
+ print(f"Script {score.script_id}: {score.overall:.1f}/5.0 (confidence: {score.confidence:.2f})")
236
+ else:
237
+ print("No new scripts to score.")
238
+
239
+ if __name__ == "__main__":
240
+ auto_score_pipeline()
bandit_learner.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-armed bandit learning system for optimizing generation policies
3
+ Learns which retrieval weights and generation parameters work best for each persona/content_type
4
+ """
5
+
6
+ import numpy as np
7
+ import random
8
+ from typing import Dict, List, Tuple, Optional
9
+ from dataclasses import dataclass
10
+ from datetime import datetime, timedelta
11
+ from sqlmodel import Session, select
12
+
13
+ from models import Script, AutoScore, PolicyWeights, Rating
14
+ from db import get_session
15
+
16
+ @dataclass
17
+ class BanditArm:
18
+ """Represents one configuration of parameters to test"""
19
+ name: str
20
+ semantic_weight: float
21
+ bm25_weight: float
22
+ quality_weight: float
23
+ freshness_weight: float
24
+ temp_low: float
25
+ temp_mid: float
26
+ temp_high: float
27
+
28
+ def __post_init__(self):
29
+ # Ensure weights sum to 1.0
30
+ total = self.semantic_weight + self.bm25_weight + self.quality_weight + self.freshness_weight
31
+ if total != 1.0:
32
+ self.semantic_weight /= total
33
+ self.bm25_weight /= total
34
+ self.quality_weight /= total
35
+ self.freshness_weight /= total
36
+
37
+ class PolicyBandit:
38
+ """Multi-armed bandit for learning optimal generation policies"""
39
+
40
+ def __init__(self, epsilon: float = 0.15, decay_rate: float = 0.99):
41
+ self.epsilon = epsilon # Exploration rate
42
+ self.decay_rate = decay_rate # Epsilon decay over time
43
+ self.min_epsilon = 0.05
44
+
45
+ # Define arms (different parameter configurations)
46
+ self.arms = [
47
+ # Current default
48
+ BanditArm("balanced", 0.45, 0.25, 0.20, 0.10, 0.4, 0.7, 0.95),
49
+
50
+ # Semantic-heavy (focus on meaning)
51
+ BanditArm("semantic_heavy", 0.60, 0.15, 0.15, 0.10, 0.4, 0.7, 0.95),
52
+
53
+ # Quality-focused (use only best examples)
54
+ BanditArm("quality_focused", 0.35, 0.20, 0.35, 0.10, 0.3, 0.6, 0.85),
55
+
56
+ # Fresh-focused (prioritize recent trends)
57
+ BanditArm("fresh_focused", 0.40, 0.20, 0.15, 0.25, 0.5, 0.8, 1.0),
58
+
59
+ # Conservative (lower temperatures)
60
+ BanditArm("conservative", 0.45, 0.25, 0.20, 0.10, 0.3, 0.5, 0.7),
61
+
62
+ # Creative (higher temperatures)
63
+ BanditArm("creative", 0.45, 0.25, 0.20, 0.10, 0.6, 0.9, 1.2),
64
+
65
+ # Text-match heavy (traditional keyword matching)
66
+ BanditArm("text_heavy", 0.25, 0.45, 0.20, 0.10, 0.4, 0.7, 0.95)
67
+ ]
68
+
69
+ # Initialize arm statistics
70
+ self.arm_counts = {arm.name: 0 for arm in self.arms}
71
+ self.arm_rewards = {arm.name: 0.0 for arm in self.arms}
72
+
73
+ def select_arm(self, persona: str, content_type: str) -> BanditArm:
74
+ """Select arm using epsilon-greedy with UCB bias"""
75
+
76
+ # Load existing policy weights to initialize arm stats
77
+ self._load_arm_stats(persona, content_type)
78
+
79
+ # Decay epsilon over time
80
+ current_epsilon = max(self.min_epsilon, self.epsilon * (self.decay_rate ** sum(self.arm_counts.values())))
81
+
82
+ if random.random() < current_epsilon:
83
+ # Explore: random arm
84
+ selected_arm = random.choice(self.arms)
85
+ print(f"🔄 Exploring with {selected_arm.name} policy (ε={current_epsilon:.3f})")
86
+ else:
87
+ # Exploit: best arm with UCB confidence bounds
88
+ selected_arm = self._select_best_arm_ucb()
89
+ print(f"⭐ Exploiting with {selected_arm.name} policy")
90
+
91
+ return selected_arm
92
+
93
+ def _select_best_arm_ucb(self) -> BanditArm:
94
+ """Select arm using Upper Confidence Bound"""
95
+ total_counts = sum(self.arm_counts.values())
96
+ if total_counts == 0:
97
+ return self.arms[0] # Default to first arm
98
+
99
+ best_arm = None
100
+ best_score = float('-inf')
101
+
102
+ for arm in self.arms:
103
+ count = self.arm_counts[arm.name]
104
+ if count == 0:
105
+ return arm # Always try unplayed arms first
106
+
107
+ # UCB score = average reward + confidence interval
108
+ avg_reward = self.arm_rewards[arm.name] / count
109
+ confidence = np.sqrt(2 * np.log(total_counts) / count)
110
+ ucb_score = avg_reward + confidence
111
+
112
+ if ucb_score > best_score:
113
+ best_score = ucb_score
114
+ best_arm = arm
115
+
116
+ return best_arm or self.arms[0]
117
+
118
+ def _load_arm_stats(self, persona: str, content_type: str):
119
+ """Load historical performance for this persona/content_type"""
120
+ with get_session() as ses:
121
+ policy = ses.exec(
122
+ select(PolicyWeights).where(
123
+ PolicyWeights.persona == persona,
124
+ PolicyWeights.content_type == content_type
125
+ )
126
+ ).first()
127
+
128
+ if policy:
129
+ # Find matching arm and update stats
130
+ for arm in self.arms:
131
+ if self._arm_matches_policy(arm, policy):
132
+ self.arm_counts[arm.name] = policy.total_generations
133
+ self.arm_rewards[arm.name] = policy.success_rate * policy.total_generations
134
+ break
135
+
136
+ def _arm_matches_policy(self, arm: BanditArm, policy: PolicyWeights, tolerance: float = 0.05) -> bool:
137
+ """Check if an arm matches the stored policy within tolerance"""
138
+ return (
139
+ abs(arm.semantic_weight - policy.semantic_weight) < tolerance and
140
+ abs(arm.bm25_weight - policy.bm25_weight) < tolerance and
141
+ abs(arm.quality_weight - policy.quality_weight) < tolerance and
142
+ abs(arm.freshness_weight - policy.freshness_weight) < tolerance
143
+ )
144
+
145
+ def update_reward(self,
146
+ arm: BanditArm,
147
+ reward: float,
148
+ persona: str,
149
+ content_type: str,
150
+ script_id: int):
151
+ """Update arm performance with new reward signal"""
152
+
153
+ # Update in-memory stats
154
+ self.arm_counts[arm.name] += 1
155
+ self.arm_rewards[arm.name] += reward
156
+
157
+ # Update database policy
158
+ self._update_policy_weights(arm, reward, persona, content_type)
159
+
160
+ print(f"📈 Updated {arm.name}: reward={reward:.3f}, avg={self.arm_rewards[arm.name]/self.arm_counts[arm.name]:.3f}")
161
+
162
+ def _update_policy_weights(self,
163
+ arm: BanditArm,
164
+ reward: float,
165
+ persona: str,
166
+ content_type: str):
167
+ """Update policy weights in database"""
168
+ with get_session() as ses:
169
+ policy = ses.exec(
170
+ select(PolicyWeights).where(
171
+ PolicyWeights.persona == persona,
172
+ PolicyWeights.content_type == content_type
173
+ )
174
+ ).first()
175
+
176
+ if not policy:
177
+ # Create new policy
178
+ policy = PolicyWeights(
179
+ persona=persona,
180
+ content_type=content_type,
181
+ semantic_weight=arm.semantic_weight,
182
+ bm25_weight=arm.bm25_weight,
183
+ quality_weight=arm.quality_weight,
184
+ freshness_weight=arm.freshness_weight,
185
+ temp_low=arm.temp_low,
186
+ temp_mid=arm.temp_mid,
187
+ temp_high=arm.temp_high,
188
+ total_generations=1,
189
+ success_rate=reward
190
+ )
191
+ else:
192
+ # Update existing policy with exponential moving average
193
+ alpha = 0.1 # Learning rate
194
+ policy.success_rate = (1 - alpha) * policy.success_rate + alpha * reward
195
+ policy.total_generations += 1
196
+
197
+ # If this arm is performing well, shift weights toward it
198
+ if reward > policy.success_rate:
199
+ shift = 0.05 # Small shift toward better performing arm
200
+ policy.semantic_weight = (1 - shift) * policy.semantic_weight + shift * arm.semantic_weight
201
+ policy.bm25_weight = (1 - shift) * policy.bm25_weight + shift * arm.bm25_weight
202
+ policy.quality_weight = (1 - shift) * policy.quality_weight + shift * arm.quality_weight
203
+ policy.freshness_weight = (1 - shift) * policy.freshness_weight + shift * arm.freshness_weight
204
+
205
+ policy.temp_low = (1 - shift) * policy.temp_low + shift * arm.temp_low
206
+ policy.temp_mid = (1 - shift) * policy.temp_mid + shift * arm.temp_mid
207
+ policy.temp_high = (1 - shift) * policy.temp_high + shift * arm.temp_high
208
+
209
+ policy.updated_at = datetime.utcnow()
210
+ ses.add(policy)
211
+ ses.commit()
212
+
213
+ def calculate_reward(self, script_id: int) -> float:
214
+ """
215
+ Calculate reward signal from script performance
216
+ Combines auto-scores and human ratings when available
217
+ """
218
+ reward_components = []
219
+
220
+ with get_session() as ses:
221
+ # Get auto-score
222
+ auto_score = ses.exec(
223
+ select(AutoScore).where(AutoScore.script_id == script_id)
224
+ ).first()
225
+
226
+ if auto_score and auto_score.confidence > 0.5:
227
+ # Weighted composite of auto-scores
228
+ auto_reward = (
229
+ 0.35 * auto_score.overall +
230
+ 0.20 * auto_score.hook +
231
+ 0.15 * auto_score.originality +
232
+ 0.15 * auto_score.style_fit +
233
+ 0.15 * auto_score.safety
234
+ ) / 5.0 # Normalize to 0-1
235
+
236
+ reward_components.append(('auto', auto_reward, auto_score.confidence))
237
+
238
+ # Get human ratings
239
+ script = ses.get(Script, script_id)
240
+ if script and script.ratings_count > 0:
241
+ human_reward = script.score_overall / 5.0 # Normalize to 0-1
242
+ confidence = min(1.0, script.ratings_count / 3.0) # More ratings = higher confidence
243
+ reward_components.append(('human', human_reward, confidence))
244
+
245
+ if not reward_components:
246
+ return 0.5 # Neutral reward if no scores available
247
+
248
+ # Weighted average of reward components by confidence
249
+ total_weight = sum(confidence for _, _, confidence in reward_components)
250
+ weighted_reward = sum(
251
+ reward * confidence for _, reward, confidence in reward_components
252
+ ) / total_weight
253
+
254
+ return weighted_reward
255
+
256
+ class PolicyLearner:
257
+ """Main interface for policy learning"""
258
+
259
+ def __init__(self):
260
+ self.bandit = PolicyBandit()
261
+
262
+ def learn_from_generation_batch(self,
263
+ persona: str,
264
+ content_type: str,
265
+ generated_script_ids: List[int],
266
+ selected_arm: BanditArm):
267
+ """Learn from a batch of generated scripts"""
268
+
269
+ if not generated_script_ids:
270
+ return
271
+
272
+ # Calculate average reward from the batch
273
+ rewards = [self.bandit.calculate_reward(sid) for sid in generated_script_ids]
274
+ avg_reward = sum(rewards) / len(rewards)
275
+
276
+ # Update bandit with average performance
277
+ self.bandit.update_reward(
278
+ selected_arm,
279
+ avg_reward,
280
+ persona,
281
+ content_type,
282
+ generated_script_ids[0] # Representative script ID
283
+ )
284
+
285
+ print(f"🧠 Policy learning: {persona}/{content_type} → {avg_reward:.3f} reward")
286
+
287
+ def get_optimized_policy(self, persona: str, content_type: str) -> BanditArm:
288
+ """Get the current best policy for this persona/content_type"""
289
+ return self.bandit.select_arm(persona, content_type)
290
+
291
+ def run_learning_cycle(self):
292
+ """Run a learning cycle on recent generations"""
293
+ print("🔄 Starting policy learning cycle...")
294
+
295
+ # Find recent AI-generated scripts by persona/content_type
296
+ cutoff = datetime.utcnow() - timedelta(hours=24)
297
+
298
+ with get_session() as ses:
299
+ recent_scripts = list(ses.exec(
300
+ select(Script).where(
301
+ Script.created_at >= cutoff,
302
+ Script.source == "ai"
303
+ )
304
+ ))
305
+
306
+ # Group by persona/content_type
307
+ groups = {}
308
+ for script in recent_scripts:
309
+ key = (script.creator, script.content_type)
310
+ if key not in groups:
311
+ groups[key] = []
312
+ groups[key].append(script.id)
313
+
314
+ # Learn from each group
315
+ for (persona, content_type), script_ids in groups.items():
316
+ if len(script_ids) >= 3: # Need minimum batch size
317
+ # For now, assume they used the balanced policy
318
+ # In practice, you'd track which policy was used for each generation
319
+ balanced_arm = next(arm for arm in self.bandit.arms if arm.name == "balanced")
320
+ self.learn_from_generation_batch(persona, content_type, script_ids, balanced_arm)
321
+
322
+ def run_policy_learning():
323
+ """Main entry point for policy learning"""
324
+ learner = PolicyLearner()
325
+ learner.run_learning_cycle()
326
+
327
+ if __name__ == "__main__":
328
+ run_policy_learning()
329
+
330
+
compliance.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ BANNED = {r"\b(naked|explicit|porn|onlyfans\.com)\b"}
4
+ CAUTION = {r"\b(hot|naughty|spicy|thirsty)\b"}
5
+
6
+ def compliance_level(text: str):
7
+ low = text.lower()
8
+ for pat in BANNED:
9
+ if re.search(pat, low):
10
+ return "fail", ["banned phrase"]
11
+ reasons = []
12
+ for pat in CAUTION:
13
+ if re.search(pat, low):
14
+ reasons.append("caution phrase")
15
+ return ("warn" if reasons else "pass"), reasons
16
+
17
+ def score_script(blob: str):
18
+ return compliance_level(blob)
19
+
20
+ def blob_from(script: dict) -> str:
21
+ parts = [
22
+ script.get("title",""), script.get("hook",""),
23
+ " ".join(script.get("beats",[])),
24
+ script.get("voiceover",""), script.get("caption",""), script.get("cta","")
25
+ ]
26
+ return " ".join(parts)
db.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # db.py
2
+ import os, json, random
3
+ from contextlib import contextmanager
4
+ from typing import List, Iterable, Tuple, Optional
5
+ from sqlmodel import SQLModel, create_engine, Session, select
6
+ from datetime import datetime
7
+
8
+ # ---- Configure DB ----
9
+ DB_URL = os.environ.get("DB_URL", "sqlite:///studio.db")
10
+ engine = create_engine(DB_URL, echo=False)
11
+
12
+ # ---- Models ----
13
+ from models import Script, Rating # make sure Script has: is_reference: bool, plus the other fields
14
+
15
+ # ---- Init / Session ----
16
+ def init_db() -> None:
17
+ SQLModel.metadata.create_all(engine)
18
+
19
+ @contextmanager
20
+ def get_session():
21
+ with Session(engine) as ses:
22
+ yield ses
23
+
24
+ # ---- Helpers for import ----
25
+
26
+ def _payload_from_jsonl_row(row: dict) -> Tuple[dict, str, str]:
27
+ """
28
+ Map a JSONL row (the file I generated for you) into Script columns.
29
+ Returns (payload, dedupe_key_title, dedupe_key_creator).
30
+ You can also add 'external_id' to Script model and dedupe on that.
31
+ """
32
+ # Prefer using the JSON 'id' as an external identifier:
33
+ external_id = row.get("id", "")
34
+
35
+ # Tone could be an array; flatten for now
36
+ tone = ", ".join(row.get("tonality", [])) or "playful"
37
+
38
+ # Compact caption: use caption options line as a quick reference
39
+ caption = " | ".join(row.get("caption_options", []))[:180]
40
+
41
+ payload = dict(
42
+ # core identity
43
+ creator=row.get("model_name", "Unknown"),
44
+ content_type=(row.get("video_type", "") or "talking_style").lower(),
45
+ tone=tone,
46
+ title=external_id or row.get("theme", "") or "Imported Script",
47
+ hook=row.get("video_hook") or "",
48
+
49
+ # structured fields
50
+ beats=row.get("storyboard", []) or [],
51
+ voiceover="",
52
+ caption=caption,
53
+ hashtags=row.get("hashtags", []) or [],
54
+ cta="",
55
+
56
+ # flags
57
+ source="import",
58
+ is_reference=True, # mark imported examples as references
59
+ compliance="pass", # we'll score again after save if you want
60
+ )
61
+ return payload, payload["title"], payload["creator"]
62
+
63
+ def _score_and_update_compliance(s: Script) -> None:
64
+ """Optional: score compliance using your simple rule-checker."""
65
+ try:
66
+ from compliance import blob_from, score_script
67
+ lvl, _ = score_script(blob_from(s.dict()))
68
+ s.compliance = lvl
69
+ except Exception:
70
+ # If no compliance module or error, keep default
71
+ pass
72
+
73
+ def _iter_jsonl(path: str) -> Iterable[dict]:
74
+ with open(path, "r", encoding="utf-8") as f:
75
+ for line in f:
76
+ line = line.strip()
77
+ if not line:
78
+ continue
79
+ yield json.loads(line)
80
+
81
+ # ---- Public: Importer ----
82
+ def import_jsonl(path: str) -> int:
83
+ """
84
+ Import (upsert) scripts from a JSONL file produced earlier.
85
+ Dedupe by (creator, title). Returns count of upserted rows.
86
+ """
87
+ init_db()
88
+ count = 0
89
+ with get_session() as ses:
90
+ for row in _iter_jsonl(path):
91
+ payload, key_title, key_creator = _payload_from_jsonl_row(row)
92
+
93
+ existing = ses.exec(
94
+ select(Script).where(
95
+ Script.title == key_title,
96
+ Script.creator == key_creator
97
+ )
98
+ ).first()
99
+
100
+ if existing:
101
+ # Update all fields
102
+ for k, v in payload.items():
103
+ setattr(existing, k, v)
104
+ _score_and_update_compliance(existing)
105
+ existing.updated_at = datetime.utcnow()
106
+ ses.add(existing)
107
+ else:
108
+ obj = Script(**payload)
109
+ _score_and_update_compliance(obj)
110
+ ses.add(obj)
111
+
112
+ count += 1
113
+ ses.commit()
114
+ return count
115
+
116
+ # ---- Ratings API ----
117
+ def add_rating(script_id: int,
118
+ overall: float,
119
+ hook: Optional[float] = None,
120
+ originality: Optional[float] = None,
121
+ style_fit: Optional[float] = None,
122
+ safety: Optional[float] = None,
123
+ notes: Optional[str] = None,
124
+ rater: str = "human") -> None:
125
+ with get_session() as ses:
126
+ # store rating event
127
+ ses.add(Rating(
128
+ script_id=script_id, overall=overall, hook=hook,
129
+ originality=originality, style_fit=style_fit, safety=safety,
130
+ notes=notes, rater=rater
131
+ ))
132
+ ses.commit()
133
+ # recompute cached aggregates on Script
134
+ _recompute_script_aggregates(ses, script_id)
135
+ ses.commit()
136
+
137
+ def _recompute_script_aggregates(ses: Session, script_id: int) -> None:
138
+ rows = list(ses.exec(select(Rating).where(Rating.script_id == script_id)))
139
+ if not rows:
140
+ return
141
+ def avg(field):
142
+ vals = [getattr(r, field) for r in rows if getattr(r, field) is not None]
143
+ return round(sum(vals)/len(vals), 3) if vals else None
144
+ s: Script = ses.get(Script, script_id)
145
+ s.score_overall = avg("overall")
146
+ s.score_hook = avg("hook")
147
+ s.score_originality = avg("originality")
148
+ s.score_style_fit = avg("style_fit")
149
+ s.score_safety = avg("safety")
150
+ s.ratings_count = len(rows)
151
+ s.updated_at = datetime.utcnow()
152
+ ses.add(s)
153
+
154
+ # ---- Public: Reference retrieval for generation ----
155
+ def extract_snippets_from_script(s: Script, max_lines: int = 3) -> List[str]:
156
+ items: List[str] = []
157
+ if s.hook:
158
+ items.append(s.hook.strip())
159
+ if s.beats:
160
+ items.extend([b.strip() for b in s.beats[:2]]) # first 1–2 beats
161
+ if s.caption:
162
+ items.append(s.caption.strip()[:120])
163
+ # dedupe while preserving order
164
+ seen, uniq = set(), []
165
+ for it in items:
166
+ if it and it not in seen:
167
+ uniq.append(it); seen.add(it)
168
+ return uniq[:max_lines]
169
+
170
+ def get_library_refs(creator: str, content_type: str, k: int = 6) -> List[str]:
171
+ with get_session() as ses:
172
+ rows = list(ses.exec(
173
+ select(Script)
174
+ .where(
175
+ Script.creator == creator,
176
+ Script.content_type == content_type,
177
+ Script.is_reference == True,
178
+ Script.compliance != "fail"
179
+ )
180
+ .order_by(Script.created_at.desc())
181
+ ))[:k]
182
+
183
+ snippets: List[str] = []
184
+ for r in rows:
185
+ snippets.extend(extract_snippets_from_script(r))
186
+ # final dedupe
187
+ seen, uniq = set(), []
188
+ for s in snippets:
189
+ if s not in seen:
190
+ uniq.append(s); seen.add(s)
191
+ return uniq[:8]
192
+
193
+ # ---- HYBRID reference retrieval ----
194
+ def get_hybrid_refs(creator: str, content_type: str, k: int = 6,
195
+ top_n: int = 3, explore_n: int = 2, newest_n: int = 1) -> List[str]:
196
+ """
197
+ Mix of:
198
+ - top_n best scored references (exploit)
199
+ - explore_n random references (explore)
200
+ - newest_n most recent references (freshness)
201
+ Returns flattened snippet list (cap ~8 to keep prompt lean).
202
+ """
203
+ with get_session() as ses:
204
+ all_refs = list(ses.exec(
205
+ select(Script).where(
206
+ Script.creator == creator,
207
+ Script.content_type == content_type,
208
+ Script.is_reference == True,
209
+ Script.compliance != "fail"
210
+ )
211
+ ))
212
+
213
+ if not all_refs:
214
+ return []
215
+
216
+ # sort by score_overall (fallback to 0) and pick top_n
217
+ scored = sorted(all_refs, key=lambda s: (s.score_overall or 0.0), reverse=True)
218
+ best = scored[:top_n]
219
+
220
+ # newest by created_at
221
+ newest = sorted(all_refs, key=lambda s: s.created_at, reverse=True)[:newest_n]
222
+
223
+ # explore = random sample from the remainder
224
+ remainder = [r for r in all_refs if r not in best and r not in newest]
225
+ explore = random.sample(remainder, min(explore_n, len(remainder))) if remainder else []
226
+
227
+ # merge (preserve order, dedupe)
228
+ chosen_scripts = []
229
+ seen_ids = set()
230
+ for bucket in (best, explore, newest):
231
+ for s in bucket:
232
+ if s.id not in seen_ids:
233
+ chosen_scripts.append(s)
234
+ seen_ids.add(s.id)
235
+
236
+ # cut to k scripts
237
+ chosen_scripts = chosen_scripts[:k]
238
+
239
+ # flatten snippets and cap to keep prompt compact
240
+ snippets: List[str] = []
241
+ for s in chosen_scripts:
242
+ snippets.extend(extract_snippets_from_script(s))
243
+ # dedupe again and cap ~8 lines
244
+ seen, out = set(), []
245
+ for sn in snippets:
246
+ if sn not in seen:
247
+ out.append(sn); seen.add(sn)
248
+ return out[:8]
deepseek_client.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, requests, json
2
+ import streamlit as st
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
+
7
+ # Get API key from Streamlit secrets or environment
8
+ def get_api_key():
9
+ if hasattr(st, 'secrets') and "DEEPSEEK_API_KEY" in st.secrets:
10
+ return st.secrets["DEEPSEEK_API_KEY"]
11
+ return os.getenv("DEEPSEEK_API_KEY")
12
+
13
+ DEEPSEEK_API_KEY = get_api_key()
14
+ BASE = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
15
+
16
+ def chat(messages, model="deepseek-chat", temperature=0.9):
17
+ headers = {"Authorization": f"Bearer {DEEPSEEK_API_KEY}", "Content-Type": "application/json"}
18
+ payload = {"model": model, "messages": messages, "temperature": temperature}
19
+ r = requests.post(f"{BASE}/chat/completions", headers=headers, data=json.dumps(payload), timeout=60)
20
+ r.raise_for_status()
21
+ return r.json()["choices"][0]["message"]["content"]
22
+
23
+ def generate_scripts(persona, boundaries, content_type, tone, refs, n=6):
24
+ system = (
25
+ "You write Instagram-compliant, suggestive-but-not-explicit Reels briefs. "
26
+ "Use tight hooks, concrete visual beats, clear CTAs. Avoid explicit sexual terms. "
27
+ "Return ONLY JSON: an array of length N, each with {title,hook,beats,voiceover,caption,hashtags,cta}."
28
+ )
29
+ user = f"""
30
+ Persona: {persona}
31
+ Boundaries: {boundaries}
32
+ Content type: {content_type} | Tone: {tone} | Duration: 15–25s
33
+ Reference snippets (inspire, don't copy):
34
+ {chr(10).join(f"- {r}" for r in refs)}
35
+
36
+ N = {n}
37
+ JSON array ONLY.
38
+ """
39
+ out = chat([{"role":"system","content":system},{"role":"user","content":user}])
40
+ # Be lenient if model wraps JSON with text
41
+ start = out.find("[")
42
+ end = out.rfind("]")
43
+ return json.loads(out[start:end+1])
44
+
45
+ def revise_for(prompt_label, draft: dict, guidance: str):
46
+ system = f"You revise scripts to {prompt_label}. Keep intent; return ONLY JSON with the same schema."
47
+ user = json.dumps({"draft": draft, "guidance": guidance})
48
+ out = chat([{"role":"system","content":system},{"role":"user","content":user}], temperature=0.6)
49
+ start = out.find("{")
50
+ end = out.rfind("}")
51
+ return json.loads(out[start:end+1])
52
+
53
+ def selective_rewrite(draft: dict, field: str, snippet: str, prompt: str):
54
+ system = "You rewrite only the targeted snippet inside the specified field. Keep style. Return ONLY JSON."
55
+ user = json.dumps({"field": field, "snippet": snippet, "prompt": prompt, "draft": draft})
56
+ out = chat([{"role":"system","content":system},{"role":"user","content":user}], temperature=0.7)
57
+ start = out.find("{")
58
+ end = out.rfind("}")
59
+ return json.loads(out[start:end+1])
models.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from typing import List, Optional
3
+ from sqlmodel import SQLModel, Field, Column
4
+ from sqlalchemy import JSON
5
+
6
+ class Script(SQLModel, table=True, extend_existing=True):
7
+ id: Optional[int] = Field(default=None, primary_key=True)
8
+ creator: str
9
+ content_type: str
10
+ tone: str
11
+ title: str
12
+ hook: str
13
+ beats: List[str] = Field(sa_column=Column(JSON))
14
+ voiceover: str
15
+ caption: str
16
+ hashtags: List[str] = Field(sa_column=Column(JSON))
17
+ cta: str
18
+ compliance: str = "pass" # pass | warn | fail
19
+ source: str = "ai" # ai | manual | import
20
+ is_reference: bool = False # mark imported examples as references
21
+
22
+ # --- NEW: cached aggregates from ratings (all optional) ---
23
+ score_overall: Optional[float] = None # 1..5 (avg)
24
+ score_hook: Optional[float] = None # 1..5 (avg)
25
+ score_originality: Optional[float] = None # 1..5 (avg)
26
+ score_style_fit: Optional[float] = None # 1..5 (avg)
27
+ score_safety: Optional[float] = None # 1..5 (avg)
28
+ ratings_count: int = 0
29
+
30
+ created_at: datetime = Field(default_factory=datetime.utcnow)
31
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
32
+
33
+ class Revision(SQLModel, table=True, extend_existing=True):
34
+ id: Optional[int] = Field(default=None, primary_key=True)
35
+ script_id: int = Field(index=True)
36
+ label: str
37
+ field: str
38
+ before: str
39
+ after: str
40
+ created_at: datetime = Field(default_factory=datetime.utcnow)
41
+
42
+ # NEW: store every rating event so you keep history
43
+ class Rating(SQLModel, table=True, extend_existing=True):
44
+ id: Optional[int] = Field(default=None, primary_key=True)
45
+ script_id: int = Field(index=True)
46
+ rater: str = "human" # optional: store user/email
47
+ overall: float # 1..5
48
+ hook: Optional[float] = None
49
+ originality: Optional[float] = None
50
+ style_fit: Optional[float] = None
51
+ safety: Optional[float] = None
52
+ notes: Optional[str] = None
53
+ created_at: datetime = Field(default_factory=datetime.utcnow)
54
+
55
+ # RAG Enhancement Models
56
+ class Embedding(SQLModel, table=True, extend_existing=True):
57
+ id: Optional[int] = Field(default=None, primary_key=True)
58
+ script_id: int = Field(index=True)
59
+ part: str = Field(index=True) # 'full', 'hook', 'beats', 'caption'
60
+ vector: List[float] = Field(sa_column=Column(JSON))
61
+ meta: dict = Field(sa_column=Column(JSON))
62
+ created_at: datetime = Field(default_factory=datetime.utcnow)
63
+
64
+ class AutoScore(SQLModel, table=True, extend_existing=True):
65
+ id: Optional[int] = Field(default=None, primary_key=True)
66
+ script_id: int = Field(index=True)
67
+ overall: float
68
+ hook: float
69
+ originality: float
70
+ style_fit: float
71
+ safety: float
72
+ confidence: float = 0.8 # LLM judge confidence
73
+ notes: Optional[str] = None
74
+ created_at: datetime = Field(default_factory=datetime.utcnow)
75
+
76
+ class PolicyWeights(SQLModel, table=True, extend_existing=True):
77
+ id: Optional[int] = Field(default=None, primary_key=True)
78
+ persona: str = Field(index=True)
79
+ content_type: str = Field(index=True)
80
+ # Retrieval weights
81
+ semantic_weight: float = 0.45
82
+ bm25_weight: float = 0.25
83
+ quality_weight: float = 0.20
84
+ freshness_weight: float = 0.10
85
+ # Generation params
86
+ temp_low: float = 0.4
87
+ temp_mid: float = 0.7
88
+ temp_high: float = 0.95
89
+ # Performance tracking
90
+ success_rate: float = 0.0
91
+ total_generations: int = 0
92
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
93
+
94
+ class StyleCard(SQLModel, table=True, extend_existing=True):
95
+ id: Optional[int] = Field(default=None, primary_key=True)
96
+ persona: str = Field(index=True)
97
+ content_type: str = Field(index=True)
98
+ exemplar_hooks: List[str] = Field(sa_column=Column(JSON))
99
+ exemplar_beats: List[str] = Field(sa_column=Column(JSON))
100
+ exemplar_captions: List[str] = Field(sa_column=Column(JSON))
101
+ negative_patterns: List[str] = Field(sa_column=Column(JSON))
102
+ constraints: dict = Field(sa_column=Column(JSON))
103
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
rag_integration.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Integration layer between the existing system and new RAG capabilities
3
+ Shows how to plug the enhanced system into the current workflow
4
+ """
5
+
6
+ from typing import List, Dict, Any, Optional
7
+ import json
8
+ from sqlmodel import Session
9
+ from datetime import datetime
10
+
11
+ from models import Script, Embedding, AutoScore, PolicyWeights
12
+ from db import get_session, init_db
13
+ from deepseek_client import chat, get_api_key
14
+ from rag_retrieval import RAGRetriever
15
+ from auto_scorer import AutoScorer, ScriptReranker
16
+ from bandit_learner import PolicyLearner
17
+
18
+ class EnhancedScriptGenerator:
19
+ """
20
+ Enhanced version of script generation with RAG + policy learning
21
+ Drop-in replacement for the existing generate_scripts function
22
+ """
23
+
24
+ def __init__(self):
25
+ self.retriever = RAGRetriever()
26
+ self.scorer = AutoScorer()
27
+ self.reranker = ScriptReranker()
28
+ self.policy_learner = PolicyLearner()
29
+
30
+ # Verify we have API key
31
+ if not get_api_key():
32
+ raise ValueError("DeepSeek API key not found!")
33
+
34
+ def generate_scripts_enhanced(self,
35
+ persona: str,
36
+ boundaries: str,
37
+ content_type: str,
38
+ tone: str,
39
+ manual_refs: List[str] = None,
40
+ n: int = 6) -> List[Dict]:
41
+ """
42
+ Enhanced script generation with:
43
+ 1. RAG-based reference selection
44
+ 2. Policy-optimized parameters
45
+ 3. Auto-scoring and reranking
46
+ 4. Online learning feedback
47
+ """
48
+
49
+ print(f"🤖 Enhanced generation: {persona} × {content_type} × {n} scripts")
50
+
51
+ # Step 1: Get optimized policy for this persona/content_type
52
+ policy_arm = self.policy_learner.get_optimized_policy(persona, content_type)
53
+
54
+ # Step 2: Build dynamic few-shot pack using RAG
55
+ query_context = f"{persona} {content_type} {tone}"
56
+ few_shot_pack = self.retriever.build_dynamic_few_shot_pack(
57
+ persona=persona,
58
+ content_type=content_type,
59
+ query_context=query_context
60
+ )
61
+
62
+ # Step 3: Combine RAG refs with manual refs
63
+ rag_refs = (
64
+ few_shot_pack.get('best_hooks', []) +
65
+ few_shot_pack.get('best_beats', []) +
66
+ few_shot_pack.get('best_captions', [])
67
+ )
68
+ all_refs = (manual_refs or []) + rag_refs
69
+
70
+ print(f"📚 Using {len(rag_refs)} RAG refs + {len(manual_refs or [])} manual refs")
71
+
72
+ # Step 4: Enhanced generation with policy-optimized parameters
73
+ drafts = self._generate_with_policy(
74
+ persona=persona,
75
+ boundaries=boundaries,
76
+ content_type=content_type,
77
+ tone=tone,
78
+ refs=all_refs,
79
+ policy_arm=policy_arm,
80
+ n=n,
81
+ few_shot_pack=few_shot_pack
82
+ )
83
+
84
+ # Step 5: Anti-copying detection and cleanup
85
+ print(f"🛡️ Checking for similarity to reference content...")
86
+
87
+ # Extract reference texts for copying detection
88
+ reference_texts = rag_refs
89
+ cleaned_drafts = []
90
+
91
+ for draft in drafts:
92
+ # Check for copying
93
+ detection_results = self.retriever.detect_copying(
94
+ generated_content=draft,
95
+ reference_texts=reference_texts,
96
+ similarity_threshold=0.92
97
+ )
98
+
99
+ if detection_results['is_copying']:
100
+ print(f"⚠️ Anti-copy triggered for draft: {draft.get('title', 'Untitled')[:30]}")
101
+ print(f" Max similarity: {detection_results['max_similarity']:.3f}")
102
+
103
+ # Auto-rewrite similar content
104
+ cleaned_draft = self.retriever.auto_rewrite_similar_content(
105
+ generated_content=draft,
106
+ detection_results=detection_results
107
+ )
108
+ cleaned_drafts.append(cleaned_draft)
109
+ else:
110
+ cleaned_drafts.append(draft)
111
+
112
+ # Step 6: Auto-score all generated drafts
113
+ script_ids = self._save_drafts_to_db(cleaned_drafts, persona, content_type, tone)
114
+ auto_scores = [self.scorer.score_and_store(sid) for sid in script_ids]
115
+
116
+ print(f"📊 Auto-scored {len(auto_scores)} drafts")
117
+
118
+ # Step 7: Rerank by composite score
119
+ ranked_script_ids = self.reranker.rerank_scripts(script_ids)
120
+
121
+ # Step 8: Policy learning feedback
122
+ self.policy_learner.learn_from_generation_batch(
123
+ persona=persona,
124
+ content_type=content_type,
125
+ generated_script_ids=script_ids,
126
+ selected_arm=policy_arm
127
+ )
128
+
129
+ # Return drafts in ranked order with scores
130
+ return self._format_enhanced_results(ranked_script_ids, cleaned_drafts)
131
+
132
+ def _generate_with_policy(self,
133
+ persona: str,
134
+ boundaries: str,
135
+ content_type: str,
136
+ tone: str,
137
+ refs: List[str],
138
+ policy_arm: Any, # BanditArm
139
+ n: int,
140
+ few_shot_pack: Dict) -> List[Dict]:
141
+ """Generate scripts using policy-optimized parameters"""
142
+
143
+ # Enhanced system prompt with few-shot pack context
144
+ system = f"""You write Instagram-compliant, suggestive-but-not-explicit Reels briefs.
145
+
146
+ STYLE CONTEXT: {few_shot_pack.get('style_card', '')}
147
+
148
+ BEST PATTERNS TO EMULATE:
149
+ Hooks: {json.dumps(few_shot_pack.get('best_hooks', []))}
150
+ Beats: {json.dumps(few_shot_pack.get('best_beats', []))}
151
+ Captions: {json.dumps(few_shot_pack.get('best_captions', []))}
152
+
153
+ AVOID THESE PATTERNS: {json.dumps(few_shot_pack.get('negative_patterns', []))}
154
+
155
+ Use tight hooks, concrete visual beats, clear CTAs. Avoid explicit sexual terms.
156
+ Return ONLY JSON: an array of length {n}, each with {{title,hook,beats,voiceover,caption,hashtags,cta}}.
157
+ """
158
+
159
+ user = f"""
160
+ Persona: {persona}
161
+ Boundaries: {boundaries}
162
+ Content type: {content_type} | Tone: {tone}
163
+ Constraints: {json.dumps(few_shot_pack.get('constraints', {}))}
164
+
165
+ Reference snippets (inspire, don't copy):
166
+ {chr(10).join(f"- {r}" for r in refs[:8])} # Limit to top 8 refs
167
+
168
+ Generate {n} unique variations. JSON array ONLY.
169
+ """
170
+
171
+ # Generate with multiple temperatures (policy-optimized)
172
+ variants = []
173
+ temps = [policy_arm.temp_low, policy_arm.temp_mid, policy_arm.temp_high]
174
+ scripts_per_temp = max(1, n // len(temps))
175
+
176
+ for i, temp in enumerate(temps):
177
+ batch_size = scripts_per_temp
178
+ if i == len(temps) - 1: # Last batch gets remainder
179
+ batch_size = n - len(variants)
180
+
181
+ if batch_size <= 0:
182
+ break
183
+
184
+ try:
185
+ out = chat([
186
+ {"role": "system", "content": system},
187
+ {"role": "user", "content": user.replace(f"Generate {n}", f"Generate {batch_size}")}
188
+ ], temperature=temp)
189
+
190
+ # Extract JSON
191
+ start = out.find("[")
192
+ end = out.rfind("]")
193
+ if start >= 0 and end > start:
194
+ batch_variants = json.loads(out[start:end+1])
195
+ variants.extend(batch_variants[:batch_size])
196
+ print(f"✨ Generated {len(batch_variants)} scripts at temp={temp}")
197
+
198
+ except Exception as e:
199
+ print(f"❌ Generation failed at temp={temp}: {e}")
200
+
201
+ return variants[:n] # Ensure we don't exceed requested count
202
+
203
+ def _save_drafts_to_db(self,
204
+ drafts: List[Dict],
205
+ persona: str,
206
+ content_type: str,
207
+ tone: str) -> List[int]:
208
+ """Save generated drafts to database and return script IDs"""
209
+
210
+ script_ids = []
211
+
212
+ with get_session() as ses:
213
+ for draft in drafts:
214
+ try:
215
+ # Calculate basic compliance
216
+ from compliance import score_script, blob_from
217
+ content_blob = blob_from(draft)
218
+ compliance_level, _ = score_script(content_blob)
219
+
220
+ script = Script(
221
+ creator=persona,
222
+ content_type=content_type,
223
+ tone=tone,
224
+ title=draft.get("title", "Generated Script"),
225
+ hook=draft.get("hook", ""),
226
+ beats=draft.get("beats", []),
227
+ voiceover=draft.get("voiceover", ""),
228
+ caption=draft.get("caption", ""),
229
+ hashtags=draft.get("hashtags", []),
230
+ cta=draft.get("cta", ""),
231
+ compliance=compliance_level,
232
+ source="ai"
233
+ )
234
+
235
+ ses.add(script)
236
+ ses.commit()
237
+ ses.refresh(script)
238
+
239
+ script_ids.append(script.id)
240
+
241
+ # Generate embeddings for new script
242
+ embeddings = self.retriever.generate_embeddings(script)
243
+ for embedding in embeddings:
244
+ ses.add(embedding)
245
+
246
+ except Exception as e:
247
+ print(f"❌ Failed to save draft: {e}")
248
+ continue
249
+
250
+ ses.commit()
251
+
252
+ return script_ids
253
+
254
+ def _format_enhanced_results(self,
255
+ ranked_script_ids: List[tuple],
256
+ original_drafts: List[Dict]) -> List[Dict]:
257
+ """Format results with ranking and score information"""
258
+
259
+ # Create a lookup for original drafts by content
260
+ draft_lookup = {}
261
+ for i, draft in enumerate(original_drafts):
262
+ key = draft.get("title", "") + draft.get("hook", "")
263
+ draft_lookup[key] = draft
264
+
265
+ results = []
266
+
267
+ with get_session() as ses:
268
+ for script_id, composite_score in ranked_script_ids:
269
+ script = ses.get(Script, script_id)
270
+ if script:
271
+ # Convert back to the expected format
272
+ result = {
273
+ "title": script.title,
274
+ "hook": script.hook,
275
+ "beats": script.beats,
276
+ "voiceover": script.voiceover,
277
+ "caption": script.caption,
278
+ "hashtags": script.hashtags,
279
+ "cta": script.cta,
280
+ # Enhanced metadata
281
+ "_enhanced_score": round(composite_score, 3),
282
+ "_script_id": script_id,
283
+ "_compliance": script.compliance
284
+ }
285
+ results.append(result)
286
+
287
+ return results
288
+
289
+ # Backward compatibility wrapper
290
+ def generate_scripts_rag(persona: str,
291
+ boundaries: str,
292
+ content_type: str,
293
+ tone: str,
294
+ refs: List[str],
295
+ n: int = 6) -> List[Dict]:
296
+ """
297
+ Drop-in replacement for existing generate_scripts function
298
+ Uses enhanced RAG system while maintaining API compatibility
299
+ """
300
+ generator = EnhancedScriptGenerator()
301
+ return generator.generate_scripts_enhanced(
302
+ persona=persona,
303
+ boundaries=boundaries,
304
+ content_type=content_type,
305
+ tone=tone,
306
+ manual_refs=refs,
307
+ n=n
308
+ )
309
+
310
+ def setup_rag_system():
311
+ """One-time setup to initialize the RAG system"""
312
+ print("🔧 Setting up RAG system...")
313
+
314
+ # Initialize database with new tables
315
+ init_db()
316
+ print("✅ Database initialized")
317
+
318
+ # Generate embeddings for existing scripts
319
+ from rag_retrieval import index_all_scripts
320
+ index_all_scripts()
321
+ print("✅ Existing scripts indexed")
322
+
323
+ # Auto-score recent scripts
324
+ scorer = AutoScorer()
325
+ recent_scores = scorer.batch_score_recent(hours=24*7) # Last week
326
+ print(f"✅ Auto-scored {len(recent_scores)} recent scripts")
327
+
328
+ print("🎉 RAG system setup complete!")
329
+
330
+ if __name__ == "__main__":
331
+ # Demo the enhanced system
332
+ setup_rag_system()
333
+
334
+ # Test generation
335
+ generator = EnhancedScriptGenerator()
336
+ results = generator.generate_scripts_enhanced(
337
+ persona="Anya",
338
+ boundaries="Instagram-safe; suggestive but not explicit",
339
+ content_type="thirst-trap",
340
+ tone="playful, flirty",
341
+ manual_refs=["Just a quick workout session", "Getting ready for the day"],
342
+ n=3
343
+ )
344
+
345
+ print(f"\n🎬 Generated {len(results)} enhanced scripts:")
346
+ for i, script in enumerate(results, 1):
347
+ score = script.get('_enhanced_score', 0)
348
+ compliance = script.get('_compliance', 'unknown')
349
+ print(f"{i}. {script['title']} (score: {score}, compliance: {compliance})")
350
+ print(f" Hook: {script['hook'][:60]}...")
rag_retrieval.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Enhanced RAG retrieval system for AI Script Studio
3
+ Extends the existing hybrid reference system with semantic search and policy learning
4
+ """
5
+
6
+ import numpy as np
7
+ import math
8
+ from typing import List, Dict, Tuple, Optional
9
+ from sentence_transformers import SentenceTransformer
10
+ from sqlmodel import Session, select
11
+ from sklearn.feature_extraction.text import TfidfVectorizer
12
+ from sklearn.metrics.pairwise import cosine_similarity
13
+ import json
14
+ from datetime import datetime, timedelta
15
+
16
+ from models import Script, Embedding, AutoScore, PolicyWeights, StyleCard
17
+ from db import get_session
18
+
19
+ class RAGRetriever:
20
+ def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
21
+ """Initialize with lightweight but effective embedding model"""
22
+ self.encoder = SentenceTransformer(model_name)
23
+ self.tfidf = TfidfVectorizer(max_features=1000, stop_words='english')
24
+
25
+ def generate_embeddings(self, script: Script) -> List[Embedding]:
26
+ """Generate embeddings for different parts of a script"""
27
+ parts = {
28
+ 'full': self._get_full_text(script),
29
+ 'hook': script.hook or '',
30
+ 'beats': ' '.join(script.beats or []),
31
+ 'caption': script.caption or ''
32
+ }
33
+
34
+ embeddings = []
35
+ for part, text in parts.items():
36
+ if text.strip(): # Only embed non-empty parts
37
+ vector = self.encoder.encode(text).tolist()
38
+ meta = {
39
+ 'creator': script.creator,
40
+ 'content_type': script.content_type,
41
+ 'tone': script.tone,
42
+ 'quality_score': script.score_overall or 0.0,
43
+ 'compliance': script.compliance
44
+ }
45
+ embeddings.append(Embedding(
46
+ script_id=script.id,
47
+ part=part,
48
+ vector=vector,
49
+ meta=meta
50
+ ))
51
+ return embeddings
52
+
53
+ def _get_full_text(self, script: Script) -> str:
54
+ """Combine all script parts into full text"""
55
+ parts = [
56
+ script.title,
57
+ script.hook or '',
58
+ ' '.join(script.beats or []),
59
+ script.voiceover or '',
60
+ script.caption or '',
61
+ script.cta or ''
62
+ ]
63
+ return ' '.join(p for p in parts if p.strip())
64
+
65
+ def hybrid_retrieve(self,
66
+ query_text: str,
67
+ persona: str,
68
+ content_type: str,
69
+ k: int = 6,
70
+ global_quality_mean: float = 4.2,
71
+ shrinkage_alpha: float = 10.0,
72
+ freshness_tau_days: float = 28.0) -> List[Dict]:
73
+ """
74
+ Production-grade hybrid retrieval with proper score normalization:
75
+ - Semantic similarity (cosine normalized to [0,1])
76
+ - BM25/TF-IDF similarity (min-max normalized per query)
77
+ - Quality scores (Bayesian shrinkage)
78
+ - Freshness boost (exponential decay)
79
+ - Policy-learned weights
80
+ """
81
+
82
+ # Get policy weights for this persona/content_type
83
+ weights = self._get_policy_weights(persona, content_type)
84
+
85
+ with get_session() as ses:
86
+ # Get all relevant scripts
87
+ scripts = list(ses.exec(
88
+ select(Script).where(
89
+ Script.creator == persona,
90
+ Script.content_type == content_type,
91
+ Script.is_reference == True,
92
+ Script.compliance != "fail"
93
+ )
94
+ ))
95
+
96
+ if not scripts:
97
+ return []
98
+
99
+ # Get embeddings for semantic similarity
100
+ embeddings = list(ses.exec(
101
+ select(Embedding).join(Script, Embedding.script_id == Script.id).where(
102
+ Embedding.part == 'full',
103
+ Script.creator == persona,
104
+ Script.content_type == content_type,
105
+ Script.is_reference == True,
106
+ Script.compliance != "fail"
107
+ )
108
+ ))
109
+
110
+ # Pre-calculate all raw scores for normalization
111
+ raw_scores = []
112
+ query_embedding = self.encoder.encode(query_text)
113
+ now = datetime.utcnow()
114
+
115
+ for script in scripts:
116
+ # Find matching embedding
117
+ script_embedding = next(
118
+ (e for e in embeddings if e.script_id == script.id),
119
+ None
120
+ )
121
+
122
+ # 1. Raw semantic similarity (cosine returns [-1,1])
123
+ if script_embedding:
124
+ raw_cosine = cosine_similarity(
125
+ [query_embedding],
126
+ [script_embedding.vector]
127
+ )[0][0]
128
+ else:
129
+ raw_cosine = -1.0 # Worst case for missing embeddings
130
+
131
+ # 2. Raw BM25/TF-IDF similarity
132
+ script_text = self._get_full_text(script)
133
+ raw_bm25 = self._calculate_tfidf_similarity(query_text, script_text)
134
+
135
+ raw_scores.append({
136
+ 'script': script,
137
+ 'raw_cosine': raw_cosine,
138
+ 'raw_bm25': raw_bm25
139
+ })
140
+
141
+ # Normalize BM25 scores (min-max normalization across this query's candidates)
142
+ bm25_scores = [s['raw_bm25'] for s in raw_scores]
143
+ min_bm25 = min(bm25_scores)
144
+ max_bm25 = max(bm25_scores)
145
+ bm25_range = max_bm25 - min_bm25 + 1e-9 # Avoid division by zero
146
+
147
+ # Calculate final normalized scores
148
+ results = []
149
+
150
+ for raw_score in raw_scores:
151
+ script = raw_score['script']
152
+ scores = {}
153
+
154
+ # 1. Semantic similarity: normalize cosine [-1,1] → [0,1]
155
+ scores['semantic'] = (raw_score['raw_cosine'] + 1.0) / 2.0
156
+
157
+ # 2. BM25: min-max normalize within this query's candidate set
158
+ scores['bm25'] = (raw_score['raw_bm25'] - min_bm25) / bm25_range
159
+
160
+ # 3. Quality: Bayesian shrinkage toward global mean
161
+ n_ratings = script.ratings_count or 0
162
+ local_quality = script.score_overall or global_quality_mean
163
+
164
+ # Shrinkage: blend local mean with global mean based on sample size
165
+ shrunk_quality = (
166
+ (n_ratings / (n_ratings + shrinkage_alpha)) * local_quality +
167
+ (shrinkage_alpha / (n_ratings + shrinkage_alpha)) * global_quality_mean
168
+ )
169
+
170
+ # Normalize to [0,1] (assuming 1-5 rating scale)
171
+ scores['quality'] = max(0.0, min(1.0, (shrunk_quality - 1) / 4))
172
+
173
+ # 4. Freshness: exponential decay (smoother than linear)
174
+ days_old = max(0, (now - script.created_at).days)
175
+ scores['freshness'] = math.exp(-days_old / freshness_tau_days)
176
+
177
+ # Combined score using policy weights
178
+ combined_score = (
179
+ weights.semantic_weight * scores['semantic'] +
180
+ weights.bm25_weight * scores['bm25'] +
181
+ weights.quality_weight * scores['quality'] +
182
+ weights.freshness_weight * scores['freshness']
183
+ )
184
+
185
+ results.append({
186
+ 'script': script,
187
+ 'score': combined_score,
188
+ 'component_scores': scores,
189
+ # Debug info
190
+ '_debug': {
191
+ 'n_ratings': n_ratings,
192
+ 'raw_quality': local_quality,
193
+ 'shrunk_quality': shrunk_quality,
194
+ 'days_old': days_old
195
+ }
196
+ })
197
+
198
+ # Sort by combined score and return top k
199
+ results.sort(key=lambda x: x['score'], reverse=True)
200
+ return results[:k]
201
+
202
+ def _calculate_tfidf_similarity(self, query: str, doc: str) -> float:
203
+ """Calculate TF-IDF similarity between query and document"""
204
+ try:
205
+ tfidf_matrix = self.tfidf.fit_transform([query, doc])
206
+ similarity = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]
207
+ return float(similarity)
208
+ except:
209
+ return 0.0
210
+
211
+ def _get_policy_weights(self, persona: str, content_type: str) -> PolicyWeights:
212
+ """Get learned policy weights or create defaults"""
213
+ with get_session() as ses:
214
+ weights = ses.exec(
215
+ select(PolicyWeights).where(
216
+ PolicyWeights.persona == persona,
217
+ PolicyWeights.content_type == content_type
218
+ )
219
+ ).first()
220
+
221
+ if not weights:
222
+ # Create default weights
223
+ weights = PolicyWeights(
224
+ persona=persona,
225
+ content_type=content_type
226
+ )
227
+ ses.add(weights)
228
+ ses.commit()
229
+ ses.refresh(weights)
230
+
231
+ return weights
232
+
233
+ def build_dynamic_few_shot_pack(self,
234
+ persona: str,
235
+ content_type: str,
236
+ query_context: str = "") -> Dict:
237
+ """Build dynamic few-shot examples pack optimized for this request"""
238
+
239
+ # Get best references via hybrid retrieval
240
+ references = self.hybrid_retrieve(
241
+ query_text=query_context or f"{persona} {content_type}",
242
+ persona=persona,
243
+ content_type=content_type,
244
+ k=6
245
+ )
246
+
247
+ if not references:
248
+ return {"style_card": "", "examples": [], "constraints": {}}
249
+
250
+ # Extract best examples by type
251
+ best_hooks = []
252
+ best_beats = []
253
+ best_captions = []
254
+
255
+ for ref in references[:4]: # Use top 4 references
256
+ script = ref['script']
257
+ if script.hook and len(best_hooks) < 2:
258
+ best_hooks.append(script.hook)
259
+ if script.beats and len(best_beats) < 1:
260
+ best_beats.extend(script.beats[:2]) # First 2 beats
261
+ if script.caption and len(best_captions) < 1:
262
+ best_captions.append(script.caption)
263
+
264
+ # Get or create style card
265
+ style_card = self._get_style_card(persona, content_type)
266
+
267
+ return {
268
+ "style_card": f"Persona: {persona} | Content: {content_type}",
269
+ "best_hooks": best_hooks[:2],
270
+ "best_beats": best_beats[:3],
271
+ "best_captions": best_captions[:1],
272
+ "constraints": {
273
+ "max_length": "15-25 seconds",
274
+ "compliance": "Instagram-safe",
275
+ "tone": references[0]['script'].tone if references else "playful"
276
+ },
277
+ "negative_patterns": style_card.negative_patterns if style_card else []
278
+ }
279
+
280
+ def _get_style_card(self, persona: str, content_type: str) -> Optional[StyleCard]:
281
+ """Get existing style card or return None"""
282
+ with get_session() as ses:
283
+ return ses.exec(
284
+ select(StyleCard).where(
285
+ StyleCard.persona == persona,
286
+ StyleCard.content_type == content_type
287
+ )
288
+ ).first()
289
+
290
+ def detect_copying(self,
291
+ generated_content: Dict,
292
+ reference_texts: List[str],
293
+ similarity_threshold: float = 0.92) -> Dict:
294
+ """
295
+ Detect if generated content is too similar to reference material.
296
+ Returns detection results with flagged content and similarity scores.
297
+
298
+ Args:
299
+ generated_content: Dict with keys like 'hook', 'caption', 'beats', etc.
300
+ reference_texts: List of reference text snippets to compare against
301
+ similarity_threshold: Cosine similarity threshold (0.92 recommended)
302
+
303
+ Returns:
304
+ Dict with detection results and recommendations
305
+ """
306
+
307
+ detection_results = {
308
+ 'is_copying': False,
309
+ 'flagged_fields': [],
310
+ 'max_similarity': 0.0,
311
+ 'rewrite_recommendations': []
312
+ }
313
+
314
+ if not reference_texts:
315
+ return detection_results
316
+
317
+ # Encode all reference texts
318
+ reference_embeddings = self.encoder.encode(reference_texts)
319
+
320
+ # Fields to check for copying
321
+ fields_to_check = ['hook', 'caption', 'cta']
322
+
323
+ for field in fields_to_check:
324
+ if field in generated_content and generated_content[field]:
325
+ generated_text = str(generated_content[field])
326
+
327
+ # Skip very short texts (less than 10 characters)
328
+ if len(generated_text.strip()) < 10:
329
+ continue
330
+
331
+ # Encode generated text
332
+ generated_embedding = self.encoder.encode([generated_text])
333
+
334
+ # Calculate similarity to all reference texts
335
+ similarities = cosine_similarity(generated_embedding, reference_embeddings)[0]
336
+ max_sim = float(np.max(similarities))
337
+
338
+ # Update overall max similarity
339
+ detection_results['max_similarity'] = max(detection_results['max_similarity'], max_sim)
340
+
341
+ # Check if similarity exceeds threshold
342
+ if max_sim >= similarity_threshold:
343
+ detection_results['is_copying'] = True
344
+ detection_results['flagged_fields'].append({
345
+ 'field': field,
346
+ 'text': generated_text,
347
+ 'similarity': max_sim,
348
+ 'similar_reference': reference_texts[int(np.argmax(similarities))]
349
+ })
350
+
351
+ # Generate rewrite recommendation
352
+ if max_sim >= 0.95:
353
+ urgency = "CRITICAL"
354
+ action = "Completely rewrite this content"
355
+ elif max_sim >= 0.92:
356
+ urgency = "HIGH"
357
+ action = "Significantly rephrase this content"
358
+ else:
359
+ urgency = "MEDIUM"
360
+ action = "Minor rewording may be needed"
361
+
362
+ detection_results['rewrite_recommendations'].append({
363
+ 'field': field,
364
+ 'urgency': urgency,
365
+ 'action': action,
366
+ 'original': generated_text
367
+ })
368
+
369
+ return detection_results
370
+
371
+ def auto_rewrite_similar_content(self,
372
+ generated_content: Dict,
373
+ detection_results: Dict,
374
+ rewrite_instruction: str = "Rewrite to be more original while keeping the same intent") -> Dict:
375
+ """
376
+ Automatically rewrite content that's too similar to references.
377
+
378
+ Args:
379
+ generated_content: The original generated content
380
+ detection_results: Results from detect_copying()
381
+ rewrite_instruction: Instructions for how to rewrite
382
+
383
+ Returns:
384
+ Rewritten content dict
385
+ """
386
+
387
+ if not detection_results['is_copying']:
388
+ return generated_content
389
+
390
+ rewritten_content = generated_content.copy()
391
+
392
+ for flag in detection_results['flagged_fields']:
393
+ field = flag['field']
394
+ original_text = flag['text']
395
+
396
+ # Simple rewrite strategy: add instruction to modify the text
397
+ # In a production system, you'd call the LLM to rewrite
398
+ rewrite_prompt = f"""
399
+ Original: {original_text}
400
+
401
+ This text is too similar to existing reference material.
402
+ Please rewrite it to be more original while keeping the same intent and tone.
403
+ Make it clearly different from the reference but equally engaging.
404
+
405
+ Rewritten version:
406
+ """
407
+
408
+ # For now, add a flag that this needs rewriting
409
+ # In production, you'd call your LLM API here
410
+ rewritten_content[field] = f"[NEEDS_REWRITE] {original_text}"
411
+
412
+ # Log the issue
413
+ print(f"🚨 Anti-copy detection: {field} flagged (similarity: {flag['similarity']:.3f})")
414
+ print(f" Original: {original_text[:60]}...")
415
+ print(f" Similar to: {flag['similar_reference'][:60]}...")
416
+
417
+ return rewritten_content
418
+
419
+ def index_all_scripts():
420
+ """Utility function to generate embeddings for all existing scripts"""
421
+ retriever = RAGRetriever()
422
+
423
+ with get_session() as ses:
424
+ scripts = list(ses.exec(select(Script)))
425
+
426
+ for script in scripts:
427
+ # Check if embeddings already exist
428
+ existing = ses.exec(
429
+ select(Embedding).where(Embedding.script_id == script.id)
430
+ ).first()
431
+
432
+ if not existing:
433
+ embeddings = retriever.generate_embeddings(script)
434
+ for embedding in embeddings:
435
+ ses.add(embedding)
436
+
437
+ print(f"Generated embeddings for script {script.id}")
438
+
439
+ ses.commit()
440
+ print(f"Indexing complete! Processed {len(scripts)} scripts.")
441
+
442
+ if __name__ == "__main__":
443
+ # Run this to index your existing scripts
444
+ index_all_scripts()
requirements.txt CHANGED
@@ -1,3 +1,16 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit>=1.37.1
2
+ sqlmodel>=0.0.16
3
+ pydantic>=1.10.15
4
+ python-dotenv>=1.0.1
5
+ requests>=2.32.3
6
+ sqlalchemy>=2.0.0
7
+
8
+ # RAG Enhancement Dependencies
9
+ sentence-transformers>=2.2.2
10
+ scikit-learn>=1.3.0
11
+ numpy>=1.24.0
12
+ faiss-cpu>=1.7.4
13
+
14
+ # Additional dependencies for deployment
15
+ torch>=2.0.0
16
+ transformers>=4.30.0