Mohaddz commited on
Commit
18a59a4
Β·
verified Β·
1 Parent(s): 12f0a4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +365 -125
app.py CHANGED
@@ -3,183 +3,423 @@ import pandas as pd
3
  import torch
4
  from sentence_transformers import SentenceTransformer, util
5
  import numpy as np
6
- from typing import Dict, List, Tuple
7
  import io
8
  import plotly.express as px
 
 
 
9
  import traceback
10
- import spaces
11
- import os
12
-
13
- # THE FIX IS HERE: Use a relative path for the cache directory.
14
- # This creates a writable 'persistent_cache' folder inside the app's own directory.
15
- os.environ['SENTENCE_TRANSFORMERS_HOME'] = './persistent_cache'
16
 
17
  class MultiClientThemeClassifier:
18
  def __init__(self):
19
- self.model: SentenceTransformer | None = None
20
- self.model_name: str | None = None
21
- self.client_themes = {}
22
-
23
- def _ensure_model_loaded(self):
24
- """
25
- Checks if the model is loaded in the current process.
26
- If not, it reloads it using the saved model_name.
27
- """
28
- if self.model is None:
29
- if self.model_name is None:
30
- raise ValueError("Model name not set. Please go to the 'Setup & Model' tab and load a model first.")
31
-
32
- print(f"Model not found in current process. Reloading '{self.model_name}' from cache...")
33
- try:
34
- self.model = SentenceTransformer(self.model_name)
35
- print(f"Model '{self.model_name}' reloaded successfully.")
36
- except Exception as e:
37
- print(f"FATAL: Failed to reload model '{self.model_name}': {e}")
38
- raise e
39
-
40
- @spaces.GPU
41
- def load_model(self, model_name: str):
42
- """Loads the model and saves its name to the state."""
43
  try:
 
 
 
 
 
 
44
  print(f"Loading model: {model_name}")
45
- self.model = SentenceTransformer(model_name)
46
- self.model_name = model_name
47
- self.client_themes = {}
48
  return f"βœ… Model '{model_name}' loaded successfully!"
49
  except Exception as e:
50
- self.model = None
51
- self.model_name = None
52
- return f"❌ Error loading model '{model_name}': {traceback.format_exc()}"
53
-
54
- @spaces.GPU
55
- def add_client_themes(self, client_id: str, themes: List[str]):
56
- """Adds themes for a client, ensuring the model is loaded first."""
 
 
57
  try:
58
- self._ensure_model_loaded()
59
  self.client_themes[client_id] = {}
60
- prototypes = self.model.encode(themes, convert_to_tensor=True)
61
- for theme, prototype in zip(themes, prototypes):
 
 
 
 
 
 
 
 
 
62
  self.client_themes[client_id][theme] = prototype
 
63
  return f"βœ… Added {len(themes)} themes for client '{client_id}'"
64
  except Exception as e:
65
  return f"❌ Error adding themes: {str(e)}"
