import gradio as gr import pandas as pd import numpy as np import os import google.generativeai as genai from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity from dotenv import load_dotenv # Import functions from model-prep from xgboost import XGBRegressor # Use Regressor as per model-prep import pickle from importlib.util import spec_from_file_location import sys # Since we are loading artifacts, we don't strictly need model-prep.py logic anymore. # But keeping basic imports is fine. # Load environment variables load_dotenv() # --- GLOBAL STATE --- MODEL = None VECTORIZER = None KNOWLEDGE_DF = None ST_MODEL = None def initialize_app(): """Initializes the model and data on app startup.""" global MODEL, VECTORIZER, KNOWLEDGE_DF, ST_MODEL print("⏳ initializing app: Loading pre-computed artifacts...") # 1. Load Parquet Data (Knowledge Base) # We expect this file to exist now. parquet_path = 'tiktok_knowledge_base.parquet' if not os.path.exists(parquet_path): raise FileNotFoundError(f"Required file '{parquet_path}' not found! Run model-prep.py first.") print(f"📂 Loading data from {parquet_path}...") knowledge_df = pd.read_parquet(parquet_path) # 2. Load Model print("🧠 Loading XGBoost Model...") model = XGBRegressor() model.load_model("viral_model.json") # 3. Load Vectorizer print("🔤 Loading TF-IDF Vectorizer...") with open("tfidf_vectorizer.pkl", "rb") as f: tfidf = pickle.load(f) # 4. Load Sentence Transformer print("🔌 Loading SentenceTransformer...") # device=model_prep.device might fail if we don't import model_prep executed. # Just use defaults or check pytorch standardly. import torch device = "mps" if torch.backends.mps.is_available() else "cpu" st_model = SentenceTransformer('all-MiniLM-L6-v2', device=device) MODEL = model VECTORIZER = tfidf KNOWLEDGE_DF = knowledge_df ST_MODEL = st_model print("✅ App initialized (Inference Mode)!") def predict_and_optimize(user_input): if not user_input: return "Please enter a video description.", "", "", "", "" # --- 1. INITIAL PREDICTION --- text_vec = VECTORIZER.transform([user_input]).toarray() # Assume default meta: 15s duration, 18:00 (6 PM), weekday (0), hashtag count from input meta_vec = np.array([[15, 18, 0, user_input.count('#')]]) feat_vec = np.hstack((text_vec, meta_vec)) initial_log = MODEL.predict(feat_vec)[0] initial_views = int(np.expm1(initial_log)) # --- 2. VECTOR SEARCH --- # Filter for viral hits in knowledge base (top 25%) high_perf_df = KNOWLEDGE_DF[KNOWLEDGE_DF['views'] > KNOWLEDGE_DF['views'].quantile(0.75)].copy() user_embedding = ST_MODEL.encode([user_input], convert_to_numpy=True) target_embeddings = np.stack(high_perf_df['embedding'].values) similarities = cosine_similarity(user_embedding, target_embeddings) top_3_indices = similarities[0].argsort()[-3:][::-1] top_3_videos = high_perf_df.iloc[top_3_indices]['description'].tolist() similar_videos_str = "\n\n".join([f"{i+1}. {v}" for i, v in enumerate(top_3_videos)]) # --- 3. GEMINI OPTIMIZATION --- api_key = os.getenv("GEMINI_API_KEY") if not api_key: return f"{initial_views:,}", similar_videos_str, "Error: GEMINI_API_KEY not found.", "N/A", "N/A" genai.configure(api_key=api_key) # Using the updated model from the user's latest change try: llm = genai.GenerativeModel('gemini-2.5-flash-lite') except: llm = genai.GenerativeModel('gemini-1.5-flash') prompt = f""" You are a TikTok Virality Expert. My Draft Description: "{user_input}" Here are 3 successful, viral videos that are similar to my topic: 1. {top_3_videos[0]} 2. {top_3_videos[1]} 3. {top_3_videos[2]} Task: Rewrite my draft description to make it go viral and full video plan. Use the slang, hashtag style, and structure of the successful examples provided. Keep it under 60 words plus hashtags. Return ONLY the new description. """ try: response = llm.generate_content(prompt) improved_idea = response.text.strip() # --- 4. RE-SCORING --- new_text_vec = VECTORIZER.transform([improved_idea]).toarray() new_meta_vec = np.array([[15, 18, 0, improved_idea.count('#')]]) new_feat_vec = np.hstack((new_text_vec, new_meta_vec)) new_log = MODEL.predict(new_feat_vec)[0] new_views = int(np.expm1(new_log)) uplift_pct = ((new_views - initial_views) / initial_views) * 100 uplift_str = f"+{uplift_pct:.1f}%" if uplift_pct > 0 else "No significant uplift" return f"{initial_views:,}", similar_videos_str, improved_idea, f"{new_views:,}", uplift_str except Exception as e: return f"{initial_views:,}", similar_videos_str, f"Error calling AI: {str(e)}", "N/A", "N/A" # --- GRADIO UI --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🚀 Viral Content Optimizer") gr.Markdown("Enter your video idea to predict its views and get AI-powered optimizations based on 2025 trends.") with gr.Row(): with gr.Column(scale=1): input_text = gr.Textbox( label="Your Video Description", placeholder="e.g., POV: trying the new grimace shake #viral", lines=3 ) with gr.Row(): submit_btn = gr.Button("Analyze & Optimize ⚡", variant="primary") demo_btn = gr.Button("🎲 Try Demo", variant="secondary") with gr.Column(scale=1): with gr.Group(): gr.Markdown("### 📊 Predictions") initial_views = gr.Textbox(label="Predicted Views (Original)", interactive=False) with gr.Group(): gr.Markdown("### ✨ AI Optimization") improved_text = gr.Textbox(label="Improved Description", interactive=False) with gr.Row(): new_views = gr.Textbox(label="New Predicted Views", interactive=False) uplift = gr.Textbox(label="Potential Uplift", interactive=False) with gr.Accordion("🔍 Similar Viral Videos (Reference)", open=False): similar_videos = gr.Textbox(label="Top 3 Context Matches", interactive=False, lines=5) submit_btn.click( fn=predict_and_optimize, inputs=[input_text], outputs=[initial_views, similar_videos, improved_text, new_views, uplift] ) # Demo Button Logic: 1. Fill Text -> 2. Run Prediction demo_text = "POV: You realize you forgot to turn off your mic during the all-hands meeting 💀 #fail #fyp #corporate" demo_btn.click( fn=lambda: demo_text, inputs=None, outputs=input_text ).then( fn=predict_and_optimize, inputs=gr.State(demo_text), # Pass directly to avoid race condition with UI update outputs=[initial_views, similar_videos, improved_text, new_views, uplift] ) # Run initialization if __name__ == "__main__": initialize_app() demo.launch(server_name="0.0.0.0", server_port=7860)