Mohaddz commited on
Commit
26f50fc
Β·
verified Β·
1 Parent(s): 30f9702

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -12
app.py CHANGED
@@ -19,7 +19,6 @@ class MultiClientThemeClassifier:
19
  self.client_themes = {}
20
  self.model_loaded = False
21
  self.default_model = 'Qwen/Qwen3-Embedding-0.6B'
22
- # CORRECTED: Add attribute to remember the last loaded model's name
23
  self.current_model_name = self.default_model
24
 
25
  def load_model(self, model_name: str):
@@ -36,7 +35,6 @@ class MultiClientThemeClassifier:
36
  print(f"Loading model: {model_name} onto CUDA device")
37
  self.model = SentenceTransformer(model_name, device='cuda', trust_remote_code=True)
38
  self.model_loaded = True
39
- # CORRECTED: Remember the name of the successfully loaded model
40
  self.current_model_name = model_name
41
  return f"βœ… Model '{model_name}' loaded successfully onto GPU!"
42
  except Exception as e:
@@ -48,7 +46,6 @@ class MultiClientThemeClassifier:
48
  """Internal helper to load the correct model if it's not already loaded."""
49
  if not self.model_loaded:
50
  print(f"Model not loaded. Automatically loading last selected model: {self.current_model_name}...")
51
- # CORRECTED: Load the last selected model, not the default one
52
  status = self.load_model(self.current_model_name)
53
  if "Error" in status:
54
  return status
@@ -93,11 +90,16 @@ class MultiClientThemeClassifier:
93
  except Exception as e:
94
  return f"Error: {str(e)}", 0.0, {}
95
 
96
- def benchmark_csv(self, csv_filepath: str, client_id: str) -> Tuple[str, Optional[str], Optional[str]]:
97
- """Benchmark the model on a CSV file, trying multiple encodings."""
98
- error_status = self._ensure_model_is_loaded()
99
- if error_status: return f"❌ Model could not be loaded: {error_status}", None, None
 
 
 
 
100
 
 
101
  encodings_to_try = ['utf-8-sig', 'utf-8', 'cp1256', 'latin1']
102
  df = None
103
  for encoding in encodings_to_try:
@@ -170,13 +172,17 @@ def classify_interface(text: str, client_id: str, confidence_threshold: float):
170
 
171
  return result, ""
172
 
 
173
  @spaces.GPU(duration=300)
174
- def benchmark_interface(csv_file_obj, client_id: str):
175
  if csv_file_obj is None:
176
  return "Please upload a CSV file!", None, None
 
 
177
  try:
178
  csv_filepath = csv_file_obj.name
179
- return classifier.benchmark_csv(csv_filepath, client_id)
 
180
  except Exception as e:
181
  error_details = traceback.format_exc()
182
  return f"❌ Error processing CSV file object: {str(e)}\n\nDetails:\n{error_details}", None, None
@@ -187,8 +193,9 @@ with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft())
187
 
188
  with gr.Tab("πŸš€ Setup & Model"):
189
  gr.Markdown("### Step 1: Load the Embedding Model (Optional)")
190
- gr.Markdown("A default model (`Qwen/Qwen3-Embedding-0.6B`) will load automatically on first use.")
191
  with gr.Row():
 
192
  model_input = gr.Textbox(label="HuggingFace Model Name", value="Qwen/Qwen3-Embedding-0.6B")
193
  load_btn = gr.Button("Load Model", variant="primary")
194
  load_status = gr.Textbox(label="Status", interactive=False)
@@ -215,7 +222,7 @@ with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft())
215
  classify_btn.click(classify_interface, inputs=[text_input, client_select, confidence_slider], outputs=[classification_result, gr.Textbox(visible=False)])
216
 
217
  with gr.Tab("πŸ“Š CSV Benchmarking"):
218
- gr.Markdown("### Benchmark on Your Dataset\nUpload a CSV with `text` and `real_tag` columns.")
219
  with gr.Row():
220
  with gr.Column():
221
  csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"])
@@ -226,7 +233,13 @@ with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft())
226
  with gr.Row():
227
  results_csv = gr.File(label="Download Detailed Results", interactive=False)
228
  visualization = gr.HTML(label="Visualization")
229
- benchmark_btn.click(benchmark_interface, inputs=[csv_upload, benchmark_client], outputs=[benchmark_results, results_csv, visualization])
 
 
 
 
 
 
230
 
231
  # Launch the app
232
  if __name__ == "__main__":
 
19
  self.client_themes = {}
20
  self.model_loaded = False
21
  self.default_model = 'Qwen/Qwen3-Embedding-0.6B'
 
22
  self.current_model_name = self.default_model
23
 
24
  def load_model(self, model_name: str):
 
35
  print(f"Loading model: {model_name} onto CUDA device")
36
  self.model = SentenceTransformer(model_name, device='cuda', trust_remote_code=True)
37
  self.model_loaded = True
 
38
  self.current_model_name = model_name
39
  return f"βœ… Model '{model_name}' loaded successfully onto GPU!"
40
  except Exception as e:
 
46
  """Internal helper to load the correct model if it's not already loaded."""
47
  if not self.model_loaded:
48
  print(f"Model not loaded. Automatically loading last selected model: {self.current_model_name}...")
 
49
  status = self.load_model(self.current_model_name)
50
  if "Error" in status:
51
  return status
 
90
  except Exception as e:
91
  return f"Error: {str(e)}", 0.0, {}
92
 
93
+ # CORRECTED: The benchmark function now takes the model_name as an argument
94
+ def benchmark_csv(self, csv_filepath: str, client_id: str, model_name: str) -> Tuple[str, Optional[str], Optional[str]]:
95
+ """Benchmark a specific model on a CSV file."""
96
+ # Step 1: Explicitly load the model requested by the user for this benchmark run.
97
+ load_status = self.load_model(model_name)
98
+ # We allow the function to proceed if the model is "already loaded", but stop for any other error.
99
+ if "❌" in load_status:
100
+ return f"❌ Model '{model_name}' could not be loaded for benchmarking.\n\nError: {load_status}", None, None
101
 
102
+ # Step 2: Proceed with the benchmark logic as before.
103
  encodings_to_try = ['utf-8-sig', 'utf-8', 'cp1256', 'latin1']
104
  df = None
105
  for encoding in encodings_to_try:
 
172
 
173
  return result, ""
174
 
175
+ # CORRECTED: The interface now accepts model_name
176
  @spaces.GPU(duration=300)
177
+ def benchmark_interface(csv_file_obj, client_id: str, model_name: str):
178
  if csv_file_obj is None:
179
  return "Please upload a CSV file!", None, None
180
+ if not model_name.strip():
181
+ return "Please enter a model name for the benchmark!", None, None
182
  try:
183
  csv_filepath = csv_file_obj.name
184
+ # Pass the model name from the UI down to the classifier method
185
+ return classifier.benchmark_csv(csv_filepath, client_id, model_name.strip())
186
  except Exception as e:
187
  error_details = traceback.format_exc()
188
  return f"❌ Error processing CSV file object: {str(e)}\n\nDetails:\n{error_details}", None, None
 
193
 
194
  with gr.Tab("πŸš€ Setup & Model"):
195
  gr.Markdown("### Step 1: Load the Embedding Model (Optional)")
196
+ gr.Markdown("A default model (`Qwen/Qwen3-Embedding-0.6B`) will load automatically on first use. You can specify a different model here to use it in other tabs.")
197
  with gr.Row():
198
+ # This input is now used by the benchmark tab as well
199
  model_input = gr.Textbox(label="HuggingFace Model Name", value="Qwen/Qwen3-Embedding-0.6B")
200
  load_btn = gr.Button("Load Model", variant="primary")
201
  load_status = gr.Textbox(label="Status", interactive=False)
 
222
  classify_btn.click(classify_interface, inputs=[text_input, client_select, confidence_slider], outputs=[classification_result, gr.Textbox(visible=False)])
223
 
224
  with gr.Tab("πŸ“Š CSV Benchmarking"):
225
+ gr.Markdown("### Benchmark on Your Dataset\nUpload a CSV with `text` and `real_tag` columns. The model from the 'Setup & Model' tab will be loaded and used for the benchmark.")
226
  with gr.Row():
227
  with gr.Column():
228
  csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"])
 
233
  with gr.Row():
234
  results_csv = gr.File(label="Download Detailed Results", interactive=False)
235
  visualization = gr.HTML(label="Visualization")
236
+
237
+ # CORRECTED: The button now sends the model_input value to the benchmark function
238
+ benchmark_btn.click(
239
+ benchmark_interface,
240
+ inputs=[csv_upload, benchmark_client, model_input],
241
+ outputs=[benchmark_results, results_csv, visualization]
242
+ )
243
 
244
  # Launch the app
245
  if __name__ == "__main__":