Matan Kriel commited on
Commit
0c8eab5
Β·
1 Parent(s): a840261

updated files full data run

Browse files
README.md CHANGED
@@ -11,77 +11,81 @@ pinned: false
11
 
12
  # πŸš€ Social Media Virality Assistant
13
 
14
- A Data Science project that uses **Large Language Models (LLMs)** and **Machine Learning** to predict and optimize social media content virality.
15
 
16
- ## 🌟 Project Overview
17
- This tool helps content creators go viral by:
18
- 1. **Predicting Views**: Analyzing video descriptions to forecast performance.
19
- 2. **Optimizing Content**: Using **Google Gemini AI** to rewrite drafts with viral hooks (slang, hashtags).
20
- 3. **Learning from History**: Retrieving similar successful videos using **Semantic Search**.
21
 
22
- ## 🧠 Data Science Methodology
23
 
24
- ### 1. Synthetic Data Generation (LLM-Based)
25
- Since real-world TikTok data is private, we simulated a "Viral Environment":
26
- * **Generator**: Utilized `tiiuae/falcon-rw-1b` (via `transformers`) to generate **10,000 realistic video descriptions**.
27
- * **Diversity**: Prompted the LLM with various scenarios ("POV", "GRWM", "Storytime") to ensure distinct content clusters.
28
- * **Ground Truth Logic**: Developed a scoring function that assigns "Views" based on linguistic patterns (e.g., questions, emotional triggers) and metadata (time of day, duration), creating a learnable signal for the ML models.
29
 
30
- ### 2. Model Development & Comparison
31
- We treated this as a **Regression Problem** (Predicting Log-Views).
32
- We compared three algorithms to find the best predictor:
33
- * **Linear Regression**: Baseline model.
34
- * **Random Forest**: Good for non-linear interactions.
35
- * **XGBoost (Winner)**: Gradient boosting provided the best accuracy (Lowest RMSE).
 
 
 
 
 
 
36
 
37
- **Metrics Used:**
38
- * **RMSE (Root Mean Squared Error)**: Primary metric for model selection.
39
- * **MAE (Mean Absolute Error)**: Average view count error.
40
- * **MAPE**: Average percentage error.
41
 
42
- ### 3. Advanced Analysis (Plots)
 
 
 
 
 
 
 
43
 
44
- #### Semantic Diversity (PCA)
45
- ![Diversity Plot](project_plots/diversity_plot.png)
46
- *A PCA visualization showing the semantic spread of the 10,000 generated descriptions.*
47
 
48
- #### Model Performance
49
- ![Model Comparison](project_plots/model_comparison.png)
50
- *Bar chart comparing RMSE across models and ROC curves for viral classification validity.*
51
 
52
- #### Feature Importance
53
- ![Feature Importance](project_plots/feature_importance.png)
54
- *The top 20 words and metadata features that drive virality in our simulated world.*
 
 
 
 
 
 
55
 
56
  ## πŸ› οΈ Tech Stack
57
- * **Core**: Python, Pandas, Numpy, Scikit-Learn
58
- * **AI/LLM**: `transformers` (Falcon-1B), `google-genai` (Gemini 2.5)
59
- * **ML**: XGBoost, Sentence-Transformers (Embeddings)
60
- * **App**: Gradio (Web UI)
61
- * **Hardware**: Optimized for Apple Silicon (MPS).
62
-
63
- ## πŸ“‚ Project Structure
64
- ```bash
65
- β”œβ”€β”€ app.py # Inference App (Gradio)
66
- β”œβ”€β”€ model-prep.py # Training Pipeline (Data Gen -> Train -> Save)
67
- β”œβ”€β”€ requirements.txt # Dependencies
68
- β”œβ”€β”€ tiktok_knowledge_base.parquet # Semantic Search Index
69
- β”œβ”€β”€ viral_model.pkl # Trained ML Model (Pickle)
70
- β”œβ”€β”€ tfidf_vectorizer.pkl # Text Processor
71
- └── project_plots/ # Generated Analysis Plots
72
- ```
73
 
74
  ## πŸš€ How to Run
 
75
  1. **Install Dependencies**:
76
  ```bash
77
  pip install -r requirements.txt
78
  ```
79
- 2. **Train & Generate Data** (Downloads 2.6GB Model):
 
 
80
  ```bash
81
- python model-prep.py
 
82
  ```
83
- 3. **Run the App**:
 
 
84
  ```bash
85
- export GEMINI_API_KEY="your_key_here"
86
- python app.py
87
  ```
 
11
 
12
  # πŸš€ Social Media Virality Assistant
13
 
14
+ A machine learning-powered tool that helps content creators predict and optimize their video virality potential using **XGBoost** and **Google Gemini AI**.
15
 
16
+ ## πŸ—οΈ Architecture & Pipeline
 
 
 
 
17
 
18
+ This project consists of two main components: a training pipeline (`model-prep.py`) and an inference application (`app.py`).
19
 
20
+ ### 1. Training Pipeline (`model-prep.py`)
21
+ the `model-prep.py` script handles the end-to-end model creation process:
 
 
 
22
 
23
+ 1. **Cloud Data Loading**: It fetches the latest synthetic dataset directly from **Hugging Face** (`MatanKriel/social-assitent-synthetic-data`).
24
+ 2. **Embedding Benchmark**: It evaluates 3 state-of-the-art models (`MiniLM`, `mpnet-base`, `bge-small`) to find the best text encoder.
25
+ * *Metrics*: Encoding Speed vs. Clustering Quality (Silhouette Score).
26
+ * *Winner*: Defaults to `sentence-transformers/all-mpnet-base-v2`.
27
+ 3. **Feature Engineering**:
28
+ * Encodes categorical inputs: `category`, `gender`, `day_of_week`, `age`.
29
+ * Combines text embeddings with metadata (`followers`, `duration`, `hour`).
30
+ 4. **Model Training**: Trains and compares three regression algorithms:
31
+ * Linear Regression
32
+ * Random Forest
33
+ * **XGBoost (Winner)**: Selected for having the lowest RMSE.
34
+ 5. **Artifact Generation**: Saves the trained model locally (`viral_model.pkl`) and generates performance plots (`project_plots/`).
35
 
