Mohaddz commited on
Commit
7b53477
Β·
verified Β·
1 Parent(s): 92204cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -127
app.py CHANGED
@@ -11,6 +11,7 @@ from collections import defaultdict
11
  import json
12
  import traceback
13
  import spaces # Import the spaces library
 
14
 
15
  class MultiClientThemeClassifier:
16
  def __init__(self):
@@ -25,7 +26,8 @@ class MultiClientThemeClassifier:
25
  model_name = self.default_model
26
 
27
  try:
28
- if self.model_loaded and self.model.name_or_path == model_name:
 
29
  return f"βœ… Model '{model_name}' is already loaded."
30
 
31
  self.model = None
@@ -52,10 +54,8 @@ class MultiClientThemeClassifier:
52
 
53
  def add_client_themes(self, client_id: str, themes: List[str], examples_per_theme: Dict[str, List[str]] = None):
54
  """Add themes for a specific client"""
55
- # Automatically load model if needed
56
  error_status = self._ensure_model_is_loaded()
57
- if error_status:
58
- return error_status
59
 
60
  try:
61
  self.client_themes[client_id] = {}
@@ -68,10 +68,8 @@ class MultiClientThemeClassifier:
68
 
69
  def classify_text(self, text: str, client_id: str, confidence_threshold: float = 0.3) -> Tuple[str, float, Dict[str, float]]:
70
  """Classify a single text for a specific client"""
71
- # Automatically load model if needed
72
  error_status = self._ensure_model_is_loaded()
73
- if error_status:
74
- return f"Error: {error_status}", 0.0, {}
75
 
76
  if client_id not in self.client_themes:
77
  return "Client not found", 0.0, {}
@@ -81,8 +79,7 @@ class MultiClientThemeClassifier:
81
  similarities = {theme: util.cos_sim(text_embedding, prototype).item()
82
  for theme, prototype in self.client_themes[client_id].items()}
83
 
84
- if not similarities:
85
- return "No themes for client", 0.0, {}
86
 
87
  best_theme = max(similarities, key=similarities.get)
88
  best_score = similarities[best_theme]
@@ -96,10 +93,8 @@ class MultiClientThemeClassifier:
96
 
97
  def benchmark_csv(self, csv_content: str, client_id: str) -> Tuple[str, Optional[str], Optional[str]]:
98
  """Benchmark the model on a CSV file"""
99
- # Automatically load model if needed
100
  error_status = self._ensure_model_is_loaded()
101
- if error_status:
102
- return f"❌ Model could not be loaded: {error_status}", None, None
103
 
104
  try:
105
  df = pd.read_csv(io.StringIO(csv_content))
@@ -124,10 +119,15 @@ class MultiClientThemeClassifier:
124
 
125
  results_summary = f"πŸ“Š **Benchmarking Results**\n\n**Accuracy: {accuracy:.2%}** ({correct}/{total})"
126
 
127
- with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8') as temp_file:
128
- df.to_csv(temp_file.name, index=False)
129
- fig = px.bar(df['real_tag'].value_counts(), title="Theme Distribution")
130
- return results_summary, temp_file.name, fig.to_html()
 
 
 
 
 
131
 
132
  except Exception as e:
133
  error_details = traceback.format_exc()
@@ -142,15 +142,13 @@ def load_model_interface(model_name: str):
142
 
143
  @spaces.GPU
144
  def add_themes_interface(client_id: str, themes_text: str):
145
- if not themes_text.strip():
146
- return "❌ Please enter themes!"
147
  themes = [theme.strip() for theme in themes_text.split('\n') if theme.strip()]
148
  return classifier.add_client_themes(client_id, themes)
149
 
150
  @spaces.GPU
151
  def classify_interface(text: str, client_id: str, confidence_threshold: float):
152
- if not text.strip():
153
- return "Please enter text to classify!", ""
154
 
155
  pred_theme, confidence, similarities = classifier.classify_text(text, client_id, confidence_threshold)
156
 
@@ -164,150 +162,66 @@ def benchmark_interface(csv_file, client_id: str):
164
  if csv_file is None:
165
  return "Please upload a CSV file!", None, None
166
  try:
167
- csv_content = csv_file.read().decode('utf-8')
 
 
 
 
 
 
 
168
  return classifier.benchmark_csv(csv_content, client_id)
169
  except Exception as e:
170
- return f"❌ Error reading CSV: {str(e)}", None, None
 
171
 
172
  # --- Gradio Interface (No Changes Below) ---
173
- # Create the Gradio interface
174
  with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft()) as demo:
175
- gr.Markdown("""
176
- # 🎯 Custom Themes Classification - MVP
177
-
178
- **A scalable, cost-effective solution for multi-client theme classification**
179
-
180
- This demo showcases an embedding-based approach that can:
181
- - βœ… Handle multiple clients with different themes
182
- - βœ… Distinguish between similar themes (e.g., "Real Estate Financing" vs "Personal Financing")
183
- - βœ… Process ~1M posts/day at low cost (~$500/month vs $30k/month for pure LLM)
184
- - βœ… Provide confidence scores and similarity breakdowns
185
- """)
186
 
187
  with gr.Tab("πŸš€ Setup & Model"):
188
  gr.Markdown("### Step 1: Load the Embedding Model (Optional)")
189
- gr.Markdown("If you don't load a model, a default one will be loaded automatically on first use.")
190
-
191
  with gr.Row():
192
- model_input = gr.Textbox(
193
- label="HuggingFace Model Name",
194
- value="Qwen/Qwen3-Embedding-0.6B",
195
- placeholder="e.g., sentence-transformers/all-MiniLM-L6-v2",
196
- info="Enter any SentenceTransformer-compatible model from HuggingFace"
197
- )
198
  load_btn = gr.Button("Load Model", variant="primary")
199
-
200
  load_status = gr.Textbox(label="Status", interactive=False)
201
-
202
- gr.Markdown("""
203
- **Popular Models:**
204
- - `Qwen/Qwen3-Embedding-0.6B` - High quality, multilingual
205
- - `sentence-transformers/all-MiniLM-L6-v2` - Fast, lightweight
206
- - `sentence-transformers/all-mpnet-base-v2` - High accuracy
207
- """)
208
-
209
  load_btn.click(load_model_interface, inputs=[model_input], outputs=load_status)
210
 
211
  gr.Markdown("### Step 2: Add Themes for a Client")
212
  with gr.Row():
213
  client_input = gr.Textbox(label="Client ID", placeholder="e.g., client_1")
214
- themes_input = gr.Textbox(
215
- label="Themes (one per line)",
216
- lines=5,
217
- placeholder="e.g.:\nReal Estate Financing\nPersonal Financing\nPrivate Education\nSports"
218
- )
219
-
220
  add_themes_btn = gr.Button("Add Themes", variant="secondary")
221
  themes_status = gr.Textbox(label="Status", interactive=False)
222
-
223
- add_themes_btn.click(
224
- add_themes_interface,
225
- inputs=[client_input, themes_input],
226
- outputs=themes_status
227
- )
228
 
229
  with gr.Tab("πŸ” Single Text Classification"):
230
  gr.Markdown("### Classify Individual Posts")
231
-
232
  with gr.Row():
233
  with gr.Column():
234
- text_input = gr.Textbox(
235
- label="Text to Classify",
236
- lines=3,
237
- placeholder="Enter text to classify..."
238
- )
239
- client_select = gr.Textbox(
240
- label="Client ID",
241
- placeholder="e.g., client_1"
242
- )
243
- confidence_slider = gr.Slider(
244
- minimum=0.0,
245
- maximum=1.0,
246
- value=0.3,
247
- step=0.1,
248
- label="Confidence Threshold"
249
- )
250
  classify_btn = gr.Button("Classify", variant="primary")
251
-
252
  with gr.Column():
253
  classification_result = gr.Markdown(label="Results")
254
-
255
- classify_btn.click(
256
- classify_interface,
257
- inputs=[text_input, client_select, confidence_slider],
258
- outputs=[classification_result, gr.Textbox(visible=False)]
259
- )
260
 
261
  with gr.Tab("πŸ“Š CSV Benchmarking"):
262
- gr.Markdown("""
263
- ### Benchmark on Your Dataset
264
-
265
- Upload a CSV file with columns:
266
- - `text`: The posts/content to classify
267
- - `real_tag`: The correct theme labels
268
-
269
- The system will automatically extract unique themes and evaluate performance.
270
- """)
271
-
272
  with gr.Row():
273
  with gr.Column():
274
- csv_upload = gr.File(
275
- label="Upload CSV File",
276
- file_types=[".csv"]
277
- )
278
- benchmark_client = gr.Textbox(
279
- label="Client ID for Benchmark",
280
- placeholder="e.g., benchmark_client"
281
- )
282
  benchmark_btn = gr.Button("Run Benchmark", variant="primary")
283
-
284
  with gr.Column():
285
  benchmark_results = gr.Markdown(label="Benchmark Results")
286
-
287
  with gr.Row():
288
  results_csv = gr.File(label="Download Detailed Results", interactive=False)
289
  visualization = gr.HTML(label="Visualization")
290
-
291
- benchmark_btn.click(
292
- benchmark_interface,
293
- inputs=[csv_upload, benchmark_client],
294
- outputs=[benchmark_results, results_csv, visualization]
295
- )
296
-
297
- with gr.Tab("πŸ“‹ About & Usage"):
298
- gr.Markdown("""
299
- ## 🎯 Solution Overview
300
-
301
- This MVP demonstrates a **hybrid embedding-based approach** for Custom Themes classification.
302
-
303
- ### πŸ—οΈ Architecture:
304
- 1. **Embedding Model**: Customizable SentenceTransformer models from HuggingFace
305
- 2. **Theme Prototypes**: Each client's themes represented as embedding vectors
306
- 3. **Similarity Matching**: Cosine similarity for classification
307
- 4. **Automatic Loading**: The application will automatically load a default model if one is not present, making it resilient to platform hibernation.
308
- """)
309
 
310
  # Launch the app
311
  if __name__ == "__main__":
312
- import tempfile
313
  demo.launch(share=True)
 
11
  import json
12
  import traceback
13
  import spaces # Import the spaces library
14
+ import tempfile
15
 
16
  class MultiClientThemeClassifier:
17
  def __init__(self):
 
26
  model_name = self.default_model
27
 
28
  try:
29
+ # Avoid reloading the same model
30
+ if self.model_loaded and hasattr(self.model, 'tokenizer') and self.model.tokenizer.name_or_path == model_name:
31
  return f"βœ… Model '{model_name}' is already loaded."
32
 
33
  self.model = None
 
54
 
55
  def add_client_themes(self, client_id: str, themes: List[str], examples_per_theme: Dict[str, List[str]] = None):
56
  """Add themes for a specific client"""
 
57
  error_status = self._ensure_model_is_loaded()
58
+ if error_status: return error_status
 
59
 
60
  try:
61
  self.client_themes[client_id] = {}
 
68
 
69
  def classify_text(self, text: str, client_id: str, confidence_threshold: float = 0.3) -> Tuple[str, float, Dict[str, float]]:
70
  """Classify a single text for a specific client"""
 
71
  error_status = self._ensure_model_is_loaded()
72
+ if error_status: return f"Error: {error_status}", 0.0, {}
 
73
 
74
  if client_id not in self.client_themes:
75
  return "Client not found", 0.0, {}
 
79
  similarities = {theme: util.cos_sim(text_embedding, prototype).item()
80
  for theme, prototype in self.client_themes[client_id].items()}
81
 
82
+ if not similarities: return "No themes for client", 0.0, {}
 
83
 
84
  best_theme = max(similarities, key=similarities.get)
85
  best_score = similarities[best_theme]
 
93
 
94
  def benchmark_csv(self, csv_content: str, client_id: str) -> Tuple[str, Optional[str], Optional[str]]:
95
  """Benchmark the model on a CSV file"""
 
96
  error_status = self._ensure_model_is_loaded()
97
+ if error_status: return f"❌ Model could not be loaded: {error_status}", None, None
 
98
 
99
  try:
100
  df = pd.read_csv(io.StringIO(csv_content))
 
119
 
120
  results_summary = f"πŸ“Š **Benchmarking Results**\n\n**Accuracy: {accuracy:.2%}** ({correct}/{total})"
121
 
122
+ # Create visualization
123
+ fig = px.bar(df['real_tag'].value_counts(), title="Theme Distribution in Dataset", labels={'index': 'Theme', 'value': 'Count'})
124
+ visualization_html = fig.to_html()
125
+
126
+ # Save results to a temporary file for download
127
+ temp_file_path = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8').name
128
+ df.to_csv(temp_file_path, index=False)
129
+
130
+ return results_summary, temp_file_path, visualization_html
131
 
132
  except Exception as e:
133
  error_details = traceback.format_exc()
 
142
 
143
  @spaces.GPU
144
  def add_themes_interface(client_id: str, themes_text: str):
145
+ if not themes_text.strip(): return "❌ Please enter themes!"
 
146
  themes = [theme.strip() for theme in themes_text.split('\n') if theme.strip()]
147
  return classifier.add_client_themes(client_id, themes)
148
 
149
  @spaces.GPU
