Matan Kriel commited on
Commit
33daa3c
Β·
1 Parent(s): 18699f2

updated the app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -84
app.py CHANGED
@@ -5,88 +5,131 @@ import os
5
  import google.generativeai as genai
6
  from sentence_transformers import SentenceTransformer
7
  from sklearn.metrics.pairwise import cosine_similarity
 
8
  from dotenv import load_dotenv
9
-
10
- # Import functions from model-prep
11
- from xgboost import XGBRegressor # Use Regressor as per model-prep
12
  import pickle
13
- from importlib.util import spec_from_file_location
14
- import sys
15
- # Since we are loading artifacts, we don't strictly need model-prep.py logic anymore.
16
- # But keeping basic imports is fine.
17
 
18
  # Load environment variables
19
  load_dotenv()
20
 
21
  # --- GLOBAL STATE ---
22
  MODEL = None
23
- VECTORIZER = None
24
  KNOWLEDGE_DF = None
25
  ST_MODEL = None
 
26
 
27
  def initialize_app():
28
- """Initializes the model and data on app startup."""
29
- global MODEL, VECTORIZER, KNOWLEDGE_DF, ST_MODEL
30
-
31
- print("⏳ initializing app: Loading pre-computed artifacts...")
32
-
33
- # 1. Load Parquet Data (Knowledge Base)
34
- # We expect this file to exist now.
35
- parquet_path = 'tiktok_knowledge_base.parquet'
36
- if not os.path.exists(parquet_path):
37
- raise FileNotFoundError(f"Required file '{parquet_path}' not found! Run model-prep.py first.")
38
-
39
- print(f"πŸ“‚ Loading data from {parquet_path}...")
40
- knowledge_df = pd.read_parquet(parquet_path)
41
-
42
- # 2. Load Model
43
- print("🧠 Loading Prediction Model (Pickle)...")
44
- with open("viral_model.pkl", "rb") as f:
45
- model = pickle.load(f)
46
- print(f" -> Loaded model type: {type(model).__name__}")
47
-
48
- # 3. Load Vectorizer
49
- print("πŸ”€ Loading TF-IDF Vectorizer...")
50
- with open("tfidf_vectorizer.pkl", "rb") as f:
51
- tfidf = pickle.load(f)
52
-
53
- # 4. Load Sentence Transformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  print("πŸ”Œ Loading SentenceTransformer...")
55
- # device=model_prep.device might fail if we don't import model_prep executed.
56
- # Just use defaults or check pytorch standardly.
 
57
  import torch
58
  device = "mps" if torch.backends.mps.is_available() else "cpu"
59
- st_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
 
 
 
 
 
 
 
 
 
 
60
 
61
  MODEL = model
62
- VECTORIZER = tfidf
63
  KNOWLEDGE_DF = knowledge_df
64
  ST_MODEL = st_model
65
  print("βœ… App initialized (Inference Mode)!")
66
 
67
- def predict_and_optimize(user_input, duration, hour, day_of_week):
68
  if not user_input:
69
  return "Please enter a video description.", "", "", "", ""
70
 
71
- # --- 1. INITIAL PREDICTION ---
72
- text_vec = VECTORIZER.transform([user_input]).toarray()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
- # Process metadata inputs
75
- # Map Day of Week to is_weekend
76
- weekend_days = ["Saturday", "Sunday"]
77
- is_weekend = 1 if day_of_week in weekend_days else 0
78
 
79
- # Construct metadata vector: [duration, hour, is_weekend, hashtag_count]
80
- meta_vec = np.array([[duration, hour, is_weekend, user_input.count('#')]])
81
  feat_vec = np.hstack((text_vec, meta_vec))
82
 
83
  initial_log = MODEL.predict(feat_vec)[0]
84
  initial_views = int(np.expm1(initial_log))
85
 
86
- # --- 2. VECTOR SEARCH ---
87
- # Filter for viral hits in knowledge base (top 25%)
88
  high_perf_df = KNOWLEDGE_DF[KNOWLEDGE_DF['views'] > KNOWLEDGE_DF['views'].quantile(0.75)].copy()
