odeyaaa's picture
chore: Remove "App initialized!" print statement from `app.py`.
997cd7e
import gradio as gr
import pandas as pd
import numpy as np
import os
import google.generativeai as genai
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from dotenv import load_dotenv
# Import functions from model-prep
from xgboost import XGBRegressor # Use Regressor as per model-prep
import pickle
from importlib.util import spec_from_file_location
import sys
# Since we are loading artifacts, we don't strictly need model-prep.py logic anymore.
# But keeping basic imports is fine.
# Load environment variables
load_dotenv()
# --- GLOBAL STATE ---
MODEL = None
VECTORIZER = None
KNOWLEDGE_DF = None
ST_MODEL = None
def initialize_app():
"""Initializes the model and data on app startup."""
global MODEL, VECTORIZER, KNOWLEDGE_DF, ST_MODEL
print("⏳ initializing app: Loading pre-computed artifacts...")
# 1. Load Parquet Data (Knowledge Base)
# We expect this file to exist now.
parquet_path = 'tiktok_knowledge_base.parquet'
if not os.path.exists(parquet_path):
raise FileNotFoundError(f"Required file '{parquet_path}' not found! Run model-prep.py first.")
print(f"πŸ“‚ Loading data from {parquet_path}...")
knowledge_df = pd.read_parquet(parquet_path)
# 2. Load Model
print("🧠 Loading XGBoost Model...")
model = XGBRegressor()
model.load_model("viral_model.json")
# 3. Load Vectorizer
print("πŸ”€ Loading TF-IDF Vectorizer...")
with open("tfidf_vectorizer.pkl", "rb") as f:
tfidf = pickle.load(f)
# 4. Load Sentence Transformer
print("πŸ”Œ Loading SentenceTransformer...")
# device=model_prep.device might fail if we don't import model_prep executed.
# Just use defaults or check pytorch standardly.
import torch
device = "mps" if torch.backends.mps.is_available() else "cpu"
st_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
MODEL = model
VECTORIZER = tfidf
KNOWLEDGE_DF = knowledge_df
ST_MODEL = st_model
print("βœ… App initialized (Inference Mode)!")
def predict_and_optimize(user_input):
if not user_input:
return "Please enter a video description.", "", "", "", ""
# --- 1. INITIAL PREDICTION ---
text_vec = VECTORIZER.transform([user_input]).toarray()
# Assume default meta: 15s duration, 18:00 (6 PM), weekday (0), hashtag count from input
meta_vec = np.array([[15, 18, 0, user_input.count('#')]])
feat_vec = np.hstack((text_vec, meta_vec))
initial_log = MODEL.predict(feat_vec)[0]
initial_views = int(np.expm1(initial_log))
# --- 2. VECTOR SEARCH ---
# Filter for viral hits in knowledge base (top 25%)
high_perf_df = KNOWLEDGE_DF[KNOWLEDGE_DF['views'] > KNOWLEDGE_DF['views'].quantile(0.75)].copy()
user_embedding = ST_MODEL.encode([user_input], convert_to_numpy=True)
target_embeddings = np.stack(high_perf_df['embedding'].values)
similarities = cosine_similarity(user_embedding, target_embeddings)
top_3_indices = similarities[0].argsort()[-3:][::-1]
top_3_videos = high_perf_df.iloc[top_3_indices]['description'].tolist()
similar_videos_str = "\n\n".join([f"{i+1}. {v}" for i, v in enumerate(top_3_videos)])
# --- 3. GEMINI OPTIMIZATION ---
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
return f"{initial_views:,}", similar_videos_str, "Error: GEMINI_API_KEY not found.", "N/A", "N/A"
genai.configure(api_key=api_key)
# Using the updated model from the user's latest change
try:
llm = genai.GenerativeModel('gemini-2.5-flash-lite')
except:
llm = genai.GenerativeModel('gemini-1.5-flash')
prompt = f"""
You are a TikTok Virality Expert.
My Draft Description: "{user_input}"
Here are 3 successful, viral videos that are similar to my topic:
1. {top_3_videos[0]}
2. {top_3_videos[1]}
3. {top_3_videos[2]}
Task: Rewrite my draft description to make it go viral and full video plan.
Use the slang, hashtag style, and structure of the successful examples provided.
Keep it under 60 words plus hashtags. Return ONLY the new description.
"""
try:
response = llm.generate_content(prompt)
improved_idea = response.text.strip()
# --- 4. RE-SCORING ---
new_text_vec = VECTORIZER.transform([improved_idea]).toarray()
new_meta_vec = np.array([[15, 18, 0, improved_idea.count('#')]])
new_feat_vec = np.hstack((new_text_vec, new_meta_vec))
new_log = MODEL.predict(new_feat_vec)[0]
new_views = int(np.expm1(new_log))
uplift_pct = ((new_views - initial_views) / initial_views) * 100
uplift_str = f"+{uplift_pct:.1f}%" if uplift_pct > 0 else "No significant uplift"
return f"{initial_views:,}", similar_videos_str, improved_idea, f"{new_views:,}", uplift_str
except Exception as e:
return f"{initial_views:,}", similar_videos_str, f"Error calling AI: {str(e)}", "N/A", "N/A"
# --- GRADIO UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸš€ Viral Content Optimizer")
gr.Markdown("Enter your video idea to predict its views and get AI-powered optimizations based on 2025 trends.")
with gr.Row():
with gr.Column(scale=1):
input_text = gr.Textbox(
label="Your Video Description",
placeholder="e.g., POV: trying the new grimace shake #viral",
lines=3
)
with gr.Row():
submit_btn = gr.Button("Analyze & Optimize ⚑", variant="primary")
demo_btn = gr.Button("🎲 Try Demo", variant="secondary")
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### πŸ“Š Predictions")
initial_views = gr.Textbox(label="Predicted Views (Original)", interactive=False)
with gr.Group():
gr.Markdown("### ✨ AI Optimization")
improved_text = gr.Textbox(label="Improved Description", interactive=False)
with gr.Row():
new_views = gr.Textbox(label="New Predicted Views", interactive=False)
uplift = gr.Textbox(label="Potential Uplift", interactive=False)
with gr.Accordion("πŸ” Similar Viral Videos (Reference)", open=False):
similar_videos = gr.Textbox(label="Top 3 Context Matches", interactive=False, lines=5)
submit_btn.click(
fn=predict_and_optimize,
inputs=[input_text],
outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
)
# Demo Button Logic: 1. Fill Text -> 2. Run Prediction
demo_text = "POV: You realize you forgot to turn off your mic during the all-hands meeting πŸ’€ #fail #fyp #corporate"
demo_btn.click(
fn=lambda: demo_text,
inputs=None,
outputs=input_text
).then(
fn=predict_and_optimize,
inputs=gr.State(demo_text), # Pass directly to avoid race condition with UI update
outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
)
# Run initialization
if __name__ == "__main__":
initialize_app()
demo.launch(server_name="0.0.0.0", server_port=7860)