File size: 11,427 Bytes
cf49347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import os
import torch
import google.generativeai as genai  
from faker import Faker
from datetime import datetime, timedelta
from sklearn.metrics.pairwise import cosine_similarity  
import pickle
from dotenv import load_dotenv  

# Load environment variables from the .env filea monk
load_dotenv()

# Machine Learning Imports
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.ensemble import RandomForestRegressor
from xgboost import XGBRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, f1_score
from sklearn.decomposition import PCA
from sentence_transformers import SentenceTransformer

# ---------------------------------------------------------
# 0. SETUP & CONFIGURATION
# ---------------------------------------------------------
warnings.filterwarnings('ignore')
pd.set_option('display.max_columns', None)

# OPTIMIZATION: Check for Apple Silicon (MPS)
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"πŸš€ Optimization: Running on {device.upper()} device")

if not os.path.exists('project_plots'):
    os.makedirs('project_plots')

# ---------------------------------------------------------
# 1. DATA GENERATION (With 2025 Trends)
# ---------------------------------------------------------
def generate_enhanced_data(n_rows=10000):
    print(f"\n[1/8] Generating {n_rows} rows of Real-World 2025 Data...")
    fake = Faker()
    
    trends = [
        'Delulu', 'Girl Dinner', 'Roman Empire', 'Silent Slay', 'Soft Life',
        'Grimace Shake', 'Wes Anderson Style', 'Beige Flag', 'Canon Event',
        'NPC Stream', 'Skibidi', 'Fanum Tax', 'Yapping', 'Glow Up', 'Fit Check'
    ]
    formats = [
        'POV: You realize...', 'GRWM for...', 'Day in the life:', 
        'Storytime:', 'Trying the viral...', 'ASMR packing orders',
        'Rating my exes...', 'Turn the lights off challenge'
    ]
    categories = ['Gaming', 'Beauty', 'Comedy', 'Edutainment', 'Lifestyle', 'Food']
    
    data = []
    start_date = datetime(2024, 1, 1)
    
    for _ in range(n_rows):
        upload_time = start_date + timedelta(days=np.random.randint(0, 365), hours=np.random.randint(0, 23))
        trend = np.random.choice(trends)
        fmt = np.random.choice(formats)
        cat = np.random.choice(categories)
        
        description = f"{fmt} {trend} edition! {fake.sentence(nb_words=6)}"
        tags = ['#fyp', '#foryou', '#viral', f'#{trend.replace(" ", "").lower()}', f'#{cat.lower()}']
        if np.random.random() > 0.5: tags.append('#trending2025')
        
        full_text = f"{description} {' '.join(tags)}"
        
        # Meta Features
        duration = np.random.randint(5, 180)
        hour = upload_time.hour
        is_weekend = 1 if upload_time.weekday() >= 5 else 0
        
        # View Count Logic
        base_virality = np.random.lognormal(mean=9.5, sigma=1.8)
        multiplier = 1.0
        if is_weekend: multiplier *= 1.2
        if duration < 15: multiplier *= 1.4
        if "Delulu" in full_text or "POV" in full_text: multiplier *= 1.6
        if hour >= 18: multiplier *= 1.1
        
        views = int(base_virality * multiplier)
        
        data.append({
            'upload_date': upload_time,
            'description': full_text,
            'category': cat,
            'video_duration_sec': duration,
            'hour_of_day': hour,
            'is_weekend': is_weekend,
            'hashtag_count': len(tags),
            'views': views
        })
        
    df = pd.DataFrame(data)
    df = df.sort_values('upload_date').reset_index(drop=True)
    threshold = df['views'].quantile(0.80)
    df['is_viral_binary'] = (df['views'] > threshold).astype(int)
    df['log_views'] = np.log1p(df['views'])
    
    return df, threshold