89
-
90
  user_embedding = ST_MODEL.encode([user_input], convert_to_numpy=True)
91
  target_embeddings = np.stack(high_perf_df['embedding'].values)
92
 
@@ -96,13 +139,12 @@ def predict_and_optimize(user_input, duration, hour, day_of_week):
96
 
97
  similar_videos_str = "\n\n".join([f"{i+1}. {v}" for i, v in enumerate(top_3_videos)])
98
 
99
- # --- 3. GEMINI OPTIMIZATION ---
100
  api_key = os.getenv("GEMINI_API_KEY")
101
  if not api_key:
102
  return f"{initial_views:,}", similar_videos_str, "Error: GEMINI_API_KEY not found.", "N/A", "N/A"
103
 
104
  genai.configure(api_key=api_key)
105
- # Using the updated model from the user's latest change
106
  try:
107
  llm = genai.GenerativeModel('gemini-2.5-flash-lite')
108
  except:
@@ -111,33 +153,31 @@ def predict_and_optimize(user_input, duration, hour, day_of_week):
111
  prompt = f"""
112
  You are a TikTok Virality Expert.
113
 
114
- My Draft Description: "{user_input}"
115
- My Metadata Plan: Duration={duration}s, Upload Time={hour}:00, Day={day_of_week}.
 
116
 
117
- Here are 3 successful, viral videos that are similar to my topic:
118
  1. {top_3_videos[0]}
119
  2. {top_3_videos[1]}
120
  3. {top_3_videos[2]}
121
 
122
  Task:
123
- 1. Rewrite my draft description to make it go viral.
124
- 2. Analyze my metadata (Duration, Time, Day) and provide a specific recommendation to improve it (e.g. "Try posting at 18:00 instead" or "Make it shorter").
125
 
126
  Output Format:
127
  [New Description]
128
-
129
- [Metadata Tip]
130
  """
131
 
132
  try:
133
  response = llm.generate_content(prompt)
134
  improved_idea = response.text.strip()
135
 
136
- # --- 4. RE-SCORING ---
137
- # Use same metadata for the new prediction (assuming user doesn't change time/duration)
138
- new_text_vec = VECTORIZER.transform([improved_idea]).toarray()
139
- new_meta_vec = np.array([[duration, hour, is_weekend, improved_idea.count('#')]])
140
- new_feat_vec = np.hstack((new_text_vec, new_meta_vec))
141
 
142
  new_log = MODEL.predict(new_feat_vec)[0]
143
  new_views = int(np.expm1(new_log))
@@ -154,25 +194,34 @@ def predict_and_optimize(user_input, duration, hour, day_of_week):
154
  # --- GRADIO UI ---
155
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
156
  gr.Markdown("# πŸš€ Viral Content Optimizer")
157
- gr.Markdown("Enter your video idea to predict its views and get AI-powered optimizations based on 2025 trends.")
158
 
159
  with gr.Row():
160
  with gr.Column(scale=1):
161
  input_text = gr.Textbox(
162
- label="Your Video Description",
163
- placeholder="e.g., POV: trying the new grimace shake #viral",
164
  lines=3
165
  )
166
 
167
  with gr.Row():
168
- duration_slider = gr.Slider(minimum=5, maximum=180, value=15, step=1, label="Duration (seconds)")
169
- hour_slider = gr.Slider(minimum=0, maximum=23, value=18, step=1, label="Upload Hour (24h)")
 
 
 
170
 
171
- day_dropdown = gr.Dropdown(
172
- choices=["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"],
173
- value="Friday",
174
- label="Day of Week"
175
- )
 
 
 
 
 
 
176
 
177
  with gr.Row():
178
  submit_btn = gr.Button("Analyze & Optimize ⚑", variant="primary")
@@ -195,23 +244,18 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
195
 
