Spaces:
Sleeping
Sleeping
Matan Kriel commited on
Commit Β·
33daa3c
1
Parent(s): 18699f2
updated the app.py
Browse files
app.py
CHANGED
|
@@ -5,88 +5,131 @@ import os
|
|
| 5 |
import google.generativeai as genai
|
| 6 |
from sentence_transformers import SentenceTransformer
|
| 7 |
from sklearn.metrics.pairwise import cosine_similarity
|
|
|
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
from xgboost import XGBRegressor # Use Regressor as per model-prep
|
| 12 |
import pickle
|
| 13 |
-
from importlib.util import spec_from_file_location
|
| 14 |
-
import sys
|
| 15 |
-
# Since we are loading artifacts, we don't strictly need model-prep.py logic anymore.
|
| 16 |
-
# But keeping basic imports is fine.
|
| 17 |
|
| 18 |
# Load environment variables
|
| 19 |
load_dotenv()
|
| 20 |
|
| 21 |
# --- GLOBAL STATE ---
|
| 22 |
MODEL = None
|
| 23 |
-
VECTORIZER = None
|
| 24 |
KNOWLEDGE_DF = None
|
| 25 |
ST_MODEL = None
|
|
|
|
| 26 |
|
| 27 |
def initialize_app():
|
| 28 |
-
"""Initializes the model and data on app startup."""
|
| 29 |
-
global MODEL,
|
| 30 |
-
|
| 31 |
-
print("β³ initializing app: Loading
|
| 32 |
-
|
| 33 |
-
# 1. Load
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
print(
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
print("π Loading SentenceTransformer...")
|
| 55 |
-
|
| 56 |
-
|
|
|
|
| 57 |
import torch
|
| 58 |
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 59 |
-
st_model = SentenceTransformer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
MODEL = model
|
| 62 |
-
VECTORIZER = tfidf
|
| 63 |
KNOWLEDGE_DF = knowledge_df
|
| 64 |
ST_MODEL = st_model
|
| 65 |
print("β
App initialized (Inference Mode)!")
|
| 66 |
|
| 67 |
-
def predict_and_optimize(user_input, duration, hour, day_of_week):
|
| 68 |
if not user_input:
|
| 69 |
return "Please enter a video description.", "", "", "", ""
|
| 70 |
|
| 71 |
-
# --- 1.
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
|
| 75 |
-
# Map Day of Week to is_weekend
|
| 76 |
-
weekend_days = ["Saturday", "Sunday"]
|
| 77 |
-
is_weekend = 1 if day_of_week in weekend_days else 0
|
| 78 |
|
| 79 |
-
# Construct metadata vector: [duration, hour, is_weekend, hashtag_count]
|
| 80 |
-
meta_vec = np.array([[duration, hour, is_weekend, user_input.count('#')]])
|
| 81 |
feat_vec = np.hstack((text_vec, meta_vec))
|
| 82 |
|
| 83 |
initial_log = MODEL.predict(feat_vec)[0]
|
| 84 |
initial_views = int(np.expm1(initial_log))
|
| 85 |
|
| 86 |
-
# ---
|
| 87 |
-
# Filter for viral hits in knowledge base (top 25%)
|
| 88 |
high_perf_df = KNOWLEDGE_DF[KNOWLEDGE_DF['views'] > KNOWLEDGE_DF['views'].quantile(0.75)].copy()
|
| 89 |
-
|
| 90 |
user_embedding = ST_MODEL.encode([user_input], convert_to_numpy=True)
|
| 91 |
target_embeddings = np.stack(high_perf_df['embedding'].values)
|
| 92 |
|
|
@@ -96,13 +139,12 @@ def predict_and_optimize(user_input, duration, hour, day_of_week):
|
|
| 96 |
|
| 97 |
similar_videos_str = "\n\n".join([f"{i+1}. {v}" for i, v in enumerate(top_3_videos)])
|
| 98 |
|
| 99 |
-
# ---
|
| 100 |
api_key = os.getenv("GEMINI_API_KEY")
|
| 101 |
if not api_key:
|
| 102 |
return f"{initial_views:,}", similar_videos_str, "Error: GEMINI_API_KEY not found.", "N/A", "N/A"
|
| 103 |
|
| 104 |
genai.configure(api_key=api_key)
|
| 105 |
-
# Using the updated model from the user's latest change
|
| 106 |
try:
|
| 107 |
llm = genai.GenerativeModel('gemini-2.5-flash-lite')
|
| 108 |
except:
|
|
@@ -111,33 +153,31 @@ def predict_and_optimize(user_input, duration, hour, day_of_week):
|
|
| 111 |
prompt = f"""
|
| 112 |
You are a TikTok Virality Expert.
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
|
|
|
| 116 |
|
| 117 |
-
|
| 118 |
1. {top_3_videos[0]}
|
| 119 |
2. {top_3_videos[1]}
|
| 120 |
3. {top_3_videos[2]}
|
| 121 |
|
| 122 |
Task:
|
| 123 |
-
|
| 124 |
-
|
| 125 |
|
| 126 |
Output Format:
|
| 127 |
[New Description]
|
| 128 |
-
|
| 129 |
-
[Metadata Tip]
|
| 130 |
"""
|
| 131 |
|
| 132 |
try:
|
| 133 |
response = llm.generate_content(prompt)
|
| 134 |
improved_idea = response.text.strip()
|
| 135 |
|
| 136 |
-
# ---
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
new_feat_vec = np.hstack((new_text_vec, new_meta_vec))
|
| 141 |
|
| 142 |
new_log = MODEL.predict(new_feat_vec)[0]
|
| 143 |
new_views = int(np.expm1(new_log))
|
|
@@ -154,25 +194,34 @@ def predict_and_optimize(user_input, duration, hour, day_of_week):
|
|
| 154 |
# --- GRADIO UI ---
|
| 155 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 156 |
gr.Markdown("# π Viral Content Optimizer")
|
| 157 |
-
gr.Markdown("Enter your video idea to predict
|
| 158 |
|
| 159 |
with gr.Row():
|
| 160 |
with gr.Column(scale=1):
|
| 161 |
input_text = gr.Textbox(
|
| 162 |
-
label="
|
| 163 |
-
placeholder="
|
| 164 |
lines=3
|
| 165 |
)
|
| 166 |
|
| 167 |
with gr.Row():
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
-
|
| 172 |
-
choices=["
|
| 173 |
-
value="
|
| 174 |
-
|
| 175 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
| 177 |
with gr.Row():
|
| 178 |
submit_btn = gr.Button("Analyze & Optimize β‘", variant="primary")
|
|
@@ -195,23 +244,18 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
| 195 |
|
| 196 |
submit_btn.click(
|
| 197 |
fn=predict_and_optimize,
|
| 198 |
-
inputs=[input_text, duration_slider, hour_slider, day_dropdown],
|
| 199 |
outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
|
| 200 |
)
|
| 201 |
|
| 202 |
-
# Demo Button Logic
|
| 203 |
-
demo_text = "Meetings fails when You realize you forgot to turn off your mic"
|
| 204 |
-
demo_duration = 35
|
| 205 |
-
demo_hour = 19
|
| 206 |
-
demo_day = "Tuesday"
|
| 207 |
-
|
| 208 |
demo_btn.click(
|
| 209 |
-
fn=lambda: (
|
| 210 |
inputs=None,
|
| 211 |
-
outputs=[input_text, duration_slider, hour_slider, day_dropdown]
|
| 212 |
).then(
|
| 213 |
fn=predict_and_optimize,
|
| 214 |
-
inputs=[gr.State(
|
| 215 |
outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
|
| 216 |
)
|
| 217 |
|
|
|
|
| 5 |
import google.generativeai as genai
|
| 6 |
from sentence_transformers import SentenceTransformer
|
| 7 |
from sklearn.metrics.pairwise import cosine_similarity
|
| 8 |
+
from sklearn.preprocessing import LabelEncoder
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
from huggingface_hub import hf_hub_download
|
|
|
|
| 12 |
import pickle
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# Load environment variables
|
| 15 |
load_dotenv()
|
| 16 |
|
| 17 |
# --- GLOBAL STATE ---
|
| 18 |
MODEL = None
|
|
|
|
| 19 |
KNOWLEDGE_DF = None
|
| 20 |
ST_MODEL = None
|
| 21 |
+
ENCODERS = {} # To store label encoders
|
| 22 |
|
| 23 |
def initialize_app():
|
| 24 |
+
"""Initializes the model and data on app startup from Hugging Face."""
|
| 25 |
+
global MODEL, KNOWLEDGE_DF, ST_MODEL, ENCODERS
|
| 26 |
+
|
| 27 |
+
print("β³ initializing app: Loading resources from Hugging Face Cloud...")
|
| 28 |
+
|
| 29 |
+
# 1. Load Dataset from HF
|
| 30 |
+
print("π Downloading Dataset (MatanKriel/social-assitent-synthetic-data)...")
|
| 31 |
+
try:
|
| 32 |
+
dataset = load_dataset("MatanKriel/social-assitent-synthetic-data")
|
| 33 |
+
if 'train' in dataset:
|
| 34 |
+
knowledge_df = dataset['train'].to_pandas()
|
| 35 |
+
else:
|
| 36 |
+
knowledge_df = dataset.to_pandas()
|
| 37 |
+
print(f" -> Loaded {len(knowledge_df)} rows.")
|
| 38 |
+
except Exception as e:
|
| 39 |
+
raise RuntimeError(f"Failed to load dataset from HF: {e}")
|
| 40 |
+
|
| 41 |
+
# 2. FIT ENCODERS (Crucial for Feature Consistency)
|
| 42 |
+
print("π€ Fitting Label Encoders...")
|
| 43 |
+
cat_cols = ['category', 'gender', 'day_of_week', 'age']
|
| 44 |
+
for c in cat_cols:
|
| 45 |
+
if c in knowledge_df.columns:
|
| 46 |
+
le = LabelEncoder()
|
| 47 |
+
# Ensure all values are strings
|
| 48 |
+
le.fit(knowledge_df[c].astype(str))
|
| 49 |
+
ENCODERS[c] = le
|
| 50 |
+
print(f" -> Encoded '{c}': {len(le.classes_)} classes")
|
| 51 |
+
else:
|
| 52 |
+
print(f" β οΈ Warning: Column '{c}' missing from dataset!")
|
| 53 |
+
|
| 54 |
+
# 3. Load Model from HF
|
| 55 |
+
print("π§ Downloading Model (MatanKriel/social-assitent-viral-predictor)...")
|
| 56 |
+
try:
|
| 57 |
+
model_path = hf_hub_download(repo_id="MatanKriel/social-assitent-viral-predictor", filename="viral_model.pkl")
|
| 58 |
+
with open(model_path, "rb") as f:
|
| 59 |
+
model = pickle.load(f)
|
| 60 |
+
print(f" -> Loaded model: {type(model).__name__}")
|
| 61 |
+
except Exception as e:
|
| 62 |
+
# Fallback to local
|
| 63 |
+
if os.path.exists("viral_model.pkl"):
|
| 64 |
+
print(f" β οΈ HF Download failed ({e}). Loading local 'viral_model.pkl' instead.")
|
| 65 |
+
with open("viral_model.pkl", "rb") as f:
|
| 66 |
+
model = pickle.load(f)
|
| 67 |
+
else:
|
| 68 |
+
raise RuntimeError(f"Failed to load model from HF and no local backup found: {e}")
|
| 69 |
+
|
| 70 |
+
# 4. Load SentenceTransformer
|
| 71 |
print("π Loading SentenceTransformer...")
|
| 72 |
+
embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
|
| 73 |
+
print(f" -> Model: {embedding_model_name}")
|
| 74 |
+
|
| 75 |
import torch
|
| 76 |
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
| 77 |
+
st_model = SentenceTransformer(embedding_model_name, device=device)
|
| 78 |
+
|
| 79 |
+
# 5. Generate Embeddings (On-the-fly)
|
| 80 |
+
print("βοΈ Generating Embeddings for Knowledge Base (One-time setup)...")
|
| 81 |
+
if 'embedding' not in knowledge_df.columns:
|
| 82 |
+
embeddings = st_model.encode(knowledge_df['description'].fillna("").tolist(),
|
| 83 |
+
convert_to_numpy=True,
|
| 84 |
+
show_progress_bar=True)
|
| 85 |
+
knowledge_df['embedding'] = list(embeddings)
|
| 86 |
+
else:
|
| 87 |
+
print(" -> Embeddings already present in dataset.")
|
| 88 |
|
| 89 |
MODEL = model
|
|
|
|
| 90 |
KNOWLEDGE_DF = knowledge_df
|
| 91 |
ST_MODEL = st_model
|
| 92 |
print("β
App initialized (Inference Mode)!")
|
| 93 |
|
| 94 |
+
def predict_and_optimize(user_input, duration, hour, day_of_week, category, followers, age, gender):
|
| 95 |
if not user_input:
|
| 96 |
return "Please enter a video description.", "", "", "", ""
|
| 97 |
|
| 98 |
+
# --- 1. ENCODE INPUTS ---
|
| 99 |
+
try:
|
| 100 |
+
# Helper to encode safely
|
| 101 |
+
def safe_encode(col, val):
|
| 102 |
+
le = ENCODERS.get(col)
|
| 103 |
+
if le:
|
| 104 |
+
# If value not seen, default to first class (mostly 'Unknown' or similar if handled, else 0)
|
| 105 |
+
if val in le.classes_:
|
| 106 |
+
return le.transform([val])[0]
|
| 107 |
+
else:
|
| 108 |
+
return 0 # Fallback
|
| 109 |
+
return 0
|
| 110 |
+
|
| 111 |
+
cat_encoded = safe_encode('category', category)
|
| 112 |
+
gender_encoded = safe_encode('gender', gender)
|
| 113 |
+
day_encoded = safe_encode('day_of_week', day_of_week)
|
| 114 |
+
age_encoded = safe_encode('age', age)
|
| 115 |
+
|
| 116 |
+
except Exception as e:
|
| 117 |
+
return f"Encoding Error: {str(e)}", "", "", "", ""
|
| 118 |
+
|
| 119 |
+
# --- 2. INITIAL PREDICTION ---
|
| 120 |
+
# Feature Order MUST match model-prep.py:
|
| 121 |
+
# Embeddings + [duration, hour, followers, category_enc, gender_enc, day_enc, age_enc]
|
| 122 |
+
text_vec = ST_MODEL.encode([user_input], convert_to_numpy=True)
|
| 123 |
|
| 124 |
+
meta_vec = np.array([[duration, hour, followers, cat_encoded, gender_encoded, day_encoded, age_encoded]])
|
|
|
|
|
|
|
|
|
|
| 125 |
|
|
|
|
|
|
|
| 126 |
feat_vec = np.hstack((text_vec, meta_vec))
|
| 127 |
|
| 128 |
initial_log = MODEL.predict(feat_vec)[0]
|
| 129 |
initial_views = int(np.expm1(initial_log))
|
| 130 |
|
| 131 |
+
# --- 3. VECTOR SEARCH ---
|
|
|
|
| 132 |
high_perf_df = KNOWLEDGE_DF[KNOWLEDGE_DF['views'] > KNOWLEDGE_DF['views'].quantile(0.75)].copy()
|
|
|
|
| 133 |
user_embedding = ST_MODEL.encode([user_input], convert_to_numpy=True)
|
| 134 |
target_embeddings = np.stack(high_perf_df['embedding'].values)
|
| 135 |
|
|
|
|
| 139 |
|
| 140 |
similar_videos_str = "\n\n".join([f"{i+1}. {v}" for i, v in enumerate(top_3_videos)])
|
| 141 |
|
| 142 |
+
# --- 4. GEMINI OPTIMIZATION ---
|
| 143 |
api_key = os.getenv("GEMINI_API_KEY")
|
| 144 |
if not api_key:
|
| 145 |
return f"{initial_views:,}", similar_videos_str, "Error: GEMINI_API_KEY not found.", "N/A", "N/A"
|
| 146 |
|
| 147 |
genai.configure(api_key=api_key)
|
|
|
|
| 148 |
try:
|
| 149 |
llm = genai.GenerativeModel('gemini-2.5-flash-lite')
|
| 150 |
except:
|
|
|
|
| 153 |
prompt = f"""
|
| 154 |
You are a TikTok Virality Expert.
|
| 155 |
|
| 156 |
+
Draft: "{user_input}"
|
| 157 |
+
Niche: {category} | Creator: {age}, {gender} with {followers} followers.
|
| 158 |
+
Context: {duration}s video posted on {day_of_week} at {hour}:00.
|
| 159 |
|
| 160 |
+
Viral Examples in this niche:
|
| 161 |
1. {top_3_videos[0]}
|
| 162 |
2. {top_3_videos[1]}
|
| 163 |
3. {top_3_videos[2]}
|
| 164 |
|
| 165 |
Task:
|
| 166 |
+
Rewrite the draft to be more viral. Add hooks and hashtags.
|
| 167 |
+
Keep it natural and relevant to the creator persona.
|
| 168 |
|
| 169 |
Output Format:
|
| 170 |
[New Description]
|
|
|
|
|
|
|
| 171 |
"""
|
| 172 |
|
| 173 |
try:
|
| 174 |
response = llm.generate_content(prompt)
|
| 175 |
improved_idea = response.text.strip()
|
| 176 |
|
| 177 |
+
# --- 5. RE-SCORING ---
|
| 178 |
+
new_text_vec = ST_MODEL.encode([improved_idea], convert_to_numpy=True)
|
| 179 |
+
# Assuming metadata stays same for optimization check
|
| 180 |
+
new_feat_vec = np.hstack((new_text_vec, meta_vec))
|
|
|
|
| 181 |
|
| 182 |
new_log = MODEL.predict(new_feat_vec)[0]
|
| 183 |
new_views = int(np.expm1(new_log))
|
|
|
|
| 194 |
# --- GRADIO UI ---
|
| 195 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
| 196 |
gr.Markdown("# π Viral Content Optimizer")
|
| 197 |
+
gr.Markdown("Enter your video idea and stats to predict views and get AI-powered optimizations.")
|
| 198 |
|
| 199 |
with gr.Row():
|
| 200 |
with gr.Column(scale=1):
|
| 201 |
input_text = gr.Textbox(
|
| 202 |
+
label="Video Description",
|
| 203 |
+
placeholder="POV: When you forget to mute your mic...",
|
| 204 |
lines=3
|
| 205 |
)
|
| 206 |
|
| 207 |
with gr.Row():
|
| 208 |
+
category = gr.Dropdown(
|
| 209 |
+
choices=["Entertainment", "Gaming", "Fitness", "Food", "Beauty", "Tech", "Travel", "Education", "Fashion", "Health", "DIY", "Pranks"],
|
| 210 |
+
value="Entertainment", label="Niche"
|
| 211 |
+
)
|
| 212 |
+
followers = gr.Number(value=1000, label="Follower Count", precision=0)
|
| 213 |
|
| 214 |
+
with gr.Row():
|
| 215 |
+
age = gr.Dropdown(choices=["18-24", "25-34", "35-44", "45+"], value="18-24", label="Creator Age")
|
| 216 |
+
gender = gr.Dropdown(choices=["Male", "Female"], value="Female", label="Creator Gender")
|
| 217 |
+
|
| 218 |
+
with gr.Row():
|
| 219 |
+
duration_slider = gr.Slider(minimum=5, maximum=60, value=15, step=1, label="Duration (s)")
|
| 220 |
+
hour_slider = gr.Slider(minimum=0, maximum=23, value=18, step=1, label="Upload Hour")
|
| 221 |
+
day_dropdown = gr.Dropdown(
|
| 222 |
+
choices=["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"],
|
| 223 |
+
value="Friday", label="Day"
|
| 224 |
+
)
|
| 225 |
|
| 226 |
with gr.Row():
|
| 227 |
submit_btn = gr.Button("Analyze & Optimize β‘", variant="primary")
|
|
|
|
| 244 |
|
| 245 |
submit_btn.click(
|
| 246 |
fn=predict_and_optimize,
|
| 247 |
+
inputs=[input_text, duration_slider, hour_slider, day_dropdown, category, followers, age, gender],
|
| 248 |
outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
|
| 249 |
)
|
| 250 |
|
| 251 |
+
# Demo Button Logic
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
demo_btn.click(
|
| 253 |
+
fn=lambda: ("My protein shake recipe ended up on the floor π #gymfail", 15, 19, "Monday", "Fitness", 50000, "18-24", "Male"),
|
| 254 |
inputs=None,
|
| 255 |
+
outputs=[input_text, duration_slider, hour_slider, day_dropdown, category, followers, age, gender]
|
| 256 |
).then(
|
| 257 |
fn=predict_and_optimize,
|
| 258 |
+
inputs=[gr.State("My protein shake recipe ended up on the floor π #gymfail"), gr.State(15), gr.State(19), gr.State("Monday"), gr.State("Fitness"), gr.State(50000), gr.State("18-24"), gr.State("Male")],
|
| 259 |
outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
|
| 260 |
)
|
| 261 |
|