Spaces:
Sleeping
Sleeping
Upload 13 files
Browse files- .gitignore +34 -0
- README.md +62 -20
- app.py +844 -0
- auto_scorer.py +240 -0
- bandit_learner.py +330 -0
- compliance.py +26 -0
- db.py +248 -0
- deepseek_client.py +59 -0
- models.py +103 -0
- packages.txt +1 -0
- rag_integration.py +350 -0
- rag_retrieval.py +444 -0
- requirements.txt +16 -3
.gitignore
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Environment and secrets
|
| 2 |
+
.env
|
| 3 |
+
.streamlit/secrets.toml
|
| 4 |
+
secrets.toml
|
| 5 |
+
|
| 6 |
+
# Python cache
|
| 7 |
+
__pycache__/
|
| 8 |
+
*.pyc
|
| 9 |
+
*.pyo
|
| 10 |
+
*.pyd
|
| 11 |
+
.Python
|
| 12 |
+
*.so
|
| 13 |
+
|
| 14 |
+
# Database files
|
| 15 |
+
*.db
|
| 16 |
+
*.sqlite
|
| 17 |
+
*.sqlite3
|
| 18 |
+
|
| 19 |
+
# IDE files
|
| 20 |
+
.vscode/
|
| 21 |
+
.idea/
|
| 22 |
+
*.swp
|
| 23 |
+
*.swo
|
| 24 |
+
|
| 25 |
+
# OS files
|
| 26 |
+
.DS_Store
|
| 27 |
+
Thumbs.db
|
| 28 |
+
|
| 29 |
+
# Logs
|
| 30 |
+
*.log
|
| 31 |
+
|
| 32 |
+
# Temporary files
|
| 33 |
+
*.tmp
|
| 34 |
+
*.temp
|
README.md
CHANGED
|
@@ -1,20 +1,62 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk:
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
short_description:
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: AI Script Studio
|
| 3 |
+
emoji: 🎬
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
+
sdk: streamlit
|
| 7 |
+
sdk_version: 1.37.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
short_description: Generate Instagram-ready scripts with AI-powered RAG system
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# 🎬 AI Script Studio
|
| 15 |
+
|
| 16 |
+
Generate Instagram-ready scripts with AI using advanced RAG (Retrieval-Augmented Generation) system.
|
| 17 |
+
|
| 18 |
+
## Features
|
| 19 |
+
|
| 20 |
+
- 🤖 **AI-Powered Generation**: Uses DeepSeek API for high-quality script generation
|
| 21 |
+
- 🧠 **RAG System**: Retrieval-Augmented Generation with semantic search
|
| 22 |
+
- 📊 **Multi-Armed Bandit Learning**: Self-improving generation policies
|
| 23 |
+
- 🎯 **Auto-Scoring**: LLM-based quality assessment
|
| 24 |
+
- 📈 **Rating System**: Human feedback integration with learning
|
| 25 |
+
- 🎨 **Multiple Personas**: Support for different creator styles
|
| 26 |
+
- 📝 **Content Types**: Various Instagram content formats
|
| 27 |
+
|
| 28 |
+
## How It Works
|
| 29 |
+
|
| 30 |
+
1. **Reference Retrieval**: Uses semantic search to find relevant examples
|
| 31 |
+
2. **Policy Learning**: Multi-armed bandit optimizes generation parameters
|
| 32 |
+
3. **AI Generation**: Creates scripts using retrieved references
|
| 33 |
+
4. **Auto-Scoring**: LLM judges quality across 5 dimensions
|
| 34 |
+
5. **Learning Loop**: System improves based on feedback
|
| 35 |
+
|
| 36 |
+
## Usage
|
| 37 |
+
|
| 38 |
+
1. Select your creator persona
|
| 39 |
+
2. Choose content type and tone
|
| 40 |
+
3. Add reference examples (optional)
|
| 41 |
+
4. Generate scripts with AI
|
| 42 |
+
5. Rate and provide feedback
|
| 43 |
+
6. System learns and improves
|
| 44 |
+
|
| 45 |
+
## Technical Stack
|
| 46 |
+
|
| 47 |
+
- **Frontend**: Streamlit
|
| 48 |
+
- **AI**: DeepSeek API
|
| 49 |
+
- **RAG**: Sentence Transformers + FAISS
|
| 50 |
+
- **Database**: SQLite with SQLModel
|
| 51 |
+
- **Learning**: Multi-armed bandit algorithms
|
| 52 |
+
- **Scoring**: LLM-based evaluation
|
| 53 |
+
|
| 54 |
+
## Setup
|
| 55 |
+
|
| 56 |
+
1. Add your DeepSeek API key to the secrets
|
| 57 |
+
2. The app will automatically initialize the database
|
| 58 |
+
3. Start generating scripts!
|
| 59 |
+
|
| 60 |
+
## API Key
|
| 61 |
+
|
| 62 |
+
Get your free API key at: https://platform.deepseek.com/api_keys
|
app.py
ADDED
|
@@ -0,0 +1,844 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, streamlit as st
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
from sqlmodel import select
|
| 4 |
+
from db import init_db, get_session, add_rating
|
| 5 |
+
from models import Script, Revision
|
| 6 |
+
from deepseek_client import generate_scripts, revise_for, selective_rewrite
|
| 7 |
+
# Lazy import for RAG system to improve startup time
|
| 8 |
+
# from rag_integration import generate_scripts_rag
|
| 9 |
+
from compliance import blob_from, score_script
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
# Configure page - MUST be first Streamlit command
|
| 13 |
+
st.set_page_config(
|
| 14 |
+
page_title="🎬 AI Script Studio",
|
| 15 |
+
layout="wide",
|
| 16 |
+
initial_sidebar_state="expanded"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
def script_to_json_dict(script):
|
| 20 |
+
"""Convert script to JSON-serializable dictionary"""
|
| 21 |
+
data = script.model_dump()
|
| 22 |
+
# Remove datetime fields that cause JSON serialization issues
|
| 23 |
+
data.pop('created_at', None)
|
| 24 |
+
data.pop('updated_at', None)
|
| 25 |
+
return data
|
| 26 |
+
|
| 27 |
+
# Load environment - works both locally and on Hugging Face Spaces
|
| 28 |
+
load_dotenv()
|
| 29 |
+
|
| 30 |
+
# Initialize database with error handling for cloud deployment
|
| 31 |
+
try:
|
| 32 |
+
init_db()
|
| 33 |
+
st.sidebar.write("✅ Database initialized successfully")
|
| 34 |
+
except Exception as e:
|
| 35 |
+
st.sidebar.write(f"⚠️ Database init warning: {str(e)}")
|
| 36 |
+
# Continue anyway - some features may be limited
|
| 37 |
+
|
| 38 |
+
# Check for API key in Streamlit secrets or environment
|
| 39 |
+
api_key = st.secrets.get("DEEPSEEK_API_KEY") if hasattr(st, 'secrets') and "DEEPSEEK_API_KEY" in st.secrets else os.getenv("DEEPSEEK_API_KEY")
|
| 40 |
+
|
| 41 |
+
# DEBUG INFO - remove after fixing
|
| 42 |
+
if hasattr(st, 'secrets'):
|
| 43 |
+
st.sidebar.write("🔍 DEBUG: Secrets available")
|
| 44 |
+
if "DEEPSEEK_API_KEY" in st.secrets:
|
| 45 |
+
st.sidebar.write("✅ DEEPSEEK_API_KEY found in secrets")
|
| 46 |
+
st.sidebar.write(f"🔑 Key length: {len(st.secrets['DEEPSEEK_API_KEY'])}")
|
| 47 |
+
st.sidebar.write(f"🔑 Key starts with: {st.secrets['DEEPSEEK_API_KEY'][:10]}...")
|
| 48 |
+
else:
|
| 49 |
+
st.sidebar.write("❌ DEEPSEEK_API_KEY NOT in secrets")
|
| 50 |
+
st.sidebar.write(f"Available secrets: {list(st.secrets.keys())}")
|
| 51 |
+
else:
|
| 52 |
+
st.sidebar.write("❌ No secrets available")
|
| 53 |
+
|
| 54 |
+
if not api_key:
|
| 55 |
+
st.error("🔑 **DeepSeek API Key Required**")
|
| 56 |
+
st.markdown("""
|
| 57 |
+
**For Local Development:**
|
| 58 |
+
- Create a `.env` file and add: `DEEPSEEK_API_KEY=your_key_here`
|
| 59 |
+
|
| 60 |
+
**For Streamlit Cloud:**
|
| 61 |
+
- Go to your app settings → Secrets
|
| 62 |
+
- Add: `DEEPSEEK_API_KEY = "your_key_here"`
|
| 63 |
+
|
| 64 |
+
Get your free API key at: https://platform.deepseek.com/api_keys
|
| 65 |
+
""")
|
| 66 |
+
st.stop()
|
| 67 |
+
else:
|
| 68 |
+
st.sidebar.write("✅ API key loaded successfully")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Custom CSS for better styling
|
| 72 |
+
st.markdown("""
|
| 73 |
+
<style>
|
| 74 |
+
.main-header {
|
| 75 |
+
text-align: center;
|
| 76 |
+
padding: 1rem;
|
| 77 |
+
background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
|
| 78 |
+
color: white;
|
| 79 |
+
border-radius: 10px;
|
| 80 |
+
margin-bottom: 2rem;
|
| 81 |
+
}
|
| 82 |
+
.step-container {
|
| 83 |
+
border: 2px solid #e1e1e1;
|
| 84 |
+
border-radius: 10px;
|
| 85 |
+
padding: 1rem;
|
| 86 |
+
margin-bottom: 1rem;
|
| 87 |
+
background-color: #f8f9fa;
|
| 88 |
+
}
|
| 89 |
+
.draft-card {
|
| 90 |
+
border: 1px solid #ddd;
|
| 91 |
+
border-radius: 8px;
|
| 92 |
+
padding: 0.8rem;
|
| 93 |
+
margin-bottom: 0.5rem;
|
| 94 |
+
background: white;
|
| 95 |
+
transition: all 0.2s ease;
|
| 96 |
+
}
|
| 97 |
+
.draft-card:hover {
|
| 98 |
+
box-shadow: 0 2px 8px rgba(0,0,0,0.1);
|
| 99 |
+
border-color: #667eea;
|
| 100 |
+
}
|
| 101 |
+
.success-box {
|
| 102 |
+
background-color: #d4edda;
|
| 103 |
+
border: 1px solid #c3e6cb;
|
| 104 |
+
border-radius: 5px;
|
| 105 |
+
padding: 1rem;
|
| 106 |
+
margin: 1rem 0;
|
| 107 |
+
}
|
| 108 |
+
</style>
|
| 109 |
+
""", unsafe_allow_html=True)
|
| 110 |
+
|
| 111 |
+
# Header
|
| 112 |
+
st.markdown("""
|
| 113 |
+
<div class="main-header">
|
| 114 |
+
<h1>🎬 AI Script Studio</h1>
|
| 115 |
+
<p>Generate Instagram-ready scripts with AI • Powered by DeepSeek</p>
|
| 116 |
+
</div>
|
| 117 |
+
""", unsafe_allow_html=True)
|
| 118 |
+
|
| 119 |
+
# Initialize session state
|
| 120 |
+
if 'generation_step' not in st.session_state:
|
| 121 |
+
st.session_state.generation_step = 'setup'
|
| 122 |
+
if 'generated_count' not in st.session_state:
|
| 123 |
+
st.session_state.generated_count = 0
|
| 124 |
+
|
| 125 |
+
# Sidebar - Generation Controls
|
| 126 |
+
with st.sidebar:
|
| 127 |
+
st.header("🎯 Script Generation")
|
| 128 |
+
|
| 129 |
+
# Step 1: Basic Settings
|
| 130 |
+
with st.expander("📝 Step 1: Basic Settings", expanded=True):
|
| 131 |
+
# Dynamic creator dropdown (pull from database + defaults)
|
| 132 |
+
with get_session() as ses:
|
| 133 |
+
db_creators = list(ses.exec(select(Script.creator).distinct()))
|
| 134 |
+
db_creator_names = [c for c in db_creators if c]
|
| 135 |
+
|
| 136 |
+
default_creators = ["Creator A", "Emily", "Anya", "Ava Cherrry", "Ava Xreyess", "FitBryceAdams", "RealCarlyJane", "Sophie Rain", "Zoe AloneAtHome"]
|
| 137 |
+
all_creators = list(set(default_creators + db_creator_names))
|
| 138 |
+
creator_options = sorted(all_creators)
|
| 139 |
+
creator = st.selectbox(
|
| 140 |
+
"Creator Name",
|
| 141 |
+
creator_options,
|
| 142 |
+
help="Choose from existing creators or your imported scripts"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Expanded content types
|
| 146 |
+
content_type = st.selectbox(
|
| 147 |
+
"Content Type",
|
| 148 |
+
["thirst-trap", "skit", "reaction-prank", "talking-style", "lifestyle", "fake-podcast", "dance-trend", "voice-tease-asmr"],
|
| 149 |
+
help="Choose the type of content you want to create"
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# Multi-select tones
|
| 153 |
+
tone_options = ["naughty", "playful", "suggestive", "funny", "flirty", "bratty", "teasing", "intimate", "witty", "comedic", "confident", "wholesome", "asmr-voice"]
|
| 154 |
+
selected_tones = st.multiselect(
|
| 155 |
+
"Tone/Vibe (select multiple)",
|
| 156 |
+
tone_options,
|
| 157 |
+
default=["playful"],
|
| 158 |
+
help="Choose one or more tones - scripts often blend 2-3 vibes"
|
| 159 |
+
)
|
| 160 |
+
tone = ", ".join(selected_tones) if selected_tones else "playful"
|
| 161 |
+
|
| 162 |
+
n = st.slider(
|
| 163 |
+
"Number of drafts",
|
| 164 |
+
min_value=1,
|
| 165 |
+
max_value=20,
|
| 166 |
+
value=6,
|
| 167 |
+
help="How many script variations to generate"
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Step 2: Persona & Style
|
| 171 |
+
with st.expander("👤 Step 2: Persona & Style", expanded=True):
|
| 172 |
+
# Persona presets
|
| 173 |
+
persona_presets = {
|
| 174 |
+
"Girl-next-door": "girl-next-door; playful; witty; approachable",
|
| 175 |
+
"Bratty tease": "bratty; teasing; demanding; playful attitude",
|
| 176 |
+
"Dominant/In control": "confident; in control; commanding; assertive",
|
| 177 |
+
"Innocent but suggestive": "innocent; sweet; accidentally suggestive; naive charm",
|
| 178 |
+
"Party girl": "outgoing; fun; social; party vibes; energetic",
|
| 179 |
+
"Gym fitspo": "fitness focused; motivational; athletic; body confident",
|
| 180 |
+
"ASMR/Voice fetish": "soft spoken; intimate; soothing; sensual voice",
|
| 181 |
+
"Girlfriend experience": "loving; intimate; caring; relationship vibes",
|
| 182 |
+
"Funny meme-style": "comedic; meme references; internet culture; quirky",
|
| 183 |
+
"Candid/Lifestyle": "authentic; relatable; everyday life; natural"
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
col1, col2 = st.columns([0.6, 0.4])
|
| 187 |
+
with col1:
|
| 188 |
+
persona_preset = st.selectbox(
|
| 189 |
+
"Persona Preset",
|
| 190 |
+
["Custom"] + list(persona_presets.keys()),
|
| 191 |
+
help="Choose a preset or use custom"
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
with col2:
|
| 195 |
+
if persona_preset != "Custom":
|
| 196 |
+
if st.button("📋 Use Preset", use_container_width=True):
|
| 197 |
+
st.session_state.persona_text = persona_presets[persona_preset]
|
| 198 |
+
|
| 199 |
+
persona = st.text_area(
|
| 200 |
+
"Persona Description",
|
| 201 |
+
value=st.session_state.get('persona_text', "girl-next-door; playful; witty"),
|
| 202 |
+
help="Describe the character/personality for the scripts"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
# Compliance/Boundaries presets
|
| 206 |
+
boundary_presets = {
|
| 207 |
+
"Safe IG mode": "No explicit words; no sexual acts; suggestive only; no banned IG terms; keep it flirty but clean",
|
| 208 |
+
"Spicy mode": "Innuendos allowed; suggestive language OK; no explicit acts; can be naughty but not graphic",
|
| 209 |
+
"Brand-safe": "No swearing; no sex references; just flirty and fun; wholesome with hint of tease",
|
| 210 |
+
"Mild NSFW": "Moaning sounds OK; wet references allowed; squirt innuendo OK; suggestive but not explicit",
|
| 211 |
+
"Platform optimized": "Avoid flagged keywords; use creative euphemisms; suggestive storytelling style"
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
col1, col2 = st.columns([0.6, 0.4])
|
| 215 |
+
with col1:
|
| 216 |
+
boundary_preset = st.selectbox(
|
| 217 |
+
"Compliance Preset",
|
| 218 |
+
["Custom"] + list(boundary_presets.keys()),
|
| 219 |
+
help="Choose platform-appropriate safety rules"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
with col2:
|
| 223 |
+
if boundary_preset != "Custom":
|
| 224 |
+
if st.button("🛡️ Use Preset", use_container_width=True):
|
| 225 |
+
st.session_state.boundaries_text = boundary_presets[boundary_preset]
|
| 226 |
+
|
| 227 |
+
boundaries = st.text_area(
|
| 228 |
+
"Content Boundaries",
|
| 229 |
+
value=st.session_state.get('boundaries_text', "No explicit words; no solicitation; no age refs"),
|
| 230 |
+
help="What should the AI avoid? Set your safety guidelines here"
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# Step 3: Advanced Options
|
| 234 |
+
with st.expander("⚡ Step 3: Advanced Options", expanded=False):
|
| 235 |
+
col1, col2 = st.columns(2)
|
| 236 |
+
|
| 237 |
+
with col1:
|
| 238 |
+
# Hook style
|
| 239 |
+
hook_style = st.selectbox(
|
| 240 |
+
"Hook Style",
|
| 241 |
+
["Auto", "Question", "Confession", "Contrarian", "PSA", "Tease", "Command", "Shock"],
|
| 242 |
+
help="How should the hook start?"
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
# Length
|
| 246 |
+
length = st.selectbox(
|
| 247 |
+
"Target Length",
|
| 248 |
+
["Auto", "Short (5-7s)", "Medium (8-12s)", "Longer (13-20s)"],
|
| 249 |
+
help="How long should the script be?"
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
# Risk level
|
| 253 |
+
risk_level = st.slider(
|
| 254 |
+
"Risk Level",
|
| 255 |
+
min_value=1,
|
| 256 |
+
max_value=5,
|
| 257 |
+
value=3,
|
| 258 |
+
help="1=Safe, 3=Suggestive, 5=Spicy"
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
with col2:
|
| 262 |
+
# Retention gimmick
|
| 263 |
+
retention = st.selectbox(
|
| 264 |
+
"Retention Hook",
|
| 265 |
+
["Auto", "Twist ending", "Shock reveal", "Naughty payoff", "Innocent→dirty flip", "Cliffhanger"],
|
| 266 |
+
help="How to keep viewers watching?"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Shot type
|
| 270 |
+
shot_type = st.selectbox(
|
| 271 |
+
"Shot Type",
|
| 272 |
+
["Auto", "POV", "Selfie cam", "Tripod", "Over-the-shoulder", "Mirror shot"],
|
| 273 |
+
help="Camera angle/perspective"
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# Wardrobe
|
| 277 |
+
wardrobe = st.selectbox(
|
| 278 |
+
"Wardrobe/Setting",
|
| 279 |
+
["Auto", "Gym fit", "Bikini", "Bed outfit", "Towel", "Dress", "Casual", "Kitchen", "Car"],
|
| 280 |
+
help="Setting or outfit context"
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
# Step 4: Optional References
|
| 284 |
+
with st.expander("📚 Step 4: Extra References (Optional)", expanded=False):
|
| 285 |
+
st.info("💡 The AI automatically uses your database references, but you can add more here")
|
| 286 |
+
refs_text = st.text_area(
|
| 287 |
+
"Additional Reference Lines",
|
| 288 |
+
value="",
|
| 289 |
+
height=100,
|
| 290 |
+
help="Add extra inspiration lines (one per line)"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
# Generation Button
|
| 294 |
+
st.markdown("---")
|
| 295 |
+
|
| 296 |
+
# Show reference count
|
| 297 |
+
from db import get_hybrid_refs
|
| 298 |
+
|
| 299 |
+
# Map new content types to existing database types for compatibility
|
| 300 |
+
content_type_mapping = {
|
| 301 |
+
"thirst-trap": "talking_style / thirst_trap",
|
| 302 |
+
"skit": "comedy",
|
| 303 |
+
"reaction-prank": "prank",
|
| 304 |
+
"talking-style": "talking_style",
|
| 305 |
+
"lifestyle": "lifestyle",
|
| 306 |
+
"fake-podcast": "fake-podcast",
|
| 307 |
+
"dance-trend": "trend-adaptation",
|
| 308 |
+
"voice-tease-asmr": "talking_style"
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
mapped_content_type = content_type_mapping.get(content_type, content_type)
|
| 312 |
+
ref_count = len(get_hybrid_refs(creator, mapped_content_type, k=6))
|
| 313 |
+
|
| 314 |
+
st.info(f"🤖 AI will use {ref_count} database references + your extras")
|
| 315 |
+
|
| 316 |
+
generate_button = st.button(
|
| 317 |
+
"🚀 Generate Scripts",
|
| 318 |
+
type="primary",
|
| 319 |
+
use_container_width=True
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# Generation Process
|
| 323 |
+
if generate_button:
|
| 324 |
+
with st.spinner("🧠 AI is creating your scripts..."):
|
| 325 |
+
try:
|
| 326 |
+
# Get manual refs from text area
|
| 327 |
+
manual_refs = [x.strip() for x in refs_text.split("\n") if x.strip()]
|
| 328 |
+
|
| 329 |
+
# Get automatic refs from selected creator scripts in database using content type mapping
|
| 330 |
+
auto_refs = get_hybrid_refs(creator, mapped_content_type, k=6)
|
| 331 |
+
|
| 332 |
+
# Combine both
|
| 333 |
+
all_refs = manual_refs + auto_refs
|
| 334 |
+
|
| 335 |
+
# Progress indicator
|
| 336 |
+
progress_bar = st.progress(0)
|
| 337 |
+
status_text = st.empty()
|
| 338 |
+
|
| 339 |
+
status_text.text("🔍 Analyzing references...")
|
| 340 |
+
progress_bar.progress(25)
|
| 341 |
+
time.sleep(0.5)
|
| 342 |
+
|
| 343 |
+
status_text.text("🧠 RAG system selecting optimal references...")
|
| 344 |
+
progress_bar.progress(40)
|
| 345 |
+
time.sleep(0.3)
|
| 346 |
+
|
| 347 |
+
status_text.text("✨ Generating enhanced content with AI learning...")
|
| 348 |
+
progress_bar.progress(60)
|
| 349 |
+
|
| 350 |
+
# Build enhanced prompt from advanced options
|
| 351 |
+
advanced_prompt = ""
|
| 352 |
+
if hook_style != "Auto":
|
| 353 |
+
advanced_prompt += f"Hook style: {hook_style}. "
|
| 354 |
+
if length != "Auto":
|
| 355 |
+
advanced_prompt += f"Target length: {length}. "
|
| 356 |
+
if retention != "Auto":
|
| 357 |
+
advanced_prompt += f"Retention strategy: {retention}. "
|
| 358 |
+
if shot_type != "Auto":
|
| 359 |
+
advanced_prompt += f"Shot type: {shot_type}. "
|
| 360 |
+
if wardrobe != "Auto":
|
| 361 |
+
advanced_prompt += f"Setting/wardrobe: {wardrobe}. "
|
| 362 |
+
if risk_level != 3:
|
| 363 |
+
risk_desc = {1: "very safe", 2: "mild", 3: "suggestive", 4: "spicy", 5: "very spicy"}
|
| 364 |
+
advanced_prompt += f"Risk level: {risk_desc[risk_level]}. "
|
| 365 |
+
|
| 366 |
+
# Enhance boundaries with advanced prompt
|
| 367 |
+
enhanced_boundaries = boundaries
|
| 368 |
+
if advanced_prompt:
|
| 369 |
+
enhanced_boundaries += f"\n\nADVANCED GUIDANCE: {advanced_prompt}"
|
| 370 |
+
|
| 371 |
+
# Generate scripts with enhanced RAG system (lazy import)
|
| 372 |
+
try:
|
| 373 |
+
from rag_integration import generate_scripts_rag
|
| 374 |
+
drafts = generate_scripts_rag(persona, enhanced_boundaries, content_type, tone, all_refs, n=n)
|
| 375 |
+
except ImportError as e:
|
| 376 |
+
st.warning(f"RAG system not available: {e}. Using fallback generation.")
|
| 377 |
+
# Fallback to simple generation
|
| 378 |
+
drafts = generate_scripts(enhanced_boundaries, n)
|
| 379 |
+
|
| 380 |
+
progress_bar.progress(75)
|
| 381 |
+
status_text.text("💾 Saving to database...")
|
| 382 |
+
|
| 383 |
+
# Save to database
|
| 384 |
+
with get_session() as ses:
|
| 385 |
+
for d in drafts:
|
| 386 |
+
lvl, _ = score_script(" ".join([d.get("title",""), d.get("hook",""), *d.get("beats",[]), d.get("voiceover",""), d.get("caption",""), d.get("cta","")]))
|
| 387 |
+
s = Script(
|
| 388 |
+
creator=creator, content_type=content_type, tone=tone,
|
| 389 |
+
title=d["title"], hook=d["hook"], beats=d["beats"],
|
| 390 |
+
voiceover=d["voiceover"], caption=d["caption"],
|
| 391 |
+
hashtags=d.get("hashtags",[]), cta=d.get("cta",""),
|
| 392 |
+
compliance=lvl, source="ai"
|
| 393 |
+
)
|
| 394 |
+
ses.add(s)
|
| 395 |
+
ses.commit()
|
| 396 |
+
|
| 397 |
+
progress_bar.progress(100)
|
| 398 |
+
status_text.text("")
|
| 399 |
+
progress_bar.empty()
|
| 400 |
+
|
| 401 |
+
st.session_state.generated_count += len(drafts)
|
| 402 |
+
st.success(f"🎉 Generated {len(drafts)} scripts successfully!")
|
| 403 |
+
|
| 404 |
+
# Show which refs were used and advanced options
|
| 405 |
+
col1, col2 = st.columns(2)
|
| 406 |
+
with col1:
|
| 407 |
+
if auto_refs:
|
| 408 |
+
st.markdown("**🤖 Hybrid refs used this run:**")
|
| 409 |
+
for line in auto_refs[:3]: # Show first 3
|
| 410 |
+
st.write(f"• {line}")
|
| 411 |
+
|
| 412 |
+
with col2:
|
| 413 |
+
if advanced_prompt:
|
| 414 |
+
st.markdown("**⚡ Advanced options applied:**")
|
| 415 |
+
st.write(f"• {advanced_prompt[:100]}...")
|
| 416 |
+
st.write(f"**📊 Settings:** {tone} • {content_type}")
|
| 417 |
+
|
| 418 |
+
st.balloons()
|
| 419 |
+
|
| 420 |
+
# Auto-refresh to show new drafts
|
| 421 |
+
time.sleep(1)
|
| 422 |
+
st.rerun()
|
| 423 |
+
|
| 424 |
+
except Exception as e:
|
| 425 |
+
st.error(f"❌ Generation failed: {str(e)}")
|
| 426 |
+
st.write("💡 Try adjusting your parameters or check your API key")
|
| 427 |
+
|
| 428 |
+
# Quick Actions
|
| 429 |
+
st.markdown("---")
|
| 430 |
+
st.subheader("⚡ Quick Actions")
|
| 431 |
+
|
| 432 |
+
col1, col2 = st.columns(2)
|
| 433 |
+
with col1:
|
| 434 |
+
if st.button("🔄 Refresh", use_container_width=True):
|
| 435 |
+
st.rerun()
|
| 436 |
+
with col2:
|
| 437 |
+
if st.button("🗑️ Clear All", use_container_width=True, help="Delete all your generated scripts"):
|
| 438 |
+
if st.session_state.get('confirm_clear'):
|
| 439 |
+
with get_session() as ses:
|
| 440 |
+
scripts_to_delete = list(ses.exec(select(Script).where(Script.creator == creator, Script.source == "ai")))
|
| 441 |
+
for script in scripts_to_delete:
|
| 442 |
+
ses.delete(script)
|
| 443 |
+
ses.commit()
|
| 444 |
+
st.success("🗑️ All drafts cleared!")
|
| 445 |
+
st.session_state.confirm_clear = False
|
| 446 |
+
st.rerun()
|
| 447 |
+
else:
|
| 448 |
+
st.session_state.confirm_clear = True
|
| 449 |
+
st.warning("Click again to confirm deletion!")
|
| 450 |
+
|
| 451 |
+
# Main Area
|
| 452 |
+
tab1, tab2, tab3 = st.tabs(["📝 Draft Review", "🎯 Filters", "📊 Analytics"])
|
| 453 |
+
|
| 454 |
+
with tab1:
|
| 455 |
+
# Load drafts
|
| 456 |
+
with get_session() as ses:
|
| 457 |
+
q = select(Script).where(Script.creator == creator, Script.source == "ai")
|
| 458 |
+
all_drafts = list(ses.exec(q))
|
| 459 |
+
|
| 460 |
+
if not all_drafts:
|
| 461 |
+
st.markdown("""
|
| 462 |
+
<div style="text-align: center; padding: 3rem;">
|
| 463 |
+
<h3>🎬 Ready to Create Amazing Scripts?</h3>
|
| 464 |
+
<p style="font-size: 1.2rem; color: #666;">
|
| 465 |
+
👈 Use the sidebar to generate your first batch of AI scripts<br>
|
| 466 |
+
🤖 The AI will learn from successful examples in the database<br>
|
| 467 |
+
✨ Then review, edit, and perfect your scripts here
|
| 468 |
+
</p>
|
| 469 |
+
</div>
|
| 470 |
+
""", unsafe_allow_html=True)
|
| 471 |
+
|
| 472 |
+
if st.session_state.generated_count > 0:
|
| 473 |
+
st.info(f"🎉 You've generated {st.session_state.generated_count} scripts so far! Use filters to find them.")
|
| 474 |
+
else:
|
| 475 |
+
# Draft management
|
| 476 |
+
col1, col2 = st.columns([0.4, 0.6], gap="large")
|
| 477 |
+
|
| 478 |
+
with col1:
|
| 479 |
+
st.subheader(f"📋 Your Drafts ({len(all_drafts)})")
|
| 480 |
+
|
| 481 |
+
# Quick filters
|
| 482 |
+
filter_col1, filter_col2 = st.columns(2)
|
| 483 |
+
with filter_col1:
|
| 484 |
+
compliance_filter = st.selectbox(
|
| 485 |
+
"Compliance",
|
| 486 |
+
["All", "PASS", "WARN", "FAIL"],
|
| 487 |
+
key="compliance_filter"
|
| 488 |
+
)
|
| 489 |
+
with filter_col2:
|
| 490 |
+
sort_by = st.selectbox(
|
| 491 |
+
"Sort by",
|
| 492 |
+
["Newest", "Oldest", "Title"],
|
| 493 |
+
key="sort_filter"
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
# Apply filters
|
| 497 |
+
filtered_drafts = all_drafts
|
| 498 |
+
if compliance_filter != "All":
|
| 499 |
+
filtered_drafts = [d for d in filtered_drafts if d.compliance.upper() == compliance_filter]
|
| 500 |
+
|
| 501 |
+
# Apply sorting
|
| 502 |
+
if sort_by == "Newest":
|
| 503 |
+
filtered_drafts.sort(key=lambda x: x.created_at, reverse=True)
|
| 504 |
+
elif sort_by == "Oldest":
|
| 505 |
+
filtered_drafts.sort(key=lambda x: x.created_at)
|
| 506 |
+
else: # Title
|
| 507 |
+
filtered_drafts.sort(key=lambda x: x.title)
|
| 508 |
+
|
| 509 |
+
# Draft cards
|
| 510 |
+
selected_id = st.session_state.get("selected_id")
|
| 511 |
+
|
| 512 |
+
for draft in filtered_drafts:
|
| 513 |
+
# Compliance color coding
|
| 514 |
+
compliance_color = {
|
| 515 |
+
"pass": "🟢",
|
| 516 |
+
"warn": "🟡",
|
| 517 |
+
"fail": "🔴"
|
| 518 |
+
}.get(draft.compliance, "⚪")
|
| 519 |
+
|
| 520 |
+
# Create card
|
| 521 |
+
with st.container(border=True):
|
| 522 |
+
if st.button(
|
| 523 |
+
f"{compliance_color} {draft.title}",
|
| 524 |
+
key=f"select-{draft.id}",
|
| 525 |
+
use_container_width=True
|
| 526 |
+
):
|
| 527 |
+
st.session_state["selected_id"] = draft.id
|
| 528 |
+
selected_id = draft.id
|
| 529 |
+
|
| 530 |
+
st.caption(f"🎭 {draft.tone} • 📅 {draft.created_at.strftime('%m/%d %H:%M')}")
|
| 531 |
+
|
| 532 |
+
# Preview hook
|
| 533 |
+
if draft.hook:
|
| 534 |
+
st.markdown(f"*{draft.hook[:80]}{'...' if len(draft.hook) > 80 else ''}*")
|
| 535 |
+
|
| 536 |
+
with col2:
|
| 537 |
+
st.subheader("✏️ Script Editor")
|
| 538 |
+
|
| 539 |
+
if not filtered_drafts:
|
| 540 |
+
st.info("No drafts match your filters. Try adjusting the filter settings.")
|
| 541 |
+
else:
|
| 542 |
+
# Auto-select first draft if none selected
|
| 543 |
+
if not selected_id or selected_id not in [d.id for d in filtered_drafts]:
|
| 544 |
+
selected_id = filtered_drafts[0].id
|
| 545 |
+
st.session_state["selected_id"] = selected_id
|
| 546 |
+
|
| 547 |
+
# Get current draft
|
| 548 |
+
current = next((x for x in filtered_drafts if x.id == selected_id), filtered_drafts[0])
|
| 549 |
+
|
| 550 |
+
# Editor tabs
|
| 551 |
+
edit_tab1, edit_tab2, edit_tab3 = st.tabs(["📝 Edit", "🛠️ AI Tools", "📜 History"])
|
| 552 |
+
|
| 553 |
+
with edit_tab1:
|
| 554 |
+
# Main editing fields
|
| 555 |
+
with st.form("edit_script"):
|
| 556 |
+
title = st.text_input("Title", value=current.title)
|
| 557 |
+
hook = st.text_area("Hook", value=current.hook or "", height=80)
|
| 558 |
+
beats_text = st.text_area("Beats (one per line)", value="\n".join(current.beats or []), height=120)
|
| 559 |
+
voiceover = st.text_area("Voiceover", value=current.voiceover or "", height=80)
|
| 560 |
+
caption = st.text_area("Caption", value=current.caption or "", height=100)
|
| 561 |
+
# Clean up hashtags display - remove commas, show as space-separated
|
| 562 |
+
current_hashtags = current.hashtags or []
|
| 563 |
+
hashtags_display = " ".join(current_hashtags) if current_hashtags else ""
|
| 564 |
+
hashtags = st.text_input("Hashtags (space separated)", value=hashtags_display, help="Enter hashtags like: #gym #fitness #workout")
|
| 565 |
+
cta = st.text_input("Call to Action", value=current.cta or "")
|
| 566 |
+
|
| 567 |
+
# Submit button
|
| 568 |
+
if st.form_submit_button("💾 Save Changes", type="primary", use_container_width=True):
|
| 569 |
+
with get_session() as ses:
|
| 570 |
+
dbs = ses.get(Script, current.id)
|
| 571 |
+
dbs.title = title
|
| 572 |
+
dbs.hook = hook
|
| 573 |
+
dbs.beats = [x.strip() for x in beats_text.split("\n") if x.strip()]
|
| 574 |
+
dbs.voiceover = voiceover
|
| 575 |
+
dbs.caption = caption
|
| 576 |
+
# Parse hashtags from space-separated input
|
| 577 |
+
dbs.hashtags = [x.strip() for x in hashtags.split() if x.strip()]
|
| 578 |
+
dbs.cta = cta
|
| 579 |
+
|
| 580 |
+
# Update compliance
|
| 581 |
+
lvl, _ = score_script(blob_from(dbs.model_dump()))
|
| 582 |
+
dbs.compliance = lvl
|
| 583 |
+
|
| 584 |
+
ses.add(dbs)
|
| 585 |
+
ses.commit()
|
| 586 |
+
|
| 587 |
+
st.success("✅ Script saved successfully!")
|
| 588 |
+
time.sleep(1)
|
| 589 |
+
st.rerun()
|
| 590 |
+
|
| 591 |
+
# Rating widget
|
| 592 |
+
st.markdown("### Rate this script (feeds future generations)")
|
| 593 |
+
|
| 594 |
+
# Show current ratings if any
|
| 595 |
+
if current.ratings_count > 0:
|
| 596 |
+
st.info(f"📊 Current ratings ({current.ratings_count} ratings): Overall: {current.score_overall:.1f}/5.0, Hook: {current.score_hook:.1f}/5.0, Originality: {current.score_originality:.1f}/5.0")
|
| 597 |
+
|
| 598 |
+
with st.form("rate_script"):
|
| 599 |
+
colA, colB, colC, colD, colE = st.columns(5)
|
| 600 |
+
overall = colA.slider("Overall", 1.0, 5.0, 4.0, 0.5)
|
| 601 |
+
hook_s = colB.slider("Hook clarity", 1.0, 5.0, 4.0, 0.5)
|
| 602 |
+
orig_s = colC.slider("Originality", 1.0, 5.0, 4.0, 0.5)
|
| 603 |
+
fit_s = colD.slider("Style fit", 1.0, 5.0, 4.0, 0.5)
|
| 604 |
+
safe_s = colE.slider("Safety", 1.0, 5.0, 4.0, 0.5)
|
| 605 |
+
notes = st.text_input("Notes (optional)")
|
| 606 |
+
|
| 607 |
+
if st.form_submit_button("💫 Save rating", type="secondary", use_container_width=True):
|
| 608 |
+
add_rating(
|
| 609 |
+
script_id=current.id,
|
| 610 |
+
overall=overall, hook=hook_s, originality=orig_s,
|
| 611 |
+
style_fit=fit_s, safety=safe_s, notes=notes, rater="human"
|
| 612 |
+
)
|
| 613 |
+
st.success("Rating saved. Future generations will weigh this higher.")
|
| 614 |
+
time.sleep(1)
|
| 615 |
+
st.rerun()
|
| 616 |
+
|
| 617 |
+
with edit_tab2:
|
| 618 |
+
st.write("🤖 **AI-Powered Improvements**")
|
| 619 |
+
|
| 620 |
+
# Quick AI actions
|
| 621 |
+
col1, col2 = st.columns(2)
|
| 622 |
+
|
| 623 |
+
with col1:
|
| 624 |
+
if st.button("🛡️ Make Safer", use_container_width=True):
|
| 625 |
+
with st.spinner("Making content safer..."):
|
| 626 |
+
revised = revise_for("be Instagram-compliant and safer", script_to_json_dict(current), "Remove risky phrases; keep intent and beat order.")
|
| 627 |
+
with get_session() as ses:
|
| 628 |
+
dbs = ses.get(Script, current.id)
|
| 629 |
+
before = dbs.caption
|
| 630 |
+
dbs.caption = revised.get("caption", dbs.caption)
|
| 631 |
+
lvl, _ = score_script(blob_from(revised))
|
| 632 |
+
dbs.compliance = lvl
|
| 633 |
+
ses.add(dbs)
|
| 634 |
+
ses.commit()
|
| 635 |
+
ses.add(Revision(script_id=dbs.id, label="Auto safer", field="caption", before=before, after=dbs.caption))
|
| 636 |
+
ses.commit()
|
| 637 |
+
st.success("✅ Content made safer!")
|
| 638 |
+
st.rerun()
|
| 639 |
+
|
| 640 |
+
if st.button("✨ More Playful", use_container_width=True):
|
| 641 |
+
with st.spinner("Adding playful vibes..."):
|
| 642 |
+
revised = revise_for("be more playful (keep safe)", script_to_json_dict(current), "Increase playful tone without adding risk.")
|
| 643 |
+
with get_session() as ses:
|
| 644 |
+
dbs = ses.get(Script, current.id)
|
| 645 |
+
before = dbs.hook
|
| 646 |
+
dbs.hook = revised.get("hook", dbs.hook)
|
| 647 |
+
ses.add(dbs)
|
| 648 |
+
ses.commit()
|
| 649 |
+
ses.add(Revision(script_id=dbs.id, label="More playful", field="hook", before=before, after=dbs.hook))
|
| 650 |
+
ses.commit()
|
| 651 |
+
st.success("✨ Added playful energy!")
|
| 652 |
+
st.rerun()
|
| 653 |
+
|
| 654 |
+
with col2:
|
| 655 |
+
if st.button("✂️ Shorter Hook", use_container_width=True):
|
| 656 |
+
with st.spinner("Tightening hook..."):
|
| 657 |
+
revised = revise_for("shorten the hook to <= 8 words", script_to_json_dict(current), "Shorten only the hook, keep intent.")
|
| 658 |
+
with get_session() as ses:
|
| 659 |
+
dbs = ses.get(Script, current.id)
|
| 660 |
+
before = dbs.hook
|
| 661 |
+
dbs.hook = revised.get("hook", dbs.hook)
|
| 662 |
+
ses.add(dbs)
|
| 663 |
+
ses.commit()
|
| 664 |
+
ses.add(Revision(script_id=dbs.id, label="Shorter hook", field="hook", before=before, after=dbs.hook))
|
| 665 |
+
ses.commit()
|
| 666 |
+
st.success("✂️ Hook tightened!")
|
| 667 |
+
st.rerun()
|
| 668 |
+
|
| 669 |
+
if st.button("🇬🇧 Localize (UK)", use_container_width=True):
|
| 670 |
+
with st.spinner("Localizing content..."):
|
| 671 |
+
revised = revise_for("localize to UK English", script_to_json_dict(current), "Adjust spelling/phrasing to UK without changing content.")
|
| 672 |
+
with get_session() as ses:
|
| 673 |
+
dbs = ses.get(Script, current.id)
|
| 674 |
+
before = dbs.caption
|
| 675 |
+
dbs.caption = revised.get("caption", dbs.caption)
|
| 676 |
+
ses.add(dbs)
|
| 677 |
+
ses.commit()
|
| 678 |
+
ses.add(Revision(script_id=dbs.id, label="Localize UK", field="caption", before=before, after=dbs.caption))
|
| 679 |
+
ses.commit()
|
| 680 |
+
st.success("🇬🇧 Localized to UK!")
|
| 681 |
+
st.rerun()
|
| 682 |
+
|
| 683 |
+
# Custom rewrite section
|
| 684 |
+
st.markdown("---")
|
| 685 |
+
st.write("🎯 **Custom Rewrite**")
|
| 686 |
+
|
| 687 |
+
with st.form("custom_rewrite"):
|
| 688 |
+
rewrite_col1, rewrite_col2 = st.columns([0.6, 0.4])
|
| 689 |
+
|
| 690 |
+
with rewrite_col1:
|
| 691 |
+
field = st.selectbox("Field to Edit", ["title","hook","voiceover","caption","cta","beats"])
|
| 692 |
+
snippet = st.text_input("Exact text you want to change")
|
| 693 |
+
|
| 694 |
+
with rewrite_col2:
|
| 695 |
+
prompt = st.text_input("How to rewrite it")
|
| 696 |
+
|
| 697 |
+
if st.form_submit_button("🪄 Rewrite", use_container_width=True):
|
| 698 |
+
if snippet and prompt:
|
| 699 |
+
with st.spinner("AI is rewriting..."):
|
| 700 |
+
draft = script_to_json_dict(current)
|
| 701 |
+
revised = selective_rewrite(draft, field, snippet, prompt)
|
| 702 |
+
with get_session() as ses:
|
| 703 |
+
dbs = ses.get(Script, current.id)
|
| 704 |
+
before = getattr(dbs, field)
|
| 705 |
+
setattr(dbs, field, revised.get(field, before))
|
| 706 |
+
lvl, _ = score_script(blob_from(dbs.model_dump()))
|
| 707 |
+
dbs.compliance = lvl
|
| 708 |
+
ses.add(dbs)
|
| 709 |
+
ses.commit()
|
| 710 |
+
ses.add(Revision(script_id=dbs.id, label="Custom rewrite", field=field, before=str(before), after=str(getattr(dbs, field))))
|
| 711 |
+
ses.commit()
|
| 712 |
+
st.success("🪄 Rewrite complete!")
|
| 713 |
+
st.rerun()
|
| 714 |
+
else:
|
| 715 |
+
st.error("Please fill in both the text and rewrite instructions")
|
| 716 |
+
|
| 717 |
+
with edit_tab3:
|
| 718 |
+
st.write("📜 **Revision History**")
|
| 719 |
+
|
| 720 |
+
with get_session() as ses:
|
| 721 |
+
revisions = list(ses.exec(
|
| 722 |
+
select(Revision).where(Revision.script_id==current.id).order_by(Revision.created_at.desc())
|
| 723 |
+
))
|
| 724 |
+
|
| 725 |
+
if not revisions:
|
| 726 |
+
st.info("No revisions yet. Make some changes to see the history!")
|
| 727 |
+
else:
|
| 728 |
+
for rev in revisions:
|
| 729 |
+
with st.expander(f"🔄 {rev.label} • {rev.field} • {rev.created_at.strftime('%m/%d %H:%M')}"):
|
| 730 |
+
col1, col2 = st.columns(2)
|
| 731 |
+
with col1:
|
| 732 |
+
st.write("**Before:**")
|
| 733 |
+
st.code(rev.before)
|
| 734 |
+
with col2:
|
| 735 |
+
st.write("**After:**")
|
| 736 |
+
st.code(rev.after)
|
| 737 |
+
|
| 738 |
+
with tab2:
|
| 739 |
+
st.subheader("🎯 Advanced Filters & Search")
|
| 740 |
+
|
| 741 |
+
# Advanced filtering interface
|
| 742 |
+
filter_col1, filter_col2, filter_col3 = st.columns(3)
|
| 743 |
+
|
| 744 |
+
with filter_col1:
|
| 745 |
+
creator_filter = st.selectbox("Creator", ["All"] + ["Creator A", "Emily"])
|
| 746 |
+
content_filter = st.selectbox("Content Type", ["All"] + ["thirst-trap", "lifestyle", "comedy", "prank", "fake-podcast", "trend-adaptation"])
|
| 747 |
+
|
| 748 |
+
with filter_col2:
|
| 749 |
+
compliance_filter_adv = st.selectbox("Compliance Status", ["All", "PASS", "WARN", "FAIL"])
|
| 750 |
+
source_filter = st.selectbox("Source", ["All", "AI Generated", "Imported", "Manual"])
|
| 751 |
+
|
| 752 |
+
with filter_col3:
|
| 753 |
+
date_filter = st.selectbox("Date Range", ["All Time", "Today", "This Week", "This Month"])
|
| 754 |
+
search_text = st.text_input("🔍 Search in titles/content")
|
| 755 |
+
|
| 756 |
+
# Apply advanced filters and show results
|
| 757 |
+
with get_session() as ses:
|
| 758 |
+
query = select(Script)
|
| 759 |
+
|
| 760 |
+
# Apply filters
|
| 761 |
+
if creator_filter != "All":
|
| 762 |
+
query = query.where(Script.creator == creator_filter)
|
| 763 |
+
if content_filter != "All":
|
| 764 |
+
query = query.where(Script.content_type == content_filter)
|
| 765 |
+
if compliance_filter_adv != "All":
|
| 766 |
+
query = query.where(Script.compliance == compliance_filter_adv.lower())
|
| 767 |
+
|
| 768 |
+
filtered_results = list(ses.exec(query))
|
| 769 |
+
|
| 770 |
+
# Search in text
|
| 771 |
+
if search_text:
|
| 772 |
+
filtered_results = [
|
| 773 |
+
r for r in filtered_results
|
| 774 |
+
if search_text.lower() in r.title.lower() or
|
| 775 |
+
search_text.lower() in (r.hook or "").lower() or
|
| 776 |
+
search_text.lower() in (r.caption or "").lower()
|
| 777 |
+
]
|
| 778 |
+
|
| 779 |
+
st.write(f"**Found {len(filtered_results)} scripts**")
|
| 780 |
+
|
| 781 |
+
# Display filtered results
|
| 782 |
+
if filtered_results:
|
| 783 |
+
for script in filtered_results[:10]: # Show first 10
|
| 784 |
+
with st.expander(f"{script.compliance.upper()} • {script.title} • {script.creator}"):
|
| 785 |
+
st.write(f"**Hook:** {script.hook}")
|
| 786 |
+
st.write(f"**Type:** {script.content_type} • **Tone:** {script.tone}")
|
| 787 |
+
st.write(f"**Created:** {script.created_at.strftime('%Y-%m-%d %H:%M')}")
|
| 788 |
+
|
| 789 |
+
with tab3:
|
| 790 |
+
st.subheader("📊 Script Analytics")
|
| 791 |
+
|
| 792 |
+
# Get all scripts for analytics
|
| 793 |
+
with get_session() as ses:
|
| 794 |
+
all_scripts = list(ses.exec(select(Script)))
|
| 795 |
+
|
| 796 |
+
if all_scripts:
|
| 797 |
+
# Create metrics
|
| 798 |
+
col1, col2, col3, col4 = st.columns(4)
|
| 799 |
+
|
| 800 |
+
with col1:
|
| 801 |
+
st.metric("Total Scripts", len(all_scripts))
|
| 802 |
+
|
| 803 |
+
with col2:
|
| 804 |
+
ai_generated = len([s for s in all_scripts if s.source == "ai"])
|
| 805 |
+
st.metric("AI Generated", ai_generated)
|
| 806 |
+
|
| 807 |
+
with col3:
|
| 808 |
+
passed_compliance = len([s for s in all_scripts if s.compliance == "pass"])
|
| 809 |
+
st.metric("Compliance PASS", passed_compliance)
|
| 810 |
+
|
| 811 |
+
with col4:
|
| 812 |
+
unique_creators = len(set(s.creator for s in all_scripts))
|
| 813 |
+
st.metric("Creators", unique_creators)
|
| 814 |
+
|
| 815 |
+
# Charts and insights
|
| 816 |
+
st.markdown("### 📈 Content Insights")
|
| 817 |
+
|
| 818 |
+
# Compliance distribution
|
| 819 |
+
compliance_counts = {}
|
| 820 |
+
for script in all_scripts:
|
| 821 |
+
compliance_counts[script.compliance] = compliance_counts.get(script.compliance, 0) + 1
|
| 822 |
+
|
| 823 |
+
if compliance_counts:
|
| 824 |
+
st.bar_chart(compliance_counts)
|
| 825 |
+
|
| 826 |
+
# Content type distribution
|
| 827 |
+
type_counts = {}
|
| 828 |
+
for script in all_scripts:
|
| 829 |
+
type_counts[script.content_type] = type_counts.get(script.content_type, 0) + 1
|
| 830 |
+
|
| 831 |
+
if type_counts:
|
| 832 |
+
st.bar_chart(type_counts)
|
| 833 |
+
|
| 834 |
+
else:
|
| 835 |
+
st.info("📊 Generate some scripts to see analytics!")
|
| 836 |
+
|
| 837 |
+
# Footer
|
| 838 |
+
st.markdown("---")
|
| 839 |
+
st.markdown("""
|
| 840 |
+
<div style="text-align: center; color: #666; padding: 1rem;">
|
| 841 |
+
🎬 AI Script Studio • Built with Streamlit & DeepSeek AI<br>
|
| 842 |
+
💡 Tip: Generate scripts in batches, then refine with AI tools for best results
|
| 843 |
+
</div>
|
| 844 |
+
""", unsafe_allow_html=True)
|
auto_scorer.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Auto-scoring system using LLM judges for script quality assessment
|
| 3 |
+
Integrates with existing DeepSeek client
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
from typing import Dict, List, Tuple
|
| 8 |
+
from sqlmodel import Session, select
|
| 9 |
+
from datetime import datetime, timedelta
|
| 10 |
+
|
| 11 |
+
from models import Script, AutoScore, PolicyWeights
|
| 12 |
+
from db import get_session
|
| 13 |
+
from deepseek_client import chat
|
| 14 |
+
|
| 15 |
+
class AutoScorer:
|
| 16 |
+
def __init__(self, confidence_threshold: float = 0.7):
|
| 17 |
+
self.confidence_threshold = confidence_threshold
|
| 18 |
+
|
| 19 |
+
def score_script(self, script_data: Dict) -> Dict[str, float]:
|
| 20 |
+
"""
|
| 21 |
+
Score a script using LLM judge across 5 dimensions
|
| 22 |
+
Returns scores and confidence level
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
system_prompt = """You are an expert Instagram content analyst. Score this script on 5 dimensions (1-5 scale):
|
| 26 |
+
|
| 27 |
+
1. OVERALL: General quality and effectiveness (1=poor, 5=excellent)
|
| 28 |
+
2. HOOK: How compelling is the opening (1=boring, 5=irresistible)
|
| 29 |
+
3. ORIGINALITY: How unique/creative (1=generic, 5=highly original)
|
| 30 |
+
4. STYLE_FIT: How well it matches the persona (1=off-brand, 5=perfect fit)
|
| 31 |
+
5. SAFETY: Instagram compliance (1=risky, 5=completely safe)
|
| 32 |
+
|
| 33 |
+
Return ONLY a JSON object with: {"overall": X, "hook": X, "originality": X, "style_fit": X, "safety": X, "confidence": X, "reasoning": "brief explanation"}
|
| 34 |
+
|
| 35 |
+
Be consistent and objective. Confidence should be 0.1-1.0 based on how certain you are."""
|
| 36 |
+
|
| 37 |
+
user_prompt = f"""
|
| 38 |
+
Script to score:
|
| 39 |
+
Title: {script_data.get('title', '')}
|
| 40 |
+
Hook: {script_data.get('hook', '')}
|
| 41 |
+
Beats: {script_data.get('beats', [])}
|
| 42 |
+
Caption: {script_data.get('caption', '')}
|
| 43 |
+
Persona: {script_data.get('creator', '')}
|
| 44 |
+
Content Type: {script_data.get('content_type', '')}
|
| 45 |
+
Tone: {script_data.get('tone', '')}
|
| 46 |
+
|
| 47 |
+
Score this script now."""
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
response = chat([
|
| 51 |
+
{"role": "system", "content": system_prompt},
|
| 52 |
+
{"role": "user", "content": user_prompt}
|
| 53 |
+
], temperature=0.3) # Low temperature for consistent scoring
|
| 54 |
+
|
| 55 |
+
# Extract JSON from response
|
| 56 |
+
start = response.find("{")
|
| 57 |
+
end = response.rfind("}") + 1
|
| 58 |
+
|
| 59 |
+
if start >= 0 and end > start:
|
| 60 |
+
scores = json.loads(response[start:end])
|
| 61 |
+
|
| 62 |
+
# Validate scores are in range
|
| 63 |
+
required_keys = ['overall', 'hook', 'originality', 'style_fit', 'safety']
|
| 64 |
+
for key in required_keys:
|
| 65 |
+
if key not in scores or not (1 <= scores[key] <= 5):
|
| 66 |
+
raise ValueError(f"Invalid score for {key}")
|
| 67 |
+
|
| 68 |
+
# Ensure confidence is present and valid
|
| 69 |
+
if 'confidence' not in scores or not (0.1 <= scores['confidence'] <= 1.0):
|
| 70 |
+
scores['confidence'] = 0.7 # Default confidence
|
| 71 |
+
|
| 72 |
+
return scores
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError("No valid JSON found in response")
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
print(f"Auto-scoring failed: {e}")
|
| 78 |
+
# Return neutral scores with low confidence
|
| 79 |
+
return {
|
| 80 |
+
'overall': 3.0,
|
| 81 |
+
'hook': 3.0,
|
| 82 |
+
'originality': 3.0,
|
| 83 |
+
'style_fit': 3.0,
|
| 84 |
+
'safety': 3.0,
|
| 85 |
+
'confidence': 0.3,
|
| 86 |
+
'reasoning': f"Scoring failed: {str(e)}"
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
def score_and_store(self, script_id: int) -> AutoScore:
|
| 90 |
+
"""Score a script and store in database"""
|
| 91 |
+
with get_session() as ses:
|
| 92 |
+
script = ses.get(Script, script_id)
|
| 93 |
+
if not script:
|
| 94 |
+
raise ValueError(f"Script {script_id} not found")
|
| 95 |
+
|
| 96 |
+
# Prepare script data for scoring
|
| 97 |
+
script_data = {
|
| 98 |
+
'title': script.title,
|
| 99 |
+
'hook': script.hook,
|
| 100 |
+
'beats': script.beats,
|
| 101 |
+
'caption': script.caption,
|
| 102 |
+
'creator': script.creator,
|
| 103 |
+
'content_type': script.content_type,
|
| 104 |
+
'tone': script.tone
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
# Get scores
|
| 108 |
+
scores = self.score_script(script_data)
|
| 109 |
+
|
| 110 |
+
# Store auto-score
|
| 111 |
+
auto_score = AutoScore(
|
| 112 |
+
script_id=script_id,
|
| 113 |
+
overall=scores['overall'],
|
| 114 |
+
hook=scores['hook'],
|
| 115 |
+
originality=scores['originality'],
|
| 116 |
+
style_fit=scores['style_fit'],
|
| 117 |
+
safety=scores['safety'],
|
| 118 |
+
confidence=scores['confidence'],
|
| 119 |
+
notes=scores.get('reasoning', '')
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
ses.add(auto_score)
|
| 123 |
+
ses.commit()
|
| 124 |
+
ses.refresh(auto_score)
|
| 125 |
+
|
| 126 |
+
return auto_score
|
| 127 |
+
|
| 128 |
+
def batch_score_recent(self, hours: int = 24) -> List[AutoScore]:
|
| 129 |
+
"""Score all recently generated scripts that haven't been auto-scored"""
|
| 130 |
+
cutoff = datetime.utcnow() - timedelta(hours=hours)
|
| 131 |
+
|
| 132 |
+
with get_session() as ses:
|
| 133 |
+
# Find scripts without auto-scores
|
| 134 |
+
recent_scripts = ses.exec(
|
| 135 |
+
select(Script).where(
|
| 136 |
+
Script.created_at >= cutoff,
|
| 137 |
+
Script.source == "ai" # Only score AI-generated scripts
|
| 138 |
+
)
|
| 139 |
+
).all()
|
| 140 |
+
|
| 141 |
+
# Filter out already scored
|
| 142 |
+
unscored = []
|
| 143 |
+
for script in recent_scripts:
|
| 144 |
+
existing_score = ses.exec(
|
| 145 |
+
select(AutoScore).where(AutoScore.script_id == script.id)
|
| 146 |
+
).first()
|
| 147 |
+
if not existing_score:
|
| 148 |
+
unscored.append(script)
|
| 149 |
+
|
| 150 |
+
print(f"Auto-scoring {len(unscored)} recent scripts...")
|
| 151 |
+
|
| 152 |
+
results = []
|
| 153 |
+
for script in unscored:
|
| 154 |
+
try:
|
| 155 |
+
auto_score = self.score_and_store(script.id)
|
| 156 |
+
results.append(auto_score)
|
| 157 |
+
print(f"Scored script {script.id}: {auto_score.overall:.1f}/5.0")
|
| 158 |
+
except Exception as e:
|
| 159 |
+
print(f"Failed to score script {script.id}: {e}")
|
| 160 |
+
|
| 161 |
+
return results
|
| 162 |
+
|
| 163 |
+
class ScriptReranker:
|
| 164 |
+
"""Rerank generated scripts using composite scoring"""
|
| 165 |
+
|
| 166 |
+
def __init__(self, weights: Dict[str, float] = None):
|
| 167 |
+
self.weights = weights or {
|
| 168 |
+
'overall': 0.35,
|
| 169 |
+
'hook': 0.20,
|
| 170 |
+
'originality': 0.15,
|
| 171 |
+
'style_fit': 0.15,
|
| 172 |
+
'safety': 0.15
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
def rerank_scripts(self, script_ids: List[int]) -> List[Tuple[int, float]]:
|
| 176 |
+
"""
|
| 177 |
+
Rerank scripts by composite score
|
| 178 |
+
Returns list of (script_id, composite_score) sorted by score descending
|
| 179 |
+
"""
|
| 180 |
+
|
| 181 |
+
results = []
|
| 182 |
+
|
| 183 |
+
with get_session() as ses:
|
| 184 |
+
for script_id in script_ids:
|
| 185 |
+
# Try to get auto-score first
|
| 186 |
+
auto_score = ses.exec(
|
| 187 |
+
select(AutoScore).where(AutoScore.script_id == script_id)
|
| 188 |
+
).first()
|
| 189 |
+
|
| 190 |
+
if auto_score and auto_score.confidence >= 0.5:
|
| 191 |
+
# Use auto-scores
|
| 192 |
+
composite = (
|
| 193 |
+
self.weights['overall'] * auto_score.overall +
|
| 194 |
+
self.weights['hook'] * auto_score.hook +
|
| 195 |
+
self.weights['originality'] * auto_score.originality +
|
| 196 |
+
self.weights['style_fit'] * auto_score.style_fit +
|
| 197 |
+
self.weights['safety'] * auto_score.safety
|
| 198 |
+
)
|
| 199 |
+
else:
|
| 200 |
+
# Fall back to human ratings if available
|
| 201 |
+
script = ses.get(Script, script_id)
|
| 202 |
+
if script and script.ratings_count > 0:
|
| 203 |
+
composite = (
|
| 204 |
+
self.weights['overall'] * (script.score_overall or 3.0) +
|
| 205 |
+
self.weights['hook'] * (script.score_hook or 3.0) +
|
| 206 |
+
self.weights['originality'] * (script.score_originality or 3.0) +
|
| 207 |
+
self.weights['style_fit'] * (script.score_style_fit or 3.0) +
|
| 208 |
+
self.weights['safety'] * (script.score_safety or 3.0)
|
| 209 |
+
)
|
| 210 |
+
else:
|
| 211 |
+
# Default neutral score
|
| 212 |
+
composite = 3.0
|
| 213 |
+
|
| 214 |
+
results.append((script_id, composite))
|
| 215 |
+
|
| 216 |
+
# Sort by composite score descending
|
| 217 |
+
results.sort(key=lambda x: x[1], reverse=True)
|
| 218 |
+
return results
|
| 219 |
+
|
| 220 |
+
def get_best_script(self, script_ids: List[int]) -> int:
|
| 221 |
+
"""Get the ID of the highest-scoring script"""
|
| 222 |
+
ranked = self.rerank_scripts(script_ids)
|
| 223 |
+
return ranked[0][0] if ranked else script_ids[0]
|
| 224 |
+
|
| 225 |
+
def auto_score_pipeline():
|
| 226 |
+
"""Main pipeline to auto-score recent scripts"""
|
| 227 |
+
scorer = AutoScorer()
|
| 228 |
+
|
| 229 |
+
# Score recent scripts
|
| 230 |
+
new_scores = scorer.batch_score_recent(hours=24)
|
| 231 |
+
|
| 232 |
+
if new_scores:
|
| 233 |
+
print(f"\n📊 Auto-scoring Results ({len(new_scores)} scripts):")
|
| 234 |
+
for score in new_scores:
|
| 235 |
+
print(f"Script {score.script_id}: {score.overall:.1f}/5.0 (confidence: {score.confidence:.2f})")
|
| 236 |
+
else:
|
| 237 |
+
print("No new scripts to score.")
|
| 238 |
+
|
| 239 |
+
if __name__ == "__main__":
|
| 240 |
+
auto_score_pipeline()
|
bandit_learner.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multi-armed bandit learning system for optimizing generation policies
|
| 3 |
+
Learns which retrieval weights and generation parameters work best for each persona/content_type
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import random
|
| 8 |
+
from typing import Dict, List, Tuple, Optional
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from datetime import datetime, timedelta
|
| 11 |
+
from sqlmodel import Session, select
|
| 12 |
+
|
| 13 |
+
from models import Script, AutoScore, PolicyWeights, Rating
|
| 14 |
+
from db import get_session
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class BanditArm:
|
| 18 |
+
"""Represents one configuration of parameters to test"""
|
| 19 |
+
name: str
|
| 20 |
+
semantic_weight: float
|
| 21 |
+
bm25_weight: float
|
| 22 |
+
quality_weight: float
|
| 23 |
+
freshness_weight: float
|
| 24 |
+
temp_low: float
|
| 25 |
+
temp_mid: float
|
| 26 |
+
temp_high: float
|
| 27 |
+
|
| 28 |
+
def __post_init__(self):
|
| 29 |
+
# Ensure weights sum to 1.0
|
| 30 |
+
total = self.semantic_weight + self.bm25_weight + self.quality_weight + self.freshness_weight
|
| 31 |
+
if total != 1.0:
|
| 32 |
+
self.semantic_weight /= total
|
| 33 |
+
self.bm25_weight /= total
|
| 34 |
+
self.quality_weight /= total
|
| 35 |
+
self.freshness_weight /= total
|
| 36 |
+
|
| 37 |
+
class PolicyBandit:
|
| 38 |
+
"""Multi-armed bandit for learning optimal generation policies"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, epsilon: float = 0.15, decay_rate: float = 0.99):
|
| 41 |
+
self.epsilon = epsilon # Exploration rate
|
| 42 |
+
self.decay_rate = decay_rate # Epsilon decay over time
|
| 43 |
+
self.min_epsilon = 0.05
|
| 44 |
+
|
| 45 |
+
# Define arms (different parameter configurations)
|
| 46 |
+
self.arms = [
|
| 47 |
+
# Current default
|
| 48 |
+
BanditArm("balanced", 0.45, 0.25, 0.20, 0.10, 0.4, 0.7, 0.95),
|
| 49 |
+
|
| 50 |
+
# Semantic-heavy (focus on meaning)
|
| 51 |
+
BanditArm("semantic_heavy", 0.60, 0.15, 0.15, 0.10, 0.4, 0.7, 0.95),
|
| 52 |
+
|
| 53 |
+
# Quality-focused (use only best examples)
|
| 54 |
+
BanditArm("quality_focused", 0.35, 0.20, 0.35, 0.10, 0.3, 0.6, 0.85),
|
| 55 |
+
|
| 56 |
+
# Fresh-focused (prioritize recent trends)
|
| 57 |
+
BanditArm("fresh_focused", 0.40, 0.20, 0.15, 0.25, 0.5, 0.8, 1.0),
|
| 58 |
+
|
| 59 |
+
# Conservative (lower temperatures)
|
| 60 |
+
BanditArm("conservative", 0.45, 0.25, 0.20, 0.10, 0.3, 0.5, 0.7),
|
| 61 |
+
|
| 62 |
+
# Creative (higher temperatures)
|
| 63 |
+
BanditArm("creative", 0.45, 0.25, 0.20, 0.10, 0.6, 0.9, 1.2),
|
| 64 |
+
|
| 65 |
+
# Text-match heavy (traditional keyword matching)
|
| 66 |
+
BanditArm("text_heavy", 0.25, 0.45, 0.20, 0.10, 0.4, 0.7, 0.95)
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
# Initialize arm statistics
|
| 70 |
+
self.arm_counts = {arm.name: 0 for arm in self.arms}
|
| 71 |
+
self.arm_rewards = {arm.name: 0.0 for arm in self.arms}
|
| 72 |
+
|
| 73 |
+
def select_arm(self, persona: str, content_type: str) -> BanditArm:
|
| 74 |
+
"""Select arm using epsilon-greedy with UCB bias"""
|
| 75 |
+
|
| 76 |
+
# Load existing policy weights to initialize arm stats
|
| 77 |
+
self._load_arm_stats(persona, content_type)
|
| 78 |
+
|
| 79 |
+
# Decay epsilon over time
|
| 80 |
+
current_epsilon = max(self.min_epsilon, self.epsilon * (self.decay_rate ** sum(self.arm_counts.values())))
|
| 81 |
+
|
| 82 |
+
if random.random() < current_epsilon:
|
| 83 |
+
# Explore: random arm
|
| 84 |
+
selected_arm = random.choice(self.arms)
|
| 85 |
+
print(f"🔄 Exploring with {selected_arm.name} policy (ε={current_epsilon:.3f})")
|
| 86 |
+
else:
|
| 87 |
+
# Exploit: best arm with UCB confidence bounds
|
| 88 |
+
selected_arm = self._select_best_arm_ucb()
|
| 89 |
+
print(f"⭐ Exploiting with {selected_arm.name} policy")
|
| 90 |
+
|
| 91 |
+
return selected_arm
|
| 92 |
+
|
| 93 |
+
def _select_best_arm_ucb(self) -> BanditArm:
|
| 94 |
+
"""Select arm using Upper Confidence Bound"""
|
| 95 |
+
total_counts = sum(self.arm_counts.values())
|
| 96 |
+
if total_counts == 0:
|
| 97 |
+
return self.arms[0] # Default to first arm
|
| 98 |
+
|
| 99 |
+
best_arm = None
|
| 100 |
+
best_score = float('-inf')
|
| 101 |
+
|
| 102 |
+
for arm in self.arms:
|
| 103 |
+
count = self.arm_counts[arm.name]
|
| 104 |
+
if count == 0:
|
| 105 |
+
return arm # Always try unplayed arms first
|
| 106 |
+
|
| 107 |
+
# UCB score = average reward + confidence interval
|
| 108 |
+
avg_reward = self.arm_rewards[arm.name] / count
|
| 109 |
+
confidence = np.sqrt(2 * np.log(total_counts) / count)
|
| 110 |
+
ucb_score = avg_reward + confidence
|
| 111 |
+
|
| 112 |
+
if ucb_score > best_score:
|
| 113 |
+
best_score = ucb_score
|
| 114 |
+
best_arm = arm
|
| 115 |
+
|
| 116 |
+
return best_arm or self.arms[0]
|
| 117 |
+
|
| 118 |
+
def _load_arm_stats(self, persona: str, content_type: str):
|
| 119 |
+
"""Load historical performance for this persona/content_type"""
|
| 120 |
+
with get_session() as ses:
|
| 121 |
+
policy = ses.exec(
|
| 122 |
+
select(PolicyWeights).where(
|
| 123 |
+
PolicyWeights.persona == persona,
|
| 124 |
+
PolicyWeights.content_type == content_type
|
| 125 |
+
)
|
| 126 |
+
).first()
|
| 127 |
+
|
| 128 |
+
if policy:
|
| 129 |
+
# Find matching arm and update stats
|
| 130 |
+
for arm in self.arms:
|
| 131 |
+
if self._arm_matches_policy(arm, policy):
|
| 132 |
+
self.arm_counts[arm.name] = policy.total_generations
|
| 133 |
+
self.arm_rewards[arm.name] = policy.success_rate * policy.total_generations
|
| 134 |
+
break
|
| 135 |
+
|
| 136 |
+
def _arm_matches_policy(self, arm: BanditArm, policy: PolicyWeights, tolerance: float = 0.05) -> bool:
|
| 137 |
+
"""Check if an arm matches the stored policy within tolerance"""
|
| 138 |
+
return (
|
| 139 |
+
abs(arm.semantic_weight - policy.semantic_weight) < tolerance and
|
| 140 |
+
abs(arm.bm25_weight - policy.bm25_weight) < tolerance and
|
| 141 |
+
abs(arm.quality_weight - policy.quality_weight) < tolerance and
|
| 142 |
+
abs(arm.freshness_weight - policy.freshness_weight) < tolerance
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
def update_reward(self,
|
| 146 |
+
arm: BanditArm,
|
| 147 |
+
reward: float,
|
| 148 |
+
persona: str,
|
| 149 |
+
content_type: str,
|
| 150 |
+
script_id: int):
|
| 151 |
+
"""Update arm performance with new reward signal"""
|
| 152 |
+
|
| 153 |
+
# Update in-memory stats
|
| 154 |
+
self.arm_counts[arm.name] += 1
|
| 155 |
+
self.arm_rewards[arm.name] += reward
|
| 156 |
+
|
| 157 |
+
# Update database policy
|
| 158 |
+
self._update_policy_weights(arm, reward, persona, content_type)
|
| 159 |
+
|
| 160 |
+
print(f"📈 Updated {arm.name}: reward={reward:.3f}, avg={self.arm_rewards[arm.name]/self.arm_counts[arm.name]:.3f}")
|
| 161 |
+
|
| 162 |
+
def _update_policy_weights(self,
|
| 163 |
+
arm: BanditArm,
|
| 164 |
+
reward: float,
|
| 165 |
+
persona: str,
|
| 166 |
+
content_type: str):
|
| 167 |
+
"""Update policy weights in database"""
|
| 168 |
+
with get_session() as ses:
|
| 169 |
+
policy = ses.exec(
|
| 170 |
+
select(PolicyWeights).where(
|
| 171 |
+
PolicyWeights.persona == persona,
|
| 172 |
+
PolicyWeights.content_type == content_type
|
| 173 |
+
)
|
| 174 |
+
).first()
|
| 175 |
+
|
| 176 |
+
if not policy:
|
| 177 |
+
# Create new policy
|
| 178 |
+
policy = PolicyWeights(
|
| 179 |
+
persona=persona,
|
| 180 |
+
content_type=content_type,
|
| 181 |
+
semantic_weight=arm.semantic_weight,
|
| 182 |
+
bm25_weight=arm.bm25_weight,
|
| 183 |
+
quality_weight=arm.quality_weight,
|
| 184 |
+
freshness_weight=arm.freshness_weight,
|
| 185 |
+
temp_low=arm.temp_low,
|
| 186 |
+
temp_mid=arm.temp_mid,
|
| 187 |
+
temp_high=arm.temp_high,
|
| 188 |
+
total_generations=1,
|
| 189 |
+
success_rate=reward
|
| 190 |
+
)
|
| 191 |
+
else:
|
| 192 |
+
# Update existing policy with exponential moving average
|
| 193 |
+
alpha = 0.1 # Learning rate
|
| 194 |
+
policy.success_rate = (1 - alpha) * policy.success_rate + alpha * reward
|
| 195 |
+
policy.total_generations += 1
|
| 196 |
+
|
| 197 |
+
# If this arm is performing well, shift weights toward it
|
| 198 |
+
if reward > policy.success_rate:
|
| 199 |
+
shift = 0.05 # Small shift toward better performing arm
|
| 200 |
+
policy.semantic_weight = (1 - shift) * policy.semantic_weight + shift * arm.semantic_weight
|
| 201 |
+
policy.bm25_weight = (1 - shift) * policy.bm25_weight + shift * arm.bm25_weight
|
| 202 |
+
policy.quality_weight = (1 - shift) * policy.quality_weight + shift * arm.quality_weight
|
| 203 |
+
policy.freshness_weight = (1 - shift) * policy.freshness_weight + shift * arm.freshness_weight
|
| 204 |
+
|
| 205 |
+
policy.temp_low = (1 - shift) * policy.temp_low + shift * arm.temp_low
|
| 206 |
+
policy.temp_mid = (1 - shift) * policy.temp_mid + shift * arm.temp_mid
|
| 207 |
+
policy.temp_high = (1 - shift) * policy.temp_high + shift * arm.temp_high
|
| 208 |
+
|
| 209 |
+
policy.updated_at = datetime.utcnow()
|
| 210 |
+
ses.add(policy)
|
| 211 |
+
ses.commit()
|
| 212 |
+
|
| 213 |
+
def calculate_reward(self, script_id: int) -> float:
|
| 214 |
+
"""
|
| 215 |
+
Calculate reward signal from script performance
|
| 216 |
+
Combines auto-scores and human ratings when available
|
| 217 |
+
"""
|
| 218 |
+
reward_components = []
|
| 219 |
+
|
| 220 |
+
with get_session() as ses:
|
| 221 |
+
# Get auto-score
|
| 222 |
+
auto_score = ses.exec(
|
| 223 |
+
select(AutoScore).where(AutoScore.script_id == script_id)
|
| 224 |
+
).first()
|
| 225 |
+
|
| 226 |
+
if auto_score and auto_score.confidence > 0.5:
|
| 227 |
+
# Weighted composite of auto-scores
|
| 228 |
+
auto_reward = (
|
| 229 |
+
0.35 * auto_score.overall +
|
| 230 |
+
0.20 * auto_score.hook +
|
| 231 |
+
0.15 * auto_score.originality +
|
| 232 |
+
0.15 * auto_score.style_fit +
|
| 233 |
+
0.15 * auto_score.safety
|
| 234 |
+
) / 5.0 # Normalize to 0-1
|
| 235 |
+
|
| 236 |
+
reward_components.append(('auto', auto_reward, auto_score.confidence))
|
| 237 |
+
|
| 238 |
+
# Get human ratings
|
| 239 |
+
script = ses.get(Script, script_id)
|
| 240 |
+
if script and script.ratings_count > 0:
|
| 241 |
+
human_reward = script.score_overall / 5.0 # Normalize to 0-1
|
| 242 |
+
confidence = min(1.0, script.ratings_count / 3.0) # More ratings = higher confidence
|
| 243 |
+
reward_components.append(('human', human_reward, confidence))
|
| 244 |
+
|
| 245 |
+
if not reward_components:
|
| 246 |
+
return 0.5 # Neutral reward if no scores available
|
| 247 |
+
|
| 248 |
+
# Weighted average of reward components by confidence
|
| 249 |
+
total_weight = sum(confidence for _, _, confidence in reward_components)
|
| 250 |
+
weighted_reward = sum(
|
| 251 |
+
reward * confidence for _, reward, confidence in reward_components
|
| 252 |
+
) / total_weight
|
| 253 |
+
|
| 254 |
+
return weighted_reward
|
| 255 |
+
|
| 256 |
+
class PolicyLearner:
|
| 257 |
+
"""Main interface for policy learning"""
|
| 258 |
+
|
| 259 |
+
def __init__(self):
|
| 260 |
+
self.bandit = PolicyBandit()
|
| 261 |
+
|
| 262 |
+
def learn_from_generation_batch(self,
|
| 263 |
+
persona: str,
|
| 264 |
+
content_type: str,
|
| 265 |
+
generated_script_ids: List[int],
|
| 266 |
+
selected_arm: BanditArm):
|
| 267 |
+
"""Learn from a batch of generated scripts"""
|
| 268 |
+
|
| 269 |
+
if not generated_script_ids:
|
| 270 |
+
return
|
| 271 |
+
|
| 272 |
+
# Calculate average reward from the batch
|
| 273 |
+
rewards = [self.bandit.calculate_reward(sid) for sid in generated_script_ids]
|
| 274 |
+
avg_reward = sum(rewards) / len(rewards)
|
| 275 |
+
|
| 276 |
+
# Update bandit with average performance
|
| 277 |
+
self.bandit.update_reward(
|
| 278 |
+
selected_arm,
|
| 279 |
+
avg_reward,
|
| 280 |
+
persona,
|
| 281 |
+
content_type,
|
| 282 |
+
generated_script_ids[0] # Representative script ID
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
print(f"🧠 Policy learning: {persona}/{content_type} → {avg_reward:.3f} reward")
|
| 286 |
+
|
| 287 |
+
def get_optimized_policy(self, persona: str, content_type: str) -> BanditArm:
|
| 288 |
+
"""Get the current best policy for this persona/content_type"""
|
| 289 |
+
return self.bandit.select_arm(persona, content_type)
|
| 290 |
+
|
| 291 |
+
def run_learning_cycle(self):
|
| 292 |
+
"""Run a learning cycle on recent generations"""
|
| 293 |
+
print("🔄 Starting policy learning cycle...")
|
| 294 |
+
|
| 295 |
+
# Find recent AI-generated scripts by persona/content_type
|
| 296 |
+
cutoff = datetime.utcnow() - timedelta(hours=24)
|
| 297 |
+
|
| 298 |
+
with get_session() as ses:
|
| 299 |
+
recent_scripts = list(ses.exec(
|
| 300 |
+
select(Script).where(
|
| 301 |
+
Script.created_at >= cutoff,
|
| 302 |
+
Script.source == "ai"
|
| 303 |
+
)
|
| 304 |
+
))
|
| 305 |
+
|
| 306 |
+
# Group by persona/content_type
|
| 307 |
+
groups = {}
|
| 308 |
+
for script in recent_scripts:
|
| 309 |
+
key = (script.creator, script.content_type)
|
| 310 |
+
if key not in groups:
|
| 311 |
+
groups[key] = []
|
| 312 |
+
groups[key].append(script.id)
|
| 313 |
+
|
| 314 |
+
# Learn from each group
|
| 315 |
+
for (persona, content_type), script_ids in groups.items():
|
| 316 |
+
if len(script_ids) >= 3: # Need minimum batch size
|
| 317 |
+
# For now, assume they used the balanced policy
|
| 318 |
+
# In practice, you'd track which policy was used for each generation
|
| 319 |
+
balanced_arm = next(arm for arm in self.bandit.arms if arm.name == "balanced")
|
| 320 |
+
self.learn_from_generation_batch(persona, content_type, script_ids, balanced_arm)
|
| 321 |
+
|
| 322 |
+
def run_policy_learning():
|
| 323 |
+
"""Main entry point for policy learning"""
|
| 324 |
+
learner = PolicyLearner()
|
| 325 |
+
learner.run_learning_cycle()
|
| 326 |
+
|
| 327 |
+
if __name__ == "__main__":
|
| 328 |
+
run_policy_learning()
|
| 329 |
+
|
| 330 |
+
|
compliance.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
BANNED = {r"\b(naked|explicit|porn|onlyfans\.com)\b"}
|
| 4 |
+
CAUTION = {r"\b(hot|naughty|spicy|thirsty)\b"}
|
| 5 |
+
|
| 6 |
+
def compliance_level(text: str):
|
| 7 |
+
low = text.lower()
|
| 8 |
+
for pat in BANNED:
|
| 9 |
+
if re.search(pat, low):
|
| 10 |
+
return "fail", ["banned phrase"]
|
| 11 |
+
reasons = []
|
| 12 |
+
for pat in CAUTION:
|
| 13 |
+
if re.search(pat, low):
|
| 14 |
+
reasons.append("caution phrase")
|
| 15 |
+
return ("warn" if reasons else "pass"), reasons
|
| 16 |
+
|
| 17 |
+
def score_script(blob: str):
|
| 18 |
+
return compliance_level(blob)
|
| 19 |
+
|
| 20 |
+
def blob_from(script: dict) -> str:
|
| 21 |
+
parts = [
|
| 22 |
+
script.get("title",""), script.get("hook",""),
|
| 23 |
+
" ".join(script.get("beats",[])),
|
| 24 |
+
script.get("voiceover",""), script.get("caption",""), script.get("cta","")
|
| 25 |
+
]
|
| 26 |
+
return " ".join(parts)
|
db.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# db.py
|
| 2 |
+
import os, json, random
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
from typing import List, Iterable, Tuple, Optional
|
| 5 |
+
from sqlmodel import SQLModel, create_engine, Session, select
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
# ---- Configure DB ----
|
| 9 |
+
DB_URL = os.environ.get("DB_URL", "sqlite:///studio.db")
|
| 10 |
+
engine = create_engine(DB_URL, echo=False)
|
| 11 |
+
|
| 12 |
+
# ---- Models ----
|
| 13 |
+
from models import Script, Rating # make sure Script has: is_reference: bool, plus the other fields
|
| 14 |
+
|
| 15 |
+
# ---- Init / Session ----
|
| 16 |
+
def init_db() -> None:
|
| 17 |
+
SQLModel.metadata.create_all(engine)
|
| 18 |
+
|
| 19 |
+
@contextmanager
|
| 20 |
+
def get_session():
|
| 21 |
+
with Session(engine) as ses:
|
| 22 |
+
yield ses
|
| 23 |
+
|
| 24 |
+
# ---- Helpers for import ----
|
| 25 |
+
|
| 26 |
+
def _payload_from_jsonl_row(row: dict) -> Tuple[dict, str, str]:
|
| 27 |
+
"""
|
| 28 |
+
Map a JSONL row (the file I generated for you) into Script columns.
|
| 29 |
+
Returns (payload, dedupe_key_title, dedupe_key_creator).
|
| 30 |
+
You can also add 'external_id' to Script model and dedupe on that.
|
| 31 |
+
"""
|
| 32 |
+
# Prefer using the JSON 'id' as an external identifier:
|
| 33 |
+
external_id = row.get("id", "")
|
| 34 |
+
|
| 35 |
+
# Tone could be an array; flatten for now
|
| 36 |
+
tone = ", ".join(row.get("tonality", [])) or "playful"
|
| 37 |
+
|
| 38 |
+
# Compact caption: use caption options line as a quick reference
|
| 39 |
+
caption = " | ".join(row.get("caption_options", []))[:180]
|
| 40 |
+
|
| 41 |
+
payload = dict(
|
| 42 |
+
# core identity
|
| 43 |
+
creator=row.get("model_name", "Unknown"),
|
| 44 |
+
content_type=(row.get("video_type", "") or "talking_style").lower(),
|
| 45 |
+
tone=tone,
|
| 46 |
+
title=external_id or row.get("theme", "") or "Imported Script",
|
| 47 |
+
hook=row.get("video_hook") or "",
|
| 48 |
+
|
| 49 |
+
# structured fields
|
| 50 |
+
beats=row.get("storyboard", []) or [],
|
| 51 |
+
voiceover="",
|
| 52 |
+
caption=caption,
|
| 53 |
+
hashtags=row.get("hashtags", []) or [],
|
| 54 |
+
cta="",
|
| 55 |
+
|
| 56 |
+
# flags
|
| 57 |
+
source="import",
|
| 58 |
+
is_reference=True, # mark imported examples as references
|
| 59 |
+
compliance="pass", # we'll score again after save if you want
|
| 60 |
+
)
|
| 61 |
+
return payload, payload["title"], payload["creator"]
|
| 62 |
+
|
| 63 |
+
def _score_and_update_compliance(s: Script) -> None:
|
| 64 |
+
"""Optional: score compliance using your simple rule-checker."""
|
| 65 |
+
try:
|
| 66 |
+
from compliance import blob_from, score_script
|
| 67 |
+
lvl, _ = score_script(blob_from(s.dict()))
|
| 68 |
+
s.compliance = lvl
|
| 69 |
+
except Exception:
|
| 70 |
+
# If no compliance module or error, keep default
|
| 71 |
+
pass
|
| 72 |
+
|
| 73 |
+
def _iter_jsonl(path: str) -> Iterable[dict]:
|
| 74 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 75 |
+
for line in f:
|
| 76 |
+
line = line.strip()
|
| 77 |
+
if not line:
|
| 78 |
+
continue
|
| 79 |
+
yield json.loads(line)
|
| 80 |
+
|
| 81 |
+
# ---- Public: Importer ----
|
| 82 |
+
def import_jsonl(path: str) -> int:
|
| 83 |
+
"""
|
| 84 |
+
Import (upsert) scripts from a JSONL file produced earlier.
|
| 85 |
+
Dedupe by (creator, title). Returns count of upserted rows.
|
| 86 |
+
"""
|
| 87 |
+
init_db()
|
| 88 |
+
count = 0
|
| 89 |
+
with get_session() as ses:
|
| 90 |
+
for row in _iter_jsonl(path):
|
| 91 |
+
payload, key_title, key_creator = _payload_from_jsonl_row(row)
|
| 92 |
+
|
| 93 |
+
existing = ses.exec(
|
| 94 |
+
select(Script).where(
|
| 95 |
+
Script.title == key_title,
|
| 96 |
+
Script.creator == key_creator
|
| 97 |
+
)
|
| 98 |
+
).first()
|
| 99 |
+
|
| 100 |
+
if existing:
|
| 101 |
+
# Update all fields
|
| 102 |
+
for k, v in payload.items():
|
| 103 |
+
setattr(existing, k, v)
|
| 104 |
+
_score_and_update_compliance(existing)
|
| 105 |
+
existing.updated_at = datetime.utcnow()
|
| 106 |
+
ses.add(existing)
|
| 107 |
+
else:
|
| 108 |
+
obj = Script(**payload)
|
| 109 |
+
_score_and_update_compliance(obj)
|
| 110 |
+
ses.add(obj)
|
| 111 |
+
|
| 112 |
+
count += 1
|
| 113 |
+
ses.commit()
|
| 114 |
+
return count
|
| 115 |
+
|
| 116 |
+
# ---- Ratings API ----
|
| 117 |
+
def add_rating(script_id: int,
|
| 118 |
+
overall: float,
|
| 119 |
+
hook: Optional[float] = None,
|
| 120 |
+
originality: Optional[float] = None,
|
| 121 |
+
style_fit: Optional[float] = None,
|
| 122 |
+
safety: Optional[float] = None,
|
| 123 |
+
notes: Optional[str] = None,
|
| 124 |
+
rater: str = "human") -> None:
|
| 125 |
+
with get_session() as ses:
|
| 126 |
+
# store rating event
|
| 127 |
+
ses.add(Rating(
|
| 128 |
+
script_id=script_id, overall=overall, hook=hook,
|
| 129 |
+
originality=originality, style_fit=style_fit, safety=safety,
|
| 130 |
+
notes=notes, rater=rater
|
| 131 |
+
))
|
| 132 |
+
ses.commit()
|
| 133 |
+
# recompute cached aggregates on Script
|
| 134 |
+
_recompute_script_aggregates(ses, script_id)
|
| 135 |
+
ses.commit()
|
| 136 |
+
|
| 137 |
+
def _recompute_script_aggregates(ses: Session, script_id: int) -> None:
|
| 138 |
+
rows = list(ses.exec(select(Rating).where(Rating.script_id == script_id)))
|
| 139 |
+
if not rows:
|
| 140 |
+
return
|
| 141 |
+
def avg(field):
|
| 142 |
+
vals = [getattr(r, field) for r in rows if getattr(r, field) is not None]
|
| 143 |
+
return round(sum(vals)/len(vals), 3) if vals else None
|
| 144 |
+
s: Script = ses.get(Script, script_id)
|
| 145 |
+
s.score_overall = avg("overall")
|
| 146 |
+
s.score_hook = avg("hook")
|
| 147 |
+
s.score_originality = avg("originality")
|
| 148 |
+
s.score_style_fit = avg("style_fit")
|
| 149 |
+
s.score_safety = avg("safety")
|
| 150 |
+
s.ratings_count = len(rows)
|
| 151 |
+
s.updated_at = datetime.utcnow()
|
| 152 |
+
ses.add(s)
|
| 153 |
+
|
| 154 |
+
# ---- Public: Reference retrieval for generation ----
|
| 155 |
+
def extract_snippets_from_script(s: Script, max_lines: int = 3) -> List[str]:
|
| 156 |
+
items: List[str] = []
|
| 157 |
+
if s.hook:
|
| 158 |
+
items.append(s.hook.strip())
|
| 159 |
+
if s.beats:
|
| 160 |
+
items.extend([b.strip() for b in s.beats[:2]]) # first 1–2 beats
|
| 161 |
+
if s.caption:
|
| 162 |
+
items.append(s.caption.strip()[:120])
|
| 163 |
+
# dedupe while preserving order
|
| 164 |
+
seen, uniq = set(), []
|
| 165 |
+
for it in items:
|
| 166 |
+
if it and it not in seen:
|
| 167 |
+
uniq.append(it); seen.add(it)
|
| 168 |
+
return uniq[:max_lines]
|
| 169 |
+
|
| 170 |
+
def get_library_refs(creator: str, content_type: str, k: int = 6) -> List[str]:
|
| 171 |
+
with get_session() as ses:
|
| 172 |
+
rows = list(ses.exec(
|
| 173 |
+
select(Script)
|
| 174 |
+
.where(
|
| 175 |
+
Script.creator == creator,
|
| 176 |
+
Script.content_type == content_type,
|
| 177 |
+
Script.is_reference == True,
|
| 178 |
+
Script.compliance != "fail"
|
| 179 |
+
)
|
| 180 |
+
.order_by(Script.created_at.desc())
|
| 181 |
+
))[:k]
|
| 182 |
+
|
| 183 |
+
snippets: List[str] = []
|
| 184 |
+
for r in rows:
|
| 185 |
+
snippets.extend(extract_snippets_from_script(r))
|
| 186 |
+
# final dedupe
|
| 187 |
+
seen, uniq = set(), []
|
| 188 |
+
for s in snippets:
|
| 189 |
+
if s not in seen:
|
| 190 |
+
uniq.append(s); seen.add(s)
|
| 191 |
+
return uniq[:8]
|
| 192 |
+
|
| 193 |
+
# ---- HYBRID reference retrieval ----
|
| 194 |
+
def get_hybrid_refs(creator: str, content_type: str, k: int = 6,
|
| 195 |
+
top_n: int = 3, explore_n: int = 2, newest_n: int = 1) -> List[str]:
|
| 196 |
+
"""
|
| 197 |
+
Mix of:
|
| 198 |
+
- top_n best scored references (exploit)
|
| 199 |
+
- explore_n random references (explore)
|
| 200 |
+
- newest_n most recent references (freshness)
|
| 201 |
+
Returns flattened snippet list (cap ~8 to keep prompt lean).
|
| 202 |
+
"""
|
| 203 |
+
with get_session() as ses:
|
| 204 |
+
all_refs = list(ses.exec(
|
| 205 |
+
select(Script).where(
|
| 206 |
+
Script.creator == creator,
|
| 207 |
+
Script.content_type == content_type,
|
| 208 |
+
Script.is_reference == True,
|
| 209 |
+
Script.compliance != "fail"
|
| 210 |
+
)
|
| 211 |
+
))
|
| 212 |
+
|
| 213 |
+
if not all_refs:
|
| 214 |
+
return []
|
| 215 |
+
|
| 216 |
+
# sort by score_overall (fallback to 0) and pick top_n
|
| 217 |
+
scored = sorted(all_refs, key=lambda s: (s.score_overall or 0.0), reverse=True)
|
| 218 |
+
best = scored[:top_n]
|
| 219 |
+
|
| 220 |
+
# newest by created_at
|
| 221 |
+
newest = sorted(all_refs, key=lambda s: s.created_at, reverse=True)[:newest_n]
|
| 222 |
+
|
| 223 |
+
# explore = random sample from the remainder
|
| 224 |
+
remainder = [r for r in all_refs if r not in best and r not in newest]
|
| 225 |
+
explore = random.sample(remainder, min(explore_n, len(remainder))) if remainder else []
|
| 226 |
+
|
| 227 |
+
# merge (preserve order, dedupe)
|
| 228 |
+
chosen_scripts = []
|
| 229 |
+
seen_ids = set()
|
| 230 |
+
for bucket in (best, explore, newest):
|
| 231 |
+
for s in bucket:
|
| 232 |
+
if s.id not in seen_ids:
|
| 233 |
+
chosen_scripts.append(s)
|
| 234 |
+
seen_ids.add(s.id)
|
| 235 |
+
|
| 236 |
+
# cut to k scripts
|
| 237 |
+
chosen_scripts = chosen_scripts[:k]
|
| 238 |
+
|
| 239 |
+
# flatten snippets and cap to keep prompt compact
|
| 240 |
+
snippets: List[str] = []
|
| 241 |
+
for s in chosen_scripts:
|
| 242 |
+
snippets.extend(extract_snippets_from_script(s))
|
| 243 |
+
# dedupe again and cap ~8 lines
|
| 244 |
+
seen, out = set(), []
|
| 245 |
+
for sn in snippets:
|
| 246 |
+
if sn not in seen:
|
| 247 |
+
out.append(sn); seen.add(sn)
|
| 248 |
+
return out[:8]
|
deepseek_client.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, requests, json
|
| 2 |
+
import streamlit as st
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
|
| 5 |
+
load_dotenv()
|
| 6 |
+
|
| 7 |
+
# Get API key from Streamlit secrets or environment
|
| 8 |
+
def get_api_key():
|
| 9 |
+
if hasattr(st, 'secrets') and "DEEPSEEK_API_KEY" in st.secrets:
|
| 10 |
+
return st.secrets["DEEPSEEK_API_KEY"]
|
| 11 |
+
return os.getenv("DEEPSEEK_API_KEY")
|
| 12 |
+
|
| 13 |
+
DEEPSEEK_API_KEY = get_api_key()
|
| 14 |
+
BASE = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
|
| 15 |
+
|
| 16 |
+
def chat(messages, model="deepseek-chat", temperature=0.9):
|
| 17 |
+
headers = {"Authorization": f"Bearer {DEEPSEEK_API_KEY}", "Content-Type": "application/json"}
|
| 18 |
+
payload = {"model": model, "messages": messages, "temperature": temperature}
|
| 19 |
+
r = requests.post(f"{BASE}/chat/completions", headers=headers, data=json.dumps(payload), timeout=60)
|
| 20 |
+
r.raise_for_status()
|
| 21 |
+
return r.json()["choices"][0]["message"]["content"]
|
| 22 |
+
|
| 23 |
+
def generate_scripts(persona, boundaries, content_type, tone, refs, n=6):
|
| 24 |
+
system = (
|
| 25 |
+
"You write Instagram-compliant, suggestive-but-not-explicit Reels briefs. "
|
| 26 |
+
"Use tight hooks, concrete visual beats, clear CTAs. Avoid explicit sexual terms. "
|
| 27 |
+
"Return ONLY JSON: an array of length N, each with {title,hook,beats,voiceover,caption,hashtags,cta}."
|
| 28 |
+
)
|
| 29 |
+
user = f"""
|
| 30 |
+
Persona: {persona}
|
| 31 |
+
Boundaries: {boundaries}
|
| 32 |
+
Content type: {content_type} | Tone: {tone} | Duration: 15–25s
|
| 33 |
+
Reference snippets (inspire, don't copy):
|
| 34 |
+
{chr(10).join(f"- {r}" for r in refs)}
|
| 35 |
+
|
| 36 |
+
N = {n}
|
| 37 |
+
JSON array ONLY.
|
| 38 |
+
"""
|
| 39 |
+
out = chat([{"role":"system","content":system},{"role":"user","content":user}])
|
| 40 |
+
# Be lenient if model wraps JSON with text
|
| 41 |
+
start = out.find("[")
|
| 42 |
+
end = out.rfind("]")
|
| 43 |
+
return json.loads(out[start:end+1])
|
| 44 |
+
|
| 45 |
+
def revise_for(prompt_label, draft: dict, guidance: str):
|
| 46 |
+
system = f"You revise scripts to {prompt_label}. Keep intent; return ONLY JSON with the same schema."
|
| 47 |
+
user = json.dumps({"draft": draft, "guidance": guidance})
|
| 48 |
+
out = chat([{"role":"system","content":system},{"role":"user","content":user}], temperature=0.6)
|
| 49 |
+
start = out.find("{")
|
| 50 |
+
end = out.rfind("}")
|
| 51 |
+
return json.loads(out[start:end+1])
|
| 52 |
+
|
| 53 |
+
def selective_rewrite(draft: dict, field: str, snippet: str, prompt: str):
|
| 54 |
+
system = "You rewrite only the targeted snippet inside the specified field. Keep style. Return ONLY JSON."
|
| 55 |
+
user = json.dumps({"field": field, "snippet": snippet, "prompt": prompt, "draft": draft})
|
| 56 |
+
out = chat([{"role":"system","content":system},{"role":"user","content":user}], temperature=0.7)
|
| 57 |
+
start = out.find("{")
|
| 58 |
+
end = out.rfind("}")
|
| 59 |
+
return json.loads(out[start:end+1])
|
models.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
from sqlmodel import SQLModel, Field, Column
|
| 4 |
+
from sqlalchemy import JSON
|
| 5 |
+
|
| 6 |
+
class Script(SQLModel, table=True, extend_existing=True):
|
| 7 |
+
id: Optional[int] = Field(default=None, primary_key=True)
|
| 8 |
+
creator: str
|
| 9 |
+
content_type: str
|
| 10 |
+
tone: str
|
| 11 |
+
title: str
|
| 12 |
+
hook: str
|
| 13 |
+
beats: List[str] = Field(sa_column=Column(JSON))
|
| 14 |
+
voiceover: str
|
| 15 |
+
caption: str
|
| 16 |
+
hashtags: List[str] = Field(sa_column=Column(JSON))
|
| 17 |
+
cta: str
|
| 18 |
+
compliance: str = "pass" # pass | warn | fail
|
| 19 |
+
source: str = "ai" # ai | manual | import
|
| 20 |
+
is_reference: bool = False # mark imported examples as references
|
| 21 |
+
|
| 22 |
+
# --- NEW: cached aggregates from ratings (all optional) ---
|
| 23 |
+
score_overall: Optional[float] = None # 1..5 (avg)
|
| 24 |
+
score_hook: Optional[float] = None # 1..5 (avg)
|
| 25 |
+
score_originality: Optional[float] = None # 1..5 (avg)
|
| 26 |
+
score_style_fit: Optional[float] = None # 1..5 (avg)
|
| 27 |
+
score_safety: Optional[float] = None # 1..5 (avg)
|
| 28 |
+
ratings_count: int = 0
|
| 29 |
+
|
| 30 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 31 |
+
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
| 32 |
+
|
| 33 |
+
class Revision(SQLModel, table=True, extend_existing=True):
|
| 34 |
+
id: Optional[int] = Field(default=None, primary_key=True)
|
| 35 |
+
script_id: int = Field(index=True)
|
| 36 |
+
label: str
|
| 37 |
+
field: str
|
| 38 |
+
before: str
|
| 39 |
+
after: str
|
| 40 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 41 |
+
|
| 42 |
+
# NEW: store every rating event so you keep history
|
| 43 |
+
class Rating(SQLModel, table=True, extend_existing=True):
|
| 44 |
+
id: Optional[int] = Field(default=None, primary_key=True)
|
| 45 |
+
script_id: int = Field(index=True)
|
| 46 |
+
rater: str = "human" # optional: store user/email
|
| 47 |
+
overall: float # 1..5
|
| 48 |
+
hook: Optional[float] = None
|
| 49 |
+
originality: Optional[float] = None
|
| 50 |
+
style_fit: Optional[float] = None
|
| 51 |
+
safety: Optional[float] = None
|
| 52 |
+
notes: Optional[str] = None
|
| 53 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 54 |
+
|
| 55 |
+
# RAG Enhancement Models
|
| 56 |
+
class Embedding(SQLModel, table=True, extend_existing=True):
|
| 57 |
+
id: Optional[int] = Field(default=None, primary_key=True)
|
| 58 |
+
script_id: int = Field(index=True)
|
| 59 |
+
part: str = Field(index=True) # 'full', 'hook', 'beats', 'caption'
|
| 60 |
+
vector: List[float] = Field(sa_column=Column(JSON))
|
| 61 |
+
meta: dict = Field(sa_column=Column(JSON))
|
| 62 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 63 |
+
|
| 64 |
+
class AutoScore(SQLModel, table=True, extend_existing=True):
|
| 65 |
+
id: Optional[int] = Field(default=None, primary_key=True)
|
| 66 |
+
script_id: int = Field(index=True)
|
| 67 |
+
overall: float
|
| 68 |
+
hook: float
|
| 69 |
+
originality: float
|
| 70 |
+
style_fit: float
|
| 71 |
+
safety: float
|
| 72 |
+
confidence: float = 0.8 # LLM judge confidence
|
| 73 |
+
notes: Optional[str] = None
|
| 74 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 75 |
+
|
| 76 |
+
class PolicyWeights(SQLModel, table=True, extend_existing=True):
|
| 77 |
+
id: Optional[int] = Field(default=None, primary_key=True)
|
| 78 |
+
persona: str = Field(index=True)
|
| 79 |
+
content_type: str = Field(index=True)
|
| 80 |
+
# Retrieval weights
|
| 81 |
+
semantic_weight: float = 0.45
|
| 82 |
+
bm25_weight: float = 0.25
|
| 83 |
+
quality_weight: float = 0.20
|
| 84 |
+
freshness_weight: float = 0.10
|
| 85 |
+
# Generation params
|
| 86 |
+
temp_low: float = 0.4
|
| 87 |
+
temp_mid: float = 0.7
|
| 88 |
+
temp_high: float = 0.95
|
| 89 |
+
# Performance tracking
|
| 90 |
+
success_rate: float = 0.0
|
| 91 |
+
total_generations: int = 0
|
| 92 |
+
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
| 93 |
+
|
| 94 |
+
class StyleCard(SQLModel, table=True, extend_existing=True):
|
| 95 |
+
id: Optional[int] = Field(default=None, primary_key=True)
|
| 96 |
+
persona: str = Field(index=True)
|
| 97 |
+
content_type: str = Field(index=True)
|
| 98 |
+
exemplar_hooks: List[str] = Field(sa_column=Column(JSON))
|
| 99 |
+
exemplar_beats: List[str] = Field(sa_column=Column(JSON))
|
| 100 |
+
exemplar_captions: List[str] = Field(sa_column=Column(JSON))
|
| 101 |
+
negative_patterns: List[str] = Field(sa_column=Column(JSON))
|
| 102 |
+
constraints: dict = Field(sa_column=Column(JSON))
|
| 103 |
+
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
packages.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
ffmpeg
|
rag_integration.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Integration layer between the existing system and new RAG capabilities
|
| 3 |
+
Shows how to plug the enhanced system into the current workflow
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import List, Dict, Any, Optional
|
| 7 |
+
import json
|
| 8 |
+
from sqlmodel import Session
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
from models import Script, Embedding, AutoScore, PolicyWeights
|
| 12 |
+
from db import get_session, init_db
|
| 13 |
+
from deepseek_client import chat, get_api_key
|
| 14 |
+
from rag_retrieval import RAGRetriever
|
| 15 |
+
from auto_scorer import AutoScorer, ScriptReranker
|
| 16 |
+
from bandit_learner import PolicyLearner
|
| 17 |
+
|
| 18 |
+
class EnhancedScriptGenerator:
|
| 19 |
+
"""
|
| 20 |
+
Enhanced version of script generation with RAG + policy learning
|
| 21 |
+
Drop-in replacement for the existing generate_scripts function
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self):
|
| 25 |
+
self.retriever = RAGRetriever()
|
| 26 |
+
self.scorer = AutoScorer()
|
| 27 |
+
self.reranker = ScriptReranker()
|
| 28 |
+
self.policy_learner = PolicyLearner()
|
| 29 |
+
|
| 30 |
+
# Verify we have API key
|
| 31 |
+
if not get_api_key():
|
| 32 |
+
raise ValueError("DeepSeek API key not found!")
|
| 33 |
+
|
| 34 |
+
def generate_scripts_enhanced(self,
|
| 35 |
+
persona: str,
|
| 36 |
+
boundaries: str,
|
| 37 |
+
content_type: str,
|
| 38 |
+
tone: str,
|
| 39 |
+
manual_refs: List[str] = None,
|
| 40 |
+
n: int = 6) -> List[Dict]:
|
| 41 |
+
"""
|
| 42 |
+
Enhanced script generation with:
|
| 43 |
+
1. RAG-based reference selection
|
| 44 |
+
2. Policy-optimized parameters
|
| 45 |
+
3. Auto-scoring and reranking
|
| 46 |
+
4. Online learning feedback
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
print(f"🤖 Enhanced generation: {persona} × {content_type} × {n} scripts")
|
| 50 |
+
|
| 51 |
+
# Step 1: Get optimized policy for this persona/content_type
|
| 52 |
+
policy_arm = self.policy_learner.get_optimized_policy(persona, content_type)
|
| 53 |
+
|
| 54 |
+
# Step 2: Build dynamic few-shot pack using RAG
|
| 55 |
+
query_context = f"{persona} {content_type} {tone}"
|
| 56 |
+
few_shot_pack = self.retriever.build_dynamic_few_shot_pack(
|
| 57 |
+
persona=persona,
|
| 58 |
+
content_type=content_type,
|
| 59 |
+
query_context=query_context
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
# Step 3: Combine RAG refs with manual refs
|
| 63 |
+
rag_refs = (
|
| 64 |
+
few_shot_pack.get('best_hooks', []) +
|
| 65 |
+
few_shot_pack.get('best_beats', []) +
|
| 66 |
+
few_shot_pack.get('best_captions', [])
|
| 67 |
+
)
|
| 68 |
+
all_refs = (manual_refs or []) + rag_refs
|
| 69 |
+
|
| 70 |
+
print(f"📚 Using {len(rag_refs)} RAG refs + {len(manual_refs or [])} manual refs")
|
| 71 |
+
|
| 72 |
+
# Step 4: Enhanced generation with policy-optimized parameters
|
| 73 |
+
drafts = self._generate_with_policy(
|
| 74 |
+
persona=persona,
|
| 75 |
+
boundaries=boundaries,
|
| 76 |
+
content_type=content_type,
|
| 77 |
+
tone=tone,
|
| 78 |
+
refs=all_refs,
|
| 79 |
+
policy_arm=policy_arm,
|
| 80 |
+
n=n,
|
| 81 |
+
few_shot_pack=few_shot_pack
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Step 5: Anti-copying detection and cleanup
|
| 85 |
+
print(f"🛡️ Checking for similarity to reference content...")
|
| 86 |
+
|
| 87 |
+
# Extract reference texts for copying detection
|
| 88 |
+
reference_texts = rag_refs
|
| 89 |
+
cleaned_drafts = []
|
| 90 |
+
|
| 91 |
+
for draft in drafts:
|
| 92 |
+
# Check for copying
|
| 93 |
+
detection_results = self.retriever.detect_copying(
|
| 94 |
+
generated_content=draft,
|
| 95 |
+
reference_texts=reference_texts,
|
| 96 |
+
similarity_threshold=0.92
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if detection_results['is_copying']:
|
| 100 |
+
print(f"⚠️ Anti-copy triggered for draft: {draft.get('title', 'Untitled')[:30]}")
|
| 101 |
+
print(f" Max similarity: {detection_results['max_similarity']:.3f}")
|
| 102 |
+
|
| 103 |
+
# Auto-rewrite similar content
|
| 104 |
+
cleaned_draft = self.retriever.auto_rewrite_similar_content(
|
| 105 |
+
generated_content=draft,
|
| 106 |
+
detection_results=detection_results
|
| 107 |
+
)
|
| 108 |
+
cleaned_drafts.append(cleaned_draft)
|
| 109 |
+
else:
|
| 110 |
+
cleaned_drafts.append(draft)
|
| 111 |
+
|
| 112 |
+
# Step 6: Auto-score all generated drafts
|
| 113 |
+
script_ids = self._save_drafts_to_db(cleaned_drafts, persona, content_type, tone)
|
| 114 |
+
auto_scores = [self.scorer.score_and_store(sid) for sid in script_ids]
|
| 115 |
+
|
| 116 |
+
print(f"📊 Auto-scored {len(auto_scores)} drafts")
|
| 117 |
+
|
| 118 |
+
# Step 7: Rerank by composite score
|
| 119 |
+
ranked_script_ids = self.reranker.rerank_scripts(script_ids)
|
| 120 |
+
|
| 121 |
+
# Step 8: Policy learning feedback
|
| 122 |
+
self.policy_learner.learn_from_generation_batch(
|
| 123 |
+
persona=persona,
|
| 124 |
+
content_type=content_type,
|
| 125 |
+
generated_script_ids=script_ids,
|
| 126 |
+
selected_arm=policy_arm
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Return drafts in ranked order with scores
|
| 130 |
+
return self._format_enhanced_results(ranked_script_ids, cleaned_drafts)
|
| 131 |
+
|
| 132 |
+
def _generate_with_policy(self,
|
| 133 |
+
persona: str,
|
| 134 |
+
boundaries: str,
|
| 135 |
+
content_type: str,
|
| 136 |
+
tone: str,
|
| 137 |
+
refs: List[str],
|
| 138 |
+
policy_arm: Any, # BanditArm
|
| 139 |
+
n: int,
|
| 140 |
+
few_shot_pack: Dict) -> List[Dict]:
|
| 141 |
+
"""Generate scripts using policy-optimized parameters"""
|
| 142 |
+
|
| 143 |
+
# Enhanced system prompt with few-shot pack context
|
| 144 |
+
system = f"""You write Instagram-compliant, suggestive-but-not-explicit Reels briefs.
|
| 145 |
+
|
| 146 |
+
STYLE CONTEXT: {few_shot_pack.get('style_card', '')}
|
| 147 |
+
|
| 148 |
+
BEST PATTERNS TO EMULATE:
|
| 149 |
+
Hooks: {json.dumps(few_shot_pack.get('best_hooks', []))}
|
| 150 |
+
Beats: {json.dumps(few_shot_pack.get('best_beats', []))}
|
| 151 |
+
Captions: {json.dumps(few_shot_pack.get('best_captions', []))}
|
| 152 |
+
|
| 153 |
+
AVOID THESE PATTERNS: {json.dumps(few_shot_pack.get('negative_patterns', []))}
|
| 154 |
+
|
| 155 |
+
Use tight hooks, concrete visual beats, clear CTAs. Avoid explicit sexual terms.
|
| 156 |
+
Return ONLY JSON: an array of length {n}, each with {{title,hook,beats,voiceover,caption,hashtags,cta}}.
|
| 157 |
+
"""
|
| 158 |
+
|
| 159 |
+
user = f"""
|
| 160 |
+
Persona: {persona}
|
| 161 |
+
Boundaries: {boundaries}
|
| 162 |
+
Content type: {content_type} | Tone: {tone}
|
| 163 |
+
Constraints: {json.dumps(few_shot_pack.get('constraints', {}))}
|
| 164 |
+
|
| 165 |
+
Reference snippets (inspire, don't copy):
|
| 166 |
+
{chr(10).join(f"- {r}" for r in refs[:8])} # Limit to top 8 refs
|
| 167 |
+
|
| 168 |
+
Generate {n} unique variations. JSON array ONLY.
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
# Generate with multiple temperatures (policy-optimized)
|
| 172 |
+
variants = []
|
| 173 |
+
temps = [policy_arm.temp_low, policy_arm.temp_mid, policy_arm.temp_high]
|
| 174 |
+
scripts_per_temp = max(1, n // len(temps))
|
| 175 |
+
|
| 176 |
+
for i, temp in enumerate(temps):
|
| 177 |
+
batch_size = scripts_per_temp
|
| 178 |
+
if i == len(temps) - 1: # Last batch gets remainder
|
| 179 |
+
batch_size = n - len(variants)
|
| 180 |
+
|
| 181 |
+
if batch_size <= 0:
|
| 182 |
+
break
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
out = chat([
|
| 186 |
+
{"role": "system", "content": system},
|
| 187 |
+
{"role": "user", "content": user.replace(f"Generate {n}", f"Generate {batch_size}")}
|
| 188 |
+
], temperature=temp)
|
| 189 |
+
|
| 190 |
+
# Extract JSON
|
| 191 |
+
start = out.find("[")
|
| 192 |
+
end = out.rfind("]")
|
| 193 |
+
if start >= 0 and end > start:
|
| 194 |
+
batch_variants = json.loads(out[start:end+1])
|
| 195 |
+
variants.extend(batch_variants[:batch_size])
|
| 196 |
+
print(f"✨ Generated {len(batch_variants)} scripts at temp={temp}")
|
| 197 |
+
|
| 198 |
+
except Exception as e:
|
| 199 |
+
print(f"❌ Generation failed at temp={temp}: {e}")
|
| 200 |
+
|
| 201 |
+
return variants[:n] # Ensure we don't exceed requested count
|
| 202 |
+
|
| 203 |
+
def _save_drafts_to_db(self,
|
| 204 |
+
drafts: List[Dict],
|
| 205 |
+
persona: str,
|
| 206 |
+
content_type: str,
|
| 207 |
+
tone: str) -> List[int]:
|
| 208 |
+
"""Save generated drafts to database and return script IDs"""
|
| 209 |
+
|
| 210 |
+
script_ids = []
|
| 211 |
+
|
| 212 |
+
with get_session() as ses:
|
| 213 |
+
for draft in drafts:
|
| 214 |
+
try:
|
| 215 |
+
# Calculate basic compliance
|
| 216 |
+
from compliance import score_script, blob_from
|
| 217 |
+
content_blob = blob_from(draft)
|
| 218 |
+
compliance_level, _ = score_script(content_blob)
|
| 219 |
+
|
| 220 |
+
script = Script(
|
| 221 |
+
creator=persona,
|
| 222 |
+
content_type=content_type,
|
| 223 |
+
tone=tone,
|
| 224 |
+
title=draft.get("title", "Generated Script"),
|
| 225 |
+
hook=draft.get("hook", ""),
|
| 226 |
+
beats=draft.get("beats", []),
|
| 227 |
+
voiceover=draft.get("voiceover", ""),
|
| 228 |
+
caption=draft.get("caption", ""),
|
| 229 |
+
hashtags=draft.get("hashtags", []),
|
| 230 |
+
cta=draft.get("cta", ""),
|
| 231 |
+
compliance=compliance_level,
|
| 232 |
+
source="ai"
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
ses.add(script)
|
| 236 |
+
ses.commit()
|
| 237 |
+
ses.refresh(script)
|
| 238 |
+
|
| 239 |
+
script_ids.append(script.id)
|
| 240 |
+
|
| 241 |
+
# Generate embeddings for new script
|
| 242 |
+
embeddings = self.retriever.generate_embeddings(script)
|
| 243 |
+
for embedding in embeddings:
|
| 244 |
+
ses.add(embedding)
|
| 245 |
+
|
| 246 |
+
except Exception as e:
|
| 247 |
+
print(f"❌ Failed to save draft: {e}")
|
| 248 |
+
continue
|
| 249 |
+
|
| 250 |
+
ses.commit()
|
| 251 |
+
|
| 252 |
+
return script_ids
|
| 253 |
+
|
| 254 |
+
def _format_enhanced_results(self,
|
| 255 |
+
ranked_script_ids: List[tuple],
|
| 256 |
+
original_drafts: List[Dict]) -> List[Dict]:
|
| 257 |
+
"""Format results with ranking and score information"""
|
| 258 |
+
|
| 259 |
+
# Create a lookup for original drafts by content
|
| 260 |
+
draft_lookup = {}
|
| 261 |
+
for i, draft in enumerate(original_drafts):
|
| 262 |
+
key = draft.get("title", "") + draft.get("hook", "")
|
| 263 |
+
draft_lookup[key] = draft
|
| 264 |
+
|
| 265 |
+
results = []
|
| 266 |
+
|
| 267 |
+
with get_session() as ses:
|
| 268 |
+
for script_id, composite_score in ranked_script_ids:
|
| 269 |
+
script = ses.get(Script, script_id)
|
| 270 |
+
if script:
|
| 271 |
+
# Convert back to the expected format
|
| 272 |
+
result = {
|
| 273 |
+
"title": script.title,
|
| 274 |
+
"hook": script.hook,
|
| 275 |
+
"beats": script.beats,
|
| 276 |
+
"voiceover": script.voiceover,
|
| 277 |
+
"caption": script.caption,
|
| 278 |
+
"hashtags": script.hashtags,
|
| 279 |
+
"cta": script.cta,
|
| 280 |
+
# Enhanced metadata
|
| 281 |
+
"_enhanced_score": round(composite_score, 3),
|
| 282 |
+
"_script_id": script_id,
|
| 283 |
+
"_compliance": script.compliance
|
| 284 |
+
}
|
| 285 |
+
results.append(result)
|
| 286 |
+
|
| 287 |
+
return results
|
| 288 |
+
|
| 289 |
+
# Backward compatibility wrapper
|
| 290 |
+
def generate_scripts_rag(persona: str,
|
| 291 |
+
boundaries: str,
|
| 292 |
+
content_type: str,
|
| 293 |
+
tone: str,
|
| 294 |
+
refs: List[str],
|
| 295 |
+
n: int = 6) -> List[Dict]:
|
| 296 |
+
"""
|
| 297 |
+
Drop-in replacement for existing generate_scripts function
|
| 298 |
+
Uses enhanced RAG system while maintaining API compatibility
|
| 299 |
+
"""
|
| 300 |
+
generator = EnhancedScriptGenerator()
|
| 301 |
+
return generator.generate_scripts_enhanced(
|
| 302 |
+
persona=persona,
|
| 303 |
+
boundaries=boundaries,
|
| 304 |
+
content_type=content_type,
|
| 305 |
+
tone=tone,
|
| 306 |
+
manual_refs=refs,
|
| 307 |
+
n=n
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
def setup_rag_system():
|
| 311 |
+
"""One-time setup to initialize the RAG system"""
|
| 312 |
+
print("🔧 Setting up RAG system...")
|
| 313 |
+
|
| 314 |
+
# Initialize database with new tables
|
| 315 |
+
init_db()
|
| 316 |
+
print("✅ Database initialized")
|
| 317 |
+
|
| 318 |
+
# Generate embeddings for existing scripts
|
| 319 |
+
from rag_retrieval import index_all_scripts
|
| 320 |
+
index_all_scripts()
|
| 321 |
+
print("✅ Existing scripts indexed")
|
| 322 |
+
|
| 323 |
+
# Auto-score recent scripts
|
| 324 |
+
scorer = AutoScorer()
|
| 325 |
+
recent_scores = scorer.batch_score_recent(hours=24*7) # Last week
|
| 326 |
+
print(f"✅ Auto-scored {len(recent_scores)} recent scripts")
|
| 327 |
+
|
| 328 |
+
print("🎉 RAG system setup complete!")
|
| 329 |
+
|
| 330 |
+
if __name__ == "__main__":
|
| 331 |
+
# Demo the enhanced system
|
| 332 |
+
setup_rag_system()
|
| 333 |
+
|
| 334 |
+
# Test generation
|
| 335 |
+
generator = EnhancedScriptGenerator()
|
| 336 |
+
results = generator.generate_scripts_enhanced(
|
| 337 |
+
persona="Anya",
|
| 338 |
+
boundaries="Instagram-safe; suggestive but not explicit",
|
| 339 |
+
content_type="thirst-trap",
|
| 340 |
+
tone="playful, flirty",
|
| 341 |
+
manual_refs=["Just a quick workout session", "Getting ready for the day"],
|
| 342 |
+
n=3
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
print(f"\n🎬 Generated {len(results)} enhanced scripts:")
|
| 346 |
+
for i, script in enumerate(results, 1):
|
| 347 |
+
score = script.get('_enhanced_score', 0)
|
| 348 |
+
compliance = script.get('_compliance', 'unknown')
|
| 349 |
+
print(f"{i}. {script['title']} (score: {score}, compliance: {compliance})")
|
| 350 |
+
print(f" Hook: {script['hook'][:60]}...")
|
rag_retrieval.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Enhanced RAG retrieval system for AI Script Studio
|
| 3 |
+
Extends the existing hybrid reference system with semantic search and policy learning
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import math
|
| 8 |
+
from typing import List, Dict, Tuple, Optional
|
| 9 |
+
from sentence_transformers import SentenceTransformer
|
| 10 |
+
from sqlmodel import Session, select
|
| 11 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 12 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 13 |
+
import json
|
| 14 |
+
from datetime import datetime, timedelta
|
| 15 |
+
|
| 16 |
+
from models import Script, Embedding, AutoScore, PolicyWeights, StyleCard
|
| 17 |
+
from db import get_session
|
| 18 |
+
|
| 19 |
+
class RAGRetriever:
|
| 20 |
+
def __init__(self, model_name: str = "all-MiniLM-L6-v2"):
|
| 21 |
+
"""Initialize with lightweight but effective embedding model"""
|
| 22 |
+
self.encoder = SentenceTransformer(model_name)
|
| 23 |
+
self.tfidf = TfidfVectorizer(max_features=1000, stop_words='english')
|
| 24 |
+
|
| 25 |
+
def generate_embeddings(self, script: Script) -> List[Embedding]:
|
| 26 |
+
"""Generate embeddings for different parts of a script"""
|
| 27 |
+
parts = {
|
| 28 |
+
'full': self._get_full_text(script),
|
| 29 |
+
'hook': script.hook or '',
|
| 30 |
+
'beats': ' '.join(script.beats or []),
|
| 31 |
+
'caption': script.caption or ''
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
embeddings = []
|
| 35 |
+
for part, text in parts.items():
|
| 36 |
+
if text.strip(): # Only embed non-empty parts
|
| 37 |
+
vector = self.encoder.encode(text).tolist()
|
| 38 |
+
meta = {
|
| 39 |
+
'creator': script.creator,
|
| 40 |
+
'content_type': script.content_type,
|
| 41 |
+
'tone': script.tone,
|
| 42 |
+
'quality_score': script.score_overall or 0.0,
|
| 43 |
+
'compliance': script.compliance
|
| 44 |
+
}
|
| 45 |
+
embeddings.append(Embedding(
|
| 46 |
+
script_id=script.id,
|
| 47 |
+
part=part,
|
| 48 |
+
vector=vector,
|
| 49 |
+
meta=meta
|
| 50 |
+
))
|
| 51 |
+
return embeddings
|
| 52 |
+
|
| 53 |
+
def _get_full_text(self, script: Script) -> str:
|
| 54 |
+
"""Combine all script parts into full text"""
|
| 55 |
+
parts = [
|
| 56 |
+
script.title,
|
| 57 |
+
script.hook or '',
|
| 58 |
+
' '.join(script.beats or []),
|
| 59 |
+
script.voiceover or '',
|
| 60 |
+
script.caption or '',
|
| 61 |
+
script.cta or ''
|
| 62 |
+
]
|
| 63 |
+
return ' '.join(p for p in parts if p.strip())
|
| 64 |
+
|
| 65 |
+
def hybrid_retrieve(self,
|
| 66 |
+
query_text: str,
|
| 67 |
+
persona: str,
|
| 68 |
+
content_type: str,
|
| 69 |
+
k: int = 6,
|
| 70 |
+
global_quality_mean: float = 4.2,
|
| 71 |
+
shrinkage_alpha: float = 10.0,
|
| 72 |
+
freshness_tau_days: float = 28.0) -> List[Dict]:
|
| 73 |
+
"""
|
| 74 |
+
Production-grade hybrid retrieval with proper score normalization:
|
| 75 |
+
- Semantic similarity (cosine normalized to [0,1])
|
| 76 |
+
- BM25/TF-IDF similarity (min-max normalized per query)
|
| 77 |
+
- Quality scores (Bayesian shrinkage)
|
| 78 |
+
- Freshness boost (exponential decay)
|
| 79 |
+
- Policy-learned weights
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
# Get policy weights for this persona/content_type
|
| 83 |
+
weights = self._get_policy_weights(persona, content_type)
|
| 84 |
+
|
| 85 |
+
with get_session() as ses:
|
| 86 |
+
# Get all relevant scripts
|
| 87 |
+
scripts = list(ses.exec(
|
| 88 |
+
select(Script).where(
|
| 89 |
+
Script.creator == persona,
|
| 90 |
+
Script.content_type == content_type,
|
| 91 |
+
Script.is_reference == True,
|
| 92 |
+
Script.compliance != "fail"
|
| 93 |
+
)
|
| 94 |
+
))
|
| 95 |
+
|
| 96 |
+
if not scripts:
|
| 97 |
+
return []
|
| 98 |
+
|
| 99 |
+
# Get embeddings for semantic similarity
|
| 100 |
+
embeddings = list(ses.exec(
|
| 101 |
+
select(Embedding).join(Script, Embedding.script_id == Script.id).where(
|
| 102 |
+
Embedding.part == 'full',
|
| 103 |
+
Script.creator == persona,
|
| 104 |
+
Script.content_type == content_type,
|
| 105 |
+
Script.is_reference == True,
|
| 106 |
+
Script.compliance != "fail"
|
| 107 |
+
)
|
| 108 |
+
))
|
| 109 |
+
|
| 110 |
+
# Pre-calculate all raw scores for normalization
|
| 111 |
+
raw_scores = []
|
| 112 |
+
query_embedding = self.encoder.encode(query_text)
|
| 113 |
+
now = datetime.utcnow()
|
| 114 |
+
|
| 115 |
+
for script in scripts:
|
| 116 |
+
# Find matching embedding
|
| 117 |
+
script_embedding = next(
|
| 118 |
+
(e for e in embeddings if e.script_id == script.id),
|
| 119 |
+
None
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# 1. Raw semantic similarity (cosine returns [-1,1])
|
| 123 |
+
if script_embedding:
|
| 124 |
+
raw_cosine = cosine_similarity(
|
| 125 |
+
[query_embedding],
|
| 126 |
+
[script_embedding.vector]
|
| 127 |
+
)[0][0]
|
| 128 |
+
else:
|
| 129 |
+
raw_cosine = -1.0 # Worst case for missing embeddings
|
| 130 |
+
|
| 131 |
+
# 2. Raw BM25/TF-IDF similarity
|
| 132 |
+
script_text = self._get_full_text(script)
|
| 133 |
+
raw_bm25 = self._calculate_tfidf_similarity(query_text, script_text)
|
| 134 |
+
|
| 135 |
+
raw_scores.append({
|
| 136 |
+
'script': script,
|
| 137 |
+
'raw_cosine': raw_cosine,
|
| 138 |
+
'raw_bm25': raw_bm25
|
| 139 |
+
})
|
| 140 |
+
|
| 141 |
+
# Normalize BM25 scores (min-max normalization across this query's candidates)
|
| 142 |
+
bm25_scores = [s['raw_bm25'] for s in raw_scores]
|
| 143 |
+
min_bm25 = min(bm25_scores)
|
| 144 |
+
max_bm25 = max(bm25_scores)
|
| 145 |
+
bm25_range = max_bm25 - min_bm25 + 1e-9 # Avoid division by zero
|
| 146 |
+
|
| 147 |
+
# Calculate final normalized scores
|
| 148 |
+
results = []
|
| 149 |
+
|
| 150 |
+
for raw_score in raw_scores:
|
| 151 |
+
script = raw_score['script']
|
| 152 |
+
scores = {}
|
| 153 |
+
|
| 154 |
+
# 1. Semantic similarity: normalize cosine [-1,1] → [0,1]
|
| 155 |
+
scores['semantic'] = (raw_score['raw_cosine'] + 1.0) / 2.0
|
| 156 |
+
|
| 157 |
+
# 2. BM25: min-max normalize within this query's candidate set
|
| 158 |
+
scores['bm25'] = (raw_score['raw_bm25'] - min_bm25) / bm25_range
|
| 159 |
+
|
| 160 |
+
# 3. Quality: Bayesian shrinkage toward global mean
|
| 161 |
+
n_ratings = script.ratings_count or 0
|
| 162 |
+
local_quality = script.score_overall or global_quality_mean
|
| 163 |
+
|
| 164 |
+
# Shrinkage: blend local mean with global mean based on sample size
|
| 165 |
+
shrunk_quality = (
|
| 166 |
+
(n_ratings / (n_ratings + shrinkage_alpha)) * local_quality +
|
| 167 |
+
(shrinkage_alpha / (n_ratings + shrinkage_alpha)) * global_quality_mean
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Normalize to [0,1] (assuming 1-5 rating scale)
|
| 171 |
+
scores['quality'] = max(0.0, min(1.0, (shrunk_quality - 1) / 4))
|
| 172 |
+
|
| 173 |
+
# 4. Freshness: exponential decay (smoother than linear)
|
| 174 |
+
days_old = max(0, (now - script.created_at).days)
|
| 175 |
+
scores['freshness'] = math.exp(-days_old / freshness_tau_days)
|
| 176 |
+
|
| 177 |
+
# Combined score using policy weights
|
| 178 |
+
combined_score = (
|
| 179 |
+
weights.semantic_weight * scores['semantic'] +
|
| 180 |
+
weights.bm25_weight * scores['bm25'] +
|
| 181 |
+
weights.quality_weight * scores['quality'] +
|
| 182 |
+
weights.freshness_weight * scores['freshness']
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
results.append({
|
| 186 |
+
'script': script,
|
| 187 |
+
'score': combined_score,
|
| 188 |
+
'component_scores': scores,
|
| 189 |
+
# Debug info
|
| 190 |
+
'_debug': {
|
| 191 |
+
'n_ratings': n_ratings,
|
| 192 |
+
'raw_quality': local_quality,
|
| 193 |
+
'shrunk_quality': shrunk_quality,
|
| 194 |
+
'days_old': days_old
|
| 195 |
+
}
|
| 196 |
+
})
|
| 197 |
+
|
| 198 |
+
# Sort by combined score and return top k
|
| 199 |
+
results.sort(key=lambda x: x['score'], reverse=True)
|
| 200 |
+
return results[:k]
|
| 201 |
+
|
| 202 |
+
def _calculate_tfidf_similarity(self, query: str, doc: str) -> float:
|
| 203 |
+
"""Calculate TF-IDF similarity between query and document"""
|
| 204 |
+
try:
|
| 205 |
+
tfidf_matrix = self.tfidf.fit_transform([query, doc])
|
| 206 |
+
similarity = cosine_similarity(tfidf_matrix[0:1], tfidf_matrix[1:2])[0][0]
|
| 207 |
+
return float(similarity)
|
| 208 |
+
except:
|
| 209 |
+
return 0.0
|
| 210 |
+
|
| 211 |
+
def _get_policy_weights(self, persona: str, content_type: str) -> PolicyWeights:
|
| 212 |
+
"""Get learned policy weights or create defaults"""
|
| 213 |
+
with get_session() as ses:
|
| 214 |
+
weights = ses.exec(
|
| 215 |
+
select(PolicyWeights).where(
|
| 216 |
+
PolicyWeights.persona == persona,
|
| 217 |
+
PolicyWeights.content_type == content_type
|
| 218 |
+
)
|
| 219 |
+
).first()
|
| 220 |
+
|
| 221 |
+
if not weights:
|
| 222 |
+
# Create default weights
|
| 223 |
+
weights = PolicyWeights(
|
| 224 |
+
persona=persona,
|
| 225 |
+
content_type=content_type
|
| 226 |
+
)
|
| 227 |
+
ses.add(weights)
|
| 228 |
+
ses.commit()
|
| 229 |
+
ses.refresh(weights)
|
| 230 |
+
|
| 231 |
+
return weights
|
| 232 |
+
|
| 233 |
+
def build_dynamic_few_shot_pack(self,
|
| 234 |
+
persona: str,
|
| 235 |
+
content_type: str,
|
| 236 |
+
query_context: str = "") -> Dict:
|
| 237 |
+
"""Build dynamic few-shot examples pack optimized for this request"""
|
| 238 |
+
|
| 239 |
+
# Get best references via hybrid retrieval
|
| 240 |
+
references = self.hybrid_retrieve(
|
| 241 |
+
query_text=query_context or f"{persona} {content_type}",
|
| 242 |
+
persona=persona,
|
| 243 |
+
content_type=content_type,
|
| 244 |
+
k=6
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
if not references:
|
| 248 |
+
return {"style_card": "", "examples": [], "constraints": {}}
|
| 249 |
+
|
| 250 |
+
# Extract best examples by type
|
| 251 |
+
best_hooks = []
|
| 252 |
+
best_beats = []
|
| 253 |
+
best_captions = []
|
| 254 |
+
|
| 255 |
+
for ref in references[:4]: # Use top 4 references
|
| 256 |
+
script = ref['script']
|
| 257 |
+
if script.hook and len(best_hooks) < 2:
|
| 258 |
+
best_hooks.append(script.hook)
|
| 259 |
+
if script.beats and len(best_beats) < 1:
|
| 260 |
+
best_beats.extend(script.beats[:2]) # First 2 beats
|
| 261 |
+
if script.caption and len(best_captions) < 1:
|
| 262 |
+
best_captions.append(script.caption)
|
| 263 |
+
|
| 264 |
+
# Get or create style card
|
| 265 |
+
style_card = self._get_style_card(persona, content_type)
|
| 266 |
+
|
| 267 |
+
return {
|
| 268 |
+
"style_card": f"Persona: {persona} | Content: {content_type}",
|
| 269 |
+
"best_hooks": best_hooks[:2],
|
| 270 |
+
"best_beats": best_beats[:3],
|
| 271 |
+
"best_captions": best_captions[:1],
|
| 272 |
+
"constraints": {
|
| 273 |
+
"max_length": "15-25 seconds",
|
| 274 |
+
"compliance": "Instagram-safe",
|
| 275 |
+
"tone": references[0]['script'].tone if references else "playful"
|
| 276 |
+
},
|
| 277 |
+
"negative_patterns": style_card.negative_patterns if style_card else []
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
def _get_style_card(self, persona: str, content_type: str) -> Optional[StyleCard]:
|
| 281 |
+
"""Get existing style card or return None"""
|
| 282 |
+
with get_session() as ses:
|
| 283 |
+
return ses.exec(
|
| 284 |
+
select(StyleCard).where(
|
| 285 |
+
StyleCard.persona == persona,
|
| 286 |
+
StyleCard.content_type == content_type
|
| 287 |
+
)
|
| 288 |
+
).first()
|
| 289 |
+
|
| 290 |
+
def detect_copying(self,
|
| 291 |
+
generated_content: Dict,
|
| 292 |
+
reference_texts: List[str],
|
| 293 |
+
similarity_threshold: float = 0.92) -> Dict:
|
| 294 |
+
"""
|
| 295 |
+
Detect if generated content is too similar to reference material.
|
| 296 |
+
Returns detection results with flagged content and similarity scores.
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
generated_content: Dict with keys like 'hook', 'caption', 'beats', etc.
|
| 300 |
+
reference_texts: List of reference text snippets to compare against
|
| 301 |
+
similarity_threshold: Cosine similarity threshold (0.92 recommended)
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
Dict with detection results and recommendations
|
| 305 |
+
"""
|
| 306 |
+
|
| 307 |
+
detection_results = {
|
| 308 |
+
'is_copying': False,
|
| 309 |
+
'flagged_fields': [],
|
| 310 |
+
'max_similarity': 0.0,
|
| 311 |
+
'rewrite_recommendations': []
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
if not reference_texts:
|
| 315 |
+
return detection_results
|
| 316 |
+
|
| 317 |
+
# Encode all reference texts
|
| 318 |
+
reference_embeddings = self.encoder.encode(reference_texts)
|
| 319 |
+
|
| 320 |
+
# Fields to check for copying
|
| 321 |
+
fields_to_check = ['hook', 'caption', 'cta']
|
| 322 |
+
|
| 323 |
+
for field in fields_to_check:
|
| 324 |
+
if field in generated_content and generated_content[field]:
|
| 325 |
+
generated_text = str(generated_content[field])
|
| 326 |
+
|
| 327 |
+
# Skip very short texts (less than 10 characters)
|
| 328 |
+
if len(generated_text.strip()) < 10:
|
| 329 |
+
continue
|
| 330 |
+
|
| 331 |
+
# Encode generated text
|
| 332 |
+
generated_embedding = self.encoder.encode([generated_text])
|
| 333 |
+
|
| 334 |
+
# Calculate similarity to all reference texts
|
| 335 |
+
similarities = cosine_similarity(generated_embedding, reference_embeddings)[0]
|
| 336 |
+
max_sim = float(np.max(similarities))
|
| 337 |
+
|
| 338 |
+
# Update overall max similarity
|
| 339 |
+
detection_results['max_similarity'] = max(detection_results['max_similarity'], max_sim)
|
| 340 |
+
|
| 341 |
+
# Check if similarity exceeds threshold
|
| 342 |
+
if max_sim >= similarity_threshold:
|
| 343 |
+
detection_results['is_copying'] = True
|
| 344 |
+
detection_results['flagged_fields'].append({
|
| 345 |
+
'field': field,
|
| 346 |
+
'text': generated_text,
|
| 347 |
+
'similarity': max_sim,
|
| 348 |
+
'similar_reference': reference_texts[int(np.argmax(similarities))]
|
| 349 |
+
})
|
| 350 |
+
|
| 351 |
+
# Generate rewrite recommendation
|
| 352 |
+
if max_sim >= 0.95:
|
| 353 |
+
urgency = "CRITICAL"
|
| 354 |
+
action = "Completely rewrite this content"
|
| 355 |
+
elif max_sim >= 0.92:
|
| 356 |
+
urgency = "HIGH"
|
| 357 |
+
action = "Significantly rephrase this content"
|
| 358 |
+
else:
|
| 359 |
+
urgency = "MEDIUM"
|
| 360 |
+
action = "Minor rewording may be needed"
|
| 361 |
+
|
| 362 |
+
detection_results['rewrite_recommendations'].append({
|
| 363 |
+
'field': field,
|
| 364 |
+
'urgency': urgency,
|
| 365 |
+
'action': action,
|
| 366 |
+
'original': generated_text
|
| 367 |
+
})
|
| 368 |
+
|
| 369 |
+
return detection_results
|
| 370 |
+
|
| 371 |
+
def auto_rewrite_similar_content(self,
|
| 372 |
+
generated_content: Dict,
|
| 373 |
+
detection_results: Dict,
|
| 374 |
+
rewrite_instruction: str = "Rewrite to be more original while keeping the same intent") -> Dict:
|
| 375 |
+
"""
|
| 376 |
+
Automatically rewrite content that's too similar to references.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
generated_content: The original generated content
|
| 380 |
+
detection_results: Results from detect_copying()
|
| 381 |
+
rewrite_instruction: Instructions for how to rewrite
|
| 382 |
+
|
| 383 |
+
Returns:
|
| 384 |
+
Rewritten content dict
|
| 385 |
+
"""
|
| 386 |
+
|
| 387 |
+
if not detection_results['is_copying']:
|
| 388 |
+
return generated_content
|
| 389 |
+
|
| 390 |
+
rewritten_content = generated_content.copy()
|
| 391 |
+
|
| 392 |
+
for flag in detection_results['flagged_fields']:
|
| 393 |
+
field = flag['field']
|
| 394 |
+
original_text = flag['text']
|
| 395 |
+
|
| 396 |
+
# Simple rewrite strategy: add instruction to modify the text
|
| 397 |
+
# In a production system, you'd call the LLM to rewrite
|
| 398 |
+
rewrite_prompt = f"""
|
| 399 |
+
Original: {original_text}
|
| 400 |
+
|
| 401 |
+
This text is too similar to existing reference material.
|
| 402 |
+
Please rewrite it to be more original while keeping the same intent and tone.
|
| 403 |
+
Make it clearly different from the reference but equally engaging.
|
| 404 |
+
|
| 405 |
+
Rewritten version:
|
| 406 |
+
"""
|
| 407 |
+
|
| 408 |
+
# For now, add a flag that this needs rewriting
|
| 409 |
+
# In production, you'd call your LLM API here
|
| 410 |
+
rewritten_content[field] = f"[NEEDS_REWRITE] {original_text}"
|
| 411 |
+
|
| 412 |
+
# Log the issue
|
| 413 |
+
print(f"🚨 Anti-copy detection: {field} flagged (similarity: {flag['similarity']:.3f})")
|
| 414 |
+
print(f" Original: {original_text[:60]}...")
|
| 415 |
+
print(f" Similar to: {flag['similar_reference'][:60]}...")
|
| 416 |
+
|
| 417 |
+
return rewritten_content
|
| 418 |
+
|
| 419 |
+
def index_all_scripts():
|
| 420 |
+
"""Utility function to generate embeddings for all existing scripts"""
|
| 421 |
+
retriever = RAGRetriever()
|
| 422 |
+
|
| 423 |
+
with get_session() as ses:
|
| 424 |
+
scripts = list(ses.exec(select(Script)))
|
| 425 |
+
|
| 426 |
+
for script in scripts:
|
| 427 |
+
# Check if embeddings already exist
|
| 428 |
+
existing = ses.exec(
|
| 429 |
+
select(Embedding).where(Embedding.script_id == script.id)
|
| 430 |
+
).first()
|
| 431 |
+
|
| 432 |
+
if not existing:
|
| 433 |
+
embeddings = retriever.generate_embeddings(script)
|
| 434 |
+
for embedding in embeddings:
|
| 435 |
+
ses.add(embedding)
|
| 436 |
+
|
| 437 |
+
print(f"Generated embeddings for script {script.id}")
|
| 438 |
+
|
| 439 |
+
ses.commit()
|
| 440 |
+
print(f"Indexing complete! Processed {len(scripts)} scripts.")
|
| 441 |
+
|
| 442 |
+
if __name__ == "__main__":
|
| 443 |
+
# Run this to index your existing scripts
|
| 444 |
+
index_all_scripts()
|
requirements.txt
CHANGED
|
@@ -1,3 +1,16 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit>=1.37.1
|
| 2 |
+
sqlmodel>=0.0.16
|
| 3 |
+
pydantic>=1.10.15
|
| 4 |
+
python-dotenv>=1.0.1
|
| 5 |
+
requests>=2.32.3
|
| 6 |
+
sqlalchemy>=2.0.0
|
| 7 |
+
|
| 8 |
+
# RAG Enhancement Dependencies
|
| 9 |
+
sentence-transformers>=2.2.2
|
| 10 |
+
scikit-learn>=1.3.0
|
| 11 |
+
numpy>=1.24.0
|
| 12 |
+
faiss-cpu>=1.7.4
|
| 13 |
+
|
| 14 |
+
# Additional dependencies for deployment
|
| 15 |
+
torch>=2.0.0
|
| 16 |
+
transformers>=4.30.0
|