# ---------------------------------------------------------
# 2. EDA & PREPROCESSING
# ---------------------------------------------------------
def process_data_pipeline(df):
    print("\n[2/8] Processing Data Pipeline...")
    
    # Simple EDA Save
    clean_df = df[df['video_duration_sec'] > 0].copy()
    plt.figure(figsize=(6,4))
    sns.histplot(clean_df['log_views'], color='teal')
    plt.title('Log Views Distribution')
    plt.savefig('project_plots/eda_distribution.png')
    plt.close()
    
    # TF-IDF & Split
    tfidf = TfidfVectorizer(max_features=2000, stop_words='english')
    X_text = tfidf.fit_transform(df['description']).toarray()
    
    num_cols = ['video_duration_sec', 'hour_of_day', 'is_weekend', 'hashtag_count']
    X_num = df[num_cols].values
    
    X = np.hstack((X_text, X_num))
    y = df['log_views'].values 
    y_bin = df['is_viral_binary'].values
    
    split_idx = int(len(df) * 0.80)
    return X[:split_idx], X[split_idx:], y[:split_idx], y[split_idx:], y_bin[split_idx:], tfidf

# ---------------------------------------------------------
# 3. TRAINING
# ---------------------------------------------------------
def train_best_model(X_train, y_train, X_test, y_test):
    print("\n[3/8] Training Model (XGBoost)...")
    model = XGBRegressor(n_estimators=100, learning_rate=0.1, max_depth=6, n_jobs=-1)
    model.fit(X_train, y_train)
    
    rmse = np.sqrt(mean_squared_error(y_test, model.predict(X_test)))
    print(f"    - Model RMSE: {rmse:.3f}")
    return model

# ---------------------------------------------------------
# 4. EMBEDDINGS GENERATION (For Search)
# ---------------------------------------------------------
def create_search_index(df):
    print("\n[4/8] Creating Vector Search Index...")
    # Generate embeddings for ALL data so we can search the whole history
    st_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
    embeddings = st_model.encode(df['description'].tolist(), convert_to_numpy=True, show_progress_bar=True)
    
    df['embedding'] = list(embeddings)
    
    # Save to Parquet (The Knowledge Base)
    save_path = 'tiktok_knowledge_base.parquet'
    df.to_parquet(save_path)
    print(f"    - Knowledge Base saved to {save_path}")
    return df, st_model