150
  def classify_interface(text: str, client_id: str, confidence_threshold: float):
151
+ if not text.strip(): return "Please enter text to classify!", ""
 
152
 
153
  pred_theme, confidence, similarities = classifier.classify_text(text, client_id, confidence_threshold)
154
 
 
162
  if csv_file is None:
163
  return "Please upload a CSV file!", None, None
164
  try:
165
+ # CORRECTED: Handle both file-like objects and string/NamedString objects from Gradio
166
+ if hasattr(csv_file, 'read'):
167
+ # It's a file-like object, read and decode it
168
+ csv_content = csv_file.read().decode('utf-8')
169
+ else:
170
+ # It's a string or NamedString, use it directly
171
+ csv_content = csv_file
172
+
173
  return classifier.benchmark_csv(csv_content, client_id)
174
  except Exception as e:
175
+ error_details = traceback.format_exc()
176
+ return f"❌ Error processing CSV file: {str(e)}\n\nDetails:\n{error_details}", None, None
177
 
178
  # --- Gradio Interface (No Changes Below) ---
 
179
  with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft()) as demo:
180
+ gr.Markdown("# 🎯 Custom Themes Classification - MVP")
 
 
 
 
 
 
 
 
 
 
181
 
182
  with gr.Tab("πŸš€ Setup & Model"):
183
  gr.Markdown("### Step 1: Load the Embedding Model (Optional)")
184
+ gr.Markdown("If you don't load a model, a default one (`Qwen/Qwen3-Embedding-0.6B`) will be loaded automatically on first use.")
 
185
  with gr.Row():
186
+ model_input = gr.Textbox(label="HuggingFace Model Name", value="Qwen/Qwen3-Embedding-0.6B")
 
 
 
 
 
187
  load_btn = gr.Button("Load Model", variant="primary")
 
188
  load_status = gr.Textbox(label="Status", interactive=False)
 
 
 
 
 
 
 
 
189
  load_btn.click(load_model_interface, inputs=[model_input], outputs=load_status)
190
 
191
  gr.Markdown("### Step 2: Add Themes for a Client")
192
  with gr.Row():
193
  client_input = gr.Textbox(label="Client ID", placeholder="e.g., client_1")
194
+ themes_input = gr.Textbox(label="Themes (one per line)", lines=5)
 
 
 
 
 
195
  add_themes_btn = gr.Button("Add Themes", variant="secondary")
196
  themes_status = gr.Textbox(label="Status", interactive=False)
197
+ add_themes_btn.click(add_themes_interface, inputs=[client_input, themes_input], outputs=themes_status)
 
 
 
 
 
198
 
199
  with gr.Tab("πŸ” Single Text Classification"):
200
  gr.Markdown("### Classify Individual Posts")
 
201
  with gr.Row():
202
  with gr.Column():
203
+ text_input = gr.Textbox(label="Text to Classify", lines=3)
204
+ client_select = gr.Textbox(label="Client ID", placeholder="e.g., client_1")
205
+ confidence_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.1, label="Confidence Threshold")
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  classify_btn = gr.Button("Classify", variant="primary")
 
207
  with gr.Column():
208
  classification_result = gr.Markdown(label="Results")
209
+ classify_btn.click(classify_interface, inputs=[text_input, client_select, confidence_slider], outputs=[classification_result, gr.Textbox(visible=False)])
 
 
 
 
 
210
 
211
  with gr.Tab("πŸ“Š CSV Benchmarking"):
212
+ gr.Markdown("### Benchmark on Your Dataset\nUpload a CSV with `text` and `real_tag` columns.")
 
 
 
 
 
 
 
 
 
213
  with gr.Row():
214
  with gr.Column():
215
+ csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"])
216
+ benchmark_client = gr.Textbox(label="Client ID for Benchmark", placeholder="e.g., benchmark_client")
 
 
 
 
 
 
217
  benchmark_btn = gr.Button("Run Benchmark", variant="primary")
 
218
  with gr.Column():
219
  benchmark_results = gr.Markdown(label="Benchmark Results")
 
220
  with gr.Row():
221
  results_csv = gr.File(label="Download Detailed Results", interactive=False)
222
  visualization = gr.HTML(label="Visualization")
223
+ benchmark_btn.click(benchmark_interface, inputs=[csv_upload, benchmark_client], outputs=[benchmark_results, results_csv, visualization])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  # Launch the app
226
  if __name__ == "__main__":
 
227
  demo.launch(share=True)