File size: 3,267 Bytes
96b47db
879ede5
 
 
 
9997737
450331a
9997737
 
96b47db
58ec65e
879ede5
 
bb866b9
879ede5
0c8eab5
879ede5
0c8eab5
879ede5
0c8eab5
 
879ede5
0c8eab5
2f9170f
 
 
0c8eab5
 
 
 
 
 
 
 
879ede5
0c8eab5
 
879ede5
0c8eab5
 
 
 
 
 
bb866b9
0c8eab5
25168f2
0c8eab5
 
 
25168f2
0c8eab5
25168f2
0c8eab5
 
 
 
 
 
 
 
 
879ede5
 
0c8eab5
 
 
 
 
 
 
879ede5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
---
title: Social Media Virality Assistant
emoji: πŸš€
colorFrom: indigo
colorTo: purple
sdk: gradio
sdk_version: 5.9.0
app_file: app.py
pinned: false
---

# πŸš€ Social Media Virality Assistant

A machine learning-powered tool that helps content creators predict and optimize their video virality potential using trained **XGBoost** model and **Google Gemini AI**.

## πŸ—οΈ Architecture & Pipeline

This project consists of two main components: a training pipeline (`model-prep.py`) and an inference application (`app.py`).

### 1. Training Pipeline (`model-prep.py`)
the `model-prep.py` script handles the end-to-end model creation process:

1.  **Cloud Data Loading**: It fetches the latest synthetic dataset directly from **Hugging Face** (`MatanKriel/social-assitent-synthetic-data`).
2.  **Embedding Benchmark**: It evaluates 3 state-of-the-art models (`MiniLM`, `mpnet-base`, `bge-small`) using **Silhouette Score** on **Composite Labels** (`Category_ViralClass`).
    *   *Why?* Instead of just clustering by topic (e.g., "Gaming"), this forces the model to distinguish between "Viral Gaming Videos" and "Average Gaming Videos".
    *   *Selection*: Automatically picks the best model for this high-resolution task.
3.  **Feature Engineering**:
    *   Encodes categorical inputs: `category`, `gender`, `day_of_week`, `age`.
    *   Combines text embeddings with metadata (`followers`, `duration`, `hour`).
4.  **Model Training**: Trains and compares three regression algorithms:
    *   Linear Regression
    *   Random Forest
    *   **XGBoost (Winner)**: Selected for having the lowest RMSE.
5.  **Artifact Generation**: Saves the trained model locally (`viral_model.pkl`) and generates performance plots (`project_plots/`).

### 2. Inference Application (`app.py`)
The `app.py` script runs a **Gradio** web interface that pulls artifacts from the cloud at startup:

1.  **Initialization**:
    *   Downloads the trained `viral_model.pkl` from Hugging Face (`MatanKriel/social-assitent-viral-predictor`).
    *   Downloads the dataset to build a Knowledge Base.
    *   Generates embeddings on-the-fly for the Knowledge Base.
2.  **Core Features**:
    *   **Virality Prediction**: Predicts raw view counts based on your draft description and stats.
    *   **AI Optimization**: Uses **Google Gemini** to rewrite your description with viral hooks and hashtags with the context of top 3 similar videos from the dataset.
    *   **Semantic Search**: Finds similar successful videos from the knowledge base using Cosine Similarity.

---

## πŸ“Š Model Performance

The training script (`model-prep.py`) automatically generates these benchmarks:

### Embedding Model Comparison
We selected the embedding model that best balances speed and semantic understanding.
![Embedding Benchmark](project_plots/embedding_benchmark.png)

### Regression Model Comparison
We chose the regressor with the lowest error (RMSE) and highest explained variance (RΒ²).
![Model Comparison](project_plots/regression_comparison.png)

---

## πŸ› οΈ Tech Stack
This project is built using:
*   **App**: `gradio`, `google-generativeai`
*   **ML**: `xgboost`, `scikit-learn`, `sentence-transformers`
*   **Data**: `pandas`, `numpy`
*   **Cloud**: `huggingface_hub`, `datasets`

---