# ---------------------------------------------------------
# 5. RETRIEVAL & IMPROVEMENT ENGINE (The Magic Step)
# ---------------------------------------------------------
def optimize_content_with_gemini(user_input, model, vectorizer, knowledge_df, st_model):
    """
    1. Scores original idea.
    2. Finds top 3 similar VIRAL videos.
    3. Asks Gemini to rewrite the idea.
    4. Re-scores the new idea.
    """
    print("\n" + "="*50)
    print("πŸš€ VIRAL OPTIMIZATION ENGINE")
    print("="*50)
    
    # --- STEP 1: INITIAL SCORE ---
    text_vec = vectorizer.transform([user_input]).toarray()
    # Assume default meta for prediction (15s, 6 PM, weekday)
    meta_vec = np.array([[15, 18, 0, user_input.count('#')]]) 
    feat_vec = np.hstack((text_vec, meta_vec))
    
    initial_log = model.predict(feat_vec)[0]
    initial_views = int(np.expm1(initial_log))
    
    print(f"\nπŸ“ ORIGINAL IDEA: {user_input}")
    print(f"πŸ“Š Predicted Views: {initial_views:,}")
    
    # --- STEP 2: VECTOR SEARCH (Find similar successful videos) ---
    print("\nπŸ” Searching for similar viral hits in Parquet file...")
    
    # Filter only for successful videos (e.g., top 25% of views)
    high_performance_df = knowledge_df[knowledge_df['views'] > knowledge_df['views'].quantile(0.75)].copy()
    
    # Encode user input
    user_embedding = st_model.encode([user_input], convert_to_numpy=True)
    
    # Stack embeddings from the dataframe into a matrix
    target_embeddings = np.stack(high_performance_df['embedding'].values)
    
    # Calculate Cosine Similarity
    similarities = cosine_similarity(user_embedding, target_embeddings)
    
    # Get Top 3 indices
    top_3_indices = similarities[0].argsort()[-3:][::-1]
    top_3_videos = high_performance_df.iloc[top_3_indices]['description'].tolist()
    
    print("    -> Found 3 similar viral videos to learn from:")
    for i, vid in enumerate(top_3_videos, 1):
        print(f"       {i}. {vid[:80]}...")

    # --- STEP 3: GEMINI OPTIMIZATION ---
    api_key = os.getenv("GEMINI_API_KEY")
    if not api_key:
        print("\n⚠️  SKIPPING AI REWRITE: No 'GEMINI_API_KEY' found in environment variables.")
        print("    (Set it via 'export GEMINI_API_KEY=your_key' in terminal)")
        return

    print("\nπŸ€– Sending context to Gemini LLM for optimization...")
    genai.configure(api_key=api_key)
    llm = genai.GenerativeModel('gemini-2.5-flash-lite')
    
    prompt = f"""
    You are a TikTok Virality Expert.
    
    My Draft Description: "{user_input}"
    
    Here are 3 successful, viral videos that are similar to my topic:
    1. {top_3_videos[0]}
    2. {top_3_videos[1]}
    3. {top_3_videos[2]}
    
    Task: Rewrite my draft description to make it go viral. 
    Use the slang, hashtag style, and structure of the successful examples provided.
    Keep it under 20 words plus hashtags. Return ONLY the new description.
    """
    
    try:
        response = llm.generate_content(prompt)
        improved_idea = response.text.strip()
        
        print(f"\n✨ IMPROVED IDEA (By Gemini): {improved_idea}")
        
        # --- STEP 4: RE-EVALUATION ---
        new_text_vec = vectorizer.transform([improved_idea]).toarray()
        # Update hashtag count for new features
        new_meta_vec = np.array([[15, 18, 0, improved_idea.count('#')]])
        new_feat_vec = np.hstack((new_text_vec, new_meta_vec))
        
        new_log = model.predict(new_feat_vec)[0]
        new_views = int(np.expm1(new_log))
        
        print(f"πŸ“Š New Predicted Views: {new_views:,}")
        
        improvement = ((new_views - initial_views) / initial_views) * 100
        if improvement > 0:
            print(f"πŸš€ POTENTIAL UPLIFT: +{improvement:.1f}%")
        else:
            print(f"😐 No significant uplift predicted (Model is strict!).")
            
    except Exception as e:
        print(f"❌ Error calling Gemini API: {e}")

# ---------------------------------------------------------
# MAIN EXECUTION
# ---------------------------------------------------------
if __name__ == "__main__":
    # 1. Pipeline
    df, _ = generate_enhanced_data(10000)
    X_train, X_test, y_train, y_test, _, tfidf = process_data_pipeline(df)
    
    # 2. Train Prediction Model
    best_model = train_best_model(X_train, y_train, X_test, y_test)
    
    # 3. Create Knowledge Base (Embeddings)
    knowledge_df, st_model = create_search_index(df)
    
    # 4. Save Artifacts for App
    print("\n[5/8] Saving Model Artifacts for Production...")
    best_model.save_model("viral_model.json")
    print("    - Model saved to 'viral_model.json'")
    
    with open("tfidf_vectorizer.pkl", "wb") as f:
        pickle.dump(tfidf, f)
    print("    - Vectorizer saved to 'tfidf_vectorizer.pkl'")
    
    # 5. User Interaction Loop
    while True:
        print("\n" + "-"*30)
        user_input = input("Enter your video idea (or 'q' to quit): ")
        if user_input.lower() == 'q':
            break
            
        optimize_content_with_gemini(
            user_input=user_input, 
            model=best_model, 
            vectorizer=tfidf, 
            knowledge_df=knowledge_df, 
            st_model=st_model
        )