File size: 11,439 Bytes
58ec65e
 
 
 
 
 
 
33daa3c
58ec65e
33daa3c
 
58ec65e
 
 
 
 
 
 
 
 
33daa3c
58ec65e
 
33daa3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb866b9
33daa3c
7803d6a
 
33daa3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58ec65e
2f9170f
33daa3c
 
58ec65e
 
33daa3c
 
bb866b9
33daa3c
 
 
 
 
 
 
 
58ec65e
 
 
 
 
 
33daa3c
58ec65e
 
 
33daa3c
 
 
 
 
 
7803d6a
33daa3c
 
 
7803d6a
33daa3c
 
 
 
 
7803d6a
 
 
 
 
 
 
 
 
33daa3c
 
 
 
 
 
7803d6a
33daa3c
f56b40a
7803d6a
 
f56b40a
58ec65e
 
 
 
 
33daa3c
58ec65e
 
 
 
 
 
 
 
 
 
33daa3c
58ec65e
 
 
 
 
 
 
 
 
 
 
 
 
33daa3c
 
 
58ec65e
33daa3c
58ec65e
 
 
 
18699f2
a840261
33daa3c
a840261
18699f2
 
 
a840261
 
 
 
 
58ec65e
a840261
58ec65e
 
 
 
33daa3c
 
7803d6a
33daa3c
58ec65e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7803d6a
58ec65e
 
 
 
7803d6a
 
58ec65e
 
f56b40a
 
33daa3c
 
 
 
 
f56b40a
33daa3c
 
 
 
 
 
 
 
 
 
 
f56b40a
58ec65e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33daa3c
58ec65e
 
 
33daa3c
58ec65e
bb866b9
58ec65e
33daa3c
58ec65e
 
bb866b9
58ec65e
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
import gradio as gr
import pandas as pd
import numpy as np
import os
import google.generativeai as genai
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import LabelEncoder
from dotenv import load_dotenv
from datasets import load_dataset
from huggingface_hub import hf_hub_download
import pickle

# Load environment variables
load_dotenv()

# --- GLOBAL STATE ---
MODEL = None
KNOWLEDGE_DF = None
ST_MODEL = None
ENCODERS = {} # To store label encoders

def initialize_app():
    """Initializes the model and data on app startup from Hugging Face."""
    global MODEL, KNOWLEDGE_DF, ST_MODEL, ENCODERS
    
    print("⏳ initializing app: Loading resources from Hugging Face Cloud...")
    
    # 1. Load Dataset from HF
    print("πŸ“‚ Downloading Dataset (MatanKriel/social-assitent-synthetic-data)...")
    try:
        dataset = load_dataset("MatanKriel/social-assitent-synthetic-data")
        if 'train' in dataset:
            knowledge_df = dataset['train'].to_pandas()
        else:
            knowledge_df = dataset.to_pandas()
        print(f"    -> Loaded {len(knowledge_df)} rows.")
    except Exception as e:
        raise RuntimeError(f"Failed to load dataset from HF: {e}")
    
    # 2. FIT ENCODERS (For Feature Consistency)
    print("πŸ”€ Fitting Label Encoders...")
    # UPDATED: 'age' removed from here, treated as numeric
    cat_cols = ['category', 'gender', 'day_of_week'] 
    for c in cat_cols:
        if c in knowledge_df.columns:
            le = LabelEncoder()
            # Ensure all values are strings
            le.fit(knowledge_df[c].astype(str))
            ENCODERS[c] = le
            print(f"    -> Encoded '{c}': {len(le.classes_)} classes")
        else:
            print(f"    ⚠️ Warning: Column '{c}' missing from dataset!")
            
    # 3. Load Model from HF
    print("🧠 Downloading Model (MatanKriel/social-assitent-viral-predictor)...")
    try:
        model_path = hf_hub_download(repo_id="MatanKriel/social-assitent-viral-predictor", filename="viral_model.pkl")
        with open(model_path, "rb") as f:
            model = pickle.load(f)
        print(f"    -> Loaded model: {type(model).__name__}")
    except Exception as e:
        # Fallback to local
        if os.path.exists("viral_model.pkl"):
            print(f"    ⚠️ HF Download failed ({e}). Loading local 'viral_model.pkl' instead.")
            with open("viral_model.pkl", "rb") as f:
                model = pickle.load(f)
        else:
            raise RuntimeError(f"Failed to load model from HF and no local backup found: {e}")
    
    # 4. Load SentenceTransformer
    print("πŸ”Œ Loading SentenceTransformer...")
    embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
    print(f"    -> Model: {embedding_model_name}")
    
    import torch
    device = "mps" if torch.backends.mps.is_available() else "cpu"
    st_model = SentenceTransformer(embedding_model_name, device=device)
    
    # 5. Generate Embeddings 
    print("βš™οΈ Generating Embeddings for Knowledge Base (One-time setup)...")
    if 'embedding' not in knowledge_df.columns:
        embeddings = st_model.encode(knowledge_df['description'].fillna("").tolist(), 
                                   convert_to_numpy=True, 
                                   show_progress_bar=True)
        knowledge_df['embedding'] = list(embeddings)
    else:
        print("    -> Embeddings already present in dataset.")
    
    MODEL = model
    KNOWLEDGE_DF = knowledge_df
    ST_MODEL = st_model
    print("βœ… App initialized (Inference Mode)!")

def predict_and_optimize(user_input, duration, hour, day_of_week, category, followers, age, gender):
    if not user_input:
        return "Please enter a video description.", "", "", "", ""
        
    # --- 1. ENCODE INPUTS ---
    try:
        # Helper to encode safely
        def safe_encode(col, val):
            le = ENCODERS.get(col)
            if le:
                # If value not seen, default to first class 
                if val in le.classes_:
                    return le.transform([val])[0]
                else:
                    return 0 
            return 0

        cat_encoded = safe_encode('category', category)
        gender_encoded = safe_encode('gender', gender)
        day_encoded = safe_encode('day_of_week', day_of_week)
        
        # FIX: Map Age String to Numeric
        age_map = {
            "18-24": 21.0, 
            "25-34": 30.0, 
            "35-44": 40.0, 
            "45+": 50.0
        }
        age_numeric = age_map.get(str(age), 25.0) # Default to 25 if unknown
        
    except Exception as e:
        return f"Encoding Error: {str(e)}", "", "", "", ""

    # --- 2. INITIAL PREDICTION ---
    # Feature Order MUST match model-prep.py:
    # Embeddings + [duration, hour, followers, age_numeric, category_enc, gender_enc, day_enc]
    text_vec = ST_MODEL.encode([user_input], convert_to_numpy=True)
    
    # Construct metadata vector
    meta_vec = np.array([[duration, hour, followers, age_numeric, cat_encoded, gender_encoded, day_encoded]])
    
    feat_vec = np.hstack((text_vec, meta_vec))
    
    initial_log = MODEL.predict(feat_vec)[0]
    initial_views = int(np.expm1(initial_log))
    
    # --- 3. VECTOR SEARCH ---
    high_perf_df = KNOWLEDGE_DF[KNOWLEDGE_DF['views'] > KNOWLEDGE_DF['views'].quantile(0.75)].copy()
    user_embedding = ST_MODEL.encode([user_input], convert_to_numpy=True)
    target_embeddings = np.stack(high_perf_df['embedding'].values)
    
    similarities = cosine_similarity(user_embedding, target_embeddings)
    top_3_indices = similarities[0].argsort()[-3:][::-1]
    top_3_videos = high_perf_df.iloc[top_3_indices]['description'].tolist()
    
    similar_videos_str = "\n\n".join([f"{i+1}. {v}" for i, v in enumerate(top_3_videos)])

    # --- 4. GEMINI OPTIMIZATION ---
    api_key = os.getenv("GEMINI_API_KEY")
    if not api_key:
        return f"{initial_views:,}", similar_videos_str, "Error: GEMINI_API_KEY not found.", "N/A", "N/A"
        
    genai.configure(api_key=api_key)
    try: 
        llm = genai.GenerativeModel('gemini-2.5-flash-lite')
    except:
         llm = genai.GenerativeModel('gemini-1.5-flash')

    prompt = f"""
    You are a TikTok Virality Expert.
    
    Draft: "{user_input}"
    Niche: {category} | Creator: {age}, {gender} with {followers} followers.
    Context: {duration}s video posted on {day_of_week} at {hour}:00.
    
    Viral Examples in this niche:
    1. {top_3_videos[0]}
    2. {top_3_videos[1]}
    3. {top_3_videos[2]}
    
    Task: 
    Rewrite the draft to be more viral. Add hooks and hashtags two hashtags MAX.
    Keep it natural and relevant to the creator persona.
    Make it short catchy and viral.
    
    Output Format:
    [New Description]

    Recomendations:
    [Upload day of week]
    [Upload hour]
    [Upload duration]
    """

    try:
        response = llm.generate_content(prompt)
        improved_idea = response.text.strip()
        
        # --- 5. RE-SCORING ---
        new_text_vec = ST_MODEL.encode([improved_idea], convert_to_numpy=True)
        # Using same metadata for the new prediction
        new_feat_vec = np.hstack((new_text_vec, meta_vec))
        
        new_log = MODEL.predict(new_feat_vec)[0]
        new_views = int(np.expm1(new_log))
        
        uplift_pct = ((new_views - initial_views) / initial_views) * 100
        uplift_str = f"+{uplift_pct:.1f}%" if uplift_pct > 0 else "No significant uplift"
        
        return f"{initial_views:,}", similar_videos_str, improved_idea, f"{new_views:,}", uplift_str
        
    except Exception as e:
        return f"{initial_views:,}", similar_videos_str, f"Error calling AI: {str(e)}", "N/A", "N/A"


