Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import os | |
| import google.generativeai as genai | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from dotenv import load_dotenv | |
| # Import functions from model-prep | |
| from xgboost import XGBRegressor # Use Regressor as per model-prep | |
| import pickle | |
| from importlib.util import spec_from_file_location | |
| import sys | |
| # Since we are loading artifacts, we don't strictly need model-prep.py logic anymore. | |
| # But keeping basic imports is fine. | |
| # Load environment variables | |
| load_dotenv() | |
| # --- GLOBAL STATE --- | |
| MODEL = None | |
| VECTORIZER = None | |
| KNOWLEDGE_DF = None | |
| ST_MODEL = None | |
| def initialize_app(): | |
| """Initializes the model and data on app startup.""" | |
| global MODEL, VECTORIZER, KNOWLEDGE_DF, ST_MODEL | |
| print("β³ initializing app: Loading pre-computed artifacts...") | |
| # 1. Load Parquet Data (Knowledge Base) | |
| # We expect this file to exist now. | |
| parquet_path = 'tiktok_knowledge_base.parquet' | |
| if not os.path.exists(parquet_path): | |
| raise FileNotFoundError(f"Required file '{parquet_path}' not found! Run model-prep.py first.") | |
| print(f"π Loading data from {parquet_path}...") | |
| knowledge_df = pd.read_parquet(parquet_path) | |
| # 2. Load Model | |
| print("π§ Loading XGBoost Model...") | |
| model = XGBRegressor() | |
| model.load_model("viral_model.json") | |
| # 3. Load Vectorizer | |
| print("π€ Loading TF-IDF Vectorizer...") | |
| with open("tfidf_vectorizer.pkl", "rb") as f: | |
| tfidf = pickle.load(f) | |
| # 4. Load Sentence Transformer | |
| print("π Loading SentenceTransformer...") | |
| # device=model_prep.device might fail if we don't import model_prep executed. | |
| # Just use defaults or check pytorch standardly. | |
| import torch | |
| device = "mps" if torch.backends.mps.is_available() else "cpu" | |
| st_model = SentenceTransformer('all-MiniLM-L6-v2', device=device) | |
| MODEL = model | |
| VECTORIZER = tfidf | |
| KNOWLEDGE_DF = knowledge_df | |
| ST_MODEL = st_model | |
| print("β App initialized (Inference Mode)!") | |
| def predict_and_optimize(user_input): | |
| if not user_input: | |
| return "Please enter a video description.", "", "", "", "" | |
| # --- 1. INITIAL PREDICTION --- | |
| text_vec = VECTORIZER.transform([user_input]).toarray() | |
| # Assume default meta: 15s duration, 18:00 (6 PM), weekday (0), hashtag count from input | |
| meta_vec = np.array([[15, 18, 0, user_input.count('#')]]) | |
| feat_vec = np.hstack((text_vec, meta_vec)) | |
| initial_log = MODEL.predict(feat_vec)[0] | |
| initial_views = int(np.expm1(initial_log)) | |
| # --- 2. VECTOR SEARCH --- | |
| # Filter for viral hits in knowledge base (top 25%) | |
| high_perf_df = KNOWLEDGE_DF[KNOWLEDGE_DF['views'] > KNOWLEDGE_DF['views'].quantile(0.75)].copy() | |
| user_embedding = ST_MODEL.encode([user_input], convert_to_numpy=True) | |
| target_embeddings = np.stack(high_perf_df['embedding'].values) | |
| similarities = cosine_similarity(user_embedding, target_embeddings) | |
| top_3_indices = similarities[0].argsort()[-3:][::-1] | |
| top_3_videos = high_perf_df.iloc[top_3_indices]['description'].tolist() | |
| similar_videos_str = "\n\n".join([f"{i+1}. {v}" for i, v in enumerate(top_3_videos)]) | |
| # --- 3. GEMINI OPTIMIZATION --- | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| return f"{initial_views:,}", similar_videos_str, "Error: GEMINI_API_KEY not found.", "N/A", "N/A" | |
| genai.configure(api_key=api_key) | |
| # Using the updated model from the user's latest change | |
| try: | |
| llm = genai.GenerativeModel('gemini-2.5-flash-lite') | |
| except: | |
| llm = genai.GenerativeModel('gemini-1.5-flash') | |
| prompt = f""" | |
| You are a TikTok Virality Expert. | |
| My Draft Description: "{user_input}" | |
| Here are 3 successful, viral videos that are similar to my topic: | |
| 1. {top_3_videos[0]} | |
| 2. {top_3_videos[1]} | |
| 3. {top_3_videos[2]} | |
| Task: Rewrite my draft description to make it go viral and full video plan. | |
| Use the slang, hashtag style, and structure of the successful examples provided. | |
| Keep it under 60 words plus hashtags. Return ONLY the new description. | |
| """ | |
| try: | |
| response = llm.generate_content(prompt) | |
| improved_idea = response.text.strip() | |
| # --- 4. RE-SCORING --- | |
| new_text_vec = VECTORIZER.transform([improved_idea]).toarray() | |
| new_meta_vec = np.array([[15, 18, 0, improved_idea.count('#')]]) | |
| new_feat_vec = np.hstack((new_text_vec, new_meta_vec)) | |
| new_log = MODEL.predict(new_feat_vec)[0] | |
| new_views = int(np.expm1(new_log)) | |
| uplift_pct = ((new_views - initial_views) / initial_views) * 100 | |
| uplift_str = f"+{uplift_pct:.1f}%" if uplift_pct > 0 else "No significant uplift" | |
| return f"{initial_views:,}", similar_videos_str, improved_idea, f"{new_views:,}", uplift_str | |
| except Exception as e: | |
| return f"{initial_views:,}", similar_videos_str, f"Error calling AI: {str(e)}", "N/A", "N/A" | |
| # --- GRADIO UI --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π Viral Content Optimizer") | |
| gr.Markdown("Enter your video idea to predict its views and get AI-powered optimizations based on 2025 trends.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_text = gr.Textbox( | |
| label="Your Video Description", | |
| placeholder="e.g., POV: trying the new grimace shake #viral", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Analyze & Optimize β‘", variant="primary") | |
| demo_btn = gr.Button("π² Try Demo", variant="secondary") | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### π Predictions") | |
| initial_views = gr.Textbox(label="Predicted Views (Original)", interactive=False) | |
| with gr.Group(): | |
| gr.Markdown("### β¨ AI Optimization") | |
| improved_text = gr.Textbox(label="Improved Description", interactive=False) | |
| with gr.Row(): | |
| new_views = gr.Textbox(label="New Predicted Views", interactive=False) | |
| uplift = gr.Textbox(label="Potential Uplift", interactive=False) | |
| with gr.Accordion("π Similar Viral Videos (Reference)", open=False): | |
| similar_videos = gr.Textbox(label="Top 3 Context Matches", interactive=False, lines=5) | |
| submit_btn.click( | |
| fn=predict_and_optimize, | |
| inputs=[input_text], | |
| outputs=[initial_views, similar_videos, improved_text, new_views, uplift] | |
| ) | |
| # Demo Button Logic: 1. Fill Text -> 2. Run Prediction | |
| demo_text = "POV: You realize you forgot to turn off your mic during the all-hands meeting π #fail #fyp #corporate" | |
| demo_btn.click( | |
| fn=lambda: demo_text, | |
| inputs=None, | |
| outputs=input_text | |
| ).then( | |
| fn=predict_and_optimize, | |
| inputs=gr.State(demo_text), # Pass directly to avoid race condition with UI update | |
| outputs=[initial_views, similar_videos, improved_text, new_views, uplift] | |
| ) | |
| # Run initialization | |
| if __name__ == "__main__": | |
| initialize_app() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |