Mohaddz commited on
Commit
40a7521
·
verified ·
1 Parent(s): 2403d73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -48
app.py CHANGED
@@ -9,6 +9,7 @@ import plotly.express as px
9
  import plotly.graph_objects as go
10
  from collections import defaultdict
11
  import json
 
12
 
13
  class MultiClientThemeClassifier:
14
  def __init__(self):
@@ -20,7 +21,7 @@ class MultiClientThemeClassifier:
20
  """Load the embedding model"""
21
  if not self.model_loaded:
22
  try:
23
- # Using a smaller, faster model for demo
24
  self.model = SentenceTransformer('Qwen/Qwen3-Embedding-0.6B')
25
  self.model_loaded = True
26
  return "✅ Model loaded successfully!"
@@ -88,47 +89,63 @@ class MultiClientThemeClassifier:
88
  return "❌ Model not loaded!", "", ""
89
 
90
  try:
 
 
91
  # Read CSV
92
  df = pd.read_csv(io.StringIO(csv_content))
 
 
93
 
94
  # Validate CSV format
95
  if 'text' not in df.columns or 'real_tag' not in df.columns:
96
  return "❌ CSV must have 'text' and 'real_tag' columns!", "", ""
97
 
 
 
 
 
 
 
98
  # Get unique themes from CSV
99
  unique_themes = df['real_tag'].unique().tolist()
 
100
 
101
  # Add themes for this client (using theme names as prototypes for demo)
102
- self.add_client_themes(client_id, unique_themes)
 
103
 
104
- # Classify all texts
105
  predictions = []
106
  confidences = []
107
 
108
- for _, row in df.iterrows():
109
- pred_theme, confidence, _ = self.classify_text(row['text'], client_id)
110
- predictions.append(pred_theme)
111
- confidences.append(confidence)
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
  df['predicted_tag'] = predictions
114
  df['confidence'] = confidences
115
 
116
  # Calculate metrics
117
- correct = (df['real_tag'] == df['predicted_tag']).sum()
118
- total = len(df)
119
- accuracy = correct / total
 
120
 
121
- # Create confusion matrix data
122
- confusion_data = []
123
- for real_tag in unique_themes:
124
- for pred_tag in unique_themes + ['UNKNOWN_THEME']:
125
- count = len(df[(df['real_tag'] == real_tag) & (df['predicted_tag'] == pred_tag)])
126
- if count > 0:
127
- confusion_data.append({
128
- 'Real': real_tag,
129
- 'Predicted': pred_tag,
130
- 'Count': count
131
- })
132
 
133
  # Generate results summary
134
  results_summary = f"""
@@ -138,32 +155,41 @@ class MultiClientThemeClassifier:
138
  - Total samples: {total}
139
  - Correct predictions: {correct}
140
  - **Accuracy: {accuracy:.2%}**
141
- - Average confidence: {np.mean(confidences):.3f}
142
 
143
  **Per-Theme Breakdown:**
144
  """
145
 
146
  for theme in unique_themes:
147
- theme_df = df[df['real_tag'] == theme]
148
- theme_correct = (theme_df['real_tag'] == theme_df['predicted_tag']).sum()
149
- theme_total = len(theme_df)
150
- theme_acc = theme_correct / theme_total if theme_total > 0 else 0
151
- avg_conf = theme_df['confidence'].mean()
152
-
153
- results_summary += f"- **{theme}**: {theme_acc:.2%} ({theme_correct}/{theme_total}) - Avg conf: {avg_conf:.3f}\n"
 
154
 
155
- # Create visualization
156
- fig = px.bar(
157
- x=unique_themes,
158
- y=[len(df[df['real_tag'] == theme]) for theme in unique_themes],
159
- title="Theme Distribution in Dataset",
160
- labels={'x': 'Themes', 'y': 'Count'}
161
- )
 
 
 
 
 
 
162
 
