Matan Kriel commited on
Commit
879ede5
Β·
1 Parent(s): 450331a

updated data generation pipeline

Browse files
README.md CHANGED
@@ -1,114 +1,78 @@
1
  ---
2
- title: Social Media Virality
3
- emoji: πŸ“Ί
4
- colorFrom: green
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.9.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
- # Social Media Virality Prediction & Optimization Project
12
 
13
- **Course**: Data Science & Machine Learning Applications
14
- **Project**: Viral Content Assistant
15
-
16
- ## 1. Project Overview
17
- This project aims to develop a data-driven system capable of predicting the viral potential of short-form video content (e.g., TikTok) and optimizing it using Generative AI. By leveraging Natural Language Processing (NLP) and Machine Learning (ML), the system analyzes video descriptions and metadata to forecast view counts and prescribes actionable improvements to maximize engagement.
18
-
19
- The core solution consists of a machine learning pipeline for virality prediction and a web application (Gradio) for real-time user interaction.
20
-
21
- ## 2. Data Science Methodology
22
-
23
- ### 2.1 Data Acquisition & Generation
24
- Due to privacy restrictions and API limitations of social platforms, we simulated a realistic dataset reflecting 2025 social media trends.
25
- * **Source**: Synthetic generation options using the `Faker` library and `numpy` probabilistic distributions.
26
- * **Volume**: 10,000 samples.
27
- * **Features**:
28
- * **Textual**: Video descriptions rich in slang (e.g., "Skibidi", "Girl Dinner"), hashtags, and emojis.
29
- * **Temporal**: Upload hour, day of week.
30
- * **Meta**: Video duration, category (Gaming, Beauty, etc.).
31
- * **Target Variable**: `views` (Log-normally distributed to mimic real-world viral discrepancies).
32
-
33
- ### 2.2 Exploratory Data Analysis (EDA)
34
- We analyzed the distribution of the target variable and feature correlations.
35
- * **Observation**: View counts follow a "power law" distribution; most videos have few views, while a few "viral hits" capture the majority.
36
- * **Preprocessing**: We applied a Log-transformation (`np.log1p`) to the `views` variable to normalize the distribution for regression models.
37
-
38
- ![Views Distribution](project_plots/eda_distribution.png)
39
- *Figure 1: Distribution of Log-Transformed View Counts.*
40
-
41
- ### 2.3 Feature Engineering
42
- * **Text Embeddings**: We used **TF-IDF Vectorization** (Top 2,000 features) to convert unstructured text descriptions into numerical vectors.
43
- * **Meta Features**: Encoded `is_weekend`, `hour_of_day`, and `video_duration_sec`.
44
- * **Data Splitting**: A **Temporal Split** (80/20) was used instead of a random split to prevent data leakage, ensuring the model predicts future videos based on past trends.
45
-
46
- ## 3. Model Development & Evaluation
47
-
48
- We evaluated three distinct algorithms to solve the regression problem (predicting log-views):
49
-
50
- 1. **Linear Regression**: Baseline model for interpretability.
51
- 2. **Random Forest Regressor**: Ensemble method to capture non-linear relationships.
52
- 3. **XGBoost Regressor**: Gradient boosting machine known for state-of-the-art tabular performance.
53
-
54
- ### 3.1 Comparative Metrics
55
- Models were assessed using:
56
- * **RMSE (Root Mean Squared Error)**: The primary metric for regression accuracy.
57
- * **RΒ² (Coefficient of Determination)**: Explains the variance captured by the model.
58
- * **F1-Score**: Used to proxy classification performance (predicting if a video hits the "Viral Threshold" (top 20%)).
59
-
60
- ![Model Leaderboard](project_plots/model_leaderboard.png)
61
- *Figure 2: Performance comparison across different architectures.*
62
-
63
- ### 3.2 Result
64
- The **XGBoost Regressor** outperformed other models, achieving the lowest RMSE on the test set. This model was selected for the final deployment.
65
-
66
- ## 4. Advanced Analysis: Embeddings & Semantic Search
67
-
68
- Beyond simple regression, we implemented a semantic search engine using **SentenceTransformers** (`all-MiniLM-L6-v2`).
69
- * **Purpose**: To retrieve historical viral hits conceptually similar to the user's new idea.
70
- * **Clustering**: We visualized the semantic space using PCA (Principal Component Analysis).
71
-
72
- ![Embedding Clusters](project_plots/embedding_clusters.png)
73
- *Figure 3: Semantic clustering of video descriptions.*
74
-
75
- ## 5. Application & Deployment
76
-
77
- The final deliverable is an interactive web application built with **Gradio**.
78
-
79
- ### 5.1 System Architecture
80
- The system is decoupled into two main components:
81
- 1. **Training Pipeline (`model-prep.py`)**: Runs offline to generate synthetic data, train the XGBoost model, and create the vector database. It saves these artifacts (`viral_model.json`, `tfidf_vectorizer.pkl`, `tiktok_knowledge_base.parquet`).
82
- 2. **Inference App (`app.py`)**: A lightweight Gradio app that loads the pre-trained artifacts to serve real-time predictions without needing to retrain.
83
-
84
- **Data Flow**:
85
- 1. **Input**: User provided video description.
86
- 2. **Inference**: Loaded XGBoost model predicts view count.
87
- 3. **Retrieval**: App searches the pre-computed Parquet knowledge base for similar viral videos.
88
- 4. **Generative Optimization**: **Google Gemini 2.5 Flash Lite** rewrites the draft.
89
- 5. **Output**: Predictions, Similar Videos, and AI-Optimized content.
90
-
91
- ### 5.2 Usage Instructions
92
-
93
- To run the project locally for assessment:
94
-
95
- 1. **Environment Setup**:
96
  ```bash
97
- python3 -m venv .venv
98
- source .venv/bin/activate
99
  pip install -r requirements.txt
100
  ```
