odeyaaa commited on
Commit
cf49347
·
verified ·
1 Parent(s): 373c82c

Upload 10 files

Browse files
Files changed (10) hide show
  1. README.md +107 -5
  2. app.py +192 -0
  3. env +1 -0
  4. gitattributes +3 -0
  5. model-prep.py +302 -0
  6. model-search.py +21 -0
  7. requirements.txt +9 -0
  8. tfidf_vectorizer.pkl +3 -0
  9. tiktok_knowledge_base.parquet +3 -0
  10. viral_model.json +0 -0
README.md CHANGED
@@ -1,12 +1,114 @@
1
  ---
2
- title: Social Assistent
3
- emoji: 📉
4
  colorFrom: green
5
- colorTo: red
6
  sdk: gradio
7
- sdk_version: 6.3.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Social Media Virality
3
+ emoji: 📺
4
  colorFrom: green
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 5.9.0
8
  app_file: app.py
9
  pinned: false
10
  ---
11
+ # Social Media Virality Prediction & Optimization Project
12
 
13
+ **Course**: Data Science & Machine Learning Applications
14
+ **Project**: Viral Content Assistant
15
+
16
+ ## 1. Project Overview
17
+ This project aims to develop a data-driven system capable of predicting the viral potential of short-form video content (e.g., TikTok) and optimizing it using Generative AI. By leveraging Natural Language Processing (NLP) and Machine Learning (ML), the system analyzes video descriptions and metadata to forecast view counts and prescribes actionable improvements to maximize engagement.
18
+
19
+ The core solution consists of a machine learning pipeline for virality prediction and a web application (Gradio) for real-time user interaction.
20
+
21
+ ## 2. Data Science Methodology
22
+
23
+ ### 2.1 Data Acquisition & Generation
24
+ Due to privacy restrictions and API limitations of social platforms, we simulated a realistic dataset reflecting 2025 social media trends.
25
+ * **Source**: Synthetic generation options using the `Faker` library and `numpy` probabilistic distributions.
26
+ * **Volume**: 10,000 samples.
27
+ * **Features**:
28
+ * **Textual**: Video descriptions rich in slang (e.g., "Skibidi", "Girl Dinner"), hashtags, and emojis.
29
+ * **Temporal**: Upload hour, day of week.
30
+ * **Meta**: Video duration, category (Gaming, Beauty, etc.).
31
+ * **Target Variable**: `views` (Log-normally distributed to mimic real-world viral discrepancies).
32
+
33
+ ### 2.2 Exploratory Data Analysis (EDA)
34
+ We analyzed the distribution of the target variable and feature correlations.
35
+ * **Observation**: View counts follow a "power law" distribution; most videos have few views, while a few "viral hits" capture the majority.
36
+ * **Preprocessing**: We applied a Log-transformation (`np.log1p`) to the `views` variable to normalize the distribution for regression models.
37
+
38
+ ![Views Distribution](project_plots/eda_distribution.png)
39
+ *Figure 1: Distribution of Log-Transformed View Counts.*
40
+
41
+ ### 2.3 Feature Engineering
42
+ * **Text Embeddings**: We used **TF-IDF Vectorization** (Top 2,000 features) to convert unstructured text descriptions into numerical vectors.
43
+ * **Meta Features**: Encoded `is_weekend`, `hour_of_day`, and `video_duration_sec`.
44
+ * **Data Splitting**: A **Temporal Split** (80/20) was used instead of a random split to prevent data leakage, ensuring the model predicts future videos based on past trends.
45
+
46
+ ## 3. Model Development & Evaluation
47
+
48
+ We evaluated three distinct algorithms to solve the regression problem (predicting log-views):
49
+
50
+ 1. **Linear Regression**: Baseline model for interpretability.
51
+ 2. **Random Forest Regressor**: Ensemble method to capture non-linear relationships.
52
+ 3. **XGBoost Regressor**: Gradient boosting machine known for state-of-the-art tabular performance.
53
+
54
+ ### 3.1 Comparative Metrics
55
+ Models were assessed using:
56
+ * **RMSE (Root Mean Squared Error)**: The primary metric for regression accuracy.
57
+ * **R² (Coefficient of Determination)**: Explains the variance captured by the model.
58
+ * **F1-Score**: Used to proxy classification performance (predicting if a video hits the "Viral Threshold" (top 20%)).
59
+
60
+ ![Model Leaderboard](project_plots/model_leaderboard.png)
61
+ *Figure 2: Performance comparison across different architectures.*
62
+
63
+ ### 3.2 Result
64
+ The **XGBoost Regressor** outperformed other models, achieving the lowest RMSE on the test set. This model was selected for the final deployment.
65
+
66
+ ## 4. Advanced Analysis: Embeddings & Semantic Search
67
+
68
+ Beyond simple regression, we implemented a semantic search engine using **SentenceTransformers** (`all-MiniLM-L6-v2`).
69
+ * **Purpose**: To retrieve historical viral hits conceptually similar to the user's new idea.
70
+ * **Clustering**: We visualized the semantic space using PCA (Principal Component Analysis).
71
+
72
+ ![Embedding Clusters](project_plots/embedding_clusters.png)
73
+ *Figure 3: Semantic clustering of video descriptions.*
74
+
75
+ ## 5. Application & Deployment
76
+
77
+ The final deliverable is an interactive web application built with **Gradio**.
78
+
79
+ ### 5.1 System Architecture
80
+ The system is decoupled into two main components:
81
+ 1. **Training Pipeline (`model-prep.py`)**: Runs offline to generate synthetic data, train the XGBoost model, and create the vector database. It saves these artifacts (`viral_model.json`, `tfidf_vectorizer.pkl`, `tiktok_knowledge_base.parquet`).
82
+ 2. **Inference App (`app.py`)**: A lightweight Gradio app that loads the pre-trained artifacts to serve real-time predictions without needing to retrain.
83
+
84
+ **Data Flow**:
85
+ 1. **Input**: User provided video description.
86
+ 2. **Inference**: Loaded XGBoost model predicts view count.
87
+ 3. **Retrieval**: App searches the pre-computed Parquet knowledge base for similar viral videos.
88
+ 4. **Generative Optimization**: **Google Gemini 2.5 Flash Lite** rewrites the draft.
89
+ 5. **Output**: Predictions, Similar Videos, and AI-Optimized content.
90
+
91
+ ### 5.2 Usage Instructions
92
+
93
+ To run the project locally for assessment:
94
+
95
+ 1. **Environment Setup**:
96
+ ```bash
97
+ python3 -m venv .venv
98
+ source .venv/bin/activate
99
+ pip install -r requirements.txt
100
+ ```
101
+ 2. **Configuration**:
102
+ Ensure the `.env` file contains a valid `GEMINI_API_KEY`.
103
+ 3. **Execution**:
104
+ ```bash
105
+ python app.py
106
+ ```
107
+ Access the UI at `http://localhost:7860`.
108
+
109
+ ## 6. Conclusion
110
+ This project demonstrates a complete end-to-end Data Science workflow: from synthetic data creation and rigorous model evaluation to the deployment of a user-facing AI application. The integration of predictive analytics (XGBoost) with generative AI (Gemini) provides a robust tool for content creators.
111
+
112
+ ## 🏆 Credits
113
+ * **Project Author:** Matan Kriel
114
+ * **Project Author:** Odeya Shmuel
app.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ import os
5
+ import google.generativeai as genai
6
+ from sentence_transformers import SentenceTransformer
7
+ from sklearn.metrics.pairwise import cosine_similarity
8
+ from dotenv import load_dotenv
9
+
10
+ # Import functions from model-prep
11
+ from xgboost import XGBRegressor # Use Regressor as per model-prep
12
+ import pickle
13
+ from importlib.util import spec_from_file_location
14
+ import sys
15
+ # Since we are loading artifacts, we don't strictly need model-prep.py logic anymore.
16
+ # But keeping basic imports is fine.
17
+
18
+ # Load environment variables
19
+ load_dotenv()
20
+
21
+ # --- GLOBAL STATE ---
22
+ MODEL = None
23
+ VECTORIZER = None
24
+ KNOWLEDGE_DF = None
25
+ ST_MODEL = None
26
+
27
+ def initialize_app():
28
+ """Initializes the model and data on app startup."""
29
+ global MODEL, VECTORIZER, KNOWLEDGE_DF, ST_MODEL
30
+
31
+ print("⏳ initializing app: Loading pre-computed artifacts...")
32
+
33
+ # 1. Load Parquet Data (Knowledge Base)
34
+ # We expect this file to exist now.
35
+ parquet_path = 'tiktok_knowledge_base.parquet'
36
+ if not os.path.exists(parquet_path):
37
+ raise FileNotFoundError(f"Required file '{parquet_path}' not found! Run model-prep.py first.")
38
+
39
+ print(f"📂 Loading data from {parquet_path}...")
40
+ knowledge_df = pd.read_parquet(parquet_path)
41
+
42
+ # 2. Load Model
43
+ print("🧠 Loading XGBoost Model...")
44
+ model = XGBRegressor()
45
+ model.load_model("viral_model.json")
46
+
47
+ # 3. Load Vectorizer
48
+ print("🔤 Loading TF-IDF Vectorizer...")
49
+ with open("tfidf_vectorizer.pkl", "rb") as f:
50
+ tfidf = pickle.load(f)
51
+
52
+ # 4. Load Sentence Transformer
53
+ print("🔌 Loading SentenceTransformer...")
54
+ # device=model_prep.device might fail if we don't import model_prep executed.
55
+ # Just use defaults or check pytorch standardly.
56
+ import torch
57
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
58
+ st_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
59
+
60
+ MODEL = model
61
+ VECTORIZER = tfidf
62
+ KNOWLEDGE_DF = knowledge_df
63
+ ST_MODEL = st_model
64
+ print("✅ App initialized (Inference Mode)!")
65
+
66
+ def predict_and_optimize(user_input):
67
+ if not user_input:
68
+ return "Please enter a video description.", "", "", "", ""
69
+
70
+ # --- 1. INITIAL PREDICTION ---
71
+ text_vec = VECTORIZER.transform([user_input]).toarray()
72
+ # Assume default meta: 15s duration, 18:00 (6 PM), weekday (0), hashtag count from input
73
+ meta_vec = np.array([[15, 18, 0, user_input.count('#')]])
74
+ feat_vec = np.hstack((text_vec, meta_vec))
75
+
76
+ initial_log = MODEL.predict(feat_vec)[0]
77
+ initial_views = int(np.expm1(initial_log))
78
+
79
+ # --- 2. VECTOR SEARCH ---
80
+ # Filter for viral hits in knowledge base (top 25%)
81
+ high_perf_df = KNOWLEDGE_DF[KNOWLEDGE_DF['views'] > KNOWLEDGE_DF['views'].quantile(0.75)].copy()
82
+
83
+ user_embedding = ST_MODEL.encode([user_input], convert_to_numpy=True)
84
+ target_embeddings = np.stack(high_perf_df['embedding'].values)
85
+
86
+ similarities = cosine_similarity(user_embedding, target_embeddings)
87
+ top_3_indices = similarities[0].argsort()[-3:][::-1]
88
+ top_3_videos = high_perf_df.iloc[top_3_indices]['description'].tolist()
89
+
90
+ similar_videos_str = "\n\n".join([f"{i+1}. {v}" for i, v in enumerate(top_3_videos)])
91
+
92
+ # --- 3. GEMINI OPTIMIZATION ---
93
+ api_key = os.getenv("GEMINI_API_KEY")
94
+ if not api_key:
95
+ return f"{initial_views:,}", similar_videos_str, "Error: GEMINI_API_KEY not found.", "N/A", "N/A"
96
+
97
+ genai.configure(api_key=api_key)
98
+ # Using the updated model from the user's latest change
99
+ try:
100
+ llm = genai.GenerativeModel('gemini-2.5-flash-lite')
101
+ except:
102
+ llm = genai.GenerativeModel('gemini-1.5-flash')
103
+
104
+ prompt = f"""
105
+ You are a TikTok Virality Expert.
106
+
107
+ My Draft Description: "{user_input}"
108
+
109
+ Here are 3 successful, viral videos that are similar to my topic:
110
+ 1. {top_3_videos[0]}
111
+ 2. {top_3_videos[1]}
112
+ 3. {top_3_videos[2]}
113
+
114
+ Task: Rewrite my draft description to make it go viral and full video plan.
115
+ Use the slang, hashtag style, and structure of the successful examples provided.
116
+ Keep it under 60 words plus hashtags. Return ONLY the new description.
117
+ """
118
+
119
+ try:
120
+ response = llm.generate_content(prompt)
121
+ improved_idea = response.text.strip()
122
+
123
+ # --- 4. RE-SCORING ---
124
+ new_text_vec = VECTORIZER.transform([improved_idea]).toarray()
125
+ new_meta_vec = np.array([[15, 18, 0, improved_idea.count('#')]])
126
+ new_feat_vec = np.hstack((new_text_vec, new_meta_vec))
127
+
128
+ new_log = MODEL.predict(new_feat_vec)[0]
129
+ new_views = int(np.expm1(new_log))
130
+
131
+ uplift_pct = ((new_views - initial_views) / initial_views) * 100
132
+ uplift_str = f"+{uplift_pct:.1f}%" if uplift_pct > 0 else "No significant uplift"
133
+
134
+ return f"{initial_views:,}", similar_videos_str, improved_idea, f"{new_views:,}", uplift_str
135
+
136
+ except Exception as e:
137
+ return f"{initial_views:,}", similar_videos_str, f"Error calling AI: {str(e)}", "N/A", "N/A"
138
+
139
+
140
+ # --- GRADIO UI ---
141
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
142
+ gr.Markdown("# 🚀 Viral Content Optimizer")
143
+ gr.Markdown("Enter your video idea to predict its views and get AI-powered optimizations based on 2025 trends.")
144
+
145
+ with gr.Row():
146
+ with gr.Column(scale=1):
147
+ input_text = gr.Textbox(
148
+ label="Your Video Description",
149
+ placeholder="e.g., POV: trying the new grimace shake #viral",
150
+ lines=3
151
+ )
152
+ with gr.Row():
153
+ submit_btn = gr.Button("Analyze & Optimize ⚡", variant="primary")
154
+ demo_btn = gr.Button("🎲 Try Demo", variant="secondary")
155
+
156
+ with gr.Column(scale=1):
157
+ with gr.Group():
158
+ gr.Markdown("### 📊 Predictions")
159
+ initial_views = gr.Textbox(label="Predicted Views (Original)", interactive=False)
160
+
161
+ with gr.Group():
162
+ gr.Markdown("### ✨ AI Optimization")
163
+ improved_text = gr.Textbox(label="Improved Description", interactive=False)
164
+ with gr.Row():
165
+ new_views = gr.Textbox(label="New Predicted Views", interactive=False)
166
+ uplift = gr.Textbox(label="Potential Uplift", interactive=False)
167
+
168
+ with gr.Accordion("🔍 Similar Viral Videos (Reference)", open=False):
169
+ similar_videos = gr.Textbox(label="Top 3 Context Matches", interactive=False, lines=5)
170
+
171
+ submit_btn.click(
172
+ fn=predict_and_optimize,
173
+ inputs=[input_text],
174
+ outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
175
+ )
176
+
177
+ # Demo Button Logic: 1. Fill Text -> 2. Run Prediction
178
+ demo_text = "POV: You realize you forgot to turn off your mic during the all-hands meeting 💀 #fail #fyp #corporate"
179
+ demo_btn.click(
180
+ fn=lambda: demo_text,
181
+ inputs=None,
182
+ outputs=input_text
183
+ ).then(
184
+ fn=predict_and_optimize,
185
+ inputs=gr.State(demo_text), # Pass directly to avoid race condition with UI update
186
+ outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
187
+ )
188
+
189
+ # Run initialization
190
+ if __name__ == "__main__":
191
+ initialize_app()
192
+ demo.launch(server_name="0.0.0.0", server_port=7860)
env ADDED
@@ -0,0 +1 @@
 
 
1
+ GEMINI_API_KEY=AIzaSyDv0m8cjeMuN5ue_VtSz9sMiQfsJ_GpvKI
gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
+ *.parquet filter=lfs diff=lfs merge=lfs -text
3
+ *.pkl filter=lfs diff=lfs merge=lfs -text
model-prep.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
+ import warnings
6
+ import os
7
+ import torch
8
+ import google.generativeai as genai
9
+ from faker import Faker
10
+ from datetime import datetime, timedelta
11
+ from sklearn.metrics.pairwise import cosine_similarity
12
+ import pickle
13
+ from dotenv import load_dotenv
14
+
15
+ # Load environment variables from the .env filea monk
16
+ load_dotenv()
17
+
18
+ # Machine Learning Imports
19
+ from sklearn.feature_extraction.text import TfidfVectorizer
20
+ from sklearn.ensemble import RandomForestRegressor
21
+ from xgboost import XGBRegressor
22
+ from sklearn.linear_model import LinearRegression
23
+ from sklearn.metrics import mean_squared_error, f1_score
24
+ from sklearn.decomposition import PCA
25
+ from sentence_transformers import SentenceTransformer
26
+
27
+ # ---------------------------------------------------------
28
+ # 0. SETUP & CONFIGURATION
29
+ # ---------------------------------------------------------
30
+ warnings.filterwarnings('ignore')
31
+ pd.set_option('display.max_columns', None)
32
+
33
+ # OPTIMIZATION: Check for Apple Silicon (MPS)
34
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
35
+ print(f"🚀 Optimization: Running on {device.upper()} device")
36
+
37
+ if not os.path.exists('project_plots'):
38
+ os.makedirs('project_plots')
39
+
40
+ # ---------------------------------------------------------
41
+ # 1. DATA GENERATION (With 2025 Trends)
42
+ # ---------------------------------------------------------
43
+ def generate_enhanced_data(n_rows=10000):
44
+ print(f"\n[1/8] Generating {n_rows} rows of Real-World 2025 Data...")
45
+ fake = Faker()
46
+
47
+ trends = [
48
+ 'Delulu', 'Girl Dinner', 'Roman Empire', 'Silent Slay', 'Soft Life',
49
+ 'Grimace Shake', 'Wes Anderson Style', 'Beige Flag', 'Canon Event',
50
+ 'NPC Stream', 'Skibidi', 'Fanum Tax', 'Yapping', 'Glow Up', 'Fit Check'
51
+ ]
52
+ formats = [
53
+ 'POV: You realize...', 'GRWM for...', 'Day in the life:',
54
+ 'Storytime:', 'Trying the viral...', 'ASMR packing orders',
55
+ 'Rating my exes...', 'Turn the lights off challenge'
56
+ ]
57
+ categories = ['Gaming', 'Beauty', 'Comedy', 'Edutainment', 'Lifestyle', 'Food']
58
+
59
+ data = []
60
+ start_date = datetime(2024, 1, 1)
61
+
62
+ for _ in range(n_rows):
63
+ upload_time = start_date + timedelta(days=np.random.randint(0, 365), hours=np.random.randint(0, 23))
64
+ trend = np.random.choice(trends)
65
+ fmt = np.random.choice(formats)
66
+ cat = np.random.choice(categories)
67
+
68
+ description = f"{fmt} {trend} edition! {fake.sentence(nb_words=6)}"
69
+ tags = ['#fyp', '#foryou', '#viral', f'#{trend.replace(" ", "").lower()}', f'#{cat.lower()}']
70
+ if np.random.random() > 0.5: tags.append('#trending2025')
71
+
72
+ full_text = f"{description} {' '.join(tags)}"
73
+
74
+ # Meta Features
75
+ duration = np.random.randint(5, 180)
76
+ hour = upload_time.hour
77
+ is_weekend = 1 if upload_time.weekday() >= 5 else 0
78
+
79
+ # View Count Logic
80
+ base_virality = np.random.lognormal(mean=9.5, sigma=1.8)
81
+ multiplier = 1.0
82
+ if is_weekend: multiplier *= 1.2
83
+ if duration < 15: multiplier *= 1.4
84
+ if "Delulu" in full_text or "POV" in full_text: multiplier *= 1.6
85
+ if hour >= 18: multiplier *= 1.1
86
+
87
+ views = int(base_virality * multiplier)
88
+
89
+ data.append({
90
+ 'upload_date': upload_time,
91
+ 'description': full_text,
92
+ 'category': cat,
93
+ 'video_duration_sec': duration,
94
+ 'hour_of_day': hour,
95
+ 'is_weekend': is_weekend,
96
+ 'hashtag_count': len(tags),
97
+ 'views': views
98
+ })
99
+
100
+ df = pd.DataFrame(data)
101
+ df = df.sort_values('upload_date').reset_index(drop=True)
102
+ threshold = df['views'].quantile(0.80)
103
+ df['is_viral_binary'] = (df['views'] > threshold).astype(int)
104
+ df['log_views'] = np.log1p(df['views'])
105
+
106
+ return df, threshold
107
+
108
+ # ---------------------------------------------------------
109
+ # 2. EDA & PREPROCESSING
110
+ # ---------------------------------------------------------
111
+ def process_data_pipeline(df):
112
+ print("\n[2/8] Processing Data Pipeline...")
113
+
114
+ # Simple EDA Save
115
+ clean_df = df[df['video_duration_sec'] > 0].copy()
116
+ plt.figure(figsize=(6,4))
117
+ sns.histplot(clean_df['log_views'], color='teal')
118
+ plt.title('Log Views Distribution')
119
+ plt.savefig('project_plots/eda_distribution.png')
120
+ plt.close()
121
+
122
+ # TF-IDF & Split
123
+ tfidf = TfidfVectorizer(max_features=2000, stop_words='english')
124
+ X_text = tfidf.fit_transform(df['description']).toarray()
125
+
126
+ num_cols = ['video_duration_sec', 'hour_of_day', 'is_weekend', 'hashtag_count']
127
+ X_num = df[num_cols].values
128
+
129
+ X = np.hstack((X_text, X_num))
130
+ y = df['log_views'].values
131
+ y_bin = df['is_viral_binary'].values
132
+
133
+ split_idx = int(len(df) * 0.80)
134
+ return X[:split_idx], X[split_idx:], y[:split_idx], y[split_idx:], y_bin[split_idx:], tfidf
135
+
136
+ # ---------------------------------------------------------
137
+ # 3. TRAINING
138
+ # ---------------------------------------------------------
139
+ def train_best_model(X_train, y_train, X_test, y_test):
140
+ print("\n[3/8] Training Model (XGBoost)...")
141
+ model = XGBRegressor(n_estimators=100, learning_rate=0.1, max_depth=6, n_jobs=-1)
142
+ model.fit(X_train, y_train)
143
+
144
+ rmse = np.sqrt(mean_squared_error(y_test, model.predict(X_test)))
145
+ print(f" - Model RMSE: {rmse:.3f}")
146
+ return model
147
+
148
+ # ---------------------------------------------------------
149
+ # 4. EMBEDDINGS GENERATION (For Search)
150
+ # ---------------------------------------------------------
151
+ def create_search_index(df):
152
+ print("\n[4/8] Creating Vector Search Index...")
153
+ # Generate embeddings for ALL data so we can search the whole history
154
+ st_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
155
+ embeddings = st_model.encode(df['description'].tolist(), convert_to_numpy=True, show_progress_bar=True)
156
+
157
+ df['embedding'] = list(embeddings)
158
+
159
+ # Save to Parquet (The Knowledge Base)
160
+ save_path = 'tiktok_knowledge_base.parquet'
161
+ df.to_parquet(save_path)
162
+ print(f" - Knowledge Base saved to {save_path}")
163
+ return df, st_model
164
+
165
+ # ---------------------------------------------------------
166
+ # 5. RETRIEVAL & IMPROVEMENT ENGINE (The Magic Step)
167
+ # ---------------------------------------------------------
168
+ def optimize_content_with_gemini(user_input, model, vectorizer, knowledge_df, st_model):
169
+ """
170
+ 1. Scores original idea.
171
+ 2. Finds top 3 similar VIRAL videos.
172
+ 3. Asks Gemini to rewrite the idea.
173
+ 4. Re-scores the new idea.
174
+ """
175
+ print("\n" + "="*50)
176
+ print("🚀 VIRAL OPTIMIZATION ENGINE")
177
+ print("="*50)
178
+
179
+ # --- STEP 1: INITIAL SCORE ---
180
+ text_vec = vectorizer.transform([user_input]).toarray()
181
+ # Assume default meta for prediction (15s, 6 PM, weekday)
182
+ meta_vec = np.array([[15, 18, 0, user_input.count('#')]])
183
+ feat_vec = np.hstack((text_vec, meta_vec))
184
+
185
+ initial_log = model.predict(feat_vec)[0]
186
+ initial_views = int(np.expm1(initial_log))
187
+
188
+ print(f"\n📝 ORIGINAL IDEA: {user_input}")
189
+ print(f"📊 Predicted Views: {initial_views:,}")
190
+
191
+ # --- STEP 2: VECTOR SEARCH (Find similar successful videos) ---
192
+ print("\n🔍 Searching for similar viral hits in Parquet file...")
193
+
194
+ # Filter only for successful videos (e.g., top 25% of views)
195
+ high_performance_df = knowledge_df[knowledge_df['views'] > knowledge_df['views'].quantile(0.75)].copy()
196
+
197
+ # Encode user input
198
+ user_embedding = st_model.encode([user_input], convert_to_numpy=True)
199
+
200
+ # Stack embeddings from the dataframe into a matrix
201
+ target_embeddings = np.stack(high_performance_df['embedding'].values)
202
+
203
+ # Calculate Cosine Similarity
204
+ similarities = cosine_similarity(user_embedding, target_embeddings)
205
+
206
+ # Get Top 3 indices
207
+ top_3_indices = similarities[0].argsort()[-3:][::-1]
208
+ top_3_videos = high_performance_df.iloc[top_3_indices]['description'].tolist()
209
+
210
+ print(" -> Found 3 similar viral videos to learn from:")
211
+ for i, vid in enumerate(top_3_videos, 1):
212
+ print(f" {i}. {vid[:80]}...")
213
+
214
+ # --- STEP 3: GEMINI OPTIMIZATION ---
215
+ api_key = os.getenv("GEMINI_API_KEY")
216
+ if not api_key:
217
+ print("\n⚠️ SKIPPING AI REWRITE: No 'GEMINI_API_KEY' found in environment variables.")
218
+ print(" (Set it via 'export GEMINI_API_KEY=your_key' in terminal)")
219
+ return
220
+
221
+ print("\n🤖 Sending context to Gemini LLM for optimization...")
222
+ genai.configure(api_key=api_key)
223
+ llm = genai.GenerativeModel('gemini-2.5-flash-lite')
224
+
225
+ prompt = f"""
226
+ You are a TikTok Virality Expert.
227
+
228
+ My Draft Description: "{user_input}"
229
+
230
+ Here are 3 successful, viral videos that are similar to my topic:
231
+ 1. {top_3_videos[0]}
232
+ 2. {top_3_videos[1]}
233
+ 3. {top_3_videos[2]}
234
+
235
+ Task: Rewrite my draft description to make it go viral.
236
+ Use the slang, hashtag style, and structure of the successful examples provided.
237
+ Keep it under 20 words plus hashtags. Return ONLY the new description.
238
+ """
239
+
240
+ try:
241
+ response = llm.generate_content(prompt)
242
+ improved_idea = response.text.strip()
243
+
244
+ print(f"\n✨ IMPROVED IDEA (By Gemini): {improved_idea}")
245
+
246
+ # --- STEP 4: RE-EVALUATION ---
247
+ new_text_vec = vectorizer.transform([improved_idea]).toarray()
248
+ # Update hashtag count for new features
249
+ new_meta_vec = np.array([[15, 18, 0, improved_idea.count('#')]])
250
+ new_feat_vec = np.hstack((new_text_vec, new_meta_vec))
251
+
252
+ new_log = model.predict(new_feat_vec)[0]
253
+ new_views = int(np.expm1(new_log))
254
+
255
+ print(f"📊 New Predicted Views: {new_views:,}")
256
+
257
+ improvement = ((new_views - initial_views) / initial_views) * 100
258
+ if improvement > 0:
259
+ print(f"🚀 POTENTIAL UPLIFT: +{improvement:.1f}%")
260
+ else:
261
+ print(f"😐 No significant uplift predicted (Model is strict!).")
262
+
263
+ except Exception as e:
264
+ print(f"❌ Error calling Gemini API: {e}")
265
+
266
+ # ---------------------------------------------------------
267
+ # MAIN EXECUTION
268
+ # ---------------------------------------------------------
269
+ if __name__ == "__main__":
270
+ # 1. Pipeline
271
+ df, _ = generate_enhanced_data(10000)
272
+ X_train, X_test, y_train, y_test, _, tfidf = process_data_pipeline(df)
273
+
274
+ # 2. Train Prediction Model
275
+ best_model = train_best_model(X_train, y_train, X_test, y_test)
276
+
277
+ # 3. Create Knowledge Base (Embeddings)
278
+ knowledge_df, st_model = create_search_index(df)
279
+
280
+ # 4. Save Artifacts for App
281
+ print("\n[5/8] Saving Model Artifacts for Production...")
282
+ best_model.save_model("viral_model.json")
283
+ print(" - Model saved to 'viral_model.json'")
284
+
285
+ with open("tfidf_vectorizer.pkl", "wb") as f:
286
+ pickle.dump(tfidf, f)
287
+ print(" - Vectorizer saved to 'tfidf_vectorizer.pkl'")
288
+
289
+ # 5. User Interaction Loop
290
+ while True:
291
+ print("\n" + "-"*30)
292
+ user_input = input("Enter your video idea (or 'q' to quit): ")
293
+ if user_input.lower() == 'q':
294
+ break
295
+
296
+ optimize_content_with_gemini(
297
+ user_input=user_input,
298
+ model=best_model,
299
+ vectorizer=tfidf,
300
+ knowledge_df=knowledge_df,
301
+ st_model=st_model
302
+ )
model-search.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import google.generativeai as genai
2
+ import os
3
+ from dotenv import load_dotenv
4
+
5
+ # 1. Load your API key
6
+ load_dotenv()
7
+ api_key = os.getenv("GEMINI_API_KEY")
8
+
9
+ if not api_key:
10
+ print("Error: API key not found. Make sure it is in your .env file.")
11
+ else:
12
+ genai.configure(api_key=api_key)
13
+
14
+ print("--- Available Gemini Models ---")
15
+ # 2. List all models and filter for those that generate content (text/chat)
16
+ for m in genai.list_models():
17
+ if 'generateContent' in m.supported_generation_methods:
18
+ print(f"Name: {m.name}")
19
+ print(f" - Display Name: {m.display_name}")
20
+ print(f" - Input Limit: {m.input_token_limit} tokens")
21
+ print("-" * 30)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=5.0
2
+ pandas
3
+ numpy
4
+ xgboost
5
+ scikit-learn
6
+ sentence-transformers
7
+ google-generativeai
8
+ python-dotenv
9
+ faker
tfidf_vectorizer.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96a8c9e0af89e51756ab34cd1219571dea1cd2ca9f2558895fcd00250a3a6c8b
3
+ size 29989
tiktok_knowledge_base.parquet ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12a2bf4192ada1bb32d08a87f9e811c28a933c40ae0d4831143f1f1fdccd6579
3
+ size 16651296
viral_model.json ADDED
The diff for this file is too large to render. See raw diff