Mohaddz commited on
Commit
2c7390f
Β·
verified Β·
1 Parent(s): 5a8e848

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -13
app.py CHANGED
@@ -26,7 +26,6 @@ class MultiClientThemeClassifier:
26
  model_name = self.default_model
27
 
28
  try:
29
- # Avoid reloading the same model
30
  if self.model_loaded and hasattr(self.model, 'tokenizer') and self.model.tokenizer.name_or_path == model_name:
31
  return f"βœ… Model '{model_name}' is already loaded."
32
 
@@ -92,17 +91,18 @@ class MultiClientThemeClassifier:
92
  return f"Error: {str(e)}", 0.0, {}
93
 
94
  def benchmark_csv(self, csv_content: str, client_id: str) -> Tuple[str, Optional[str], Optional[str]]:
95
- """Benchmark the model on a CSV file"""
96
  error_status = self._ensure_model_is_loaded()
97
  if error_status: return f"❌ Model could not be loaded: {error_status}", None, None
98
 
99
  try:
100
- # CORRECTED: Use encoding 'utf-8-sig' to handle the invisible BOM character
101
- df = pd.read_csv(io.StringIO(csv_content), encoding='utf-8-sig')
102
 
 
103
  if 'text' not in df.columns or 'real_tag' not in df.columns:
104
- return "❌ CSV must have 'text' and 'real_tag' columns!", None, None
105
-
106
  df.dropna(subset=['text', 'real_tag'], inplace=True)
107
  df['text'] = df['text'].astype(str)
108
  df['real_tag'] = df['real_tag'].astype(str)
@@ -121,11 +121,9 @@ class MultiClientThemeClassifier:
121
 
122
  results_summary = f"πŸ“Š **Benchmarking Results**\n\n**Accuracy: {accuracy:.2%}** ({correct}/{total})"
123
 
124
- # Create visualization
125
- fig = px.bar(df['real_tag'].value_counts(), title="Theme Distribution in Dataset", labels={'index': 'Theme', 'value': 'Count'})
126
  visualization_html = fig.to_html()
127
 
128
- # Save results to a temporary file for download
129
  temp_file_path = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8-sig').name
130
  df.to_csv(temp_file_path, index=False)
131
 
@@ -164,23 +162,28 @@ def benchmark_interface(csv_file, client_id: str):
164
  if csv_file is None:
165
  return "Please upload a CSV file!", None, None
166
  try:
 
167
  if hasattr(csv_file, 'read'):
168
- csv_content = csv_file.read().decode('utf-8')
 
169
  else:
170
- csv_content = csv_file
 
 
171
 
 
172
  return classifier.benchmark_csv(csv_content, client_id)
173
  except Exception as e:
174
  error_details = traceback.format_exc()
175
  return f"❌ Error processing CSV file: {str(e)}\n\nDetails:\n{error_details}", None, None
176
 
177
- # --- Gradio Interface (No Changes Below) ---
178
  with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft()) as demo:
179
  gr.Markdown("# 🎯 Custom Themes Classification - MVP")
180
 
181
  with gr.Tab("πŸš€ Setup & Model"):
182
  gr.Markdown("### Step 1: Load the Embedding Model (Optional)")
183
- gr.Markdown("If you don't load a model, a default one (`Qwen/Qwen3-Embedding-0.6B`) will be loaded automatically on first use.")
184
  with gr.Row():
185
  model_input = gr.Textbox(label="HuggingFace Model Name", value="Qwen/Qwen3-Embedding-0.6B")
186
  load_btn = gr.Button("Load Model", variant="primary")
 
26
  model_name = self.default_model
27
 
28
  try:
 
29
  if self.model_loaded and hasattr(self.model, 'tokenizer') and self.model.tokenizer.name_or_path == model_name:
30
  return f"βœ… Model '{model_name}' is already loaded."
31
 
 
91
  return f"Error: {str(e)}", 0.0, {}
92
 
93
  def benchmark_csv(self, csv_content: str, client_id: str) -> Tuple[str, Optional[str], Optional[str]]:
94
+ """Benchmark the model on a CSV file. Assumes csv_content is a clean string."""
95
  error_status = self._ensure_model_is_loaded()
96
  if error_status: return f"❌ Model could not be loaded: {error_status}", None, None
97
 
98
  try:
99
+ # The string is now clean, so no special encoding is needed here.
100
+ df = pd.read_csv(io.StringIO(csv_content))
101
 
102
+ # Check for columns after reading
103
  if 'text' not in df.columns or 'real_tag' not in df.columns:
104
+ return f"❌ CSV must have 'text' and 'real_tag' columns! Found: {df.columns.to_list()}", None, None
105
+
106
  df.dropna(subset=['text', 'real_tag'], inplace=True)
107
  df['text'] = df['text'].astype(str)
108
  df['real_tag'] = df['real_tag'].astype(str)
 
121
 
122
  results_summary = f"πŸ“Š **Benchmarking Results**\n\n**Accuracy: {accuracy:.2%}** ({correct}/{total})"
123
 
124
+ fig = px.bar(df['real_tag'].value_counts(), title="Theme Distribution", labels={'index': 'Theme', 'value': 'Count'})
 
125
  visualization_html = fig.to_html()
126
 
 
127
  temp_file_path = tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False, encoding='utf-8-sig').name
128
  df.to_csv(temp_file_path, index=False)
129
 
 
162
  if csv_file is None:
163
  return "Please upload a CSV file!", None, None
164
  try:
165
+ # CORRECTED AND FINAL FIX: Handle the BOM at the point of file reading.
166
  if hasattr(csv_file, 'read'):
167
+ # It's a file-like object (TemporaryFile), read its bytes and decode with utf-8-sig
168
+ csv_content = csv_file.read().decode('utf-8-sig')
169
  else:
170
+ # It's a string (NamedString), which was likely decoded with 'utf-8'.
171
+ # Manually remove the BOM if it exists.
172
+ csv_content = str(csv_file).lstrip('\ufeff')
173
 
174
+ # Now, pass the clean string to the benchmark function
175
  return classifier.benchmark_csv(csv_content, client_id)
176
  except Exception as e:
177
  error_details = traceback.format_exc()
178
  return f"❌ Error processing CSV file: {str(e)}\n\nDetails:\n{error_details}", None, None
179
 
180
+ # --- Gradio Interface ---
181
  with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft()) as demo:
182
  gr.Markdown("# 🎯 Custom Themes Classification - MVP")
183
 
184
  with gr.Tab("πŸš€ Setup & Model"):
185
  gr.Markdown("### Step 1: Load the Embedding Model (Optional)")
186
+ gr.Markdown("A default model (`Qwen/Qwen3-Embedding-0.6B`) will load automatically on first use.")
187
  with gr.Row():
188
  model_input = gr.Textbox(label="HuggingFace Model Name", value="Qwen/Qwen3-Embedding-0.6B")
189
  load_btn = gr.Button("Load Model", variant="primary")