196
  submit_btn.click(
197
  fn=predict_and_optimize,
198
- inputs=[input_text, duration_slider, hour_slider, day_dropdown],
199
  outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
200
  )
201
 
202
- # Demo Button Logic: 1. Fill Form -> 2. Run Prediction
203
- demo_text = "Meetings fails when You realize you forgot to turn off your mic"
204
- demo_duration = 35
205
- demo_hour = 19
206
- demo_day = "Tuesday"
207
-
208
  demo_btn.click(
209
- fn=lambda: (demo_text, demo_duration, demo_hour, demo_day),
210
  inputs=None,
211
- outputs=[input_text, duration_slider, hour_slider, day_dropdown]
212
  ).then(
213
  fn=predict_and_optimize,
214
- inputs=[gr.State(demo_text), gr.State(demo_duration), gr.State(demo_hour), gr.State(demo_day)],
215
  outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
216
  )
217
 
 
5
  import google.generativeai as genai
6
  from sentence_transformers import SentenceTransformer
7
  from sklearn.metrics.pairwise import cosine_similarity
8
+ from sklearn.preprocessing import LabelEncoder
9
  from dotenv import load_dotenv
10
+ from datasets import load_dataset
11
+ from huggingface_hub import hf_hub_download
 
12
  import pickle
 
 
 
 
13
 
14
  # Load environment variables
15
  load_dotenv()
16
 
17
  # --- GLOBAL STATE ---
18
  MODEL = None
 
19
  KNOWLEDGE_DF = None
20
  ST_MODEL = None
21
+ ENCODERS = {} # To store label encoders
22
 
23
  def initialize_app():
24
+ """Initializes the model and data on app startup from Hugging Face."""
25
+ global MODEL, KNOWLEDGE_DF, ST_MODEL, ENCODERS
26
+
27
+ print("⏳ initializing app: Loading resources from Hugging Face Cloud...")
28
+
29
+ # 1. Load Dataset from HF
30
+ print("πŸ“‚ Downloading Dataset (MatanKriel/social-assitent-synthetic-data)...")
31
+ try:
32
+ dataset = load_dataset("MatanKriel/social-assitent-synthetic-data")
33
+ if 'train' in dataset:
34
+ knowledge_df = dataset['train'].to_pandas()
35
+ else:
36
+ knowledge_df = dataset.to_pandas()
37
+ print(f" -> Loaded {len(knowledge_df)} rows.")
38
+ except Exception as e:
39
+ raise RuntimeError(f"Failed to load dataset from HF: {e}")
40
+
41
+ # 2. FIT ENCODERS (Crucial for Feature Consistency)
42
+ print("πŸ”€ Fitting Label Encoders...")
43
+ cat_cols = ['category', 'gender', 'day_of_week', 'age']
44
+ for c in cat_cols:
45
+ if c in knowledge_df.columns:
46
+ le = LabelEncoder()
47
+ # Ensure all values are strings
48
+ le.fit(knowledge_df[c].astype(str))
49
+ ENCODERS[c] = le
50
+ print(f" -> Encoded '{c}': {len(le.classes_)} classes")
51
+ else:
52
+ print(f" ⚠️ Warning: Column '{c}' missing from dataset!")
53
+
54
+ # 3. Load Model from HF
55
+ print("🧠 Downloading Model (MatanKriel/social-assitent-viral-predictor)...")
56
+ try:
57
+ model_path = hf_hub_download(repo_id="MatanKriel/social-assitent-viral-predictor", filename="viral_model.pkl")
58
+ with open(model_path, "rb") as f:
59
+ model = pickle.load(f)
60
+ print(f" -> Loaded model: {type(model).__name__}")
61
+ except Exception as e:
62
+ # Fallback to local
63
+ if os.path.exists("viral_model.pkl"):
64
+ print(f" ⚠️ HF Download failed ({e}). Loading local 'viral_model.pkl' instead.")
65
+ with open("viral_model.pkl", "rb") as f:
66
+ model = pickle.load(f)
67
+ else:
68
+ raise RuntimeError(f"Failed to load model from HF and no local backup found: {e}")
69
+
70
+ # 4. Load SentenceTransformer
71
  print("πŸ”Œ Loading SentenceTransformer...")