66
-
67
- @spaces.GPU
68
- def benchmark_csv(self, csv_content: str, client_id: str) -> Tuple[str, pd.DataFrame | None, str | None]:
69
- """Benchmarks a CSV, ensuring the model is loaded first."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  try:
71
- self._ensure_model_loaded()
72
- print("Model confirmed loaded in benchmark process. Starting benchmark...")
73
 
 
74
  df = pd.read_csv(io.StringIO(csv_content))
 
 
 
 
75
  if 'text' not in df.columns or 'real_tag' not in df.columns:
76
- return "❌ CSV must have 'text' and 'real_tag' columns!", None, ""
77
-
78
- df.dropna(subset=['text', 'real_tag'], inplace=True)
79
 
 
 
 
 
 
 
 
80
  unique_themes = df['real_tag'].unique().tolist()
81
- self.add_client_themes(client_id, unique_themes)
82
 
83
- texts_to_classify = df['text'].astype(str).tolist()
84
- text_embeddings = self.model.encode(texts_to_classify, convert_to_tensor=True, show_progress_bar=True)
 
85
 
86
- themes = list(self.client_themes[client_id].keys())
87
- prototypes = torch.stack(list(self.client_themes[client_id].values()))
 
88
 
89
- similarities_matrix = util.cos_sim(text_embeddings, prototypes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- best_scores, best_indices = torch.max(similarities_matrix, dim=1)
92
 
93
- df['predicted_tag'] = [themes[i] for i in best_indices]
94
- df['confidence'] = best_scores.tolist()
95
-
96
- correct = (df['real_tag'] == df['predicted_tag']).sum()
97
- total = len(df)
 
 
98
  accuracy = correct / total if total > 0 else 0
 
 
 
 
 
 
99
 
100
- results_summary = f"πŸ“Š **Benchmarking Results**\n\n- **Accuracy: {accuracy:.2%}** ({correct} / {total} correct)"
101
- fig = px.bar(df['real_tag'].value_counts(), title="Theme Distribution in Dataset")
102
- visualization_html = fig.to_html()
103
-
104
- return results_summary, df, visualization_html
105
 
106
- except ValueError as e:
107
- return f"❌ {str(e)}", None, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
- return f"❌ Error during benchmarking: {traceback.format_exc()}", None, ""
110
-
 
111
 
112
- # --- Interface Functions ---
 
113
 
114
- def load_model_interface(classifier, model_name: str):
115
  if not model_name.strip():
116
- # Fallback to a default model if input is empty
117
- model_name = 'sentence-transformers/all-MiniLM-L6-v2'
118
- status = classifier.load_model(model_name.strip())
119
- return status, classifier
120
 
121
- def add_themes_interface(classifier, client_id: str, themes_text: str):
122
- if not client_id.strip() or not themes_text.strip():
123
- return "❌ Client ID and Themes cannot be empty.", classifier
 
124
  themes = [theme.strip() for theme in themes_text.split('\n') if theme.strip()]
125
- status = classifier.add_client_themes(client_id, themes)
126
- return status, classifier
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- def benchmark_interface(classifier, csv_file, client_id: str):
129
  if csv_file is None:
130
- return "Please upload a CSV file!", None, "", classifier
131
- if not client_id.strip():
132
- return "❌ Please enter a Client ID for the benchmark.", None, "", classifier
133
  try:
134
- with open(csv_file.name, 'r', encoding='utf-8') as f:
135
- csv_content = f.read()
136
- results, df, viz = classifier.benchmark_csv(csv_content, client_id)
137
- return results, df, viz, classifier
 
 
 
 
 
 
138
  except Exception as e:
139
- return f"❌ Error processing CSV: {traceback.format_exc()}", None, "", classifier
 
140
 
141
- # --- Gradio UI ---
142
  with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft()) as demo:
143
- classifier_state = gr.State(MultiClientThemeClassifier())
 
144
 
145
- gr.Markdown("# 🎯 Custom Themes Classification - MVP")
 
 
 
 
 
 
 
146
 
147
  with gr.Tab("πŸš€ Setup & Model"):
148
- model_input = gr.Textbox(label="HuggingFace Model Name", value="Qwen/Qwen3-Embedding-0.6B")
149
- load_btn = gr.Button("Load Model", variant="primary")
 
 
 
 
 
 
 
 
 
150
  load_status = gr.Textbox(label="Status", interactive=False)
151
 
152
- gr.Markdown("---")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
- client_input = gr.Textbox(label="Client ID", placeholder="e.g., client_1")
155
- themes_input = gr.Textbox(label="Themes (one per line)", lines=5)
156
  add_themes_btn = gr.Button("Add Themes", variant="secondary")
157
  themes_status = gr.Textbox(label="Status", interactive=False)
158
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  with gr.Tab("πŸ“Š CSV Benchmarking"):
160
- csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"])
161
- benchmark_client = gr.Textbox(label="Client ID for Benchmark", placeholder="e.g., benchmark_client")
162
- benchmark_btn = gr.Button("Run Benchmark", variant="primary")
163
- benchmark_results = gr.Markdown(label="Benchmark Results")
164
- results_dataframe = gr.Dataframe(label="Detailed Results", interactive=False, wrap=True)
165
- visualization = gr.HTML(label="Visualization")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
- # --- Event Handlers ---
168
- load_btn.click(
169
- load_model_interface,
170
- inputs=[classifier_state, model_input],
171
- outputs=[load_status, classifier_state]
172
- )
173
- add_themes_btn.click(
174
- add_themes_interface,
175
- inputs=[classifier_state, client_input, themes_input],
176
- outputs=[themes_status, classifier_state]
177
- )
178
- benchmark_btn.click(
179
- benchmark_interface,
180
- inputs=[classifier_state, csv_upload, benchmark_client],
181
- outputs=[benchmark_results, results_dataframe, visualization, classifier_state]
182
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
 
184
  if __name__ == "__main__":
185
  demo.launch(share=True)
 
3
  import torch
4
  from sentence_transformers import SentenceTransformer, util
5
  import numpy as np
6
+ from typing import Dict, List, Tuple, Optional
7
  import io
8
  import plotly.express as px
9
+ import plotly.graph_objects as go
10
+ from collections import defaultdict
11
+ import json
12
  import traceback
 
 
 
 
 
 
13
 
14
  class MultiClientThemeClassifier:
15
  def __init__(self):
16
+ self.model = None
17
+ self.client_themes = {} # {client_id: {theme: prototype_embedding}}
18
+ self.model_loaded = False
19
+
20
+ def load_model(self, model_name: str = 'Qwen/Qwen3-Embedding-0.6B'):
21
+ """Load the embedding model"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  try:
23
+ if self.model_loaded:
24
+ # If switching models, reset everything
25
+ self.model = None
26
+ self.client_themes = {}
27
+ self.model_loaded = False
28
+
29
  print(f"Loading model: {model_name}")
