Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import os | |
| import google.generativeai as genai | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from sklearn.preprocessing import LabelEncoder | |
| from dotenv import load_dotenv | |
| from datasets import load_dataset | |
| from huggingface_hub import hf_hub_download | |
| import pickle | |
| # Load environment variables | |
| load_dotenv() | |
| # --- GLOBAL STATE --- | |
| MODEL = None | |
| KNOWLEDGE_DF = None | |
| ST_MODEL = None | |
| ENCODERS = {} # To store label encoders | |
| def initialize_app(): | |
| """Initializes the model and data on app startup from Hugging Face.""" | |
| global MODEL, KNOWLEDGE_DF, ST_MODEL, ENCODERS | |
| print("β³ initializing app: Loading resources from Hugging Face Cloud...") | |
| # 1. Load Dataset from HF | |
| print("π Downloading Dataset (MatanKriel/social-assitent-synthetic-data)...") | |
| try: | |
| dataset = load_dataset("MatanKriel/social-assitent-synthetic-data") | |
| if 'train' in dataset: | |
| knowledge_df = dataset['train'].to_pandas() | |
| else: | |
| knowledge_df = dataset.to_pandas() | |
| print(f" -> Loaded {len(knowledge_df)} rows.") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load dataset from HF: {e}") | |
| # 2. FIT ENCODERS (For Feature Consistency) | |
| print("π€ Fitting Label Encoders...") | |
| # UPDATED: 'age' removed from here, treated as numeric | |
| cat_cols = ['category', 'gender', 'day_of_week'] | |
| for c in cat_cols: | |
| if c in knowledge_df.columns: | |
| le = LabelEncoder() | |
| # Ensure all values are strings | |
| le.fit(knowledge_df[c].astype(str)) | |
| ENCODERS[c] = le | |
| print(f" -> Encoded '{c}': {len(le.classes_)} classes") | |
| else: | |
| print(f" β οΈ Warning: Column '{c}' missing from dataset!") | |
| # 3. Load Model from HF | |
| print("π§ Downloading Model (MatanKriel/social-assitent-viral-predictor)...") | |
| try: | |
| model_path = hf_hub_download(repo_id="MatanKriel/social-assitent-viral-predictor", filename="viral_model.pkl") | |
| with open(model_path, "rb") as f: | |
| model = pickle.load(f) | |
| print(f" -> Loaded model: {type(model).__name__}") | |
| except Exception as e: | |
| # Fallback to local | |
| if os.path.exists("viral_model.pkl"): | |
| print(f" β οΈ HF Download failed ({e}). Loading local 'viral_model.pkl' instead.") | |
| with open("viral_model.pkl", "rb") as f: | |
| model = pickle.load(f) | |
| else: | |
| raise RuntimeError(f"Failed to load model from HF and no local backup found: {e}") | |
| # 4. Load SentenceTransformer | |
| print("π Loading SentenceTransformer...") | |
| embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
| print(f" -> Model: {embedding_model_name}") | |
| import torch | |
| device = "mps" if torch.backends.mps.is_available() else "cpu" | |
| st_model = SentenceTransformer(embedding_model_name, device=device) | |
| # 5. Generate Embeddings | |
| print("βοΈ Generating Embeddings for Knowledge Base (One-time setup)...") | |
| if 'embedding' not in knowledge_df.columns: | |
| embeddings = st_model.encode(knowledge_df['description'].fillna("").tolist(), | |
| convert_to_numpy=True, | |
| show_progress_bar=True) | |
| knowledge_df['embedding'] = list(embeddings) | |
| else: | |
| print(" -> Embeddings already present in dataset.") | |
| MODEL = model | |
| KNOWLEDGE_DF = knowledge_df | |
| ST_MODEL = st_model | |
| print("β App initialized (Inference Mode)!") | |
| def predict_and_optimize(user_input, duration, hour, day_of_week, category, followers, age, gender): | |
| if not user_input: | |
| return "Please enter a video description.", "", "", "", "" | |
| # --- 1. ENCODE INPUTS --- | |
| try: | |
| # Helper to encode safely | |
| def safe_encode(col, val): | |
| le = ENCODERS.get(col) | |
| if le: | |
| # If value not seen, default to first class | |
| if val in le.classes_: | |
| return le.transform([val])[0] | |
| else: | |
| return 0 | |
| return 0 | |
| cat_encoded = safe_encode('category', category) | |
| gender_encoded = safe_encode('gender', gender) | |
| day_encoded = safe_encode('day_of_week', day_of_week) | |
| # FIX: Map Age String to Numeric | |
| age_map = { | |
| "18-24": 21.0, | |
| "25-34": 30.0, | |
| "35-44": 40.0, | |
| "45+": 50.0 | |
| } | |
| age_numeric = age_map.get(str(age), 25.0) # Default to 25 if unknown | |
| except Exception as e: | |
| return f"Encoding Error: {str(e)}", "", "", "", "" | |
| # --- 2. INITIAL PREDICTION --- | |
| # Feature Order MUST match model-prep.py: | |
| # Embeddings + [duration, hour, followers, age_numeric, category_enc, gender_enc, day_enc] | |
| text_vec = ST_MODEL.encode([user_input], convert_to_numpy=True) | |
| # Construct metadata vector | |
| meta_vec = np.array([[duration, hour, followers, age_numeric, cat_encoded, gender_encoded, day_encoded]]) | |
| feat_vec = np.hstack((text_vec, meta_vec)) | |
| initial_log = MODEL.predict(feat_vec)[0] | |
| initial_views = int(np.expm1(initial_log)) | |
| # --- 3. VECTOR SEARCH --- | |
| high_perf_df = KNOWLEDGE_DF[KNOWLEDGE_DF['views'] > KNOWLEDGE_DF['views'].quantile(0.75)].copy() | |
| user_embedding = ST_MODEL.encode([user_input], convert_to_numpy=True) | |
| target_embeddings = np.stack(high_perf_df['embedding'].values) | |
| similarities = cosine_similarity(user_embedding, target_embeddings) | |
| top_3_indices = similarities[0].argsort()[-3:][::-1] | |
| top_3_videos = high_perf_df.iloc[top_3_indices]['description'].tolist() | |
| similar_videos_str = "\n\n".join([f"{i+1}. {v}" for i, v in enumerate(top_3_videos)]) | |
| # --- 4. GEMINI OPTIMIZATION --- | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| return f"{initial_views:,}", similar_videos_str, "Error: GEMINI_API_KEY not found.", "N/A", "N/A" | |
| genai.configure(api_key=api_key) | |
| try: | |
| llm = genai.GenerativeModel('gemini-2.5-flash-lite') | |
| except: | |
| llm = genai.GenerativeModel('gemini-1.5-flash') | |
| prompt = f""" | |
| You are a TikTok Virality Expert. | |
| Draft: "{user_input}" | |
| Niche: {category} | Creator: {age}, {gender} with {followers} followers. | |
| Context: {duration}s video posted on {day_of_week} at {hour}:00. | |
| Viral Examples in this niche: | |
| 1. {top_3_videos[0]} | |
| 2. {top_3_videos[1]} | |
| 3. {top_3_videos[2]} | |
| Task: | |
| Rewrite the draft to be more viral. Add hooks and hashtags two hashtags MAX. | |
| Keep it natural and relevant to the creator persona. | |
| Make it short catchy and viral. | |
| Output Format: | |
| [New Description] | |
| Recomendations: | |
| [Upload day of week] | |
| [Upload hour] | |
| [Upload duration] | |
| """ | |
| try: | |
| response = llm.generate_content(prompt) | |
| improved_idea = response.text.strip() | |
| # --- 5. RE-SCORING --- | |
| new_text_vec = ST_MODEL.encode([improved_idea], convert_to_numpy=True) | |
| # Using same metadata for the new prediction | |
| new_feat_vec = np.hstack((new_text_vec, meta_vec)) | |
| new_log = MODEL.predict(new_feat_vec)[0] | |
| new_views = int(np.expm1(new_log)) | |
| uplift_pct = ((new_views - initial_views) / initial_views) * 100 | |
| uplift_str = f"+{uplift_pct:.1f}%" if uplift_pct > 0 else "No significant uplift" | |
| return f"{initial_views:,}", similar_videos_str, improved_idea, f"{new_views:,}", uplift_str | |
| except Exception as e: | |
| return f"{initial_views:,}", similar_videos_str, f"Error calling AI: {str(e)}", "N/A", "N/A" | |
| # --- GRADIO UI --- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π Viral Content Optimizer") | |
| gr.Markdown("Enter your video idea and stats to predict views and get AI-powered optimizations based on 2025 trends.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_text = gr.Textbox( | |
| label="Your Video Description", | |
| placeholder="e.g., POV: trying the new grimace shake #viral", | |
| lines=3 | |
| ) | |
| with gr.Row(): | |
| category = gr.Dropdown( | |
| choices=["Entertainment", "Gaming", "Fitness", "Food", "Beauty", "Tech", "Travel", "Education", "Fashion", "Health", "DIY", "Pranks"], | |
| value="Entertainment", label="Niche" | |
| ) | |
| followers = gr.Number(value=1000, label="Follower Count", precision=0) | |
| with gr.Row(): | |
| age = gr.Dropdown(choices=["18-24", "25-34", "35-44", "45+"], value="18-24", label="Creator Age") | |
| gender = gr.Dropdown(choices=["Male", "Female"], value="Female", label="Creator Gender") | |
| with gr.Row(): | |
| duration_slider = gr.Slider(minimum=5, maximum=60, value=15, step=1, label="Duration (s)") | |
| hour_slider = gr.Slider(minimum=0, maximum=23, value=18, step=1, label="Upload Hour") | |
| day_dropdown = gr.Dropdown( | |
| choices=["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"], | |
| value="Friday", label="Day" | |
| ) | |
| with gr.Row(): | |
| submit_btn = gr.Button("Analyze & Optimize β‘", variant="primary") | |
| demo_btn = gr.Button("π² Try Demo", variant="secondary") | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### π Predictions") | |
| initial_views = gr.Textbox(label="Predicted Views (Original)", interactive=False) | |
| with gr.Group(): | |
| gr.Markdown("### β¨ AI Optimization") | |
| improved_text = gr.Textbox(label="Improved Description", interactive=False) | |
| with gr.Row(): | |
| new_views = gr.Textbox(label="New Predicted Views", interactive=False) | |
| uplift = gr.Textbox(label="Potential Uplift", interactive=False) | |
| with gr.Accordion("π Similar Viral Videos (Reference)", open=False): | |
| similar_videos = gr.Textbox(label="Top 3 Context Matches", interactive=False, lines=5) | |
| submit_btn.click( | |
| fn=predict_and_optimize, | |
| inputs=[input_text, duration_slider, hour_slider, day_dropdown, category, followers, age, gender], | |
| outputs=[initial_views, similar_videos, improved_text, new_views, uplift] | |
| ) | |
| # Demo Button Logic | |
| demo_btn.click( | |
| fn=lambda: ("My protein shake ended up on the floor", 15, 19, "Monday", "Fitness", 50000, "18-24", "Male"), | |
| inputs=None, | |
| outputs=[input_text, duration_slider, hour_slider, day_dropdown, category, followers, age, gender] | |
| ).then( | |
| fn=predict_and_optimize, | |
| inputs=[gr.State("My protein shake ended up on the floor"), gr.State(15), gr.State(19), gr.State("Monday"), gr.State("Fitness"), gr.State(50000), gr.State("18-24"), gr.State("Male")], | |
| outputs=[initial_views, similar_videos, improved_text, new_views, uplift] | |
| ) | |
| # Run initialization | |
| if __name__ == "__main__": | |
| initialize_app() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |