Matan Kriel commited on
Commit
2f9170f
·
1 Parent(s): 7803d6a

updated clustering metric in model test

Browse files
README.md CHANGED
@@ -21,9 +21,9 @@ This project consists of two main components: a training pipeline (`model-prep.p
21
  the `model-prep.py` script handles the end-to-end model creation process:
22
 
23
  1. **Cloud Data Loading**: It fetches the latest synthetic dataset directly from **Hugging Face** (`MatanKriel/social-assitent-synthetic-data`).
24
- 2. **Embedding Benchmark**: It evaluates 3 state-of-the-art models (`MiniLM`, `mpnet-base`, `bge-small`) to find the best text encoder.
25
- * *Metrics*: Encoding Speed vs. Clustering Quality (Silhouette Score).
26
- * *Winner*: Defaults to `sentence-transformers/all-mpnet-base-v2`.
27
  3. **Feature Engineering**:
28
  * Encodes categorical inputs: `category`, `gender`, `day_of_week`, `age`.
29
  * Combines text embeddings with metadata (`followers`, `duration`, `hour`).
 
21
  the `model-prep.py` script handles the end-to-end model creation process:
22
 
23
  1. **Cloud Data Loading**: It fetches the latest synthetic dataset directly from **Hugging Face** (`MatanKriel/social-assitent-synthetic-data`).
24
+ 2. **Embedding Benchmark**: It evaluates 3 state-of-the-art models (`MiniLM`, `mpnet-base`, `bge-small`) using **Silhouette Score** on **Composite Labels** (`Category_ViralClass`).
25
+ * *Why?* Instead of just clustering by topic (e.g., "Gaming"), this forces the model to distinguish between "Viral Gaming Videos" and "Average Gaming Videos".
26
+ * *Selection*: Automatically picks the best model for this high-resolution task.
27
  3. **Feature Engineering**:
28
  * Encodes categorical inputs: `category`, `gender`, `day_of_week`, `age`.
29
  * Combines text embeddings with metadata (`followers`, `duration`, `hour`).
app.py CHANGED
@@ -70,7 +70,7 @@ def initialize_app():
70
 
71
  # 4. Load SentenceTransformer
72
  print("🔌 Loading SentenceTransformer...")
73
- embedding_model_name = "sentence-transformers/all-mpnet-base-v2"
74
  print(f" -> Model: {embedding_model_name}")
75
 
76
  import torch
 
70
 
71
  # 4. Load SentenceTransformer
72
  print("🔌 Loading SentenceTransformer...")
73
+ embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
74
  print(f" -> Model: {embedding_model_name}")
75
 
76
  import torch
model-prep.py CHANGED
@@ -68,16 +68,38 @@ def benchmark_and_select_model(df):
68
 
69
  results = []
70
 
71
- if 'category' not in df.columns:
72
- print("⚠️ No 'category' column. Skipping quality metric.")
73
- labels = np.zeros(len(df))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  else:
75
- labels = df['category'].values
76
 
77
- # Sample for speed
78
  sample_df = df.sample(min(len(df), 3000), random_state=42)
79
  sample_texts = sample_df['description'].fillna("").tolist()
80
- sample_labels = sample_df['category'].values
 
 
 
 
 
 
 
 
81
 
82
  print(f"{'Model':<40} | {'Time (s)':<10} | {'Silhouette':<10}")
83
  print("-" * 65)
 
68
 
69
  results = []
70
 
71
+ # Create Composite Labels for Silhouette Score
72
+ # Goal: Use "Category_ViralClass" (e.g., "Fitness_High") to measure separation
73
+
74
+ # 1. Ensure viral_class exists for benchmarking
75
+ if 'viral_class' not in df.columns and 'views' in df.columns:
76
+ threshold = df['views'].quantile(0.75)
77
+ df['viral_class'] = np.where(df['views'] > threshold, 'High', 'Low')
78
+ print(f" -> ℹ️ Created temporary 'viral_class' (High/Low) for benchmarking.")
79
+
80
+ # 2. Define Labels
81
+ if 'category' in df.columns and 'viral_class' in df.columns:
82
+ print(" -> 🏷️ Using Composite Labels (Category + Viral Class) for metrics.")
83
+ # We need to perform this on the SAMPLE, not the whole DF if we sample later.
84
+ # But to be safe, let's just use the column if it exists.
85
+ pass # Logic handled after sampling
86
+ elif 'category' in df.columns:
87
+ print(" -> ⚠️ 'viral_class' missing. Falling back to 'category' only.")
88
  else:
89
+ print(" -> ⚠️ No categories found. Skipping quality metric.")
90
 
91
+ # Sample for speed (using the updated df which might have viral_class)
92
  sample_df = df.sample(min(len(df), 3000), random_state=42)
93
  sample_texts = sample_df['description'].fillna("").tolist()
94
+
95
+ if 'category' in sample_df.columns and 'viral_class' in sample_df.columns:
96
+ # Composite Label Formula
97
+ sample_labels = sample_df['category'].astype(str) + "_" + sample_df['viral_class'].astype(str)
98
+ sample_labels = sample_labels.values
99
+ elif 'category' in sample_df.columns:
100
+ sample_labels = sample_df['category'].values
101
+ else:
102
+ sample_labels = np.zeros(len(sample_df))
103
 
104
  print(f"{'Model':<40} | {'Time (s)':<10} | {'Silhouette':<10}")
105
  print("-" * 65)
project_plots/embedding_benchmark.png CHANGED

Git LFS Details

  • SHA256: da6fc1de241e564e73cc87eb57cd8b71d5a84c829e2975e476f937ac28f78b06
  • Pointer size: 130 Bytes
  • Size of remote file: 26.7 kB

Git LFS Details

  • SHA256: 5da26c856671fcedc44d92d1de23252ea6c37ad159a0db9d38384cb795d0a027
  • Pointer size: 130 Bytes
  • Size of remote file: 29.8 kB
project_plots/regression_comparison.png CHANGED

Git LFS Details

  • SHA256: d5c0e245d4cdcbc3124c2ab43c89c956099acceacaae041d889c7b2b5627d01a
  • Pointer size: 130 Bytes
  • Size of remote file: 27.4 kB

Git LFS Details

  • SHA256: 4fcee525946abe3cec8bec7416efd6d0345fdd980f635b945c27627a362f5d02
  • Pointer size: 130 Bytes
  • Size of remote file: 28.3 kB
viral_model.pkl CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4dc2e1d1bfe66f8970d7f2770e10f70ad426f78d7493ac4d0250f18ef878e9d7
3
- size 327349
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fede1adc51df6bc4148b7fdd758625ea3bfd17bbea07f435b0416acfef28c9e9
3
+ size 337752