Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import warnings | |
| import os | |
| import torch | |
| import google.generativeai as genai | |
| from faker import Faker | |
| from datetime import datetime, timedelta | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import pickle | |
| from dotenv import load_dotenv | |
| # Load environment variables from the .env filea monk | |
| load_dotenv() | |
| # Machine Learning Imports | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.ensemble import RandomForestRegressor | |
| from xgboost import XGBRegressor | |
| from sklearn.linear_model import LinearRegression | |
| from sklearn.metrics import mean_squared_error, f1_score | |
| from sklearn.decomposition import PCA | |
| from sentence_transformers import SentenceTransformer | |
| # --------------------------------------------------------- | |
| # 0. SETUP & CONFIGURATION | |
| # --------------------------------------------------------- | |
| warnings.filterwarnings('ignore') | |
| pd.set_option('display.max_columns', None) | |
| # OPTIMIZATION: Check for Apple Silicon (MPS) | |
| device = "mps" if torch.backends.mps.is_available() else "cpu" | |
| print(f"π Optimization: Running on {device.upper()} device") | |
| if not os.path.exists('project_plots'): | |
| os.makedirs('project_plots') | |
| # --------------------------------------------------------- | |
| # 1. DATA GENERATION (With 2025 Trends) | |
| # --------------------------------------------------------- | |
| def generate_enhanced_data(n_rows=10000): | |
| print(f"\n[1/8] Generating {n_rows} rows of Real-World 2025 Data...") | |
| fake = Faker() | |
| trends = [ | |
| 'Delulu', 'Girl Dinner', 'Roman Empire', 'Silent Slay', 'Soft Life', | |
| 'Grimace Shake', 'Wes Anderson Style', 'Beige Flag', 'Canon Event', | |
| 'NPC Stream', 'Skibidi', 'Fanum Tax', 'Yapping', 'Glow Up', 'Fit Check' | |
| ] | |
| formats = [ | |
| 'POV: You realize...', 'GRWM for...', 'Day in the life:', | |
| 'Storytime:', 'Trying the viral...', 'ASMR packing orders', | |
| 'Rating my exes...', 'Turn the lights off challenge' | |
| ] | |
| categories = ['Gaming', 'Beauty', 'Comedy', 'Edutainment', 'Lifestyle', 'Food'] | |
| data = [] | |
| start_date = datetime(2024, 1, 1) | |
| for _ in range(n_rows): | |
| upload_time = start_date + timedelta(days=np.random.randint(0, 365), hours=np.random.randint(0, 23)) | |
| trend = np.random.choice(trends) | |
| fmt = np.random.choice(formats) | |
| cat = np.random.choice(categories) | |
| description = f"{fmt} {trend} edition! {fake.sentence(nb_words=6)}" | |
| tags = ['#fyp', '#foryou', '#viral', f'#{trend.replace(" ", "").lower()}', f'#{cat.lower()}'] | |
| if np.random.random() > 0.5: tags.append('#trending2025') | |
| full_text = f"{description} {' '.join(tags)}" | |
| # Meta Features | |
| duration = np.random.randint(5, 180) | |
| hour = upload_time.hour | |
| is_weekend = 1 if upload_time.weekday() >= 5 else 0 | |
| # View Count Logic | |
| base_virality = np.random.lognormal(mean=9.5, sigma=1.8) | |
| multiplier = 1.0 | |
| if is_weekend: multiplier *= 1.2 | |
| if duration < 15: multiplier *= 1.4 | |
| if "Delulu" in full_text or "POV" in full_text: multiplier *= 1.6 | |
| if hour >= 18: multiplier *= 1.1 | |
| views = int(base_virality * multiplier) | |
| data.append({ | |
| 'upload_date': upload_time, | |
| 'description': full_text, | |
| 'category': cat, | |
| 'video_duration_sec': duration, | |
| 'hour_of_day': hour, | |
| 'is_weekend': is_weekend, | |
| 'hashtag_count': len(tags), | |
| 'views': views | |
| }) | |
| df = pd.DataFrame(data) | |
| df = df.sort_values('upload_date').reset_index(drop=True) | |
| threshold = df['views'].quantile(0.80) | |
| df['is_viral_binary'] = (df['views'] > threshold).astype(int) | |
| df['log_views'] = np.log1p(df['views']) | |
| return df, threshold | |
| # --------------------------------------------------------- | |
| # 2. EDA & PREPROCESSING | |
| # --------------------------------------------------------- | |
| def process_data_pipeline(df): | |
| print("\n[2/8] Processing Data Pipeline...") | |
| # Simple EDA Save | |
| clean_df = df[df['video_duration_sec'] > 0].copy() | |
| plt.figure(figsize=(6,4)) | |
| sns.histplot(clean_df['log_views'], color='teal') | |
| plt.title('Log Views Distribution') | |
| plt.savefig('project_plots/eda_distribution.png') | |
| plt.close() | |
| # TF-IDF & Split | |
| tfidf = TfidfVectorizer(max_features=2000, stop_words='english') | |
| X_text = tfidf.fit_transform(df['description']).toarray() | |
| num_cols = ['video_duration_sec', 'hour_of_day', 'is_weekend', 'hashtag_count'] | |
| X_num = df[num_cols].values | |
| X = np.hstack((X_text, X_num)) | |
| y = df['log_views'].values | |
| y_bin = df['is_viral_binary'].values | |
| split_idx = int(len(df) * 0.80) | |
| return X[:split_idx], X[split_idx:], y[:split_idx], y[split_idx:], y_bin[split_idx:], tfidf | |
| # --------------------------------------------------------- | |
| # 3. TRAINING | |
| # --------------------------------------------------------- | |
| def train_best_model(X_train, y_train, X_test, y_test): | |
| print("\n[3/8] Training Model (XGBoost)...") | |
| model = XGBRegressor(n_estimators=100, learning_rate=0.1, max_depth=6, n_jobs=-1) | |
| model.fit(X_train, y_train) | |
| rmse = np.sqrt(mean_squared_error(y_test, model.predict(X_test))) | |
| print(f" - Model RMSE: {rmse:.3f}") | |
| return model | |
| # --------------------------------------------------------- | |
| # 4. EMBEDDINGS GENERATION (For Search) | |
| # --------------------------------------------------------- | |
| def create_search_index(df): | |
| print("\n[4/8] Creating Vector Search Index...") | |
| # Generate embeddings for ALL data so we can search the whole history | |
| st_model = SentenceTransformer('all-MiniLM-L6-v2', device=device) | |
| embeddings = st_model.encode(df['description'].tolist(), convert_to_numpy=True, show_progress_bar=True) | |
| df['embedding'] = list(embeddings) | |
| # Save to Parquet (The Knowledge Base) | |
| save_path = 'tiktok_knowledge_base.parquet' | |
| df.to_parquet(save_path) | |
| print(f" - Knowledge Base saved to {save_path}") | |
| return df, st_model | |
| # --------------------------------------------------------- | |
| # 5. RETRIEVAL & IMPROVEMENT ENGINE (The Magic Step) | |
| # --------------------------------------------------------- | |
| def optimize_content_with_gemini(user_input, model, vectorizer, knowledge_df, st_model): | |
| """ | |
| 1. Scores original idea. | |
| 2. Finds top 3 similar VIRAL videos. | |
| 3. Asks Gemini to rewrite the idea. | |
| 4. Re-scores the new idea. | |
| """ | |
| print("\n" + "="*50) | |
| print("π VIRAL OPTIMIZATION ENGINE") | |
| print("="*50) | |
| # --- STEP 1: INITIAL SCORE --- | |
| text_vec = vectorizer.transform([user_input]).toarray() | |
| # Assume default meta for prediction (15s, 6 PM, weekday) | |
| meta_vec = np.array([[15, 18, 0, user_input.count('#')]]) | |
| feat_vec = np.hstack((text_vec, meta_vec)) | |
| initial_log = model.predict(feat_vec)[0] | |
| initial_views = int(np.expm1(initial_log)) | |
| print(f"\nπ ORIGINAL IDEA: {user_input}") | |
| print(f"π Predicted Views: {initial_views:,}") | |
| # --- STEP 2: VECTOR SEARCH (Find similar successful videos) --- | |
| print("\nπ Searching for similar viral hits in Parquet file...") | |
| # Filter only for successful videos (e.g., top 25% of views) | |
| high_performance_df = knowledge_df[knowledge_df['views'] > knowledge_df['views'].quantile(0.75)].copy() | |
| # Encode user input | |
| user_embedding = st_model.encode([user_input], convert_to_numpy=True) | |
| # Stack embeddings from the dataframe into a matrix | |
| target_embeddings = np.stack(high_performance_df['embedding'].values) | |
| # Calculate Cosine Similarity | |
| similarities = cosine_similarity(user_embedding, target_embeddings) | |
| # Get Top 3 indices | |
| top_3_indices = similarities[0].argsort()[-3:][::-1] | |
| top_3_videos = high_performance_df.iloc[top_3_indices]['description'].tolist() | |
| print(" -> Found 3 similar viral videos to learn from:") | |
| for i, vid in enumerate(top_3_videos, 1): | |
| print(f" {i}. {vid[:80]}...") | |
| # --- STEP 3: GEMINI OPTIMIZATION --- | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| print("\nβ οΈ SKIPPING AI REWRITE: No 'GEMINI_API_KEY' found in environment variables.") | |
| print(" (Set it via 'export GEMINI_API_KEY=your_key' in terminal)") | |
| return | |
| print("\nπ€ Sending context to Gemini LLM for optimization...") | |
| genai.configure(api_key=api_key) | |
| llm = genai.GenerativeModel('gemini-2.5-flash-lite') | |
| prompt = f""" | |
| You are a TikTok Virality Expert. | |
| My Draft Description: "{user_input}" | |
| Here are 3 successful, viral videos that are similar to my topic: | |
| 1. {top_3_videos[0]} | |
| 2. {top_3_videos[1]} | |
| 3. {top_3_videos[2]} | |
| Task: Rewrite my draft description to make it go viral. | |
| Use the slang, hashtag style, and structure of the successful examples provided. | |
| Keep it under 20 words plus hashtags. Return ONLY the new description. | |
| """ | |
| try: | |
| response = llm.generate_content(prompt) | |
| improved_idea = response.text.strip() | |
| print(f"\n⨠IMPROVED IDEA (By Gemini): {improved_idea}") | |
| # --- STEP 4: RE-EVALUATION --- | |
| new_text_vec = vectorizer.transform([improved_idea]).toarray() | |
| # Update hashtag count for new features | |
| new_meta_vec = np.array([[15, 18, 0, improved_idea.count('#')]]) | |
| new_feat_vec = np.hstack((new_text_vec, new_meta_vec)) | |
| new_log = model.predict(new_feat_vec)[0] | |
| new_views = int(np.expm1(new_log)) | |
| print(f"π New Predicted Views: {new_views:,}") | |
| improvement = ((new_views - initial_views) / initial_views) * 100 | |
| if improvement > 0: | |
| print(f"π POTENTIAL UPLIFT: +{improvement:.1f}%") | |
| else: | |
| print(f"π No significant uplift predicted (Model is strict!).") | |
| except Exception as e: | |
| print(f"β Error calling Gemini API: {e}") | |
| # --------------------------------------------------------- | |
| # MAIN EXECUTION | |
| # --------------------------------------------------------- | |
| if __name__ == "__main__": | |
| # 1. Pipeline | |
| df, _ = generate_enhanced_data(10000) | |
| X_train, X_test, y_train, y_test, _, tfidf = process_data_pipeline(df) | |
| # 2. Train Prediction Model | |
| best_model = train_best_model(X_train, y_train, X_test, y_test) | |
| # 3. Create Knowledge Base (Embeddings) | |
| knowledge_df, st_model = create_search_index(df) | |
| # 4. Save Artifacts for App | |
| print("\n[5/8] Saving Model Artifacts for Production...") | |
| best_model.save_model("viral_model.json") | |
| print(" - Model saved to 'viral_model.json'") | |
| with open("tfidf_vectorizer.pkl", "wb") as f: | |
| pickle.dump(tfidf, f) | |
| print(" - Vectorizer saved to 'tfidf_vectorizer.pkl'") | |
| # 5. User Interaction Loop | |
| while True: | |
| print("\n" + "-"*30) | |
| user_input = input("Enter your video idea (or 'q' to quit): ") | |
| if user_input.lower() == 'q': | |
| break | |
| optimize_content_with_gemini( | |
| user_input=user_input, | |
| model=best_model, | |
| vectorizer=tfidf, | |
| knowledge_df=knowledge_df, | |
| st_model=st_model | |
| ) |