72
+ embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
73
+ print(f" -> Model: {embedding_model_name}")
74
+
75
  import torch
76
  device = "mps" if torch.backends.mps.is_available() else "cpu"
77
+ st_model = SentenceTransformer(embedding_model_name, device=device)
78
+
79
+ # 5. Generate Embeddings (On-the-fly)
80
+ print("βš™οΈ Generating Embeddings for Knowledge Base (One-time setup)...")
81
+ if 'embedding' not in knowledge_df.columns:
82
+ embeddings = st_model.encode(knowledge_df['description'].fillna("").tolist(),
83
+ convert_to_numpy=True,
84
+ show_progress_bar=True)
85
+ knowledge_df['embedding'] = list(embeddings)
86
+ else:
87
+ print(" -> Embeddings already present in dataset.")
88
 
89
  MODEL = model
 
90
  KNOWLEDGE_DF = knowledge_df
91
  ST_MODEL = st_model
92
  print("βœ… App initialized (Inference Mode)!")
93
 
94
+ def predict_and_optimize(user_input, duration, hour, day_of_week, category, followers, age, gender):
95
  if not user_input:
96
  return "Please enter a video description.", "", "", "", ""
97
 
98
+ # --- 1. ENCODE INPUTS ---
99
+ try:
100
+ # Helper to encode safely
101
+ def safe_encode(col, val):
102
+ le = ENCODERS.get(col)
103
+ if le:
104
+ # If value not seen, default to first class (mostly 'Unknown' or similar if handled, else 0)
105
+ if val in le.classes_:
106
+ return le.transform([val])[0]
107
+ else:
108
+ return 0 # Fallback
109
+ return 0
110
+
111
+ cat_encoded = safe_encode('category', category)
112
+ gender_encoded = safe_encode('gender', gender)
113
+ day_encoded = safe_encode('day_of_week', day_of_week)
114
+ age_encoded = safe_encode('age', age)
115
+
116
+ except Exception as e:
117
+ return f"Encoding Error: {str(e)}", "", "", "", ""
118
+
119
+ # --- 2. INITIAL PREDICTION ---
120
+ # Feature Order MUST match model-prep.py:
121
+ # Embeddings + [duration, hour, followers, category_enc, gender_enc, day_enc, age_enc]
122
+ text_vec = ST_MODEL.encode([user_input], convert_to_numpy=True)
123
 
124
+ meta_vec = np.array([[duration, hour, followers, cat_encoded, gender_encoded, day_encoded, age_encoded]])
 
 
 
125
 
 
 
126
  feat_vec = np.hstack((text_vec, meta_vec))
127
 
128
  initial_log = MODEL.predict(feat_vec)[0]
129
  initial_views = int(np.expm1(initial_log))
130
 
131
+ # --- 3. VECTOR SEARCH ---
 
132
  high_perf_df = KNOWLEDGE_DF[KNOWLEDGE_DF['views'] > KNOWLEDGE_DF['views'].quantile(0.75)].copy()
 
133
  user_embedding = ST_MODEL.encode([user_input], convert_to_numpy=True)
134
  target_embeddings = np.stack(high_perf_df['embedding'].values)
135
 
 
139
 
140
  similar_videos_str = "\n\n".join([f"{i+1}. {v}" for i, v in enumerate(top_3_videos)])
141
 
142
+ # --- 4. GEMINI OPTIMIZATION ---
143
  api_key = os.getenv("GEMINI_API_KEY")
144
  if not api_key:
145
  return f"{initial_views:,}", similar_videos_str, "Error: GEMINI_API_KEY not found.", "N/A", "N/A"
146
 
147
  genai.configure(api_key=api_key)
 
148
  try:
149
  llm = genai.GenerativeModel('gemini-2.5-flash-lite')