101
- 2. **Configuration**:
102
- Ensure the `.env` file contains a valid `GEMINI_API_KEY`.
103
- 3. **Execution**:
 
 
104
  ```bash
 
105
  python app.py
106
  ```
107
- Access the UI at `http://localhost:7860`.
108
-
109
- ## 6. Conclusion
110
- This project demonstrates a complete end-to-end Data Science workflow: from synthetic data creation and rigorous model evaluation to the deployment of a user-facing AI application. The integration of predictive analytics (XGBoost) with generative AI (Gemini) provides a robust tool for content creators.
111
-
112
- ## πŸ† Credits
113
- * **Project Author:** Matan Kriel
114
- * **Project Author:** Odeya Shmuel
 
1
  ---
2
+ title: Social Media Virality Assistant
3
+ emoji: πŸš€
4
+ colorFrom: indigo
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.9.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
11
 
12
+ # πŸš€ Social Media Virality Assistant
13
+
14
+ A Data Science project that uses **Large Language Models (LLMs)** and **Machine Learning** to predict and optimize social media content virality.
15
+
16
+ ## 🌟 Project Overview
17
+ This tool helps content creators go viral by:
18
+ 1. **Predicting Views**: Analyzing video descriptions to forecast performance.
19
+ 2. **Optimizing Content**: Using **Google Gemini AI** to rewrite drafts with viral hooks (slang, hashtags).
20
+ 3. **Learning from History**: Retrieving similar successful videos using **Semantic Search**.
21
+
22
+ ## 🧠 Data Science Methodology
23
+
24
+ ### 1. Synthetic Data Generation (LLM-Based)
25
+ Since real-world TikTok data is private, we simulated a "Viral Environment":
26
+ * **Generator**: Utilized `tiiuae/falcon-rw-1b` (via `transformers`) to generate **10,000 realistic video descriptions**.
27
+ * **Diversity**: Prompted the LLM with various scenarios ("POV", "GRWM", "Storytime") to ensure distinct content clusters.
28
+ * **Ground Truth Logic**: Developed a scoring function that assigns "Views" based on linguistic patterns (e.g., questions, emotional triggers) and metadata (time of day, duration), creating a learnable signal for the ML models.
29
+
30
+ ### 2. Model Development & Comparison
31
+ We treated this as a **Regression Problem** (Predicting Log-Views).
32
+ We compared three algorithms to find the best predictor:
33
+ * **Linear Regression**: Baseline model.
34
+ * **Random Forest**: Good for non-linear interactions.
35
+ * **XGBoost (Winner)**: Gradient boosting provided the best accuracy (Lowest RMSE).
36
+
37
+ **Metrics Used:**
38
+ * **RMSE (Root Mean Squared Error)**: Primary metric for model selection.
39
+ * **MAE (Mean Absolute Error)**: Average view count error.
40
+ * **MAPE**: Average percentage error.
41
+
42
+ ### 3. Advanced Analysis (Plots)
43
+ * **`diversity_plot.png`**: A PCA visualization showing the semantic spread of the 10,000 generated descriptions.
44
+ * **`model_comparison.png`**: Bar chart comparing RMSE across models and ROC curves for viral classification validity.
45
+ * **`feature_importance.png`**: The top 20 words and metadata features that drive virality in our simulated world.
46
+
47
+ ## πŸ› οΈ Tech Stack
48
+ * **Core**: Python, Pandas, Numpy, Scikit-Learn
49
+ * **AI/LLM**: `transformers` (Falcon-1B), `google-genai` (Gemini 2.5)
50
+ * **ML**: XGBoost, Sentence-Transformers (Embeddings)
51
+ * **App**: Gradio (Web UI)
52
+ * **Hardware**: Optimized for Apple Silicon (MPS).
53
+
54
+ ## πŸ“‚ Project Structure
55
+ ```bash
56
+ β”œβ”€β”€ app.py # Inference App (Gradio)
57
+ β”œβ”€β”€ model-prep.py # Training Pipeline (Data Gen -> Train -> Save)
58
+ β”œβ”€β”€ requirements.txt # Dependencies
59
+ β”œβ”€β”€ tiktok_knowledge_base.parquet # Semantic Search Index
60
+ β”œβ”€β”€ viral_model.pkl # Trained ML Model (Pickle)
61
+ β”œβ”€β”€ tfidf_vectorizer.pkl # Text Processor
62
+ └── project_plots/ # Generated Analysis Plots
63
+ ```
64
+
65
+ ## πŸš€ How to Run
66
+ 1. **Install Dependencies**:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  ```bash
 
 
68
  pip install -r requirements.txt
69
  ```
