Matan Kriel
updated clustering metric in model test
2f9170f
import gradio as gr
import pandas as pd
import numpy as np
import os
import google.generativeai as genai
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import LabelEncoder
from dotenv import load_dotenv
from datasets import load_dataset
from huggingface_hub import hf_hub_download
import pickle
# Load environment variables
load_dotenv()
# --- GLOBAL STATE ---
MODEL = None
KNOWLEDGE_DF = None
ST_MODEL = None
ENCODERS = {} # To store label encoders
def initialize_app():
"""Initializes the model and data on app startup from Hugging Face."""
global MODEL, KNOWLEDGE_DF, ST_MODEL, ENCODERS
print("⏳ initializing app: Loading resources from Hugging Face Cloud...")
# 1. Load Dataset from HF
print("πŸ“‚ Downloading Dataset (MatanKriel/social-assitent-synthetic-data)...")
try:
dataset = load_dataset("MatanKriel/social-assitent-synthetic-data")
if 'train' in dataset:
knowledge_df = dataset['train'].to_pandas()
else:
knowledge_df = dataset.to_pandas()
print(f" -> Loaded {len(knowledge_df)} rows.")
except Exception as e:
raise RuntimeError(f"Failed to load dataset from HF: {e}")
# 2. FIT ENCODERS (For Feature Consistency)
print("πŸ”€ Fitting Label Encoders...")
# UPDATED: 'age' removed from here, treated as numeric
cat_cols = ['category', 'gender', 'day_of_week']
for c in cat_cols:
if c in knowledge_df.columns:
le = LabelEncoder()
# Ensure all values are strings
le.fit(knowledge_df[c].astype(str))
ENCODERS[c] = le
print(f" -> Encoded '{c}': {len(le.classes_)} classes")
else:
print(f" ⚠️ Warning: Column '{c}' missing from dataset!")
# 3. Load Model from HF
print("🧠 Downloading Model (MatanKriel/social-assitent-viral-predictor)...")
try:
model_path = hf_hub_download(repo_id="MatanKriel/social-assitent-viral-predictor", filename="viral_model.pkl")
with open(model_path, "rb") as f:
model = pickle.load(f)
print(f" -> Loaded model: {type(model).__name__}")
except Exception as e:
# Fallback to local
if os.path.exists("viral_model.pkl"):
print(f" ⚠️ HF Download failed ({e}). Loading local 'viral_model.pkl' instead.")
with open("viral_model.pkl", "rb") as f:
model = pickle.load(f)
else:
raise RuntimeError(f"Failed to load model from HF and no local backup found: {e}")
# 4. Load SentenceTransformer
print("πŸ”Œ Loading SentenceTransformer...")
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
print(f" -> Model: {embedding_model_name}")
import torch
device = "mps" if torch.backends.mps.is_available() else "cpu"
st_model = SentenceTransformer(embedding_model_name, device=device)
# 5. Generate Embeddings
print("βš™οΈ Generating Embeddings for Knowledge Base (One-time setup)...")
if 'embedding' not in knowledge_df.columns:
embeddings = st_model.encode(knowledge_df['description'].fillna("").tolist(),
convert_to_numpy=True,
show_progress_bar=True)
knowledge_df['embedding'] = list(embeddings)
else:
print(" -> Embeddings already present in dataset.")
MODEL = model
KNOWLEDGE_DF = knowledge_df
ST_MODEL = st_model
print("βœ… App initialized (Inference Mode)!")
def predict_and_optimize(user_input, duration, hour, day_of_week, category, followers, age, gender):
if not user_input:
return "Please enter a video description.", "", "", "", ""
# --- 1. ENCODE INPUTS ---
try:
# Helper to encode safely
def safe_encode(col, val):
le = ENCODERS.get(col)
if le:
# If value not seen, default to first class
if val in le.classes_:
return le.transform([val])[0]
else:
return 0
return 0
cat_encoded = safe_encode('category', category)
gender_encoded = safe_encode('gender', gender)
day_encoded = safe_encode('day_of_week', day_of_week)
# FIX: Map Age String to Numeric
age_map = {
"18-24": 21.0,
"25-34": 30.0,
"35-44": 40.0,
"45+": 50.0
}
age_numeric = age_map.get(str(age), 25.0) # Default to 25 if unknown
except Exception as e:
return f"Encoding Error: {str(e)}", "", "", "", ""
# --- 2. INITIAL PREDICTION ---
# Feature Order MUST match model-prep.py:
# Embeddings + [duration, hour, followers, age_numeric, category_enc, gender_enc, day_enc]
text_vec = ST_MODEL.encode([user_input], convert_to_numpy=True)
# Construct metadata vector
meta_vec = np.array([[duration, hour, followers, age_numeric, cat_encoded, gender_encoded, day_encoded]])
feat_vec = np.hstack((text_vec, meta_vec))
initial_log = MODEL.predict(feat_vec)[0]
initial_views = int(np.expm1(initial_log))
# --- 3. VECTOR SEARCH ---
high_perf_df = KNOWLEDGE_DF[KNOWLEDGE_DF['views'] > KNOWLEDGE_DF['views'].quantile(0.75)].copy()
user_embedding = ST_MODEL.encode([user_input], convert_to_numpy=True)
target_embeddings = np.stack(high_perf_df['embedding'].values)
similarities = cosine_similarity(user_embedding, target_embeddings)
top_3_indices = similarities[0].argsort()[-3:][::-1]
top_3_videos = high_perf_df.iloc[top_3_indices]['description'].tolist()
similar_videos_str = "\n\n".join([f"{i+1}. {v}" for i, v in enumerate(top_3_videos)])
# --- 4. GEMINI OPTIMIZATION ---
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
return f"{initial_views:,}", similar_videos_str, "Error: GEMINI_API_KEY not found.", "N/A", "N/A"
genai.configure(api_key=api_key)
try:
llm = genai.GenerativeModel('gemini-2.5-flash-lite')
except:
llm = genai.GenerativeModel('gemini-1.5-flash')
prompt = f"""
You are a TikTok Virality Expert.
Draft: "{user_input}"
Niche: {category} | Creator: {age}, {gender} with {followers} followers.
Context: {duration}s video posted on {day_of_week} at {hour}:00.
Viral Examples in this niche:
1. {top_3_videos[0]}
2. {top_3_videos[1]}
3. {top_3_videos[2]}
Task:
Rewrite the draft to be more viral. Add hooks and hashtags two hashtags MAX.
Keep it natural and relevant to the creator persona.
Make it short catchy and viral.
Output Format:
[New Description]
Recomendations:
[Upload day of week]
[Upload hour]
[Upload duration]
"""
try:
response = llm.generate_content(prompt)
improved_idea = response.text.strip()
# --- 5. RE-SCORING ---
new_text_vec = ST_MODEL.encode([improved_idea], convert_to_numpy=True)
# Using same metadata for the new prediction
new_feat_vec = np.hstack((new_text_vec, meta_vec))
new_log = MODEL.predict(new_feat_vec)[0]
new_views = int(np.expm1(new_log))
uplift_pct = ((new_views - initial_views) / initial_views) * 100
uplift_str = f"+{uplift_pct:.1f}%" if uplift_pct > 0 else "No significant uplift"
return f"{initial_views:,}", similar_videos_str, improved_idea, f"{new_views:,}", uplift_str
except Exception as e:
return f"{initial_views:,}", similar_videos_str, f"Error calling AI: {str(e)}", "N/A", "N/A"
# --- GRADIO UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸš€ Viral Content Optimizer")
gr.Markdown("Enter your video idea and stats to predict views and get AI-powered optimizations based on 2025 trends.")
with gr.Row():
with gr.Column(scale=1):
input_text = gr.Textbox(
label="Your Video Description",
placeholder="e.g., POV: trying the new grimace shake #viral",
lines=3
)
with gr.Row():
category = gr.Dropdown(
choices=["Entertainment", "Gaming", "Fitness", "Food", "Beauty", "Tech", "Travel", "Education", "Fashion", "Health", "DIY", "Pranks"],
value="Entertainment", label="Niche"
)
followers = gr.Number(value=1000, label="Follower Count", precision=0)
with gr.Row():
age = gr.Dropdown(choices=["18-24", "25-34", "35-44", "45+"], value="18-24", label="Creator Age")
gender = gr.Dropdown(choices=["Male", "Female"], value="Female", label="Creator Gender")
with gr.Row():
duration_slider = gr.Slider(minimum=5, maximum=60, value=15, step=1, label="Duration (s)")
hour_slider = gr.Slider(minimum=0, maximum=23, value=18, step=1, label="Upload Hour")
day_dropdown = gr.Dropdown(
choices=["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"],
value="Friday", label="Day"
)
with gr.Row():
submit_btn = gr.Button("Analyze & Optimize ⚑", variant="primary")
demo_btn = gr.Button("🎲 Try Demo", variant="secondary")
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### πŸ“Š Predictions")
initial_views = gr.Textbox(label="Predicted Views (Original)", interactive=False)
with gr.Group():
gr.Markdown("### ✨ AI Optimization")
improved_text = gr.Textbox(label="Improved Description", interactive=False)
with gr.Row():
new_views = gr.Textbox(label="New Predicted Views", interactive=False)
uplift = gr.Textbox(label="Potential Uplift", interactive=False)
with gr.Accordion("πŸ” Similar Viral Videos (Reference)", open=False):
similar_videos = gr.Textbox(label="Top 3 Context Matches", interactive=False, lines=5)
submit_btn.click(
fn=predict_and_optimize,
inputs=[input_text, duration_slider, hour_slider, day_dropdown, category, followers, age, gender],
outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
)
# Demo Button Logic
demo_btn.click(
fn=lambda: ("My protein shake ended up on the floor", 15, 19, "Monday", "Fitness", 50000, "18-24", "Male"),
inputs=None,
outputs=[input_text, duration_slider, hour_slider, day_dropdown, category, followers, age, gender]
).then(
fn=predict_and_optimize,
inputs=[gr.State("My protein shake ended up on the floor"), gr.State(15), gr.State(19), gr.State("Monday"), gr.State("Fitness"), gr.State(50000), gr.State("18-24"), gr.State("Male")],
outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
)
# Run initialization
if __name__ == "__main__":
initialize_app()
demo.launch(server_name="0.0.0.0", server_port=7860)