import gradio as gr import pandas as pd import numpy as np import os import google.generativeai as genai from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity from sklearn.preprocessing import LabelEncoder from dotenv import load_dotenv from datasets import load_dataset from huggingface_hub import hf_hub_download import pickle # Load environment variables load_dotenv() # --- GLOBAL STATE --- MODEL = None KNOWLEDGE_DF = None ST_MODEL = None ENCODERS = {} # To store label encoders def initialize_app(): """Initializes the model and data on app startup from Hugging Face.""" global MODEL, KNOWLEDGE_DF, ST_MODEL, ENCODERS print("⏳ initializing app: Loading resources from Hugging Face Cloud...") # 1. Load Dataset from HF print("📂 Downloading Dataset (MatanKriel/social-assitent-synthetic-data)...") try: dataset = load_dataset("MatanKriel/social-assitent-synthetic-data") if 'train' in dataset: knowledge_df = dataset['train'].to_pandas() else: knowledge_df = dataset.to_pandas() print(f" -> Loaded {len(knowledge_df)} rows.") except Exception as e: raise RuntimeError(f"Failed to load dataset from HF: {e}") # 2. FIT ENCODERS (For Feature Consistency) print("🔤 Fitting Label Encoders...") # UPDATED: 'age' removed from here, treated as numeric cat_cols = ['category', 'gender', 'day_of_week'] for c in cat_cols: if c in knowledge_df.columns: le = LabelEncoder() # Ensure all values are strings le.fit(knowledge_df[c].astype(str)) ENCODERS[c] = le print(f" -> Encoded '{c}': {len(le.classes_)} classes") else: print(f" ⚠️ Warning: Column '{c}' missing from dataset!") # 3. Load Model from HF print("🧠 Downloading Model (MatanKriel/social-assitent-viral-predictor)...") try: model_path = hf_hub_download(repo_id="MatanKriel/social-assitent-viral-predictor", filename="viral_model.pkl") with open(model_path, "rb") as f: model = pickle.load(f) print(f" -> Loaded model: {type(model).__name__}") except Exception as e: # Fallback to local if os.path.exists("viral_model.pkl"): print(f" ⚠️ HF Download failed ({e}). Loading local 'viral_model.pkl' instead.") with open("viral_model.pkl", "rb") as f: model = pickle.load(f) else: raise RuntimeError(f"Failed to load model from HF and no local backup found: {e}") # 4. Load SentenceTransformer print("🔌 Loading SentenceTransformer...") embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2" print(f" -> Model: {embedding_model_name}") import torch device = "mps" if torch.backends.mps.is_available() else "cpu" st_model = SentenceTransformer(embedding_model_name, device=device) # 5. Generate Embeddings print("⚙️ Generating Embeddings for Knowledge Base (One-time setup)...") if 'embedding' not in knowledge_df.columns: embeddings = st_model.encode(knowledge_df['description'].fillna("").tolist(), convert_to_numpy=True, show_progress_bar=True) knowledge_df['embedding'] = list(embeddings) else: print(" -> Embeddings already present in dataset.") MODEL = model KNOWLEDGE_DF = knowledge_df ST_MODEL = st_model print("✅ App initialized (Inference Mode)!") def predict_and_optimize(user_input, duration, hour, day_of_week, category, followers, age, gender): if not user_input: return "Please enter a video description.", "", "", "", "" # --- 1. ENCODE INPUTS --- try: # Helper to encode safely def safe_encode(col, val): le = ENCODERS.get(col) if le: # If value not seen, default to first class if val in le.classes_: return le.transform([val])[0] else: return 0 return 0 cat_encoded = safe_encode('category', category) gender_encoded = safe_encode('gender', gender) day_encoded = safe_encode('day_of_week', day_of_week) # FIX: Map Age String to Numeric age_map = { "18-24": 21.0, "25-34": 30.0, "35-44": 40.0, "45+": 50.0 } age_numeric = age_map.get(str(age), 25.0) # Default to 25 if unknown except Exception as e: return f"Encoding Error: {str(e)}", "", "", "", "" # --- 2. INITIAL PREDICTION --- # Feature Order MUST match model-prep.py: # Embeddings + [duration, hour, followers, age_numeric, category_enc, gender_enc, day_enc] text_vec = ST_MODEL.encode([user_input], convert_to_numpy=True) # Construct metadata vector meta_vec = np.array([[duration, hour, followers, age_numeric, cat_encoded, gender_encoded, day_encoded]]) feat_vec = np.hstack((text_vec, meta_vec)) initial_log = MODEL.predict(feat_vec)[0] initial_views = int(np.expm1(initial_log)) # --- 3. VECTOR SEARCH --- high_perf_df = KNOWLEDGE_DF[KNOWLEDGE_DF['views'] > KNOWLEDGE_DF['views'].quantile(0.75)].copy() user_embedding = ST_MODEL.encode([user_input], convert_to_numpy=True) target_embeddings = np.stack(high_perf_df['embedding'].values) similarities = cosine_similarity(user_embedding, target_embeddings) top_3_indices = similarities[0].argsort()[-3:][::-1] top_3_videos = high_perf_df.iloc[top_3_indices]['description'].tolist() similar_videos_str = "\n\n".join([f"{i+1}. {v}" for i, v in enumerate(top_3_videos)]) # --- 4. GEMINI OPTIMIZATION --- api_key = os.getenv("GEMINI_API_KEY") if not api_key: return f"{initial_views:,}", similar_videos_str, "Error: GEMINI_API_KEY not found.", "N/A", "N/A" genai.configure(api_key=api_key) try: llm = genai.GenerativeModel('gemini-2.5-flash-lite') except: llm = genai.GenerativeModel('gemini-1.5-flash') prompt = f""" You are a TikTok Virality Expert. Draft: "{user_input}" Niche: {category} | Creator: {age}, {gender} with {followers} followers. Context: {duration}s video posted on {day_of_week} at {hour}:00. Viral Examples in this niche: 1. {top_3_videos[0]} 2. {top_3_videos[1]} 3. {top_3_videos[2]} Task: Rewrite the draft to be more viral. Add hooks and hashtags two hashtags MAX. Keep it natural and relevant to the creator persona. Make it short catchy and viral. Output Format: [New Description] Recomendations: [Upload day of week] [Upload hour] [Upload duration] """ try: response = llm.generate_content(prompt) improved_idea = response.text.strip() # --- 5. RE-SCORING --- new_text_vec = ST_MODEL.encode([improved_idea], convert_to_numpy=True) # Using same metadata for the new prediction new_feat_vec = np.hstack((new_text_vec, meta_vec)) new_log = MODEL.predict(new_feat_vec)[0] new_views = int(np.expm1(new_log)) uplift_pct = ((new_views - initial_views) / initial_views) * 100 uplift_str = f"+{uplift_pct:.1f}%" if uplift_pct > 0 else "No significant uplift" return f"{initial_views:,}", similar_videos_str, improved_idea, f"{new_views:,}", uplift_str except Exception as e: return f"{initial_views:,}", similar_videos_str, f"Error calling AI: {str(e)}", "N/A", "N/A" # --- GRADIO UI --- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🚀 Viral Content Optimizer") gr.Markdown("Enter your video idea and stats to predict views and get AI-powered optimizations based on 2025 trends.") with gr.Row(): with gr.Column(scale=1): input_text = gr.Textbox( label="Your Video Description", placeholder="e.g., POV: trying the new grimace shake #viral", lines=3 ) with gr.Row(): category = gr.Dropdown( choices=["Entertainment", "Gaming", "Fitness", "Food", "Beauty", "Tech", "Travel", "Education", "Fashion", "Health", "DIY", "Pranks"], value="Entertainment", label="Niche" ) followers = gr.Number(value=1000, label="Follower Count", precision=0) with gr.Row(): age = gr.Dropdown(choices=["18-24", "25-34", "35-44", "45+"], value="18-24", label="Creator Age") gender = gr.Dropdown(choices=["Male", "Female"], value="Female", label="Creator Gender") with gr.Row(): duration_slider = gr.Slider(minimum=5, maximum=60, value=15, step=1, label="Duration (s)") hour_slider = gr.Slider(minimum=0, maximum=23, value=18, step=1, label="Upload Hour") day_dropdown = gr.Dropdown( choices=["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"], value="Friday", label="Day" ) with gr.Row(): submit_btn = gr.Button("Analyze & Optimize ⚡", variant="primary") demo_btn = gr.Button("🎲 Try Demo", variant="secondary") with gr.Column(scale=1): with gr.Group(): gr.Markdown("### 📊 Predictions") initial_views = gr.Textbox(label="Predicted Views (Original)", interactive=False) with gr.Group(): gr.Markdown("### ✨ AI Optimization") improved_text = gr.Textbox(label="Improved Description", interactive=False) with gr.Row(): new_views = gr.Textbox(label="New Predicted Views", interactive=False) uplift = gr.Textbox(label="Potential Uplift", interactive=False) with gr.Accordion("🔍 Similar Viral Videos (Reference)", open=False): similar_videos = gr.Textbox(label="Top 3 Context Matches", interactive=False, lines=5) submit_btn.click( fn=predict_and_optimize, inputs=[input_text, duration_slider, hour_slider, day_dropdown, category, followers, age, gender], outputs=[initial_views, similar_videos, improved_text, new_views, uplift] ) # Demo Button Logic demo_btn.click( fn=lambda: ("My protein shake ended up on the floor", 15, 19, "Monday", "Fitness", 50000, "18-24", "Male"), inputs=None, outputs=[input_text, duration_slider, hour_slider, day_dropdown, category, followers, age, gender] ).then( fn=predict_and_optimize, inputs=[gr.State("My protein shake ended up on the floor"), gr.State(15), gr.State(19), gr.State("Monday"), gr.State("Fitness"), gr.State(50000), gr.State("18-24"), gr.State("Male")], outputs=[initial_views, similar_videos, improved_text, new_views, uplift] ) # Run initialization if __name__ == "__main__": initialize_app() demo.launch(server_name="0.0.0.0", server_port=7860)