163
- return results_summary, df.to_csv(index=False), fig.to_html()
164
 
165
  except Exception as e:
166
- return f"❌ Error during benchmarking: {str(e)}", "", ""
 
 
167
 
168
  # Initialize the classifier
169
  classifier = MultiClientThemeClassifier()
@@ -204,21 +230,30 @@ def benchmark_interface(csv_file, client_id: str):
204
  return "Please upload a CSV file!", "", ""
205
 
206
  try:
207
- csv_content = csv_file.decode('utf-8')
 
 
 
 
 
 
 
 
208
  return classifier.benchmark_csv(csv_content, client_id)
209
  except Exception as e:
210
- return f"❌ Error reading CSV: {str(e)}", "", ""
 
211
 
212
  # Create the Gradio interface
213
- with gr.Blocks(title="Company Custom Themes Classification MVP", theme=gr.themes.Soft()) as demo:
214
  gr.Markdown("""
215
- # 🎯 Company Custom Themes Classification - MVP
216
 
217
  **A scalable, cost-effective solution for multi-client theme classification**
218
 
219
  This demo showcases an embedding-based approach that can:
220
  - ✅ Handle multiple clients with different themes
221
- - ✅ Distinguish between similar themes (e.g., "التمويل العقاري" vs "التمويل الشخصي")
222
  - ✅ Process ~1M posts/day at low cost (~$500/month vs $30k/month for pure LLM)
223
  - ✅ Provide confidence scores and similarity breakdowns
224
  """)
@@ -236,7 +271,7 @@ with gr.Blocks(title="Company Custom Themes Classification MVP", theme=gr.themes
236
  themes_input = gr.Textbox(
237
  label="Themes (one per line)",
238
  lines=5,
239
- placeholder="e.g.:\nالتمويل العقاري\nالتمويل الشخصي\nالتعليم الأهلي\nالرياضة"
240
  )
241
 
242
  add_themes_btn = gr.Button("Add Themes", variant="secondary")
@@ -256,7 +291,7 @@ with gr.Blocks(title="Company Custom Themes Classification MVP", theme=gr.themes
256
  text_input = gr.Textbox(
257
  label="Text to Classify",
258
  lines=3,
259
- placeholder="Enter Arabic or English text..."
260
  )
261
  client_select = gr.Textbox(
262
  label="Client ID",
@@ -320,7 +355,7 @@ with gr.Blocks(title="Company Custom Themes Classification MVP", theme=gr.themes
320
  gr.Markdown("""
321
  ## 🎯 Solution Overview
322
 
323
- This MVP demonstrates a **hybrid embedding-based approach** for Company's Custom Themes feature:
324
 
325
  ### ✅ Key Advantages:
326
  1. **Cost Effective**: ~$500/month vs $30,000/month for pure LLM approach
@@ -330,7 +365,7 @@ with gr.Blocks(title="Company Custom Themes Classification MVP", theme=gr.themes
330
  5. **Confidence Scoring**: Provides transparency in predictions
331
 
332
  ### 🏗️ Architecture:
333
- 1. **Embedding Model**: SentenceTransformers for semantic understanding
334
  2. **Theme Prototypes**: Each client's themes represented as embedding vectors
335
  3. **Similarity Matching**: Cosine similarity for classification
336
  4. **Confidence Thresholding**: Flags uncertain predictions
@@ -348,7 +383,7 @@ with gr.Blocks(title="Company Custom Themes Classification MVP", theme=gr.themes
348
 
349
  ---
350
 
351
- **Built for Company's Case Study | Scalable Theme Classification MVP**
352
  """)
353
 
354
  # Launch the app
 
9
  import plotly.graph_objects as go
10
  from collections import defaultdict
11
  import json
12
+ import traceback
13
 
14
  class MultiClientThemeClassifier:
15
  def __init__(self):
 
21
  """Load the embedding model"""
22
  if not self.model_loaded:
23
  try:
24
+ # Using Qwen embedding model
25
  self.model = SentenceTransformer('Qwen/Qwen3-Embedding-0.6B')
26
  self.model_loaded = True
27
  return "✅ Model loaded successfully!"
 
89
  return "❌ Model not loaded!", "", ""
90
 
91
  try:
92
+ print("Starting CSV benchmark...")
93
+
94
  # Read CSV
95
  df = pd.read_csv(io.StringIO(csv_content))
96
+ print(f"CSV loaded with shape: {df.shape}")
97
+ print(f"CSV columns: {df.columns.tolist()}")
98
 
99
  # Validate CSV format
100
  if 'text' not in df.columns or 'real_tag' not in df.columns:
101
  return "❌ CSV must have 'text' and 'real_tag' columns!", "", ""
102
 
103
+ # Clean data
104
+ df = df.dropna(subset=['text', 'real_tag'])
105
+ df['text'] = df['text'].astype(str)
106
+ df['real_tag'] = df['real_tag'].astype(str)
107
+ print(f"After cleaning: {df.shape}")
108
+
109
  # Get unique themes from CSV
110
  unique_themes = df['real_tag'].unique().tolist()
111
+ print(f"Unique themes found: {len(unique_themes)} - {unique_themes}")
112
 
113
  # Add themes for this client (using theme names as prototypes for demo)
114
+ theme_add_result = self.add_client_themes(client_id, unique_themes)
115
+ print(f"Theme addition result: {theme_add_result}")
116
 
117
+ # Classify all texts with progress
118
  predictions = []
119
  confidences = []
120
 
121
+ print("Starting classification...")
122
+ for idx, row in df.iterrows():
123
+ try:
124
+ text = str(row['text'])[:500] # Limit text length
125
+ pred_theme, confidence, _ = self.classify_text(text, client_id)
126
+ predictions.append(pred_theme)
127
+ confidences.append(confidence)
128
+
129
+ if idx % 10 == 0: # Progress logging
130
+ print(f"Processed {idx + 1}/{len(df)} samples")
131
+
132
+ except Exception as e:
133
+ print(f"Error classifying row {idx}: {str(e)}")
134
+ predictions.append("ERROR")
135
+ confidences.append(0.0)
136
+
137
+ print("Classification complete!")
138
 
139
  df['predicted_tag'] = predictions
140
  df['confidence'] = confidences
141
 
142
  # Calculate metrics
143
+ valid_predictions = df[df['predicted_tag'] != 'ERROR']
144
+ correct = (valid_predictions['real_tag'] == valid_predictions['predicted_tag']).sum()
145
+ total = len(valid_predictions)
146
+ accuracy = correct / total if total > 0 else 0
147
 
148
+ print(f"Metrics calculated: {correct}/{total} = {accuracy:.2%}")
 
 
 
 
 
 
 
 
 
 
149
 
150
  # Generate results summary
151
  results_summary = f"""
 
155
  - Total samples: {total}
156
  - Correct predictions: {correct}
157
  - **Accuracy: {accuracy:.2%}**
158
+ - Average confidence: {np.mean([c for c in confidences if c > 0]):.3f}
159
 
160
  **Per-Theme Breakdown:**
161
  """
162
 
163
  for theme in unique_themes:
164
+ theme_df = valid_predictions[valid_predictions['real_tag'] == theme]
165
+ if len(theme_df) > 0:
166
+ theme_correct = (theme_df['real_tag'] == theme_df['predicted_tag']).sum()
167
+ theme_total = len(theme_df)
168
+ theme_acc = theme_correct / theme_total if theme_total > 0 else 0
169
+ avg_conf = theme_df['confidence'].mean()
170
+
171
+ results_summary += f"- **{theme}**: {theme_acc:.2%} ({theme_correct}/{theme_total}) - Avg conf: {avg_conf:.3f}\n"
172
 
173
+ # Create simple visualization
174
+ try:
175
+ theme_counts = [len(df[df['real_tag'] == theme]) for theme in unique_themes]
176
+ fig = px.bar(
177
+ x=unique_themes,
178
+ y=theme_counts,
179
+ title="Theme Distribution in Dataset",
180
+ labels={'x': 'Themes', 'y': 'Count'}
181
+ )
182
+ visualization_html = fig.to_html()
183
+ except Exception as viz_error:
184
+ print(f"Visualization error: {viz_error}")
185
+ visualization_html = "<p>Visualization error occurred</p>"
186
 
187
+ return results_summary, df.to_csv(index=False), visualization_html
188
 
189
  except Exception as e:
190
+ error_details = traceback.format_exc()
191
+ print(f"Full error: {error_details}")
192
+ return f"❌ Error during benchmarking: {str(e)}\n\nFull traceback:\n{error_details}", "", ""
193
 
194
  # Initialize the classifier
195
  classifier = MultiClientThemeClassifier()
 
230
  return "Please upload a CSV file!", "", ""
231
 
232
  try:
233
+ # Handle both file objects and file paths
234
+ if hasattr(csv_file, 'read'):
235
+ csv_content = csv_file.read().decode('utf-8')
236
+ elif hasattr(csv_file, 'name'):
237
+ with open(csv_file.name, 'r', encoding='utf-8') as f:
238
+ csv_content = f.read()
239
+ else:
240
+ csv_content = str(csv_file)
241
+
242
  return classifier.benchmark_csv(csv_content, client_id)
243
  except Exception as e:
244
+ error_details = traceback.format_exc()
245
+ return f"❌ Error reading CSV: {str(e)}\n\nDetails:\n{error_details}", "", ""
246
 
247
  # Create the Gradio interface
248
+ with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft()) as demo:
249
  gr.Markdown("""
250
+ # 🎯 Custom Themes Classification - MVP
251
 
252
  **A scalable, cost-effective solution for multi-client theme classification**
253
 
254
  This demo showcases an embedding-based approach that can:
255
  - ✅ Handle multiple clients with different themes
256
+ - ✅ Distinguish between similar themes (e.g., "Real Estate Financing" vs "Personal Financing")
257
  - ✅ Process ~1M posts/day at low cost (~$500/month vs $30k/month for pure LLM)
258
  - ✅ Provide confidence scores and similarity breakdowns
259
  """)
 
271
  themes_input = gr.Textbox(
272
  label="Themes (one per line)",
273
  lines=5,
274
+ placeholder="e.g.:\nReal Estate Financing\nPersonal Financing\nPrivate Education\nSports"
275
  )
276
 
277
  add_themes_btn = gr.Button("Add Themes", variant="secondary")
 
291
  text_input = gr.Textbox(
292
  label="Text to Classify",
293
  lines=3,
294
+ placeholder="Enter text to classify..."
295
  )
296
  client_select = gr.Textbox(
297
  label="Client ID",
 
355
  gr.Markdown("""
356
  ## 🎯 Solution Overview
357
 
358
+ This MVP demonstrates a **hybrid embedding-based approach** for Custom Themes classification:
359
 
360
  ### ✅ Key Advantages:
361
  1. **Cost Effective**: ~$500/month vs $30,000/month for pure LLM approach
 
365
  5. **Confidence Scoring**: Provides transparency in predictions
366
 
367
  ### 🏗️ Architecture:
368
+ 1. **Embedding Model**: Qwen/Qwen3-Embedding-0.6B for semantic understanding
369
  2. **Theme Prototypes**: Each client's themes represented as embedding vectors
370
  3. **Similarity Matching**: Cosine similarity for classification
371
  4. **Confidence Thresholding**: Flags uncertain predictions
 
383
 
384
  ---
385
 
386
+ **Scalable Theme Classification MVP**
387
  """)
388
 
389
  # Launch the app