Spaces:
Sleeping
Sleeping
File size: 7,401 Bytes
cf49347 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import gradio as gr
import pandas as pd
import numpy as np
import os
import google.generativeai as genai
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from dotenv import load_dotenv
# Import functions from model-prep
from xgboost import XGBRegressor # Use Regressor as per model-prep
import pickle
from importlib.util import spec_from_file_location
import sys
# Since we are loading artifacts, we don't strictly need model-prep.py logic anymore.
# But keeping basic imports is fine.
# Load environment variables
load_dotenv()
# --- GLOBAL STATE ---
MODEL = None
VECTORIZER = None
KNOWLEDGE_DF = None
ST_MODEL = None
def initialize_app():
"""Initializes the model and data on app startup."""
global MODEL, VECTORIZER, KNOWLEDGE_DF, ST_MODEL
print("β³ initializing app: Loading pre-computed artifacts...")
# 1. Load Parquet Data (Knowledge Base)
# We expect this file to exist now.
parquet_path = 'tiktok_knowledge_base.parquet'
if not os.path.exists(parquet_path):
raise FileNotFoundError(f"Required file '{parquet_path}' not found! Run model-prep.py first.")
print(f"π Loading data from {parquet_path}...")
knowledge_df = pd.read_parquet(parquet_path)
# 2. Load Model
print("π§ Loading XGBoost Model...")
model = XGBRegressor()
model.load_model("viral_model.json")
# 3. Load Vectorizer
print("π€ Loading TF-IDF Vectorizer...")
with open("tfidf_vectorizer.pkl", "rb") as f:
tfidf = pickle.load(f)
# 4. Load Sentence Transformer
print("π Loading SentenceTransformer...")
# device=model_prep.device might fail if we don't import model_prep executed.
# Just use defaults or check pytorch standardly.
import torch
device = "mps" if torch.backends.mps.is_available() else "cpu"
st_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
MODEL = model
VECTORIZER = tfidf
KNOWLEDGE_DF = knowledge_df
ST_MODEL = st_model
print("β
App initialized (Inference Mode)!")
def predict_and_optimize(user_input):
if not user_input:
return "Please enter a video description.", "", "", "", ""
# --- 1. INITIAL PREDICTION ---
text_vec = VECTORIZER.transform([user_input]).toarray()
# Assume default meta: 15s duration, 18:00 (6 PM), weekday (0), hashtag count from input
meta_vec = np.array([[15, 18, 0, user_input.count('#')]])
feat_vec = np.hstack((text_vec, meta_vec))
initial_log = MODEL.predict(feat_vec)[0]
initial_views = int(np.expm1(initial_log))
# --- 2. VECTOR SEARCH ---
# Filter for viral hits in knowledge base (top 25%)
high_perf_df = KNOWLEDGE_DF[KNOWLEDGE_DF['views'] > KNOWLEDGE_DF['views'].quantile(0.75)].copy()
user_embedding = ST_MODEL.encode([user_input], convert_to_numpy=True)
target_embeddings = np.stack(high_perf_df['embedding'].values)
similarities = cosine_similarity(user_embedding, target_embeddings)
top_3_indices = similarities[0].argsort()[-3:][::-1]
top_3_videos = high_perf_df.iloc[top_3_indices]['description'].tolist()
similar_videos_str = "\n\n".join([f"{i+1}. {v}" for i, v in enumerate(top_3_videos)])
# --- 3. GEMINI OPTIMIZATION ---
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
return f"{initial_views:,}", similar_videos_str, "Error: GEMINI_API_KEY not found.", "N/A", "N/A"
genai.configure(api_key=api_key)
# Using the updated model from the user's latest change
try:
llm = genai.GenerativeModel('gemini-2.5-flash-lite')
except:
llm = genai.GenerativeModel('gemini-1.5-flash')
prompt = f"""
You are a TikTok Virality Expert.
My Draft Description: "{user_input}"
Here are 3 successful, viral videos that are similar to my topic:
1. {top_3_videos[0]}
2. {top_3_videos[1]}
3. {top_3_videos[2]}
Task: Rewrite my draft description to make it go viral and full video plan.
Use the slang, hashtag style, and structure of the successful examples provided.
Keep it under 60 words plus hashtags. Return ONLY the new description.
"""
try:
response = llm.generate_content(prompt)
improved_idea = response.text.strip()
# --- 4. RE-SCORING ---
new_text_vec = VECTORIZER.transform([improved_idea]).toarray()
new_meta_vec = np.array([[15, 18, 0, improved_idea.count('#')]])
new_feat_vec = np.hstack((new_text_vec, new_meta_vec))
new_log = MODEL.predict(new_feat_vec)[0]
new_views = int(np.expm1(new_log))
uplift_pct = ((new_views - initial_views) / initial_views) * 100
uplift_str = f"+{uplift_pct:.1f}%" if uplift_pct > 0 else "No significant uplift"
return f"{initial_views:,}", similar_videos_str, improved_idea, f"{new_views:,}", uplift_str
except Exception as e:
return f"{initial_views:,}", similar_videos_str, f"Error calling AI: {str(e)}", "N/A", "N/A"
# --- GRADIO UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# π Viral Content Optimizer")
gr.Markdown("Enter your video idea to predict its views and get AI-powered optimizations based on 2025 trends.")
with gr.Row():
with gr.Column(scale=1):
input_text = gr.Textbox(
label="Your Video Description",
placeholder="e.g., POV: trying the new grimace shake #viral",
lines=3
)
with gr.Row():
submit_btn = gr.Button("Analyze & Optimize β‘", variant="primary")
demo_btn = gr.Button("π² Try Demo", variant="secondary")
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### π Predictions")
initial_views = gr.Textbox(label="Predicted Views (Original)", interactive=False)
with gr.Group():
gr.Markdown("### β¨ AI Optimization")
improved_text = gr.Textbox(label="Improved Description", interactive=False)
with gr.Row():
new_views = gr.Textbox(label="New Predicted Views", interactive=False)
uplift = gr.Textbox(label="Potential Uplift", interactive=False)
with gr.Accordion("π Similar Viral Videos (Reference)", open=False):
similar_videos = gr.Textbox(label="Top 3 Context Matches", interactive=False, lines=5)
submit_btn.click(
fn=predict_and_optimize,
inputs=[input_text],
outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
)
# Demo Button Logic: 1. Fill Text -> 2. Run Prediction
demo_text = "POV: You realize you forgot to turn off your mic during the all-hands meeting π #fail #fyp #corporate"
demo_btn.click(
fn=lambda: demo_text,
inputs=None,
outputs=input_text
).then(
fn=predict_and_optimize,
inputs=gr.State(demo_text), # Pass directly to avoid race condition with UI update
outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
)
# Run initialization
if __name__ == "__main__":
initialize_app()
demo.launch(server_name="0.0.0.0", server_port=7860)
|