30
+ self.model = SentenceTransformer(model_name,trust_remote_code=True)
31
+ self.model_loaded = True
 
32
  return f"βœ… Model '{model_name}' loaded successfully!"
33
  except Exception as e:
34
+ self.model_loaded = False
35
+ error_details = traceback.format_exc()
36
+ return f"❌ Error loading model '{model_name}': {str(e)}\n\nDetails:\n{error_details}"
37
+
38
+ def add_client_themes(self, client_id: str, themes: List[str], examples_per_theme: Dict[str, List[str]] = None):
39
+ """Add themes for a specific client"""
40
+ if not self.model_loaded:
41
+ return "❌ Please load the model first!"
42
+
43
  try:
 
44
  self.client_themes[client_id] = {}
45
+
46
+ for theme in themes:
47
+ if examples_per_theme and theme in examples_per_theme:
48
+ # Use provided examples to create prototype
49
+ examples = examples_per_theme[theme]
50
+ embeddings = self.model.encode(examples, convert_to_tensor=True)
51
+ prototype = torch.mean(embeddings, dim=0)
52
+ else:
53
+ # Use theme name itself as prototype (fallback)
54
+ prototype = self.model.encode(theme, convert_to_tensor=True)
55
+
56
  self.client_themes[client_id][theme] = prototype
57
+
58
  return f"βœ… Added {len(themes)} themes for client '{client_id}'"
59
  except Exception as e:
60
  return f"❌ Error adding themes: {str(e)}"