# --- GRADIO UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# πŸš€ Viral Content Optimizer")
    gr.Markdown("Enter your video idea and stats to predict views and get AI-powered optimizations based on 2025 trends.")
    
    with gr.Row():
        with gr.Column(scale=1):
            input_text = gr.Textbox(
                label="Your Video Description", 
                placeholder="e.g., POV: trying the new grimace shake #viral",
                lines=3
            )
            
            with gr.Row():
                category = gr.Dropdown(
                    choices=["Entertainment", "Gaming", "Fitness", "Food", "Beauty", "Tech", "Travel", "Education", "Fashion", "Health", "DIY", "Pranks"],
                    value="Entertainment", label="Niche"
                )
                followers = gr.Number(value=1000, label="Follower Count", precision=0)
            
            with gr.Row():
                age = gr.Dropdown(choices=["18-24", "25-34", "35-44", "45+"], value="18-24", label="Creator Age")
                gender = gr.Dropdown(choices=["Male", "Female"], value="Female", label="Creator Gender")
            
            with gr.Row():
                duration_slider = gr.Slider(minimum=5, maximum=60, value=15, step=1, label="Duration (s)")
                hour_slider = gr.Slider(minimum=0, maximum=23, value=18, step=1, label="Upload Hour")
                day_dropdown = gr.Dropdown(
                    choices=["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"], 
                    value="Friday", label="Day"
                )
            
            with gr.Row():
                submit_btn = gr.Button("Analyze & Optimize ⚑", variant="primary")
                demo_btn = gr.Button("🎲 Try Demo", variant="secondary")
            
        with gr.Column(scale=1):
            with gr.Group():
                gr.Markdown("### πŸ“Š Predictions")
                initial_views = gr.Textbox(label="Predicted Views (Original)", interactive=False)
            
            with gr.Group():
                gr.Markdown("### ✨ AI Optimization")
                improved_text = gr.Textbox(label="Improved Description", interactive=False)
                with gr.Row():
                    new_views = gr.Textbox(label="New Predicted Views", interactive=False)
                    uplift = gr.Textbox(label="Potential Uplift", interactive=False)
    
    with gr.Accordion("πŸ” Similar Viral Videos (Reference)", open=False):
        similar_videos = gr.Textbox(label="Top 3 Context Matches", interactive=False, lines=5)

    submit_btn.click(
        fn=predict_and_optimize,
        inputs=[input_text, duration_slider, hour_slider, day_dropdown, category, followers, age, gender],
        outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
    )

    # Demo Button Logic
    demo_btn.click(
        fn=lambda: ("My protein shake ended up on the floor", 15, 19, "Monday", "Fitness", 50000, "18-24", "Male"),
        inputs=None,
        outputs=[input_text, duration_slider, hour_slider, day_dropdown, category, followers, age, gender]
    ).then(
        fn=predict_and_optimize,
        inputs=[gr.State("My protein shake ended up on the floor"), gr.State(15), gr.State(19), gr.State("Monday"), gr.State("Fitness"), gr.State(50000), gr.State("18-24"), gr.State("Male")],
        outputs=[initial_views, similar_videos, improved_text, new_views, uplift]
    )

# Run initialization
if __name__ == "__main__":
    initialize_app()
    demo.launch(server_name="0.0.0.0", server_port=7860)