70
+ 2. **Train & Generate Data** (Downloads 2.6GB Model):
71
+ ```bash
72
+ python model-prep.py
73
+ ```
74
+ 3. **Run the App**:
75
  ```bash
76
+ export GEMINI_API_KEY="your_key_here"
77
  python app.py
78
  ```
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -40,9 +40,10 @@ def initialize_app():
40
  knowledge_df = pd.read_parquet(parquet_path)
41
 
42
  # 2. Load Model
43
- print("🧠 Loading XGBoost Model...")
44
- model = XGBRegressor()
45
- model.load_model("viral_model.json")
 
46
 
47
  # 3. Load Vectorizer
48
  print("πŸ”€ Loading TF-IDF Vectorizer...")
 
40
  knowledge_df = pd.read_parquet(parquet_path)
41
 
42
  # 2. Load Model
43
+ print("🧠 Loading Prediction Model (Pickle)...")
44
+ with open("viral_model.pkl", "rb") as f:
45
+ model = pickle.load(f)
46
+ print(f" -> Loaded model type: {type(model).__name__}")
47
 
48
  # 3. Load Vectorizer
49
  print("πŸ”€ Loading TF-IDF Vectorizer...")
model-prep.py CHANGED
@@ -5,6 +5,7 @@ import seaborn as sns
5
  import warnings
6
  import os
7
  import torch
 
8
  import google.generativeai as genai
9
  from faker import Faker
10
  from datetime import datetime, timedelta
@@ -20,7 +21,7 @@ from sklearn.feature_extraction.text import TfidfVectorizer
20
  from sklearn.ensemble import RandomForestRegressor
21
  from xgboost import XGBRegressor
22
  from sklearn.linear_model import LinearRegression
23
- from sklearn.metrics import mean_squared_error, f1_score
24
  from sklearn.decomposition import PCA
25
  from sentence_transformers import SentenceTransformer
26
 
@@ -40,64 +41,165 @@ if not os.path.exists('project_plots'):
40
  # ---------------------------------------------------------
41
  # 1. DATA GENERATION (With 2025 Trends)
42
  # ---------------------------------------------------------
43
- def generate_enhanced_data(n_rows=10000):
44
- print(f"\n[1/8] Generating {n_rows} rows of Real-World 2025 Data...")
45
- fake = Faker()
 
 
46
 
47
- trends = [
48
- 'Delulu', 'Girl Dinner', 'Roman Empire', 'Silent Slay', 'Soft Life',
49
- 'Grimace Shake', 'Wes Anderson Style', 'Beige Flag', 'Canon Event',
50
- 'NPC Stream', 'Skibidi', 'Fanum Tax', 'Yapping', 'Glow Up', 'Fit Check'
51
- ]
52
- formats = [
53
- 'POV: You realize...', 'GRWM for...', 'Day in the life:',
54
- 'Storytime:', 'Trying the viral...', 'ASMR packing orders',
55
- 'Rating my exes...', 'Turn the lights off challenge'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  ]
57
- categories = ['Gaming', 'Beauty', 'Comedy', 'Edutainment', 'Lifestyle', 'Food']
58
 
59
  data = []
 
60
  start_date = datetime(2024, 1, 1)
61
 