61
+
62
+ def classify_text(self, text: str, client_id: str, confidence_threshold: float = 0.3) -> Tuple[str, float, Dict[str, float]]:
63
+ """Classify a single text for a specific client"""
64
+ if not self.model_loaded:
65
+ return "Model not loaded", 0.0, {}
66
+
67
+ if client_id not in self.client_themes:
68
+ return "Client not found", 0.0, {}
69
+
70
+ try:
71
+ # Encode input text
72
+ text_embedding = self.model.encode(text, convert_to_tensor=True)
73
+
74
+ # Calculate similarities with all themes
75
+ similarities = {}
76
+ for theme, prototype in self.client_themes[client_id].items():
77
+ similarity = util.cos_sim(text_embedding, prototype).item()
78
+ similarities[theme] = similarity
79
+
80
+ # Get best match
81
+ best_theme = max(similarities, key=similarities.get)
82
+ best_score = similarities[best_theme]
83
+
84
+ # Apply confidence threshold
85
+ if best_score < confidence_threshold:
86
+ return "UNKNOWN_THEME", best_score, similarities
87
+
88
+ return best_theme, best_score, similarities
89
+ except Exception as e:
90
+ return f"Error: {str(e)}", 0.0, {}
91
+
92
+ def benchmark_csv(self, csv_content: str, client_id: str) -> Tuple[str, str, str]:
93
+ """Benchmark the model on a CSV file"""
94
+ if not self.model_loaded:
95
+ return "❌ Model not loaded!", "", ""
96
+
97
  try:
98
+ print("Starting CSV benchmark...")
 
99
 
100
+ # Read CSV
101
  df = pd.read_csv(io.StringIO(csv_content))
102
+ print(f"CSV loaded with shape: {df.shape}")
103
+ print(f"CSV columns: {df.columns.tolist()}")
104
+
105
+ # Validate CSV format
106
  if 'text' not in df.columns or 'real_tag' not in df.columns:
107
+ return "❌ CSV must have 'text' and 'real_tag' columns!", "", ""
 
 
108
 
109
+ # Clean data
110
+ df = df.dropna(subset=['text', 'real_tag'])
111
+ df['text'] = df['text'].astype(str)
112
+ df['real_tag'] = df['real_tag'].astype(str)
113
+ print(f"After cleaning: {df.shape}")
114
+
115
+ # Get unique themes from CSV
116
  unique_themes = df['real_tag'].unique().tolist()
117
+ print(f"Unique themes found: {len(unique_themes)} - {unique_themes}")
118
 
119
+ # Add themes for this client (using theme names as prototypes for demo)
120
+ theme_add_result = self.add_client_themes(client_id, unique_themes)
121
+ print(f"Theme addition result: {theme_add_result}")
122
 
123
+ # Classify all texts with progress
124
+ predictions = []
125
+ confidences = []
126
 
127
+ print("Starting classification...")
128
+ for idx, row in df.iterrows():
129
+ try:
130
+ text = str(row['text'])[:500] # Limit text length
131
+ pred_theme, confidence, _ = self.classify_text(text, client_id)
132
+ predictions.append(pred_theme)
133
+ confidences.append(confidence)
134
+
135
+ if idx % 10 == 0: # Progress logging
136
+ print(f"Processed {idx + 1}/{len(df)} samples")
137
+
138
+ except Exception as e:
139
+ print(f"Error classifying row {idx}: {str(e)}")
140
+ predictions.append("ERROR")
141
+ confidences.append(0.0)
142
 
143
+ print("Classification complete!")
144
 
145
+ df['predicted_tag'] = predictions
146
+ df['confidence'] = confidences
147
+
148
+ # Calculate metrics
149
+ valid_predictions = df[df['predicted_tag'] != 'ERROR']
150
+ correct = (valid_predictions['real_tag'] == valid_predictions['predicted_tag']).sum()
151
+ total = len(valid_predictions)
152
  accuracy = correct / total if total > 0 else 0
