Mohaddz commited on
Commit
1bb89fc
·
verified ·
1 Parent(s): 2c7390f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -20
app.py CHANGED
@@ -90,16 +90,15 @@ class MultiClientThemeClassifier:
90
  except Exception as e:
91
  return f"Error: {str(e)}", 0.0, {}
92
 
93
- def benchmark_csv(self, csv_content: str, client_id: str) -> Tuple[str, Optional[str], Optional[str]]:
94
- """Benchmark the model on a CSV file. Assumes csv_content is a clean string."""
95
  error_status = self._ensure_model_is_loaded()
96
  if error_status: return f"❌ Model could not be loaded: {error_status}", None, None
97
 
98
  try:
99
- # The string is now clean, so no special encoding is needed here.
100
- df = pd.read_csv(io.StringIO(csv_content))
101
 
102
- # Check for columns after reading
103
  if 'text' not in df.columns or 'real_tag' not in df.columns:
104
  return f"❌ CSV must have 'text' and 'real_tag' columns! Found: {df.columns.to_list()}", None, None
105
 
@@ -110,7 +109,8 @@ class MultiClientThemeClassifier:
110
  unique_themes = df['real_tag'].unique().tolist()
111
  self.add_client_themes(client_id, unique_themes)
112
 
113
- results = [self.classify_text(str(row['text'])[:500], client_id) for _, row in df.iterrows()]
 
114
 
115
  df['predicted_tag'] = [res[0] for res in results]
116
  df['confidence'] = [res[1] for res in results]
@@ -158,24 +158,21 @@ def classify_interface(text: str, client_id: str, confidence_threshold: float):
158
  return result, ""
159
 
160
  @spaces.GPU(duration=300)
161
- def benchmark_interface(csv_file, client_id: str):
162
- if csv_file is None:
 
 
 
163
  return "Please upload a CSV file!", None, None
164
  try:
165
- # CORRECTED AND FINAL FIX: Handle the BOM at the point of file reading.
166
- if hasattr(csv_file, 'read'):
167
- # It's a file-like object (TemporaryFile), read its bytes and decode with utf-8-sig
168
- csv_content = csv_file.read().decode('utf-8-sig')
169
- else:
170
- # It's a string (NamedString), which was likely decoded with 'utf-8'.
171
- # Manually remove the BOM if it exists.
172
- csv_content = str(csv_file).lstrip('\ufeff')
173
-
174
- # Now, pass the clean string to the benchmark function
175
- return classifier.benchmark_csv(csv_content, client_id)
176
  except Exception as e:
177
  error_details = traceback.format_exc()
178
- return f"❌ Error processing CSV file: {str(e)}\n\nDetails:\n{error_details}", None, None
179
 
180
  # --- Gradio Interface ---
181
  with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft()) as demo:
 
90
  except Exception as e:
91
  return f"Error: {str(e)}", 0.0, {}
92
 
93
+ def benchmark_csv(self, csv_filepath: str, client_id: str) -> Tuple[str, Optional[str], Optional[str]]:
94
+ """Benchmark the model on a CSV file from a given filepath."""
95
  error_status = self._ensure_model_is_loaded()
96
  if error_status: return f"❌ Model could not be loaded: {error_status}", None, None
97
 
98
  try:
99
+ # CORRECTED: Read directly from the filepath and handle BOM with utf-8-sig
100
+ df = pd.read_csv(csv_filepath, encoding='utf-8-sig')
101
 
 
102
  if 'text' not in df.columns or 'real_tag' not in df.columns:
103
  return f"❌ CSV must have 'text' and 'real_tag' columns! Found: {df.columns.to_list()}", None, None
104
 
 
109
  unique_themes = df['real_tag'].unique().tolist()
110
  self.add_client_themes(client_id, unique_themes)
111
 
112
+ texts_to_classify = df['text'].str.slice(0, 500).tolist()
113
+ results = [self.classify_text(text, client_id) for text in texts_to_classify]
114
 
115
  df['predicted_tag'] = [res[0] for res in results]
116
  df['confidence'] = [res[1] for res in results]
 
158
  return result, ""
159
 
160
  @spaces.GPU(duration=300)
161
+ def benchmark_interface(csv_file_obj, client_id: str):
162
+ """
163
+ Handles the Gradio file object and passes the filepath to the benchmark function.
164
+ """
165
+ if csv_file_obj is None:
166
  return "Please upload a CSV file!", None, None
167
  try:
168
+ # THE FINAL, CORRECT FIX: Get the filepath from the .name attribute of the Gradio file object
169
+ csv_filepath = csv_file_obj.name
170
+
171
+ # Pass the filepath to the actual processing function
172
+ return classifier.benchmark_csv(csv_filepath, client_id)
 
 
 
 
 
 
173
  except Exception as e:
174
  error_details = traceback.format_exc()
175
+ return f"❌ Error processing CSV file object: {str(e)}\n\nDetails:\n{error_details}", None, None
176
 
177
  # --- Gradio Interface ---
178
  with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft()) as demo: