Spaces:
Sleeping
Sleeping
Matan Kriel
commited on
Commit
·
2f9170f
1
Parent(s):
7803d6a
updated clustering metric in model test
Browse files- README.md +3 -3
- app.py +1 -1
- model-prep.py +28 -6
- project_plots/embedding_benchmark.png +2 -2
- project_plots/regression_comparison.png +2 -2
- viral_model.pkl +2 -2
README.md
CHANGED
|
@@ -21,9 +21,9 @@ This project consists of two main components: a training pipeline (`model-prep.p
|
|
| 21 |
the `model-prep.py` script handles the end-to-end model creation process:
|
| 22 |
|
| 23 |
1. **Cloud Data Loading**: It fetches the latest synthetic dataset directly from **Hugging Face** (`MatanKriel/social-assitent-synthetic-data`).
|
| 24 |
-
2. **Embedding Benchmark**: It evaluates 3 state-of-the-art models (`MiniLM`, `mpnet-base`, `bge-small`)
|
| 25 |
-
* *
|
| 26 |
-
* *
|
| 27 |
3. **Feature Engineering**:
|
| 28 |
* Encodes categorical inputs: `category`, `gender`, `day_of_week`, `age`.
|
| 29 |
* Combines text embeddings with metadata (`followers`, `duration`, `hour`).
|
|
|
|
| 21 |
the `model-prep.py` script handles the end-to-end model creation process:
|
| 22 |
|
| 23 |
1. **Cloud Data Loading**: It fetches the latest synthetic dataset directly from **Hugging Face** (`MatanKriel/social-assitent-synthetic-data`).
|
| 24 |
+
2. **Embedding Benchmark**: It evaluates 3 state-of-the-art models (`MiniLM`, `mpnet-base`, `bge-small`) using **Silhouette Score** on **Composite Labels** (`Category_ViralClass`).
|
| 25 |
+
* *Why?* Instead of just clustering by topic (e.g., "Gaming"), this forces the model to distinguish between "Viral Gaming Videos" and "Average Gaming Videos".
|
| 26 |
+
* *Selection*: Automatically picks the best model for this high-resolution task.
|
| 27 |
3. **Feature Engineering**:
|
| 28 |
* Encodes categorical inputs: `category`, `gender`, `day_of_week`, `age`.
|
| 29 |
* Combines text embeddings with metadata (`followers`, `duration`, `hour`).
|
app.py
CHANGED
|
@@ -70,7 +70,7 @@ def initialize_app():
|
|
| 70 |
|
| 71 |
# 4. Load SentenceTransformer
|
| 72 |
print("🔌 Loading SentenceTransformer...")
|
| 73 |
-
embedding_model_name = "sentence-transformers/all-
|
| 74 |
print(f" -> Model: {embedding_model_name}")
|
| 75 |
|
| 76 |
import torch
|
|
|
|
| 70 |
|
| 71 |
# 4. Load SentenceTransformer
|
| 72 |
print("🔌 Loading SentenceTransformer...")
|
| 73 |
+
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
|
| 74 |
print(f" -> Model: {embedding_model_name}")
|
| 75 |
|
| 76 |
import torch
|
model-prep.py
CHANGED
|
@@ -68,16 +68,38 @@ def benchmark_and_select_model(df):
|
|
| 68 |
|
| 69 |
results = []
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
else:
|
| 75 |
-
|
| 76 |
|
| 77 |
-
# Sample for speed
|
| 78 |
sample_df = df.sample(min(len(df), 3000), random_state=42)
|
| 79 |
sample_texts = sample_df['description'].fillna("").tolist()
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
print(f"{'Model':<40} | {'Time (s)':<10} | {'Silhouette':<10}")
|
| 83 |
print("-" * 65)
|
|
|
|
| 68 |
|
| 69 |
results = []
|
| 70 |
|
| 71 |
+
# Create Composite Labels for Silhouette Score
|
| 72 |
+
# Goal: Use "Category_ViralClass" (e.g., "Fitness_High") to measure separation
|
| 73 |
+
|
| 74 |
+
# 1. Ensure viral_class exists for benchmarking
|
| 75 |
+
if 'viral_class' not in df.columns and 'views' in df.columns:
|
| 76 |
+
threshold = df['views'].quantile(0.75)
|
| 77 |
+
df['viral_class'] = np.where(df['views'] > threshold, 'High', 'Low')
|
| 78 |
+
print(f" -> ℹ️ Created temporary 'viral_class' (High/Low) for benchmarking.")
|
| 79 |
+
|
| 80 |
+
# 2. Define Labels
|
| 81 |
+
if 'category' in df.columns and 'viral_class' in df.columns:
|
| 82 |
+
print(" -> 🏷️ Using Composite Labels (Category + Viral Class) for metrics.")
|
| 83 |
+
# We need to perform this on the SAMPLE, not the whole DF if we sample later.
|
| 84 |
+
# But to be safe, let's just use the column if it exists.
|
| 85 |
+
pass # Logic handled after sampling
|
| 86 |
+
elif 'category' in df.columns:
|
| 87 |
+
print(" -> ⚠️ 'viral_class' missing. Falling back to 'category' only.")
|
| 88 |
else:
|
| 89 |
+
print(" -> ⚠️ No categories found. Skipping quality metric.")
|
| 90 |
|
| 91 |
+
# Sample for speed (using the updated df which might have viral_class)
|
| 92 |
sample_df = df.sample(min(len(df), 3000), random_state=42)
|
| 93 |
sample_texts = sample_df['description'].fillna("").tolist()
|
| 94 |
+
|
| 95 |
+
if 'category' in sample_df.columns and 'viral_class' in sample_df.columns:
|
| 96 |
+
# Composite Label Formula
|
| 97 |
+
sample_labels = sample_df['category'].astype(str) + "_" + sample_df['viral_class'].astype(str)
|
| 98 |
+
sample_labels = sample_labels.values
|
| 99 |
+
elif 'category' in sample_df.columns:
|
| 100 |
+
sample_labels = sample_df['category'].values
|
| 101 |
+
else:
|
| 102 |
+
sample_labels = np.zeros(len(sample_df))
|
| 103 |
|
| 104 |
print(f"{'Model':<40} | {'Time (s)':<10} | {'Silhouette':<10}")
|
| 105 |
print("-" * 65)
|
project_plots/embedding_benchmark.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
project_plots/regression_comparison.png
CHANGED
|
Git LFS Details
|
|
Git LFS Details
|
viral_model.pkl
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fede1adc51df6bc4148b7fdd758625ea3bfd17bbea07f435b0416acfef28c9e9
|
| 3 |
+
size 337752
|