153
+
154
+ print(f"Metrics calculated: {correct}/{total} = {accuracy:.2%}")
155
+
156
+ # Generate results summary
157
+ results_summary = f"""
158
+ πŸ“Š **Benchmarking Results**
159
 
160
+ **Overall Metrics:**
161
+ - Total samples: {total}
162
+ - Correct predictions: {correct}
163
+ - **Accuracy: {accuracy:.2%}**
164
+ - Average confidence: {np.mean([c for c in confidences if c > 0]):.3f}
165
 
166
+ **Per-Theme Breakdown:**
167
+ """
168
+
169
+ for theme in unique_themes:
170
+ theme_df = valid_predictions[valid_predictions['real_tag'] == theme]
171
+ if len(theme_df) > 0:
172
+ theme_correct = (theme_df['real_tag'] == theme_df['predicted_tag']).sum()
173
+ theme_total = len(theme_df)
174
+ theme_acc = theme_correct / theme_total if theme_total > 0 else 0
175
+ avg_conf = theme_df['confidence'].mean()
176
+
177
+ results_summary += f"- **{theme}**: {theme_acc:.2%} ({theme_correct}/{theme_total}) - Avg conf: {avg_conf:.3f}\n"
178
+
179
+ # Create simple visualization
180
+ try:
181
+ theme_counts = [len(df[df['real_tag'] == theme]) for theme in unique_themes]
182
+ fig = px.bar(
183
+ x=unique_themes,
184
+ y=theme_counts,
185
+ title="Theme Distribution in Dataset",
186
+ labels={'x': 'Themes', 'y': 'Count'}
187
+ )
188
+ visualization_html = fig.to_html()
189
+ except Exception as viz_error:
190
+ print(f"Visualization error: {viz_error}")
191
+ visualization_html = "<p>Visualization error occurred</p>"
192
+
193
+ # Save CSV to a temporary file for download
194
+ import tempfile
195
+ import os
196
+
197
+ temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8')
198
+ df.to_csv(temp_file.name, index=False)
199
+ temp_file.close()
200
+
201
+ return results_summary, temp_file.name, visualization_html
202
+
203
  except Exception as e:
204
+ error_details = traceback.format_exc()
205
+ print(f"Full error: {error_details}")
206
+ return f"❌ Error during benchmarking: {str(e)}\n\nFull traceback:\n{error_details}", "", ""
207
 
208
+ # Initialize the classifier
209
+ classifier = MultiClientThemeClassifier()
210
 
211
+ def load_model_interface(model_name: str):
212
  if not model_name.strip():
213
+ model_name = 'Qwen/Qwen3-Embedding-0.6B' # Default
214
+ return classifier.load_model(model_name.strip())
 
 
215
 
216
+ def add_themes_interface(client_id: str, themes_text: str):
217
+ if not themes_text.strip():
218
+ return "❌ Please enter themes!"
219
+
220
  themes = [theme.strip() for theme in themes_text.split('\n') if theme.strip()]
221
+ return classifier.add_client_themes(client_id, themes)
222
+
223
+ def classify_interface(text: str, client_id: str, confidence_threshold: float):
224
+ if not text.strip():
225
+ return "Please enter text to classify!", ""
226
+
227
+ pred_theme, confidence, similarities = classifier.classify_text(text, client_id, confidence_threshold)
228
+
229
+ # Format similarities for display
230
+ sim_display = "**Similarity Scores:**\n"
231
+ sorted_sims = sorted(similarities.items(), key=lambda x: x[1], reverse=True)
232
+ for theme, sim in sorted_sims:
233
+ sim_display += f"- {theme}: {sim:.3f}\n"
234
+
235
+ result = f"""
236
+ 🎯 **Predicted Theme:** {pred_theme}
237
+ πŸ”₯ **Confidence:** {confidence:.3f}
238
+
239
+ {sim_display}
240
+ """
241
+
242
+ return result, ""
243
 
244
+ def benchmark_interface(csv_file, client_id: str):
245
  if csv_file is None:
246
+ return "Please upload a CSV file!", "", ""
247
+
 
248
  try:
