File size: 7,401 Bytes
cf49347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import gradio as gr
import pandas as pd
import numpy as np
import os
import google.generativeai as genai
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from dotenv import load_dotenv

# Import functions from model-prep
from xgboost import XGBRegressor # Use Regressor as per model-prep
import pickle
from importlib.util import spec_from_file_location
import sys
# Since we are loading artifacts, we don't strictly need model-prep.py logic anymore.
# But keeping basic imports is fine.

# Load environment variables
load_dotenv()

# --- GLOBAL STATE ---
MODEL = None
VECTORIZER = None
KNOWLEDGE_DF = None
ST_MODEL = None

def initialize_app():
    """Initializes the model and data on app startup."""
    global MODEL, VECTORIZER, KNOWLEDGE_DF, ST_MODEL
    
    print("⏳ initializing app: Loading pre-computed artifacts...")
    
    # 1. Load Parquet Data (Knowledge Base)
    # We expect this file to exist now.
    parquet_path = 'tiktok_knowledge_base.parquet'
    if not os.path.exists(parquet_path):
        raise FileNotFoundError(f"Required file '{parquet_path}' not found! Run model-prep.py first.")
    
    print(f"πŸ“‚ Loading data from {parquet_path}...")
    knowledge_df = pd.read_parquet(parquet_path)
    
    # 2. Load Model
    print("🧠 Loading XGBoost Model...")
    model = XGBRegressor()
    model.load_model("viral_model.json")
    
    # 3. Load Vectorizer
    print("πŸ”€ Loading TF-IDF Vectorizer...")
    with open("tfidf_vectorizer.pkl", "rb") as f:
        tfidf = pickle.load(f)

    # 4. Load Sentence Transformer
    print("πŸ”Œ Loading SentenceTransformer...")
    # device=model_prep.device might fail if we don't import model_prep executed. 
    # Just use defaults or check pytorch standardly.
    import torch
    device = "mps" if torch.backends.mps.is_available() else "cpu"
    st_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
    
    MODEL = model
    VECTORIZER = tfidf
    KNOWLEDGE_DF = knowledge_df
    ST_MODEL = st_model
    print("βœ… App initialized (Inference Mode)!")

def predict_and_optimize(user_input):
    if not user_input:
        return "Please enter a video description.", "", "", "", ""
        
    # --- 1. INITIAL PREDICTION ---
    text_vec = VECTORIZER.transform([user_input]).toarray()
    # Assume default meta: 15s duration, 18:00 (6 PM), weekday (0), hashtag count from input
    meta_vec = np.array([[15, 18, 0, user_input.count('#')]]) 
    feat_vec = np.hstack((text_vec, meta_vec))
    
    initial_log = MODEL.predict(feat_vec)[0]
    initial_views = int(np.expm1(initial_log))
    
    # --- 2. VECTOR SEARCH ---
    # Filter for viral hits in knowledge base (top 25%)
    high_perf_df = KNOWLEDGE_DF[KNOWLEDGE_DF['views'] > KNOWLEDGE_DF['views'].quantile(0.75)].copy()
    
    user_embedding = ST_MODEL.encode([user_input], convert_to_numpy=True)
    target_embeddings = np.stack(high_perf_df['embedding'].values)
    
    similarities = cosine_similarity(user_embedding, target_embeddings)
    top_3_indices = similarities[0].argsort()[-3:][::-1]
    top_3_videos = high_perf_df.iloc[top_3_indices]['description'].tolist()
    
    similar_videos_str = "\n\n".join([f"{i+1}. {v}" for i, v in enumerate(top_3_videos)])

    # --- 3. GEMINI OPTIMIZATION ---
    api_key = os.getenv("GEMINI_API_KEY")
    if not api_key:
        return f"{initial_views:,}", similar_videos_str, "Error: GEMINI_API_KEY not found.", "N/A", "N/A"
        
    genai.configure(api_key=api_key)
    # Using the updated model from the user's latest change
    try: 
        llm = genai.GenerativeModel('gemini-2.5-flash-lite')
    except:
         llm = genai.GenerativeModel('gemini-1.5-flash')

    prompt = f"""
    You are a TikTok Virality Expert.
    
    My Draft Description: "{user_input}"
    
    Here are 3 successful, viral videos that are similar to my topic:
    1. {top_3_videos[0]}
    2. {top_3_videos[1]}
    3. {top_3_videos[2]}
    
    Task: Rewrite my draft description to make it go viral and full video plan. 
    Use the slang, hashtag style, and structure of the successful examples provided.
    Keep it under 60 words plus hashtags. Return ONLY the new description.
    """
    
    try:
        response = llm.generate_content(prompt)
        improved_idea = response.text.strip()
        
        # --- 4. RE-SCORING ---
        new_text_vec = VECTORIZER.transform([improved_idea]).toarray()
        new_meta_vec = np.array([[15, 18, 0, improved_idea.count('#')]])
        new_feat_vec = np.hstack((new_text_vec, new_meta_vec))
        
        new_log = MODEL.predict(new_feat_vec)[0]
        new_views = int(np.expm1(new_log))
        
        uplift_pct = ((new_views - initial_views) / initial_views) * 100
        uplift_str = f"+{uplift_pct:.1f}%" if uplift_pct > 0 else "No significant uplift"
        
        return f"{initial_views:,}", similar_videos_str, improved_idea, f"{new_views:,}", uplift_str
        
    except Exception as e:
        return f"{initial_views:,}", similar_videos_str, f"Error calling AI: {str(e)}", "N/A", "N/A"


# --- GRADIO UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# πŸš€ Viral Content Optimizer")
    gr.Markdown("Enter your video idea to predict its views and get AI-powered optimizations based on 2025 trends.")
    
    with gr.Row():
        with gr.Column(scale=1):
            input_text = gr.Textbox(
                label="Your Video Description", 
                placeholder="e.g., POV: trying the new grimace shake #viral",
                lines=3
            )
            with gr.Row():
                submit_btn = gr.Button("Analyze & Optimize ⚑", variant="primary")
                demo_btn = gr.Button("🎲 Try Demo", variant="secondary")
            
        with gr.Column(scale=1):
            with gr.Group():
                gr.Markdown("### πŸ“Š Predictions")
                initial_views = gr.Textbox(label="Predicted Views (Original)", interactive=False)
            
            with gr.Group():
                gr.Markdown("### ✨ AI Optimization")
                improved_text = gr.Textbox(label="Improved Description", interactive=False)
                with gr.Row():
                    new_views = gr.Textbox(label="New Predicted Views", interactive=False)
                    uplift = gr.Textbox(label="Potential Uplift", interactive=False)
    
    with gr.Accordion("πŸ” Similar Viral Videos (Reference)", open=False):
        similar_videos = gr.Textbox(label="Top 3 Context Matches", interactive=False, lines=5)

    submit_btn.click(
        fn=predict_and_optimize,
        inputs=[input_text],
        outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
    )

    # Demo Button Logic: 1. Fill Text -> 2. Run Prediction
    demo_text = "POV: You realize you forgot to turn off your mic during the all-hands meeting πŸ’€ #fail #fyp #corporate"
    demo_btn.click(
        fn=lambda: demo_text,
        inputs=None,
        outputs=input_text
    ).then(
        fn=predict_and_optimize,
        inputs=gr.State(demo_text), # Pass directly to avoid race condition with UI update
        outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
    )

# Run initialization
if __name__ == "__main__":
    initialize_app()
    demo.launch(server_name="0.0.0.0", server_port=7860)