150
  except:
 
153
  prompt = f"""
154
  You are a TikTok Virality Expert.
155
 
156
+ Draft: "{user_input}"
157
+ Niche: {category} | Creator: {age}, {gender} with {followers} followers.
158
+ Context: {duration}s video posted on {day_of_week} at {hour}:00.
159
 
160
+ Viral Examples in this niche:
161
  1. {top_3_videos[0]}
162
  2. {top_3_videos[1]}
163
  3. {top_3_videos[2]}
164
 
165
  Task:
166
+ Rewrite the draft to be more viral. Add hooks and hashtags.
167
+ Keep it natural and relevant to the creator persona.
168
 
169
  Output Format:
170
  [New Description]
 
 
171
  """
172
 
173
  try:
174
  response = llm.generate_content(prompt)
175
  improved_idea = response.text.strip()
176
 
177
+ # --- 5. RE-SCORING ---
178
+ new_text_vec = ST_MODEL.encode([improved_idea], convert_to_numpy=True)
179
+ # Assuming metadata stays same for optimization check
180
+ new_feat_vec = np.hstack((new_text_vec, meta_vec))
 
181
 
182
  new_log = MODEL.predict(new_feat_vec)[0]
183
  new_views = int(np.expm1(new_log))
 
194
  # --- GRADIO UI ---
195
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
196
  gr.Markdown("# πŸš€ Viral Content Optimizer")
197
+ gr.Markdown("Enter your video idea and stats to predict views and get AI-powered optimizations.")
198
 
199
  with gr.Row():
200
  with gr.Column(scale=1):
201
  input_text = gr.Textbox(
202
+ label="Video Description",
203
+ placeholder="POV: When you forget to mute your mic...",
204
  lines=3
205
  )
206
 
207
  with gr.Row():
208
+ category = gr.Dropdown(
209
+ choices=["Entertainment", "Gaming", "Fitness", "Food", "Beauty", "Tech", "Travel", "Education", "Fashion", "Health", "DIY", "Pranks"],
210
+ value="Entertainment", label="Niche"
211
+ )
212
+ followers = gr.Number(value=1000, label="Follower Count", precision=0)
213
 
214
+ with gr.Row():
215
+ age = gr.Dropdown(choices=["18-24", "25-34", "35-44", "45+"], value="18-24", label="Creator Age")
216
+ gender = gr.Dropdown(choices=["Male", "Female"], value="Female", label="Creator Gender")
217
+
218
+ with gr.Row():
219
+ duration_slider = gr.Slider(minimum=5, maximum=60, value=15, step=1, label="Duration (s)")
220
+ hour_slider = gr.Slider(minimum=0, maximum=23, value=18, step=1, label="Upload Hour")
221
+ day_dropdown = gr.Dropdown(
222
+ choices=["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"],
223
+ value="Friday", label="Day"
224
+ )
225
 
226
  with gr.Row():
227
  submit_btn = gr.Button("Analyze & Optimize ⚑", variant="primary")
 
244
 
245
  submit_btn.click(
246
  fn=predict_and_optimize,
247
+ inputs=[input_text, duration_slider, hour_slider, day_dropdown, category, followers, age, gender],
248
  outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
249
  )
250
 
251
+ # Demo Button Logic
 
 
 
 
 
252
  demo_btn.click(
253
+ fn=lambda: ("My protein shake recipe ended up on the floor 😭 #gymfail", 15, 19, "Monday", "Fitness", 50000, "18-24", "Male"),
254
  inputs=None,
255
+ outputs=[input_text, duration_slider, hour_slider, day_dropdown, category, followers, age, gender]
256
  ).then(
257
  fn=predict_and_optimize,
258
+ inputs=[gr.State("My protein shake recipe ended up on the floor 😭 #gymfail"), gr.State(15), gr.State(19), gr.State("Monday"), gr.State("Fitness"), gr.State(50000), gr.State("18-24"), gr.State("Male")],
259
  outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
260
  )
261