249
+ # Handle both file objects and file paths
250
+ if hasattr(csv_file, 'read'):
251
+ csv_content = csv_file.read().decode('utf-8')
252
+ elif hasattr(csv_file, 'name'):
253
+ with open(csv_file.name, 'r', encoding='utf-8') as f:
254
+ csv_content = f.read()
255
+ else:
256
+ csv_content = str(csv_file)
257
+
258
+ return classifier.benchmark_csv(csv_content, client_id)
259
  except Exception as e:
260
+ error_details = traceback.format_exc()
261
+ return f"❌ Error reading CSV: {str(e)}\n\nDetails:\n{error_details}", "", ""
262
 
263
+ # Create the Gradio interface
264
  with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft()) as demo:
265
+ gr.Markdown("""
266
+ # 🎯 Custom Themes Classification - MVP
267
 
268
+ **A scalable, cost-effective solution for multi-client theme classification**
269
+
270
+ This demo showcases an embedding-based approach that can:
271
+ - βœ… Handle multiple clients with different themes
272
+ - βœ… Distinguish between similar themes (e.g., "Real Estate Financing" vs "Personal Financing")
273
+ - βœ… Process ~1M posts/day at low cost (~$500/month vs $30k/month for pure LLM)
274
+ - βœ… Provide confidence scores and similarity breakdowns
275
+ """)
276
 
277
  with gr.Tab("πŸš€ Setup & Model"):
278
+ gr.Markdown("### Step 1: Load the Embedding Model")
279
+
280
+ with gr.Row():
281
+ model_input = gr.Textbox(
282
+ label="HuggingFace Model Name",
283
+ value="Qwen/Qwen3-Embedding-0.6B",
284
+ placeholder="e.g., sentence-transformers/all-MiniLM-L6-v2",
285
+ info="Enter any SentenceTransformer-compatible model from HuggingFace"
286
+ )
287
+ load_btn = gr.Button("Load Model", variant="primary")
288
+
289
  load_status = gr.Textbox(label="Status", interactive=False)
290
 
291
+ gr.Markdown("""
292
+ **Popular Models:**
293
+ - `Qwen/Qwen3-Embedding-0.6B` - High quality, multilingual
294
+ - `sentence-transformers/all-MiniLM-L6-v2` - Fast, lightweight
295
+ - `sentence-transformers/all-mpnet-base-v2` - High accuracy
296
+ - `sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2` - Multilingual
297
+ - `intfloat/multilingual-e5-base` - Strong multilingual performance
298
+ """)
299
+
300
+ load_btn.click(load_model_interface, inputs=[model_input], outputs=load_status)
301
+
302
+ gr.Markdown("### Step 2: Add Themes for a Client")
303
+ with gr.Row():
304
+ client_input = gr.Textbox(label="Client ID", placeholder="e.g., client_1")
305
+ themes_input = gr.Textbox(
306
+ label="Themes (one per line)",
307
+ lines=5,
308
+ placeholder="e.g.:\nReal Estate Financing\nPersonal Financing\nPrivate Education\nSports"
309
+ )
310
 
 
 
311
  add_themes_btn = gr.Button("Add Themes", variant="secondary")
312
  themes_status = gr.Textbox(label="Status", interactive=False)
313
+
314
+ add_themes_btn.click(
315
+ add_themes_interface,
316
+ inputs=[client_input, themes_input],
317
+ outputs=themes_status
318
+ )
319
+
320
+ with gr.Tab("πŸ” Single Text Classification"):
321
+ gr.Markdown("### Classify Individual Posts")
322
+
323
+ with gr.Row():
324
+ with gr.Column():
325
+ text_input = gr.Textbox(
326
+ label="Text to Classify",
327
+ lines=3,
328
+ placeholder="Enter text to classify..."
329
+ )
330
+ client_select = gr.Textbox(
331
+ label="Client ID",
332
+ placeholder="e.g., client_1"
333
+ )
334
+ confidence_slider = gr.Slider(
335
+ minimum=0.0,
336
+ maximum=1.0,
337
+ value=0.3,
338
+ step=0.1,
339
+ label="Confidence Threshold"
340
+ )
341
+ classify_btn = gr.Button("Classify", variant="primary")
342
+
343
+ with gr.Column():
344
+ classification_result = gr.Markdown(label="Results")
345
+
346
+ classify_btn.click(
347
+ classify_interface,
348
+ inputs=[text_input, client_select, confidence_slider],
349
+ outputs=[classification_result, gr.Textbox(visible=False)]
350
+ )
351
+
352
  with gr.Tab("πŸ“Š CSV Benchmarking"):