36
+ ### 2. Inference Application (`app.py`)
37
+ The `app.py` script runs a **Gradio** web interface that pulls artifacts from the cloud at startup:
 
 
38
 
39
+ 1. **Initialization**:
40
+ * Downloads the trained `viral_model.pkl` from Hugging Face (`MatanKriel/social-assitent-viral-predictor`).
41
+ * Downloads the dataset to build a Knowledge Base.
42
+ * Generates embeddings on-the-fly for the Knowledge Base.
43
+ 2. **Core Features**:
44
+ * **Virality Prediction**: Predicts raw view counts based on your draft description and stats.
45
+ * **AI Optimization**: Uses **Google Gemini** to rewrite your description with viral hooks and hashtags.
46
+ * **Semantic Search**: Finds similar successful videos from the knowledge base using Cosine Similarity.
47
 
48
+ ---
49
+
50
+ ## πŸ“Š Model Performance
51
 
52
+ The training script (`model-prep.py`) automatically generates these benchmarks:
 
 
53
 
54
+ ### Embedding Model Comparison
55
+ We selected the embedding model that best balances speed and semantic understanding.
56
+ ![Embedding Benchmark](project_plots/embedding_benchmark.png)
57
+
58
+ ### Regression Model Comparison
59
+ We chose the regressor with the lowest error (RMSE) and highest explained variance (RΒ²).
60
+ ![Model Comparison](project_plots/regression_comparison.png)
61
+
62
+ ---
63
 
64
  ## πŸ› οΈ Tech Stack
65
+ This project is built using:
66
+ * **App**: `gradio`, `google-generativeai`
67
+ * **ML**: `xgboost`, `scikit-learn`, `sentence-transformers`
68
+ * **Data**: `pandas`, `numpy`
69
+ * **Cloud**: `huggingface_hub`, `datasets`
70
+
71
+ ---
 
 
 
 
 
 
 
 
 
72
 
73
  ## πŸš€ How to Run
74
+
75
  1. **Install Dependencies**:
76
  ```bash
77
  pip install -r requirements.txt
78
  ```
79
+
80
+ 2. **Run the App**:
81
+ The app will automatically download the necessary data and models from Hugging Face.
82
  ```bash
83
+ export GEMINI_API_KEY="your_api_key_here"
84
+ python app.py
85
  ```
86
+
87
+ 3. **(Optional) Retrain the Model**:
88
+ If you want to re-run the benchmarks and training using the latest data:
89
  ```bash
90
+ python model-prep.py
 
91
  ```
model-prep.py CHANGED
@@ -5,510 +5,286 @@ import seaborn as sns
5
  import warnings
6
  import os
7
  import torch
8
- from transformers import pipeline
9
- import google.generativeai as genai
10
- from faker import Faker
11
- from datetime import datetime, timedelta
12
- from sklearn.metrics.pairwise import cosine_similarity
13
  import pickle
14
- from dotenv import load_dotenv
15
-
16
- # Load environment variables from the .env filea monk
17
- load_dotenv()
18
-
19
- # Machine Learning Imports
20
- from sklearn.feature_extraction.text import TfidfVectorizer
21
  from sklearn.ensemble import RandomForestRegressor
 
 
22
  from xgboost import XGBRegressor
23
  from sklearn.linear_model import LinearRegression
24
- from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score
25
- from sklearn.decomposition import PCA
26
- from sentence_transformers import SentenceTransformer
27
 
28
- # ---------------------------------------------------------
29
- # 0. SETUP & CONFIGURATION
30
- # ---------------------------------------------------------
 
31
  warnings.filterwarnings('ignore')
32
  pd.set_option('display.max_columns', None)
33
-
34
- # OPTIMIZATION: Check for Apple Silicon (MPS)
35
  device = "mps" if torch.backends.mps.is_available() else "cpu"
 
36
  print(f"πŸš€ Optimization: Running on {device.upper()} device")
37
 
38
  if not os.path.exists('project_plots'):
39
  os.makedirs('project_plots')
40
 
 
41
  # ---------------------------------------------------------
42
- # 1. DATA GENERATION (With 2025 Trends)
43
- # ---------------------------------------------------------
44
- # ---------------------------------------------------------
45
- # 1. DATA GENERATION (With LLM - Falcon-RW-1B)
46
  # ---------------------------------------------------------
47
- def generate_synthetic_data_llm(n_rows=10000):
48
- print(f"\n[1/8] Generating {n_rows} rows of Real-World Data using LLM (Falcon-RW-1B)...")
49
-
50
- # Setup Pipeline
51
- print(" -> Loading Falcon model... (This may take a moment)")
52
-
53
- # MPS Optimization Logic
54
- # 'device' variable is already set globally (cpu or mps)
55
- # Pipelines usually take device=0 for GPU, or device="mps"
56
-
57
- pipeline_kwargs = {
58
- "task": "text-generation",
59
- "model": "tiiuae/falcon-rw-1b",
60
- "device": device # "mps" or "cpu"
61
- }
62
-
63
- # Optimizations for Apple Silicon
64
- if device == "mps":
65
- print(" -> 🍎 Optimization: Using Apple Silicon (MPS) with float16")
66
- pipeline_kwargs["torch_dtype"] = torch.float16
67
- elif device == "cuda":
68
- pipeline_kwargs["device"] = 0 # Transformers often prefers int for CUDA
69
- pipeline_kwargs["torch_dtype"] = torch.float16
70
-
71
  try:
