Spaces:
Sleeping
Sleeping
File size: 11,439 Bytes
58ec65e 33daa3c 58ec65e 33daa3c 58ec65e 33daa3c 58ec65e 33daa3c bb866b9 33daa3c 7803d6a 33daa3c 58ec65e 2f9170f 33daa3c 58ec65e 33daa3c bb866b9 33daa3c 58ec65e 33daa3c 58ec65e 33daa3c 7803d6a 33daa3c 7803d6a 33daa3c 7803d6a 33daa3c 7803d6a 33daa3c f56b40a 7803d6a f56b40a 58ec65e 33daa3c 58ec65e 33daa3c 58ec65e 33daa3c 58ec65e 33daa3c 58ec65e 18699f2 a840261 33daa3c a840261 18699f2 a840261 58ec65e a840261 58ec65e 33daa3c 7803d6a 33daa3c 58ec65e 7803d6a 58ec65e 7803d6a 58ec65e f56b40a 33daa3c f56b40a 33daa3c f56b40a 58ec65e 33daa3c 58ec65e 33daa3c 58ec65e bb866b9 58ec65e 33daa3c 58ec65e bb866b9 58ec65e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
import gradio as gr
import pandas as pd
import numpy as np
import os
import google.generativeai as genai
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import LabelEncoder
from dotenv import load_dotenv
from datasets import load_dataset
from huggingface_hub import hf_hub_download
import pickle
# Load environment variables
load_dotenv()
# --- GLOBAL STATE ---
MODEL = None
KNOWLEDGE_DF = None
ST_MODEL = None
ENCODERS = {} # To store label encoders
def initialize_app():
"""Initializes the model and data on app startup from Hugging Face."""
global MODEL, KNOWLEDGE_DF, ST_MODEL, ENCODERS
print("β³ initializing app: Loading resources from Hugging Face Cloud...")
# 1. Load Dataset from HF
print("π Downloading Dataset (MatanKriel/social-assitent-synthetic-data)...")
try:
dataset = load_dataset("MatanKriel/social-assitent-synthetic-data")
if 'train' in dataset:
knowledge_df = dataset['train'].to_pandas()
else:
knowledge_df = dataset.to_pandas()
print(f" -> Loaded {len(knowledge_df)} rows.")
except Exception as e:
raise RuntimeError(f"Failed to load dataset from HF: {e}")
# 2. FIT ENCODERS (For Feature Consistency)
print("π€ Fitting Label Encoders...")
# UPDATED: 'age' removed from here, treated as numeric
cat_cols = ['category', 'gender', 'day_of_week']
for c in cat_cols:
if c in knowledge_df.columns:
le = LabelEncoder()
# Ensure all values are strings
le.fit(knowledge_df[c].astype(str))
ENCODERS[c] = le
print(f" -> Encoded '{c}': {len(le.classes_)} classes")
else:
print(f" β οΈ Warning: Column '{c}' missing from dataset!")
# 3. Load Model from HF
print("π§ Downloading Model (MatanKriel/social-assitent-viral-predictor)...")
try:
model_path = hf_hub_download(repo_id="MatanKriel/social-assitent-viral-predictor", filename="viral_model.pkl")
with open(model_path, "rb") as f:
model = pickle.load(f)
print(f" -> Loaded model: {type(model).__name__}")
except Exception as e:
# Fallback to local
if os.path.exists("viral_model.pkl"):
print(f" β οΈ HF Download failed ({e}). Loading local 'viral_model.pkl' instead.")
with open("viral_model.pkl", "rb") as f:
model = pickle.load(f)
else:
raise RuntimeError(f"Failed to load model from HF and no local backup found: {e}")
# 4. Load SentenceTransformer
print("π Loading SentenceTransformer...")
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
print(f" -> Model: {embedding_model_name}")
import torch
device = "mps" if torch.backends.mps.is_available() else "cpu"
st_model = SentenceTransformer(embedding_model_name, device=device)
# 5. Generate Embeddings
print("βοΈ Generating Embeddings for Knowledge Base (One-time setup)...")
if 'embedding' not in knowledge_df.columns:
embeddings = st_model.encode(knowledge_df['description'].fillna("").tolist(),
convert_to_numpy=True,
show_progress_bar=True)
knowledge_df['embedding'] = list(embeddings)
else:
print(" -> Embeddings already present in dataset.")
MODEL = model
KNOWLEDGE_DF = knowledge_df
ST_MODEL = st_model
print("β
App initialized (Inference Mode)!")
def predict_and_optimize(user_input, duration, hour, day_of_week, category, followers, age, gender):
if not user_input:
return "Please enter a video description.", "", "", "", ""
# --- 1. ENCODE INPUTS ---
try:
# Helper to encode safely
def safe_encode(col, val):
le = ENCODERS.get(col)
if le:
# If value not seen, default to first class
if val in le.classes_:
return le.transform([val])[0]
else:
return 0
return 0
cat_encoded = safe_encode('category', category)
gender_encoded = safe_encode('gender', gender)
day_encoded = safe_encode('day_of_week', day_of_week)
# FIX: Map Age String to Numeric
age_map = {
"18-24": 21.0,
"25-34": 30.0,
"35-44": 40.0,
"45+": 50.0
}
age_numeric = age_map.get(str(age), 25.0) # Default to 25 if unknown
except Exception as e:
return f"Encoding Error: {str(e)}", "", "", "", ""
# --- 2. INITIAL PREDICTION ---
# Feature Order MUST match model-prep.py:
# Embeddings + [duration, hour, followers, age_numeric, category_enc, gender_enc, day_enc]
text_vec = ST_MODEL.encode([user_input], convert_to_numpy=True)
# Construct metadata vector
meta_vec = np.array([[duration, hour, followers, age_numeric, cat_encoded, gender_encoded, day_encoded]])
feat_vec = np.hstack((text_vec, meta_vec))
initial_log = MODEL.predict(feat_vec)[0]
initial_views = int(np.expm1(initial_log))
# --- 3. VECTOR SEARCH ---
high_perf_df = KNOWLEDGE_DF[KNOWLEDGE_DF['views'] > KNOWLEDGE_DF['views'].quantile(0.75)].copy()
user_embedding = ST_MODEL.encode([user_input], convert_to_numpy=True)
target_embeddings = np.stack(high_perf_df['embedding'].values)
similarities = cosine_similarity(user_embedding, target_embeddings)
top_3_indices = similarities[0].argsort()[-3:][::-1]
top_3_videos = high_perf_df.iloc[top_3_indices]['description'].tolist()
similar_videos_str = "\n\n".join([f"{i+1}. {v}" for i, v in enumerate(top_3_videos)])
# --- 4. GEMINI OPTIMIZATION ---
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
return f"{initial_views:,}", similar_videos_str, "Error: GEMINI_API_KEY not found.", "N/A", "N/A"
genai.configure(api_key=api_key)
try:
llm = genai.GenerativeModel('gemini-2.5-flash-lite')
except:
llm = genai.GenerativeModel('gemini-1.5-flash')
prompt = f"""
You are a TikTok Virality Expert.
Draft: "{user_input}"
Niche: {category} | Creator: {age}, {gender} with {followers} followers.
Context: {duration}s video posted on {day_of_week} at {hour}:00.
Viral Examples in this niche:
1. {top_3_videos[0]}
2. {top_3_videos[1]}
3. {top_3_videos[2]}
Task:
Rewrite the draft to be more viral. Add hooks and hashtags two hashtags MAX.
Keep it natural and relevant to the creator persona.
Make it short catchy and viral.
Output Format:
[New Description]
Recomendations:
[Upload day of week]
[Upload hour]
[Upload duration]
"""
try:
response = llm.generate_content(prompt)
improved_idea = response.text.strip()
# --- 5. RE-SCORING ---
new_text_vec = ST_MODEL.encode([improved_idea], convert_to_numpy=True)
# Using same metadata for the new prediction
new_feat_vec = np.hstack((new_text_vec, meta_vec))
new_log = MODEL.predict(new_feat_vec)[0]
new_views = int(np.expm1(new_log))
uplift_pct = ((new_views - initial_views) / initial_views) * 100
uplift_str = f"+{uplift_pct:.1f}%" if uplift_pct > 0 else "No significant uplift"
return f"{initial_views:,}", similar_videos_str, improved_idea, f"{new_views:,}", uplift_str
except Exception as e:
return f"{initial_views:,}", similar_videos_str, f"Error calling AI: {str(e)}", "N/A", "N/A"
# --- GRADIO UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# π Viral Content Optimizer")
gr.Markdown("Enter your video idea and stats to predict views and get AI-powered optimizations based on 2025 trends.")
with gr.Row():
with gr.Column(scale=1):
input_text = gr.Textbox(
label="Your Video Description",
placeholder="e.g., POV: trying the new grimace shake #viral",
lines=3
)
with gr.Row():
category = gr.Dropdown(
choices=["Entertainment", "Gaming", "Fitness", "Food", "Beauty", "Tech", "Travel", "Education", "Fashion", "Health", "DIY", "Pranks"],
value="Entertainment", label="Niche"
)
followers = gr.Number(value=1000, label="Follower Count", precision=0)
with gr.Row():
age = gr.Dropdown(choices=["18-24", "25-34", "35-44", "45+"], value="18-24", label="Creator Age")
gender = gr.Dropdown(choices=["Male", "Female"], value="Female", label="Creator Gender")
with gr.Row():
duration_slider = gr.Slider(minimum=5, maximum=60, value=15, step=1, label="Duration (s)")
hour_slider = gr.Slider(minimum=0, maximum=23, value=18, step=1, label="Upload Hour")
day_dropdown = gr.Dropdown(
choices=["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"],
value="Friday", label="Day"
)
with gr.Row():
submit_btn = gr.Button("Analyze & Optimize β‘", variant="primary")
demo_btn = gr.Button("π² Try Demo", variant="secondary")
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### π Predictions")
initial_views = gr.Textbox(label="Predicted Views (Original)", interactive=False)
with gr.Group():
gr.Markdown("### β¨ AI Optimization")
improved_text = gr.Textbox(label="Improved Description", interactive=False)
with gr.Row():
new_views = gr.Textbox(label="New Predicted Views", interactive=False)
uplift = gr.Textbox(label="Potential Uplift", interactive=False)
with gr.Accordion("π Similar Viral Videos (Reference)", open=False):
similar_videos = gr.Textbox(label="Top 3 Context Matches", interactive=False, lines=5)
submit_btn.click(
fn=predict_and_optimize,
inputs=[input_text, duration_slider, hour_slider, day_dropdown, category, followers, age, gender],
outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
)
# Demo Button Logic
demo_btn.click(
fn=lambda: ("My protein shake ended up on the floor", 15, 19, "Monday", "Fitness", 50000, "18-24", "Male"),
inputs=None,
outputs=[input_text, duration_slider, hour_slider, day_dropdown, category, followers, age, gender]
).then(
fn=predict_and_optimize,
inputs=[gr.State("My protein shake ended up on the floor"), gr.State(15), gr.State(19), gr.State("Monday"), gr.State("Fitness"), gr.State(50000), gr.State("18-24"), gr.State("Male")],
outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
)
# Run initialization
if __name__ == "__main__":
initialize_app()
demo.launch(server_name="0.0.0.0", server_port=7860)
|