social-assistent / model-prep.py
odeyaaa's picture
Upload 10 files
cf49347 verified
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
import os
import torch
import google.generativeai as genai
from faker import Faker
from datetime import datetime, timedelta
from sklearn.metrics.pairwise import cosine_similarity
import pickle
from dotenv import load_dotenv
# Load environment variables from the .env filea monk
load_dotenv()
# Machine Learning Imports
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.ensemble import RandomForestRegressor
from xgboost import XGBRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, f1_score
from sklearn.decomposition import PCA
from sentence_transformers import SentenceTransformer
# ---------------------------------------------------------
# 0. SETUP & CONFIGURATION
# ---------------------------------------------------------
warnings.filterwarnings('ignore')
pd.set_option('display.max_columns', None)
# OPTIMIZATION: Check for Apple Silicon (MPS)
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"πŸš€ Optimization: Running on {device.upper()} device")
if not os.path.exists('project_plots'):
os.makedirs('project_plots')
# ---------------------------------------------------------
# 1. DATA GENERATION (With 2025 Trends)
# ---------------------------------------------------------
def generate_enhanced_data(n_rows=10000):
print(f"\n[1/8] Generating {n_rows} rows of Real-World 2025 Data...")
fake = Faker()
trends = [
'Delulu', 'Girl Dinner', 'Roman Empire', 'Silent Slay', 'Soft Life',
'Grimace Shake', 'Wes Anderson Style', 'Beige Flag', 'Canon Event',
'NPC Stream', 'Skibidi', 'Fanum Tax', 'Yapping', 'Glow Up', 'Fit Check'
]
formats = [
'POV: You realize...', 'GRWM for...', 'Day in the life:',
'Storytime:', 'Trying the viral...', 'ASMR packing orders',
'Rating my exes...', 'Turn the lights off challenge'
]
categories = ['Gaming', 'Beauty', 'Comedy', 'Edutainment', 'Lifestyle', 'Food']
data = []
start_date = datetime(2024, 1, 1)
for _ in range(n_rows):
upload_time = start_date + timedelta(days=np.random.randint(0, 365), hours=np.random.randint(0, 23))
trend = np.random.choice(trends)
fmt = np.random.choice(formats)
cat = np.random.choice(categories)
description = f"{fmt} {trend} edition! {fake.sentence(nb_words=6)}"
tags = ['#fyp', '#foryou', '#viral', f'#{trend.replace(" ", "").lower()}', f'#{cat.lower()}']
if np.random.random() > 0.5: tags.append('#trending2025')
full_text = f"{description} {' '.join(tags)}"
# Meta Features
duration = np.random.randint(5, 180)
hour = upload_time.hour
is_weekend = 1 if upload_time.weekday() >= 5 else 0
# View Count Logic
base_virality = np.random.lognormal(mean=9.5, sigma=1.8)
multiplier = 1.0
if is_weekend: multiplier *= 1.2
if duration < 15: multiplier *= 1.4
if "Delulu" in full_text or "POV" in full_text: multiplier *= 1.6
if hour >= 18: multiplier *= 1.1
views = int(base_virality * multiplier)
data.append({
'upload_date': upload_time,
'description': full_text,
'category': cat,
'video_duration_sec': duration,
'hour_of_day': hour,
'is_weekend': is_weekend,
'hashtag_count': len(tags),
'views': views
})
df = pd.DataFrame(data)
df = df.sort_values('upload_date').reset_index(drop=True)
threshold = df['views'].quantile(0.80)
df['is_viral_binary'] = (df['views'] > threshold).astype(int)
df['log_views'] = np.log1p(df['views'])
return df, threshold
# ---------------------------------------------------------
# 2. EDA & PREPROCESSING
# ---------------------------------------------------------
def process_data_pipeline(df):
print("\n[2/8] Processing Data Pipeline...")
# Simple EDA Save
clean_df = df[df['video_duration_sec'] > 0].copy()
plt.figure(figsize=(6,4))
sns.histplot(clean_df['log_views'], color='teal')
plt.title('Log Views Distribution')
plt.savefig('project_plots/eda_distribution.png')
plt.close()
# TF-IDF & Split
tfidf = TfidfVectorizer(max_features=2000, stop_words='english')
X_text = tfidf.fit_transform(df['description']).toarray()
num_cols = ['video_duration_sec', 'hour_of_day', 'is_weekend', 'hashtag_count']
X_num = df[num_cols].values
X = np.hstack((X_text, X_num))
y = df['log_views'].values
y_bin = df['is_viral_binary'].values
split_idx = int(len(df) * 0.80)
return X[:split_idx], X[split_idx:], y[:split_idx], y[split_idx:], y_bin[split_idx:], tfidf
# ---------------------------------------------------------
# 3. TRAINING
# ---------------------------------------------------------
def train_best_model(X_train, y_train, X_test, y_test):
print("\n[3/8] Training Model (XGBoost)...")
model = XGBRegressor(n_estimators=100, learning_rate=0.1, max_depth=6, n_jobs=-1)
model.fit(X_train, y_train)
rmse = np.sqrt(mean_squared_error(y_test, model.predict(X_test)))
print(f" - Model RMSE: {rmse:.3f}")
return model
# ---------------------------------------------------------
# 4. EMBEDDINGS GENERATION (For Search)
# ---------------------------------------------------------
def create_search_index(df):
print("\n[4/8] Creating Vector Search Index...")
# Generate embeddings for ALL data so we can search the whole history
st_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
embeddings = st_model.encode(df['description'].tolist(), convert_to_numpy=True, show_progress_bar=True)
df['embedding'] = list(embeddings)
# Save to Parquet (The Knowledge Base)
save_path = 'tiktok_knowledge_base.parquet'
df.to_parquet(save_path)
print(f" - Knowledge Base saved to {save_path}")
return df, st_model
# ---------------------------------------------------------
# 5. RETRIEVAL & IMPROVEMENT ENGINE (The Magic Step)
# ---------------------------------------------------------
def optimize_content_with_gemini(user_input, model, vectorizer, knowledge_df, st_model):
"""
1. Scores original idea.
2. Finds top 3 similar VIRAL videos.
3. Asks Gemini to rewrite the idea.
4. Re-scores the new idea.
"""
print("\n" + "="*50)
print("πŸš€ VIRAL OPTIMIZATION ENGINE")
print("="*50)
# --- STEP 1: INITIAL SCORE ---
text_vec = vectorizer.transform([user_input]).toarray()
# Assume default meta for prediction (15s, 6 PM, weekday)
meta_vec = np.array([[15, 18, 0, user_input.count('#')]])
feat_vec = np.hstack((text_vec, meta_vec))
initial_log = model.predict(feat_vec)[0]
initial_views = int(np.expm1(initial_log))
print(f"\nπŸ“ ORIGINAL IDEA: {user_input}")
print(f"πŸ“Š Predicted Views: {initial_views:,}")
# --- STEP 2: VECTOR SEARCH (Find similar successful videos) ---
print("\nπŸ” Searching for similar viral hits in Parquet file...")
# Filter only for successful videos (e.g., top 25% of views)
high_performance_df = knowledge_df[knowledge_df['views'] > knowledge_df['views'].quantile(0.75)].copy()
# Encode user input
user_embedding = st_model.encode([user_input], convert_to_numpy=True)
# Stack embeddings from the dataframe into a matrix
target_embeddings = np.stack(high_performance_df['embedding'].values)
# Calculate Cosine Similarity
similarities = cosine_similarity(user_embedding, target_embeddings)
# Get Top 3 indices
top_3_indices = similarities[0].argsort()[-3:][::-1]
top_3_videos = high_performance_df.iloc[top_3_indices]['description'].tolist()
print(" -> Found 3 similar viral videos to learn from:")
for i, vid in enumerate(top_3_videos, 1):
print(f" {i}. {vid[:80]}...")
# --- STEP 3: GEMINI OPTIMIZATION ---
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
print("\n⚠️ SKIPPING AI REWRITE: No 'GEMINI_API_KEY' found in environment variables.")
print(" (Set it via 'export GEMINI_API_KEY=your_key' in terminal)")
return
print("\nπŸ€– Sending context to Gemini LLM for optimization...")
genai.configure(api_key=api_key)
llm = genai.GenerativeModel('gemini-2.5-flash-lite')
prompt = f"""
You are a TikTok Virality Expert.
My Draft Description: "{user_input}"
Here are 3 successful, viral videos that are similar to my topic:
1. {top_3_videos[0]}
2. {top_3_videos[1]}
3. {top_3_videos[2]}
Task: Rewrite my draft description to make it go viral.
Use the slang, hashtag style, and structure of the successful examples provided.
Keep it under 20 words plus hashtags. Return ONLY the new description.
"""
try:
response = llm.generate_content(prompt)
improved_idea = response.text.strip()
print(f"\n✨ IMPROVED IDEA (By Gemini): {improved_idea}")
# --- STEP 4: RE-EVALUATION ---
new_text_vec = vectorizer.transform([improved_idea]).toarray()
# Update hashtag count for new features
new_meta_vec = np.array([[15, 18, 0, improved_idea.count('#')]])
new_feat_vec = np.hstack((new_text_vec, new_meta_vec))
new_log = model.predict(new_feat_vec)[0]
new_views = int(np.expm1(new_log))
print(f"πŸ“Š New Predicted Views: {new_views:,}")
improvement = ((new_views - initial_views) / initial_views) * 100
if improvement > 0:
print(f"πŸš€ POTENTIAL UPLIFT: +{improvement:.1f}%")
else:
print(f"😐 No significant uplift predicted (Model is strict!).")
except Exception as e:
print(f"❌ Error calling Gemini API: {e}")
# ---------------------------------------------------------
# MAIN EXECUTION
# ---------------------------------------------------------
if __name__ == "__main__":
# 1. Pipeline
df, _ = generate_enhanced_data(10000)
X_train, X_test, y_train, y_test, _, tfidf = process_data_pipeline(df)
# 2. Train Prediction Model
best_model = train_best_model(X_train, y_train, X_test, y_test)
# 3. Create Knowledge Base (Embeddings)
knowledge_df, st_model = create_search_index(df)
# 4. Save Artifacts for App
print("\n[5/8] Saving Model Artifacts for Production...")
best_model.save_model("viral_model.json")
print(" - Model saved to 'viral_model.json'")
with open("tfidf_vectorizer.pkl", "wb") as f:
pickle.dump(tfidf, f)
print(" - Vectorizer saved to 'tfidf_vectorizer.pkl'")
# 5. User Interaction Loop
while True:
print("\n" + "-"*30)
user_input = input("Enter your video idea (or 'q' to quit): ")
if user_input.lower() == 'q':
break
optimize_content_with_gemini(
user_input=user_input,
model=best_model,
vectorizer=tfidf,
knowledge_df=knowledge_df,
st_model=st_model
)