72
- generator = pipeline(**pipeline_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  except Exception as e:
74
- print(f" -> Error loading model: {e}")
75
- return pd.DataFrame(), 0
76
 
77
- print(f" -> βœ… Model Loaded on {device.upper()}")
 
 
 
 
78
 
79
- # Diversity Prompts
80
- prompts = [
81
- "TikTok Description: POV you realize",
82
- "TikTok Description: GRWM for",
83
- "TikTok Description: Day in the life of",
84
- "TikTok Description: Trying the viral",
85
- "TikTok Description: Storytime about",
86
- "TikTok Description: ASMR",
87
- "TikTok Description: My skincare routine",
88
- "TikTok Description: Cooking a healthy",
89
- "TikTok Description: Coding a new",
90
- "TikTok Description: Travel vlog to"
91
  ]
92
 
93
- data = []
94
- fake = Faker()
95
- start_date = datetime(2024, 1, 1)
96
-
97
- # We generate in batches to manage memory/speed better or just loop
98
- # Given n_rows is large, a progress bar or simple print every N is good.
99
-
100
- print(f" -> Starting generation of {n_rows} items...")
101
 
102
- # To speed up, we can ask for multiple sequences per prompt,
103
- # but we need total n_rows.
104
-
105
- rows_generated = 0
106
- batch_size = 5 # Generate 5 variations per prompt call
107
-
108
- while rows_generated < n_rows:
109
- prompt = np.random.choice(prompts)
110
-
 
 
 
 
 
 
 
 
 
 
 
111
  try:
112
- outputs = generator(
113
- prompt,
114
- max_new_tokens=40,
115
- num_return_sequences=batch_size,
116
- do_sample=True,
117
- temperature=0.9,
118
- top_k=50,
119
- top_p=0.95,
120
- truncation=True,
121
- pad_token_id=50256 # Falcon-RW default pad token usually
122
- )
123
 
124
- for o in outputs:
125
- if rows_generated >= n_rows: break
126
-
127
- raw_text = o['generated_text']
128
- # Clean up: remove the prompt prefix if desired, or keep it.
129
- # Usually we want the full description.
130
- # Let's clean newlines.
131
- clean_text = raw_text.replace("\n", " ").strip()
132
-
133
- # Add some synthetic tags if missing (LLM might not add enough)
134
- if "#" not in clean_text:
135
- clean_text += " #fyp #viral #trending"
136
-
137
- # --- SOPHISTICATED VIEW COUNT LOGIC ---
138
- # We inject "ground truth" rules so the model can learn real patterns.
139
-
140
- # Base distribution
141
- base_virality = np.random.lognormal(mean=9.5, sigma=1.8)
142
- multiplier = 1.0
143
-
144
- # 1. Linguistic Patterns (The "Text" Signal)
145
- full_lower = clean_text.lower()
146
-
147
- # Boost for "Hooks" (Questions, direct address)
148
- if "?" in clean_text: multiplier *= 1.2
149
- if "you" in full_lower or "pov" in full_lower: multiplier *= 1.4
150
-
151
- # Boost for Emotional/Urgent words
152
- viral_triggers = ['secret', 'hack', 'wait for it', 'won\'t believe', 'shocking', 'obsessed']
153
- if any(w in full_lower for w in viral_triggers): multiplier *= 1.3
154
-
155
- # Boost for Niche Keywords (Targeting specific audiences)
156
- niche_keywords = ['coding', 'recipe', 'tutorial', 'routine', 'haul']
157
- if any(w in full_lower for w in niche_keywords): multiplier *= 1.2
158
-
159
- # 2. Metadata Signals
160
- upload_time = start_date + timedelta(days=np.random.randint(0, 365), hours=np.random.randint(0, 23))
161
- duration = np.random.randint(5, 180)
162
- hour = upload_time.hour
163
- is_weekend = 1 if upload_time.weekday() >= 5 else 0
164
-
165
- if is_weekend: multiplier *= 1.25 # Weekends are slightly better
166
- if duration < 15: multiplier *= 1.3 # Short content is king
167
- if hour >= 17 and hour <= 21: multiplier *= 1.15 # Prime time boost
168
-
169
- # Calculate Final Views
170
- views = int(base_virality * multiplier)
171
-
172
- data.append({
173
- 'upload_date': upload_time,
174
- 'description': clean_text,
175
- 'category': 'General',
176
- 'video_duration_sec': duration,
177
- 'hour_of_day': hour,
178
- 'is_weekend': is_weekend,
179
- 'hashtag_count': clean_text.count('#'),
180
- 'views': views
181
- })
182
- rows_generated += 1
183
-
184
- # Print one example per batch for quality control
185
- if len(outputs) > 0:
186
- print(f" πŸ‘€ Sample: {data[-1]['description'][:100]}...")
187
-
188
- if rows_generated % 100 == 0:
189
- print(f" -> Generated {rows_generated}/{n_rows} rows...")
190
 
191
  except Exception as e:
192
- print(f" ⚠️ Generation Error: {e}")
193
- break
194
 
195
- df = pd.DataFrame(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- # --- SAVE RAW DATA ---
198
- raw_save_path = 'raw_social_media_data.parquet'
199
- df.to_parquet(raw_save_path)
200
- print(f" -> πŸ’Ύ Raw Data Saved to {raw_save_path}")
201
 
202
- # Process for training (Targets)
203
- df = df.sort_values('upload_date').reset_index(drop=True)
204
- threshold = df['views'].quantile(0.80)
205
- df['is_viral_binary'] = (df['views'] > threshold).astype(int)
206
- df['log_views'] = np.log1p(df['views'])
207
 
208
- return df, threshold
 
 
209
 
210
  # ---------------------------------------------------------
211
- # 2. EDA & PREPROCESSING
212
  # ---------------------------------------------------------
213
- def process_data_pipeline(df):
214
- print("\n[2/8] Processing Data Pipeline...")
215
 
216
- # Simple EDA Save
217
- clean_df = df[df['video_duration_sec'] > 0].copy()
218
- plt.figure(figsize=(6,4))
219
- sns.histplot(clean_df['log_views'], color='teal')
220
- plt.title('Log Views Distribution')
221
- plt.savefig('project_plots/eda_distribution.png')
222
- plt.close()
223
 
224
- # TF-IDF & Split
225
- tfidf = TfidfVectorizer(max_features=2000, stop_words='english')
226
- X_text = tfidf.fit_transform(df['description']).toarray()
 
227
 
228
- # --- NEW: Data Diversity Plot (PCA) ---
229
- print(" -> 🎨 Generating Diversity Plot...")
230
- from sklearn.decomposition import PCA
 
 
 
 
 
 
 
 
 
 
231
 
232
- # 2D Projection of text features
233
- pca = PCA(n_components=2)
234
- X_pca = pca.fit_transform(X_text)
235
 
236
- plt.figure(figsize=(10, 6))
237
- plt.scatter(X_pca[:, 0], X_pca[:, 1], c=df['log_views'], cmap='viridis', alpha=0.5)
238
- plt.colorbar(label='Log Views')
239
- plt.title('Semantic Diversity of Generated Content (PCA)')
240
- plt.xlabel('Principal Component 1')
241
- plt.ylabel('Principal Component 2')
242
- plt.savefig('project_plots/diversity_plot.png')
243
- plt.close()
244
- print(" -> Plot saved to 'project_plots/diversity_plot.png'")
245
- # --------------------------------------
246
 
247
- num_cols = ['video_duration_sec', 'hour_of_day', 'is_weekend', 'hashtag_count']
248
- X_num = df[num_cols].values
 
249
 
250
- X = np.hstack((X_text, X_num))
251
- y = df['log_views'].values
252
- split_idx = int(len(df) * 0.80)
253
- return X[:split_idx], X[split_idx:], y[:split_idx], y[split_idx:], tfidf
254
-
255
- # ---------------------------------------------------------
256
- # 3. MODEL COMPARISON & TRAINING
257
- # ---------------------------------------------------------
258
- def compare_and_train_best_model(X_train, y_train, X_test, y_test):
259
- print("\n[3/8] Comparing 3 Models to find the best one...")
260
 
 
261
  models = {
262
- "Linear Regression": LinearRegression(),
263
- "Random Forest": RandomForestRegressor(n_estimators=50, max_depth=10, n_jobs=-1),
264
- "XGBoost": XGBRegressor(n_estimators=100, learning_rate=0.1, n_jobs=-1)
265
  }
266
 
267
- results = {}
268
- best_name = None
269
- best_score = float('inf') # RMSE so lower is better
270
- best_model_obj = None
271
 
272
- print(f"{'Model':<20} | {'RMSE':<10} | {'MAE':<10} | {'MAPE':<10} | {'RΒ²':<10}")
273
- print("-" * 70)
274
 
275
  for name, model in models.items():
276
  model.fit(X_train, y_train)
277
  preds_log = model.predict(X_test)
278
 
279
- # Invert log for real metrics
280
  preds_real = np.expm1(preds_log)
281
- y_test_real = np.expm1(y_test)
282
 
283
- rmse = np.sqrt(mean_squared_error(y_test_real, preds_real))
284
- mae = mean_absolute_error(y_test_real, preds_real)
285
- mape = mean_absolute_percentage_error(y_test_real, preds_real)
286
- r2 = r2_score(y_test, preds_log)
287
 
288
- results[name] = {'RMSE': rmse, 'MAE': mae, 'MAPE': mape, 'R2': r2}
 
 
 
 
289
 
290
- print(f"{name:<20} | {rmse:.0f} | {mae:.0f} | {mape:.2%} | {r2:.3f}")
291
 
292
- if rmse < best_score:
293
- best_score = rmse
294
- best_name = name
295
- best_model_obj = model
296
 
297
- print("-" * 70)
298
- print(f"πŸ† Winner: {best_name} (RMSE: {best_score:.0f})")
299
 
300
  # --- PLOTTING ---
301
- plt.figure(figsize=(8, 5))
302
-
303
- # Comparison Bar Chart (RMSE)
304
- names = list(results.keys())
305
- rmse_scores = [results[n]['RMSE'] for n in names]
306
- plt.bar(names, rmse_scores, color=['gray', 'gray', 'green'])
307
- plt.title('Model Comparison (RMSE - Lower is Better)')
308
- plt.ylabel('RMSE (Views)')
309
-
310
- plt.tight_layout()
311
- plt.savefig('project_plots/model_comparison.png')
312
- plt.close()
313
- print(" -> Comparison plot saved to 'project_plots/model_comparison.png'")
314
-
315
- return best_model_obj
316
-
317
- def plot_feature_importance(model, vectorizer, output_path='project_plots/feature_importance.png'):
318
- print(" -> πŸ“Š Generating Feature Importance Plot...")
319
-
320
- # 1. Get Feature Names
321
- # TF-IDF features
322
- tfidf_names = vectorizer.get_feature_names_out()
323
- # Numeric features (Hardcoded based on process_data_pipeline)
324
- meta_names = ['video_duration_sec', 'hour_of_day', 'is_weekend', 'hashtag_count']
325
- all_features = np.concatenate([tfidf_names, meta_names])
326
-
327
- # 2. Get Importances
328
- if hasattr(model, 'feature_importances_'):
329
- # XGBoost / Random Forest
330
- importances = model.feature_importances_
331
- title = f"Top 20 Features ({type(model).__name__})"
332
- elif hasattr(model, 'coef_'):
333
- # Linear Regression
334
- importances = np.abs(model.coef_) # Magnitude matters
335
- title = f"Top 20 Feature Coefficients ({type(model).__name__})"
336
- else:
337
- print(" ⚠️ Model type does not support feature importance extraction.")
338
- return
339
-
340
- # 3. Sort and Plot Top 20
341
- indices = np.argsort(importances)[-20:]
342
-
343
- plt.figure(figsize=(10, 8))
344
- plt.title(title)
345
- plt.barh(range(len(indices)), importances[indices], align='center', color='teal')
346
- plt.yticks(range(len(indices)), [all_features[i] for i in indices])
347
- plt.xlabel('Relative Importance')
348
- plt.tight_layout()
349
- plt.savefig(output_path)
350
- plt.close()
351
- print(f" -> Feature Importance saved to '{output_path}'")
352
-
353
- # ---------------------------------------------------------
354
- # 4. EMBEDDINGS GENERATION (For Search)
355
- # ---------------------------------------------------------
356
- def create_search_index(df):
357
- print("\n[4/8] Creating Vector Search Index...")
358
- # Generate embeddings for ALL data so we can search the whole history
359
- st_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
360
- embeddings = st_model.encode(df['description'].tolist(), convert_to_numpy=True, show_progress_bar=True)
361
-
362
- df['embedding'] = list(embeddings)
363
-
364
- # Save to Parquet (The Knowledge Base)
365
- save_path = 'tiktok_knowledge_base.parquet'
366
- df.to_parquet(save_path)
367
- print(f" - Knowledge Base saved to {save_path}")
368
- return df, st_model
369
-
370
- # ---------------------------------------------------------
371
- # 5. RETRIEVAL & IMPROVEMENT ENGINE (The Magic Step)
372
- # ---------------------------------------------------------
373
- def optimize_content_with_gemini(user_input, model, vectorizer, knowledge_df, st_model):
374
- """
375
- 1. Scores original idea.
376
- 2. Finds top 3 similar VIRAL videos.
377
- 3. Asks Gemini to rewrite the idea.
378
- 4. Re-scores the new idea.
379
- """
380
- print("\n" + "="*50)
381
- print("πŸš€ VIRAL OPTIMIZATION ENGINE")
382
- print("="*50)
383
-
384
- # --- STEP 1: INITIAL SCORE ---
385
- text_vec = vectorizer.transform([user_input]).toarray()
386
- # Assume default meta for prediction (15s, 6 PM, weekday)
387
- meta_vec = np.array([[15, 18, 0, user_input.count('#')]])
388
- feat_vec = np.hstack((text_vec, meta_vec))
389
-
390
- initial_log = model.predict(feat_vec)[0]
391
- initial_views = int(np.expm1(initial_log))
392
-
393
- print(f"\nπŸ“ ORIGINAL IDEA: {user_input}")
394
- print(f"πŸ“Š Predicted Views: {initial_views:,}")
395
-
396
- # --- STEP 2: VECTOR SEARCH (Find similar successful videos) ---
397
- print("\nπŸ” Searching for similar viral hits in Parquet file...")
398
-
399
- # Filter only for successful videos (e.g., top 25% of views)
400
- high_performance_df = knowledge_df[knowledge_df['views'] > knowledge_df['views'].quantile(0.75)].copy()
401
-
402
- # Encode user input
403
- user_embedding = st_model.encode([user_input], convert_to_numpy=True)
404
-
405
- # Stack embeddings from the dataframe into a matrix
406
- target_embeddings = np.stack(high_performance_df['embedding'].values)
407
-
408
- # Calculate Cosine Similarity
409
- similarities = cosine_similarity(user_embedding, target_embeddings)
410
-
411
- # Get Top 3 indices
412
- top_3_indices = similarities[0].argsort()[-3:][::-1]
413
- top_3_videos = high_performance_df.iloc[top_3_indices]['description'].tolist()
414
-
415
- print(" -> Found 3 similar viral videos to learn from:")
416
- for i, vid in enumerate(top_3_videos, 1):
417
- print(f" {i}. {vid[:80]}...")
418
-
419
- # --- STEP 3: GEMINI OPTIMIZATION ---
420
- api_key = os.getenv("GEMINI_API_KEY")
421
- if not api_key:
422
- print("\n⚠️ SKIPPING AI REWRITE: No 'GEMINI_API_KEY' found in environment variables.")
423
- print(" (Set it via 'export GEMINI_API_KEY=your_key' in terminal)")
424
- return
425
-
426
- print("\nπŸ€– Sending context to Gemini LLM for optimization...")
427
- genai.configure(api_key=api_key)
428
- llm = genai.GenerativeModel('gemini-2.5-flash-lite')
429
-
430
- prompt = f"""
431
- You are a TikTok Virality Expert.
432
-
433
- My Draft Description: "{user_input}"
434
-
435
- Here are 3 successful, viral videos that are similar to my topic:
436
- 1. {top_3_videos[0]}
437
- 2. {top_3_videos[1]}
438
- 3. {top_3_videos[2]}
439
-
440
- Task: Rewrite my draft description to make it go viral.
441
- Use the slang, hashtag style, and structure of the successful examples provided.
442
- Keep it under 20 words plus hashtags. Return ONLY the new description.
443
- """
444
-
445
- try:
446
- response = llm.generate_content(prompt)
447
- improved_idea = response.text.strip()
448
-
449
- print(f"\n✨ IMPROVED IDEA (By Gemini): {improved_idea}")
450
 
451
- # --- STEP 4: RE-EVALUATION ---
452
- new_text_vec = vectorizer.transform([improved_idea]).toarray()
453
- # Update hashtag count for new features
454
- new_meta_vec = np.array([[15, 18, 0, improved_idea.count('#')]])
455
- new_feat_vec = np.hstack((new_text_vec, new_meta_vec))
456
 
457
- new_log = model.predict(new_feat_vec)[0]
458
- new_views = int(np.expm1(new_log))
 
 
459
 
460
- print(f"πŸ“Š New Predicted Views: {new_views:,}")
461
-
462
- improvement = ((new_views - initial_views) / initial_views) * 100
463
- if improvement > 0:
464
- print(f"πŸš€ POTENTIAL UPLIFT: +{improvement:.1f}%")
465
- else:
466
- print(f"😐 No significant uplift predicted (Model is strict!).")
467
-
468
- except Exception as e:
469
- print(f"❌ Error calling Gemini API: {e}")
470
-
471
- # ---------------------------------------------------------
472
- # MAIN EXECUTION
473
- # ---------------------------------------------------------
474
- if __name__ == "__main__":
475
- # 1. Pipeline (LLM)
476
- print("πŸš€ Starting Production Run: Generatng 10,000 rows...")
477
- df, _ = generate_synthetic_data_llm(10000)
478
- X_train, X_test, y_train, y_test, tfidf = process_data_pipeline(df)
479
 
480
- # 2. Train Prediction Model (COMPARISON Step)
481
- best_model = compare_and_train_best_model(X_train, y_train, X_test, y_test)
 
 
482
 
483
- # 3. Create Knowledge Base (Embeddings)
484
- knowledge_df, st_model = create_search_index(df)
 
 
 
485
 
486
- # 4. Save Artifacts for App & Plot Importance
487
- print("\n[5/8] Saving Model Artifacts & Finalizing Plots...")
 
488
 
489
- # Plot Feature Importance (Now that we have the winner)
490
- plot_feature_importance(best_model, tfidf)
491
-
492
- # Use Pickle for Model (Generic)
493
- with open("viral_model.pkl", "wb") as f:
494
- pickle.dump(best_model, f)
495
- print(" - Model saved to 'viral_model.pkl'")
496
 
497
- with open("tfidf_vectorizer.pkl", "wb") as f:
498
- pickle.dump(tfidf, f)
499
- print(" - Vectorizer saved to 'tfidf_vectorizer.pkl'")
500
 
501
- # 5. User Interaction Loop
502
- while True:
503
- print("\n" + "-"*30)
504
- user_input = input("Enter your video idea (or 'q' to quit): ")
505
- if user_input.lower() == 'q':
506
- break
507
-
508
- optimize_content_with_gemini(
509
- user_input=user_input,
510
- model=best_model,
511
- vectorizer=tfidf,
512
- knowledge_df=knowledge_df,
513
- st_model=st_model
514
- )
 
5
  import warnings
6
  import os
7
  import torch
8
+ import time
 
 
 
 
9
  import pickle
10
+ import google.generativeai as genai
11
+ from datasets import load_dataset
12
+ from sentence_transformers import SentenceTransformer
13
+ from sklearn.metrics import silhouette_score, mean_squared_error, mean_absolute_error, r2_score
 
 
 
14
  from sklearn.ensemble import RandomForestRegressor
15
+ from sklearn.preprocessing import LabelEncoder
16
+ from sklearn.metrics.pairwise import cosine_similarity
17
  from xgboost import XGBRegressor
18
  from sklearn.linear_model import LinearRegression
19
+ from dotenv import load_dotenv
 
 
20
 
21
+ # Load environment variables
22
+ load_dotenv()
23
+
24
+ # Configuration
25
  warnings.filterwarnings('ignore')
26
  pd.set_option('display.max_columns', None)
 
 
27
  device = "mps" if torch.backends.mps.is_available() else "cpu"
28
+
29
  print(f"πŸš€ Optimization: Running on {device.upper()} device")
30
 
31
  if not os.path.exists('project_plots'):
32
  os.makedirs('project_plots')
33
 
34
+
35
  # ---------------------------------------------------------
36
+ # 1. LOAD DATA
 
 
 
37
  # ---------------------------------------------------------
38
+ def load_data():
39
+ print(f"\n[1/5] Loading Dataset from Hugging Face...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  try:
41
+ dataset = load_dataset("MatanKriel/social-assitent-synthetic-data")
42
+ if 'train' in dataset:
43
+ df = dataset['train'].to_pandas()
44
+ else:
45
+ df = dataset.to_pandas()
46
+
47
+ print(f" -> βœ… Loaded {len(df)} rows.")
48
+
49
+ # Basic Preprocessing
50
+ if 'views' in df.columns:
51
+ # Create Log Targets for better regression
52
+ df['log_views'] = np.log1p(df['views'])
53
+
54
+ return df
55
  except Exception as e:
56
+ print(f" ❌ Error loading data: {e}")
57
+ return pd.DataFrame()
58
 
59
+ # ---------------------------------------------------------
60
+ # 2. EMBEDDING BENCHMARK
61
+ # ---------------------------------------------------------
62
+ def benchmark_and_select_model(df):
63
+ print("\n[2/5] Benchmarking Embedding Models...")
64
 
65
+ models = [
66
+ "sentence-transformers/all-MiniLM-L6-v2",
67
+ "sentence-transformers/all-mpnet-base-v2",
68
+ "BAAI/bge-small-en-v1.5"
 
 
 
 
 
 
 
 
69
  ]
70
 
71
+ results = []
 
 
 
 
 
 
 
72
 
73
+ # We need labels for Silhouette Score (Cluster Quality)
74
+ # 'category' is the perfect ground truth for semantic clusters
75
+ if 'category' not in df.columns:
76
+ print("⚠️ No 'category' column. Skipping quality metric.")
77
+ labels = np.zeros(len(df))
78
+ else:
79
+ labels = df['category'].values
80
+
81
+ # Sample for speed if dataset is huge (>5k)
82
+ sample_df = df.sample(min(len(df), 3000), random_state=42)
83
+ sample_texts = sample_df['description'].fillna("").tolist()
84
+ sample_labels = sample_df['category'].values
85
+
86
+ print(f"{'Model':<40} | {'Time (s)':<10} | {'Silhouette':<10}")
87
+ print("-" * 65)
88
+
89
+ best_score = -2
90
+ best_model_name = models[0] # Default
91
+
92
+ for model_name in models:
93
  try:
94
+ st_model = SentenceTransformer(model_name, device=device)
 
 
 
 
 
 
 
 
 
 
95
 
96
+ # Measure Encoding Time
97
+ start_t = time.time()
98
+ embeddings = st_model.encode(sample_texts, convert_to_numpy=True, show_progress_bar=False)
99
+ time_taken = time.time() - start_t
100
+
101
+ # Measure Cluster Quality
102
+ score = silhouette_score(embeddings, sample_labels)
103
+
104
+ results.append({
105
+ "Model": model_name.split('/')[-1], # Short name
106
+ "Time (s)": time_taken,
107
+ "Silhouette Score": score
108
+ })
109
+
110
+ print(f"{model_name:<40} | {time_taken:.2f} | {score:.4f}")
111
+
112
+ if score > best_score:
113
+ best_score = score
114
+ best_model_name = model_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  except Exception as e:
117
+ print(f"❌ Error with {model_name}: {e}")
 
118
 
119
+ print("-" * 65)
120
+ print(f"πŸ† Winner: {best_model_name} (Score: {best_score:.4f})")
121
+
122
+ # --- PLOTTING ---
123
+ if results:
124
+ res_df = pd.DataFrame(results)
125
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
126
+
127
+ # 1. Time Plot
128
+ sns.barplot(data=res_df, x='Model', y='Time (s)', ax=axes[0], palette='Blues_d')
129
+ axes[0].set_title('Encoding Speed (Lower is Better)')
130
+ axes[0].tick_params(axis='x', rotation=45)
131
+
132
+ # 2. Quality Plot
133
+ sns.barplot(data=res_df, x='Model', y='Silhouette Score', ax=axes[1], palette='Greens_d')
134
+ axes[1].set_title('Clustering Quality (Higher is Better)')
135
+ axes[1].tick_params(axis='x', rotation=45)
136
+
137
+ plt.tight_layout()
138
+ plt.savefig('project_plots/embedding_benchmark.png')
139
+ plt.close()
140
+ print(" -> πŸ“Š Plot saved: 'project_plots/embedding_benchmark.png'")
141
+
142
+ # Save the winner name for app.py
143
+ with open("embedding_model_name.txt", "w") as f:
144
+ f.write(best_model_name)
145
+
146
+ return best_model_name
147
+
148
+ # ---------------------------------------------------------
149
+ # 3. GENERATE KNOWLEDGE BASE
150
+ # ---------------------------------------------------------
151
+ def generate_knowledge_base(df, model_name):
152
+ print(f"\n[3/5] Generating Embeddings with Winner ({model_name})...")
153
 
154
+ st_model = SentenceTransformer(model_name, device=device)
 
 
 
155
 
156
+ # Encode ALL descriptions
157
+ embeddings = st_model.encode(df['description'].fillna("").tolist(),
158
+ convert_to_numpy=True,
159
+ show_progress_bar=True)
 
160
 
161
+ # Store in DataFrame
162
+ df['embedding'] = list(embeddings)
163
+ return df, st_model
164
 
165
  # ---------------------------------------------------------
166
+ # 4. TRAIN REGRESSION MODEL
167
  # ---------------------------------------------------------
168
+ def train_regressor(df):
169
+ print("\n[4/5] Training View Prediction Model...")
170
 
171
+ # Feature Engineering
172
+ # 1. Semantic Features (The Embeddings)
173
+ X_text = np.stack(df['embedding'].values)
 
 
 
 
174
 
175
+ # 2. Meta Features (Duration, etc.)
176
+ # Define features to include
177
+ num_cols = ['duration', 'hour_of_day', 'followers']
178
+ cat_cols = ['category', 'gender', 'day_of_week', 'age']
179
 
180
+ # Fill missing numerics
181
+ for c in num_cols:
182
+ if c not in df.columns: df[c] = 0
183
+
184
+ # Process Categoricals (Label Encoding)
185
+ for c in cat_cols:
186
+ if c not in df.columns:
187
+ df[c] = 'Unknown'
188
+ le = LabelEncoder()
189
+ df[c + '_encoded'] = le.fit_transform(df[c].astype(str))
190
+
191
+ # Combine all numeric features (original numeric + encoded categorical)
192
+ final_meta_cols = num_cols + [c + '_encoded' for c in cat_cols]
193
 
194
+ print(f" -> Features used: Embeddings + {final_meta_cols}")
 
 
195
 
196
+ X_meta = df[final_meta_cols].values
 
 
 
 
 
 
 
 
 
197
 
198
+ # Combine
199
+ X = np.hstack((X_text, X_meta))
200
+ y = df['log_views'].values
201
 
202
+ # Split (80/20)
203
+ split = int(len(df) * 0.8)
204
+ X_train, X_test = X[:split], X[split:]
205
+ y_train, y_test = y[:split], y[split:]
 
 
 
 
 
 
206
 
207
+ # Model Comparison
208
  models = {
209
+ "RandomForest": RandomForestRegressor(n_estimators=100, max_depth=10, n_jobs=-1),
210
+ "XGBoost": XGBRegressor(n_estimators=100, learning_rate=0.1, n_jobs=-1),
211
+ "LinearReg": LinearRegression()
212
  }
213
 
214
+ best_model = None
215
+ best_rmse = float('inf')
216
+
217
+ results = [] # Store for plotting
218
 
219
+ print(f"{'Model':<15} | {'RMSE (Views)':<15} | {'RΒ²':<10}")
220
+ print("-" * 45)
221
 
222
  for name, model in models.items():
223
  model.fit(X_train, y_train)
224
  preds_log = model.predict(X_test)
225
 
226
+ # Convert log predictions back to real views
227
  preds_real = np.expm1(preds_log)
228
+ y_real = np.expm1(y_test)
229
 
230
+ rmse = np.sqrt(mean_squared_error(y_real, preds_real))
231
+ r2 = r2_score(y_test, preds_log)
 
 
232
 
233
+ results.append({
234
+ "Model": name,
235
+ "RMSE": rmse,
236
+ "R2": r2
237
+ })
238
 
239
+ print(f"{name:<15} | {rmse:,.0f} | {r2:.3f}")
240
 
241
+ if rmse < best_rmse:
242
+ best_rmse = rmse
243
+ best_model = model
 
244
 
245
+ print("-" * 45)
246
+ print(f"πŸ† Best Regressor: {type(best_model).__name__}")
247
 
248
  # --- PLOTTING ---
249
+ if results:
250
+ res_df = pd.DataFrame(results)
251
+ fig, axes = plt.subplots(1, 2, figsize=(14, 6))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
+ # 1. RMSE Plot
254
+ sns.barplot(data=res_df, x='Model', y='RMSE', ax=axes[0], palette='Reds_d')
255
+ axes[0].set_title('Prediction Error (RMSE) - Lower is Better')
 
 
256
 
257
+ # 2. R2 Plot
258
+ sns.barplot(data=res_df, x='Model', y='R2', ax=axes[1], palette='Greens_d')
259
+ axes[1].set_title('Explained Variance (RΒ²) - Higher is Better')
260
+ axes[1].set_ylim(0, 1) # R2 is usually 0-1
261
 
262
+ plt.tight_layout()
263
+ plt.savefig('project_plots/regression_comparison.png')
264
+ plt.close()
265
+ print(" -> πŸ“Š Plot saved: 'project_plots/regression_comparison.png'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
266
 
267
+ # Save Model
268
+ with open("viral_model.pkl", "wb") as f:
269
+ pickle.dump(best_model, f)
270
+ print(" -> βœ… Model saved to 'viral_model.pkl'")
271
 
272
+ return best_model
273
+
274
+
275
+ if __name__ == "__main__":
276
+ # EXECUTION PIPELINE
277
 
278
+ # 1. Load
279
+ df = load_data()
280
+ if df.empty: exit()
281
 
282
+ # 2. Benchmark
283
+ best_emb_model = benchmark_and_select_model(df)
 
 
 
 
 
284
 
285
+ # 3. Generate Knowledge Base
286
+ df, st_model = generate_knowledge_base(df, best_emb_model)
 
287
 
288
+ # 4. Train
289
+ reg_model = train_regressor(df)
290
+
 
 
 
 
 
 
 
 
 
 
 
project_plots/diversity_plot.png DELETED

Git LFS Details

  • SHA256: 84b06bd4e538740e3057e7f48631ece2f140b31c376f93e602653b92d8cade26
  • Pointer size: 131 Bytes
  • Size of remote file: 163 kB
tfidf_vectorizer.pkl β†’ project_plots/embedding_benchmark.png RENAMED
File without changes
project_plots/feature_importance.png DELETED

Git LFS Details

  • SHA256: a0ac14b476322d1a9d53149a728ff0b3a6002763d1157478379fa68ea701ae04
  • Pointer size: 130 Bytes
  • Size of remote file: 32 kB
project_plots/model_comparison.png DELETED

Git LFS Details

  • SHA256: a6c178f39a12baa6335e4442ac83bdac07594cfec6e4b2550819f873eda0b2ef
  • Pointer size: 130 Bytes
  • Size of remote file: 17.8 kB
project_plots/{eda_distribution.png β†’ regression_comparison.png} RENAMED
File without changes
raw_social_media_data.parquet DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:92939908f14b69157b0a99ee186ef1f0ff70d54974bfcf14235468674f73d450
3
- size 1185030
 
 
 
 
tiktok_knowledge_base.parquet DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:82dacd5da6cc1e8f9a62db8e8b6d68f5d5e466300d94dc7707c7afd342a97594
3
- size 17274184
 
 
 
 
upload_model_to_hf.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from huggingface_hub import HfApi, login
3
+
4
+ # Configuration
5
+ MODEL_FILE = "viral_model.pkl"
6
+ REPO_ID = "MatanKriel/social-assitent-viral-predictor"
7
+
8
+ def upload_model():
9
+ print(f"πŸš€ Preparing to upload '{MODEL_FILE}' to Hugging Face...")
10
+
11
+ if not os.path.exists(MODEL_FILE):
12
+ print(f"❌ Error: {MODEL_FILE} not found. Run model-prep.py first.")
13
+ return
14
+
15
+ try:
16
+ api = HfApi()
17
+
18
+ # Create repo if it doesn't exist
19
+ print(f"πŸ“¦ Checking repository '{REPO_ID}'...")
20
+ api.create_repo(repo_id=REPO_ID, exist_ok=True, repo_type="model")
21
+
22
+ # Upload file
23
+ print(f"πŸ“€ Uploading {MODEL_FILE}...")
24
+ api.upload_file(
25
+ path_or_fileobj=MODEL_FILE,
26
+ path_in_repo=MODEL_FILE,
27
+ repo_id=REPO_ID,
28
+ repo_type="model"
29
+ )
30
+
31
+ print("\nβœ… Upload Complete!")
32
+ print(f"πŸ”— Model available at: https://huggingface.co/{REPO_ID}")
33
+ print("πŸ’‘ You can now run 'python app.py' and it will download this model.")
34
+
35
+ except Exception as e:
36
+ print(f"\n❌ Error during upload: {e}")
37
+ print("πŸ’‘ Tip: Ensure you are logged in. Run 'huggingface-cli login' if needed.")
38
+
39
+ if __name__ == "__main__":
40
+ upload_model()
viral_model.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:92f2ca0ca3bf30dd6a5d7e84d8ebff5612134ff895124e03cb51586a000d9527
3
- size 214620
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:301463409cf3d6c05f45fe8a31244fafe1ea7bb88619b1fafa35fcab4e207acc
3
+ size 320476