353
+ gr.Markdown("""
354
+ ### Benchmark on Your Dataset
355
+
356
+ Upload a CSV file with columns:
357
+ - `text`: The posts/content to classify
358
+ - `real_tag`: The correct theme labels
359
+
360
+ The system will automatically extract unique themes and evaluate performance.
361
+ """)
362
+
363
+ with gr.Row():
364
+ with gr.Column():
365
+ csv_upload = gr.File(
366
+ label="Upload CSV File",
367
+ file_types=[".csv"]
368
+ )
369
+ benchmark_client = gr.Textbox(
370
+ label="Client ID for Benchmark",
371
+ placeholder="e.g., benchmark_client"
372
+ )
373
+ benchmark_btn = gr.Button("Run Benchmark", variant="primary")
374
+
375
+ with gr.Column():
376
+ benchmark_results = gr.Markdown(label="Benchmark Results")
377
+
378
+ with gr.Row():
379
+ results_csv = gr.File(label="Download Detailed Results", interactive=False)
380
+ visualization = gr.HTML(label="Visualization")
381
+
382
+ benchmark_btn.click(
383
+ benchmark_interface,
384
+ inputs=[csv_upload, benchmark_client],
385
+ outputs=[benchmark_results, results_csv, visualization]
386
+ )
387
 
388
+ with gr.Tab("πŸ“‹ About & Usage"):
389
+ gr.Markdown("""
390
+ ## 🎯 Solution Overview
391
+
392
+ This MVP demonstrates a **hybrid embedding-based approach** for Custom Themes classification:
393
+
394
+ ### βœ… Key Advantages:
395
+ 1. **Cost Effective**: ~$500/month vs $30,000/month for pure LLM approach
396
+ 2. **Fast**: Can handle 1M+ posts/day with sub-second response times
397
+ 3. **Multi-Client**: Each client can have completely different themes
398
+ 4. **Disambiguates Similar Themes**: Uses semantic embeddings to distinguish between similar concepts
399
+ 5. **Confidence Scoring**: Provides transparency in predictions
400
+
401
+ ### πŸ—οΈ Architecture:
402
+ 1. **Embedding Model**: Customizable SentenceTransformer models from HuggingFace
403
+ 2. **Theme Prototypes**: Each client's themes represented as embedding vectors
404
+ 3. **Similarity Matching**: Cosine similarity for classification
405
+ 4. **Confidence Thresholding**: Flags uncertain predictions
406
+
407
+ ### πŸ“ˆ Scaling Strategy:
408
+ - **Batch Processing**: Process thousands of posts simultaneously
409
+ - **GPU Optimization**: Single GPU can handle 1M posts/day
410
+ - **Caching**: Store client prototypes in memory/Redis
411
+ - **Hybrid Fallback**: LLM backup for ambiguous cases (5-10% of posts)
412
+
413
+ ### πŸ”§ Usage Instructions:
414
+ 1. **Setup Tab**: Load model and define client themes
415
+ 2. **Single Classification**: Test individual posts
416
+ 3. **CSV Benchmark**: Evaluate on your datasets
417
+
418
+ ---
419
+
420
+ **Scalable Theme Classification MVP**
421
+ """)
422
 
423
+ # Launch the app
424
  if __name__ == "__main__":
425
  demo.launch(share=True)