62
- for _ in range(n_rows):
63
- upload_time = start_date + timedelta(days=np.random.randint(0, 365), hours=np.random.randint(0, 23))
64
- trend = np.random.choice(trends)
65
- fmt = np.random.choice(formats)
66
- cat = np.random.choice(categories)
67
-
68
- description = f"{fmt} {trend} edition! {fake.sentence(nb_words=6)}"
69
- tags = ['#fyp', '#foryou', '#viral', f'#{trend.replace(" ", "").lower()}', f'#{cat.lower()}']
70
- if np.random.random() > 0.5: tags.append('#trending2025')
71
-
72
- full_text = f"{description} {' '.join(tags)}"
73
-
74
- # Meta Features
75
- duration = np.random.randint(5, 180)
76
- hour = upload_time.hour
77
- is_weekend = 1 if upload_time.weekday() >= 5 else 0
78
-
79
- # View Count Logic
80
- base_virality = np.random.lognormal(mean=9.5, sigma=1.8)
81
- multiplier = 1.0
82
- if is_weekend: multiplier *= 1.2
83
- if duration < 15: multiplier *= 1.4
84
- if "Delulu" in full_text or "POV" in full_text: multiplier *= 1.6
85
- if hour >= 18: multiplier *= 1.1
86
-
87
- views = int(base_virality * multiplier)
88
-
89
- data.append({
90
- 'upload_date': upload_time,
91
- 'description': full_text,
92
- 'category': cat,
93
- 'video_duration_sec': duration,
94
- 'hour_of_day': hour,
95
- 'is_weekend': is_weekend,
96
- 'hashtag_count': len(tags),
97
- 'views': views
98
- })
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  df = pd.DataFrame(data)
 
 
 
 
 
 
 
101
  df = df.sort_values('upload_date').reset_index(drop=True)
102
  threshold = df['views'].quantile(0.80)
103
  df['is_viral_binary'] = (df['views'] > threshold).astype(int)
@@ -123,27 +225,130 @@ def process_data_pipeline(df):
123
  tfidf = TfidfVectorizer(max_features=2000, stop_words='english')
124
  X_text = tfidf.fit_transform(df['description']).toarray()
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  num_cols = ['video_duration_sec', 'hour_of_day', 'is_weekend', 'hashtag_count']
127
  X_num = df[num_cols].values
128
 
129
  X = np.hstack((X_text, X_num))
130
  y = df['log_views'].values
131
- y_bin = df['is_viral_binary'].values
132
-
133
  split_idx = int(len(df) * 0.80)
134
- return X[:split_idx], X[split_idx:], y[:split_idx], y[split_idx:], y_bin[split_idx:], tfidf
135
 
136
  # ---------------------------------------------------------
137
- # 3. TRAINING
138
  # ---------------------------------------------------------
139
- def train_best_model(X_train, y_train, X_test, y_test):
140
- print("\n[3/8] Training Model (XGBoost)...")
141
- model = XGBRegressor(n_estimators=100, learning_rate=0.1, max_depth=6, n_jobs=-1)
142
- model.fit(X_train, y_train)
143
-
144
- rmse = np.sqrt(mean_squared_error(y_test, model.predict(X_test)))
145
- print(f" - Model RMSE: {rmse:.3f}")
146
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
  # ---------------------------------------------------------
149
  # 4. EMBEDDINGS GENERATION (For Search)
@@ -267,20 +472,27 @@ def optimize_content_with_gemini(user_input, model, vectorizer, knowledge_df, st
267
  # MAIN EXECUTION
268
  # ---------------------------------------------------------
269
  if __name__ == "__main__":
270
- # 1. Pipeline
271
- df, _ = generate_enhanced_data(10000)
272
- X_train, X_test, y_train, y_test, _, tfidf = process_data_pipeline(df)
 
273
 
274
- # 2. Train Prediction Model
275
- best_model = train_best_model(X_train, y_train, X_test, y_test)
276
 
277
  # 3. Create Knowledge Base (Embeddings)
278
  knowledge_df, st_model = create_search_index(df)
279
 
280
- # 4. Save Artifacts for App
281
- print("\n[5/8] Saving Model Artifacts for Production...")
282
- best_model.save_model("viral_model.json")
283
- print(" - Model saved to 'viral_model.json'")
 
 
 
 
 
 
284
 
285
  with open("tfidf_vectorizer.pkl", "wb") as f:
286
  pickle.dump(tfidf, f)
 
5
  import warnings
6
  import os
7
  import torch
8
+ from transformers import pipeline
9
  import google.generativeai as genai
10
  from faker import Faker
11
  from datetime import datetime, timedelta
 
21
  from sklearn.ensemble import RandomForestRegressor
22
  from xgboost import XGBRegressor
23
  from sklearn.linear_model import LinearRegression
24
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, mean_absolute_percentage_error, r2_score
25
  from sklearn.decomposition import PCA
26
  from sentence_transformers import SentenceTransformer
27
 
 
41
  # ---------------------------------------------------------
42
  # 1. DATA GENERATION (With 2025 Trends)
43
  # ---------------------------------------------------------
44
+ # ---------------------------------------------------------
45
+ # 1. DATA GENERATION (With LLM - Falcon-RW-1B)
46
+ # ---------------------------------------------------------
47
+ def generate_synthetic_data_llm(n_rows=10000):
48
+ print(f"\n[1/8] Generating {n_rows} rows of Real-World Data using LLM (Falcon-RW-1B)...")
49
 
50
+ # Setup Pipeline
51
+ print(" -> Loading Falcon model... (This may take a moment)")
52
+
53
+ # MPS Optimization Logic
54
+ # 'device' variable is already set globally (cpu or mps)
55
+ # Pipelines usually take device=0 for GPU, or device="mps"
56
+
57
+ pipeline_kwargs = {
58
+ "task": "text-generation",
59
+ "model": "tiiuae/falcon-rw-1b",
60
+ "device": device # "mps" or "cpu"
61
+ }
62
+
63
+ # Optimizations for Apple Silicon
64
+ if device == "mps":
65
+ print(" -> 🍎 Optimization: Using Apple Silicon (MPS) with float16")
66
+ pipeline_kwargs["torch_dtype"] = torch.float16
67
+ elif device == "cuda":
68
+ pipeline_kwargs["device"] = 0 # Transformers often prefers int for CUDA
69
+ pipeline_kwargs["torch_dtype"] = torch.float16
70
+
71
+ try:
72
+ generator = pipeline(**pipeline_kwargs)
73
+ except Exception as e:
74
+ print(f" -> Error loading model: {e}")
75
+ return pd.DataFrame(), 0
76
+
77
+ print(f" -> βœ… Model Loaded on {device.upper()}")
78
+
79
+ # Diversity Prompts
80
+ prompts = [
81
+ "TikTok Description: POV you realize",
82
+ "TikTok Description: GRWM for",
83
+ "TikTok Description: Day in the life of",
84
+ "TikTok Description: Trying the viral",
85
+ "TikTok Description: Storytime about",
86
+ "TikTok Description: ASMR",
87
+ "TikTok Description: My skincare routine",
88
+ "TikTok Description: Cooking a healthy",
89
+ "TikTok Description: Coding a new",
90
+ "TikTok Description: Travel vlog to"
91
  ]
 
92
 
93
  data = []
94
+ fake = Faker()
95
  start_date = datetime(2024, 1, 1)
96
 
97
+ # We generate in batches to manage memory/speed better or just loop
98
+ # Given n_rows is large, a progress bar or simple print every N is good.
99
+
100
+ print(f" -> Starting generation of {n_rows} items...")
101
+
102
+ # To speed up, we can ask for multiple sequences per prompt,
103
+ # but we need total n_rows.
104
+
105
+ rows_generated = 0
106
+ batch_size = 5 # Generate 5 variations per prompt call
107
+
108
+ while rows_generated < n_rows:
109
+ prompt = np.random.choice(prompts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ try:
112
+ outputs = generator(
113
+ prompt,
114
+ max_new_tokens=40,
115
+ num_return_sequences=batch_size,
116
+ do_sample=True,
117
+ temperature=0.9,
118
+ top_k=50,
119
+ top_p=0.95,
120
+ truncation=True,
121
+ pad_token_id=50256 # Falcon-RW default pad token usually
122
+ )
123
+
124
+ for o in outputs:
125
+ if rows_generated >= n_rows: break
126
+
127
+ raw_text = o['generated_text']
128
+ # Clean up: remove the prompt prefix if desired, or keep it.
129
+ # Usually we want the full description.
130
+ # Let's clean newlines.
131
+ clean_text = raw_text.replace("\n", " ").strip()
132
+
133
+ # Add some synthetic tags if missing (LLM might not add enough)
134
+ if "#" not in clean_text:
135
+ clean_text += " #fyp #viral #trending"
136
+
137
+ # --- SOPHISTICATED VIEW COUNT LOGIC ---
138
+ # We inject "ground truth" rules so the model can learn real patterns.
139
+
140
+ # Base distribution
141
+ base_virality = np.random.lognormal(mean=9.5, sigma=1.8)
142
+ multiplier = 1.0
143
+
144
+ # 1. Linguistic Patterns (The "Text" Signal)
145
+ full_lower = clean_text.lower()
146
+
147
+ # Boost for "Hooks" (Questions, direct address)
148
+ if "?" in clean_text: multiplier *= 1.2
149
+ if "you" in full_lower or "pov" in full_lower: multiplier *= 1.4
150
+
151
+ # Boost for Emotional/Urgent words
152
+ viral_triggers = ['secret', 'hack', 'wait for it', 'won\'t believe', 'shocking', 'obsessed']
153
+ if any(w in full_lower for w in viral_triggers): multiplier *= 1.3
154
+
155
+ # Boost for Niche Keywords (Targeting specific audiences)
156
+ niche_keywords = ['coding', 'recipe', 'tutorial', 'routine', 'haul']
157
+ if any(w in full_lower for w in niche_keywords): multiplier *= 1.2
158
+
159
+ # 2. Metadata Signals
160
+ upload_time = start_date + timedelta(days=np.random.randint(0, 365), hours=np.random.randint(0, 23))
161
+ duration = np.random.randint(5, 180)
162
+ hour = upload_time.hour
163
+ is_weekend = 1 if upload_time.weekday() >= 5 else 0
164
+
165
+ if is_weekend: multiplier *= 1.25 # Weekends are slightly better
166
+ if duration < 15: multiplier *= 1.3 # Short content is king
167
+ if hour >= 17 and hour <= 21: multiplier *= 1.15 # Prime time boost
168
+
169
+ # Calculate Final Views
170
+ views = int(base_virality * multiplier)
171
+
172
+ data.append({
173
+ 'upload_date': upload_time,
174
+ 'description': clean_text,
175
+ 'category': 'General',
176
+ 'video_duration_sec': duration,
177
+ 'hour_of_day': hour,
178
+ 'is_weekend': is_weekend,
179
+ 'hashtag_count': clean_text.count('#'),
180
+ 'views': views
181
+ })
182
+ rows_generated += 1
183
+
184
+ # Print one example per batch for quality control
185
+ if len(outputs) > 0:
186
+ print(f" πŸ‘€ Sample: {data[-1]['description'][:100]}...")
187
+
188
+ if rows_generated % 100 == 0:
189
+ print(f" -> Generated {rows_generated}/{n_rows} rows...")
190
+
191
+ except Exception as e:
192
+ print(f" ⚠️ Generation Error: {e}")
193
+ break
194
+
195
  df = pd.DataFrame(data)
196
+
197
+ # --- SAVE RAW DATA ---
198
+ raw_save_path = 'raw_social_media_data.parquet'
199
+ df.to_parquet(raw_save_path)
200
+ print(f" -> πŸ’Ύ Raw Data Saved to {raw_save_path}")
201
+
202
+ # Process for training (Targets)
203
  df = df.sort_values('upload_date').reset_index(drop=True)
204
  threshold = df['views'].quantile(0.80)
205
  df['is_viral_binary'] = (df['views'] > threshold).astype(int)
 
225
  tfidf = TfidfVectorizer(max_features=2000, stop_words='english')
226
  X_text = tfidf.fit_transform(df['description']).toarray()
227
 
228
+ # --- NEW: Data Diversity Plot (PCA) ---
229
+ print(" -> 🎨 Generating Diversity Plot...")
230
+ from sklearn.decomposition import PCA
231
+
232
+ # 2D Projection of text features
233
+ pca = PCA(n_components=2)
234
+ X_pca = pca.fit_transform(X_text)
235
+
236
+ plt.figure(figsize=(10, 6))
237
+ plt.scatter(X_pca[:, 0], X_pca[:, 1], c=df['log_views'], cmap='viridis', alpha=0.5)
238
+ plt.colorbar(label='Log Views')
239
+ plt.title('Semantic Diversity of Generated Content (PCA)')
240
+ plt.xlabel('Principal Component 1')
241
+ plt.ylabel('Principal Component 2')
242
+ plt.savefig('project_plots/diversity_plot.png')
243
+ plt.close()
244
+ print(" -> Plot saved to 'project_plots/diversity_plot.png'")
245
+ # --------------------------------------
246
+
247
  num_cols = ['video_duration_sec', 'hour_of_day', 'is_weekend', 'hashtag_count']
248
  X_num = df[num_cols].values
249
 
250
  X = np.hstack((X_text, X_num))
251
  y = df['log_views'].values
 
 
252
  split_idx = int(len(df) * 0.80)
253
+ return X[:split_idx], X[split_idx:], y[:split_idx], y[split_idx:], tfidf
254
 
255
  # ---------------------------------------------------------
256
+ # 3. MODEL COMPARISON & TRAINING
257
  # ---------------------------------------------------------
258
+ def compare_and_train_best_model(X_train, y_train, X_test, y_test):
259
+ print("\n[3/8] Comparing 3 Models to find the best one...")
260
+
261
+ models = {
262
+ "Linear Regression": LinearRegression(),
263
+ "Random Forest": RandomForestRegressor(n_estimators=50, max_depth=10, n_jobs=-1),
264
+ "XGBoost": XGBRegressor(n_estimators=100, learning_rate=0.1, n_jobs=-1)
265
+ }
266
+
267
+ results = {}
268
+ best_name = None
269
+ best_score = float('inf') # RMSE so lower is better
270
+ best_model_obj = None
271
+
272
+ print(f"{'Model':<20} | {'RMSE':<10} | {'MAE':<10} | {'MAPE':<10} | {'RΒ²':<10}")
273
+ print("-" * 70)
274
+
275
+ for name, model in models.items():
276
+ model.fit(X_train, y_train)
277
+ preds_log = model.predict(X_test)
278
+
279
+ # Invert log for real metrics
280
+ preds_real = np.expm1(preds_log)
281
+ y_test_real = np.expm1(y_test)
282
+
283
+ rmse = np.sqrt(mean_squared_error(y_test_real, preds_real))
284
+ mae = mean_absolute_error(y_test_real, preds_real)
285
+ mape = mean_absolute_percentage_error(y_test_real, preds_real)
286
+ r2 = r2_score(y_test, preds_log)
287
+
288
+ results[name] = {'RMSE': rmse, 'MAE': mae, 'MAPE': mape, 'R2': r2}
289
+
290
+ print(f"{name:<20} | {rmse:.0f} | {mae:.0f} | {mape:.2%} | {r2:.3f}")
291
+
292
+ if rmse < best_score:
293
+ best_score = rmse
294
+ best_name = name
295
+ best_model_obj = model
296
+
297
+ print("-" * 70)
298
+ print(f"πŸ† Winner: {best_name} (RMSE: {best_score:.0f})")
299
+
300
+ # --- PLOTTING ---
301
+ plt.figure(figsize=(8, 5))
302
+
303
+ # Comparison Bar Chart (RMSE)
304
+ names = list(results.keys())
305
+ rmse_scores = [results[n]['RMSE'] for n in names]
306
+ plt.bar(names, rmse_scores, color=['gray', 'gray', 'green'])
307
+ plt.title('Model Comparison (RMSE - Lower is Better)')
308
+ plt.ylabel('RMSE (Views)')
309
+
310
+ plt.tight_layout()
311
+ plt.savefig('project_plots/model_comparison.png')
312
+ plt.close()
313
+ print(" -> Comparison plot saved to 'project_plots/model_comparison.png'")
314
+
315
+ return best_model_obj
316
+
317
+ def plot_feature_importance(model, vectorizer, output_path='project_plots/feature_importance.png'):
318
+ print(" -> πŸ“Š Generating Feature Importance Plot...")
319
+
320
+ # 1. Get Feature Names
321
+ # TF-IDF features
322
+ tfidf_names = vectorizer.get_feature_names_out()
323
+ # Numeric features (Hardcoded based on process_data_pipeline)
324
+ meta_names = ['video_duration_sec', 'hour_of_day', 'is_weekend', 'hashtag_count']
325
+ all_features = np.concatenate([tfidf_names, meta_names])
326
+
327
+ # 2. Get Importances
328
+ if hasattr(model, 'feature_importances_'):
329
+ # XGBoost / Random Forest
330
+ importances = model.feature_importances_
331
+ title = f"Top 20 Features ({type(model).__name__})"
332
+ elif hasattr(model, 'coef_'):
333
+ # Linear Regression
334
+ importances = np.abs(model.coef_) # Magnitude matters
335
+ title = f"Top 20 Feature Coefficients ({type(model).__name__})"
336
+ else:
337
+ print(" ⚠️ Model type does not support feature importance extraction.")
338
+ return
339
+
340
+ # 3. Sort and Plot Top 20
341
+ indices = np.argsort(importances)[-20:]
342
+
343
+ plt.figure(figsize=(10, 8))
344
+ plt.title(title)
345
+ plt.barh(range(len(indices)), importances[indices], align='center', color='teal')
346
+ plt.yticks(range(len(indices)), [all_features[i] for i in indices])
347
+ plt.xlabel('Relative Importance')
348
+ plt.tight_layout()
349
+ plt.savefig(output_path)
350
+ plt.close()
351
+ print(f" -> Feature Importance saved to '{output_path}'")
352
 
353
  # ---------------------------------------------------------
354
  # 4. EMBEDDINGS GENERATION (For Search)
 
472
  # MAIN EXECUTION
473
  # ---------------------------------------------------------
474
  if __name__ == "__main__":
475
+ # 1. Pipeline (LLM)
476
+ print("πŸš€ Starting Production Run: Generatng 10,000 rows...")
477
+ df, _ = generate_synthetic_data_llm(10000)
478
+ X_train, X_test, y_train, y_test, tfidf = process_data_pipeline(df)
479
 
480
+ # 2. Train Prediction Model (COMPARISON Step)
481
+ best_model = compare_and_train_best_model(X_train, y_train, X_test, y_test)
482
 
483
  # 3. Create Knowledge Base (Embeddings)
484
  knowledge_df, st_model = create_search_index(df)
485
 
486
+ # 4. Save Artifacts for App & Plot Importance
487
+ print("\n[5/8] Saving Model Artifacts & Finalizing Plots...")
488
+
489
+ # Plot Feature Importance (Now that we have the winner)
490
+ plot_feature_importance(best_model, tfidf)
491
+
492
+ # Use Pickle for Model (Generic)
493
+ with open("viral_model.pkl", "wb") as f:
494
+ pickle.dump(best_model, f)
495
+ print(" - Model saved to 'viral_model.pkl'")
496
 
497
  with open("tfidf_vectorizer.pkl", "wb") as f:
498
  pickle.dump(tfidf, f)
model-search.py DELETED
@@ -1,21 +0,0 @@
1
- import google.generativeai as genai
2
- import os
3
- from dotenv import load_dotenv
4
-
5
- # 1. Load your API key
6
- load_dotenv()
7
- api_key = os.getenv("GEMINI_API_KEY")
8
-
9
- if not api_key:
10
- print("Error: API key not found. Make sure it is in your .env file.")
11
- else:
12
- genai.configure(api_key=api_key)
13
-
14
- print("--- Available Gemini Models ---")
15
- # 2. List all models and filter for those that generate content (text/chat)
16
- for m in genai.list_models():
17
- if 'generateContent' in m.supported_generation_methods:
18
- print(f"Name: {m.name}")
19
- print(f" - Display Name: {m.display_name}")
20
- print(f" - Input Limit: {m.input_token_limit} tokens")
21
- print("-" * 30)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
project_plots/{model_leaderboard.png β†’ diversity_plot.png} RENAMED
File without changes
project_plots/eda_distribution.png CHANGED

Git LFS Details

  • SHA256: 18de2fc7c92289b75a7c77d7e8ff57ae5d462269f8cdd99d584cad6377367f14
  • Pointer size: 130 Bytes
  • Size of remote file: 13.7 kB

Git LFS Details

  • SHA256: 96cf3d4ff0977409b642050164d9ded8a43e2565437d19bd3fb26234c434ba6f
  • Pointer size: 130 Bytes
  • Size of remote file: 13.9 kB
project_plots/feature_importance.png CHANGED

Git LFS Details

  • SHA256: 215f9eabe41998d7222c3b5104993a60bff7be958cd224edf9d1a42108070db6
  • Pointer size: 130 Bytes
  • Size of remote file: 81.9 kB

Git LFS Details

  • SHA256: a0ac14b476322d1a9d53149a728ff0b3a6002763d1157478379fa68ea701ae04
  • Pointer size: 130 Bytes
  • Size of remote file: 32 kB
project_plots/{embedding_clusters.png β†’ model_comparison.png} RENAMED
File without changes
tiktok_test_data_embeddings.parquet β†’ raw_social_media_data.parquet RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9bd222284d380a2ddb7f29f27e1ff77eff7b379eed34674c959804f70ef80ede
3
- size 3862601
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92939908f14b69157b0a99ee186ef1f0ff70d54974bfcf14235468674f73d450
3
+ size 1185030
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- gradio>=5.0
2
  pandas
3
  numpy
4
  xgboost
@@ -7,3 +6,6 @@ sentence-transformers
7
  google-generativeai
8
  python-dotenv
9
  faker
 
 
 
 
 
1
  pandas
2
  numpy
3
  xgboost
 
6
  google-generativeai
7
  python-dotenv
8
  faker
9
+ transformers
10
+ torch
11
+ accelerate
tfidf_vectorizer.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:96a8c9e0af89e51756ab34cd1219571dea1cd2ca9f2558895fcd00250a3a6c8b
3
- size 29989
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebe788182b51f7023d3b94676a566553723f10b0a6795e191f827bda12136339
3
+ size 73096
tiktok_knowledge_base.parquet CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:12a2bf4192ada1bb32d08a87f9e811c28a933c40ae0d4831143f1f1fdccd6579
3
- size 16651296
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82dacd5da6cc1e8f9a62db8e8b6d68f5d5e466300d94dc7707c7afd342a97594
3
+ size 17274184
viral_model.json DELETED
The diff for this file is too large to render. See raw diff
 
viral_model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92f2ca0ca3bf30dd6a5d7e84d8ebff5612134ff895124e03cb51586a000d9527
3
+ size 214620