Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -19,7 +19,6 @@ class MultiClientThemeClassifier:
|
|
| 19 |
self.client_themes = {}
|
| 20 |
self.model_loaded = False
|
| 21 |
self.default_model = 'Qwen/Qwen3-Embedding-0.6B'
|
| 22 |
-
# CORRECTED: Add attribute to remember the last loaded model's name
|
| 23 |
self.current_model_name = self.default_model
|
| 24 |
|
| 25 |
def load_model(self, model_name: str):
|
|
@@ -36,7 +35,6 @@ class MultiClientThemeClassifier:
|
|
| 36 |
print(f"Loading model: {model_name} onto CUDA device")
|
| 37 |
self.model = SentenceTransformer(model_name, device='cuda', trust_remote_code=True)
|
| 38 |
self.model_loaded = True
|
| 39 |
-
# CORRECTED: Remember the name of the successfully loaded model
|
| 40 |
self.current_model_name = model_name
|
| 41 |
return f"β
Model '{model_name}' loaded successfully onto GPU!"
|
| 42 |
except Exception as e:
|
|
@@ -48,7 +46,6 @@ class MultiClientThemeClassifier:
|
|
| 48 |
"""Internal helper to load the correct model if it's not already loaded."""
|
| 49 |
if not self.model_loaded:
|
| 50 |
print(f"Model not loaded. Automatically loading last selected model: {self.current_model_name}...")
|
| 51 |
-
# CORRECTED: Load the last selected model, not the default one
|
| 52 |
status = self.load_model(self.current_model_name)
|
| 53 |
if "Error" in status:
|
| 54 |
return status
|
|
@@ -93,11 +90,16 @@ class MultiClientThemeClassifier:
|
|
| 93 |
except Exception as e:
|
| 94 |
return f"Error: {str(e)}", 0.0, {}
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
|
|
|
|
| 101 |
encodings_to_try = ['utf-8-sig', 'utf-8', 'cp1256', 'latin1']
|
| 102 |
df = None
|
| 103 |
for encoding in encodings_to_try:
|
|
@@ -170,13 +172,17 @@ def classify_interface(text: str, client_id: str, confidence_threshold: float):
|
|
| 170 |
|
| 171 |
return result, ""
|
| 172 |
|
|
|
|
| 173 |
@spaces.GPU(duration=300)
|
| 174 |
-
def benchmark_interface(csv_file_obj, client_id: str):
|
| 175 |
if csv_file_obj is None:
|
| 176 |
return "Please upload a CSV file!", None, None
|
|
|
|
|
|
|
| 177 |
try:
|
| 178 |
csv_filepath = csv_file_obj.name
|
| 179 |
-
|
|
|
|
| 180 |
except Exception as e:
|
| 181 |
error_details = traceback.format_exc()
|
| 182 |
return f"β Error processing CSV file object: {str(e)}\n\nDetails:\n{error_details}", None, None
|
|
@@ -187,8 +193,9 @@ with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft())
|
|
| 187 |
|
| 188 |
with gr.Tab("π Setup & Model"):
|
| 189 |
gr.Markdown("### Step 1: Load the Embedding Model (Optional)")
|
| 190 |
-
gr.Markdown("A default model (`Qwen/Qwen3-Embedding-0.6B`) will load automatically on first use.")
|
| 191 |
with gr.Row():
|
|
|
|
| 192 |
model_input = gr.Textbox(label="HuggingFace Model Name", value="Qwen/Qwen3-Embedding-0.6B")
|
| 193 |
load_btn = gr.Button("Load Model", variant="primary")
|
| 194 |
load_status = gr.Textbox(label="Status", interactive=False)
|
|
@@ -215,7 +222,7 @@ with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft())
|
|
| 215 |
classify_btn.click(classify_interface, inputs=[text_input, client_select, confidence_slider], outputs=[classification_result, gr.Textbox(visible=False)])
|
| 216 |
|
| 217 |
with gr.Tab("π CSV Benchmarking"):
|
| 218 |
-
gr.Markdown("### Benchmark on Your Dataset\nUpload a CSV with `text` and `real_tag` columns.")
|
| 219 |
with gr.Row():
|
| 220 |
with gr.Column():
|
| 221 |
csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"])
|
|
@@ -226,7 +233,13 @@ with gr.Blocks(title="Custom Themes Classification MVP", theme=gr.themes.Soft())
|
|
| 226 |
with gr.Row():
|
| 227 |
results_csv = gr.File(label="Download Detailed Results", interactive=False)
|
| 228 |
visualization = gr.HTML(label="Visualization")
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
# Launch the app
|
| 232 |
if __name__ == "__main__":
|
|
|
|
| 19 |
self.client_themes = {}
|
| 20 |
self.model_loaded = False
|
| 21 |
self.default_model = 'Qwen/Qwen3-Embedding-0.6B'
|
|
|
|
| 22 |
self.current_model_name = self.default_model
|
| 23 |
|
| 24 |
def load_model(self, model_name: str):
|
|
|
|
| 35 |
print(f"Loading model: {model_name} onto CUDA device")
|
| 36 |
self.model = SentenceTransformer(model_name, device='cuda', trust_remote_code=True)
|
| 37 |
self.model_loaded = True
|
|
|
|
| 38 |
self.current_model_name = model_name
|
| 39 |
return f"β
Model '{model_name}' loaded successfully onto GPU!"
|
| 40 |
except Exception as e:
|
|
|
|
| 46 |
"""Internal helper to load the correct model if it's not already loaded."""
|
| 47 |
if not self.model_loaded:
|
| 48 |
print(f"Model not loaded. Automatically loading last selected model: {self.current_model_name}...")
|
|
|
|
| 49 |
status = self.load_model(self.current_model_name)
|
| 50 |
if "Error" in status:
|
| 51 |
return status
|
|
|
|
| 90 |
except Exception as e:
|
| 91 |
return f"Error: {str(e)}", 0.0, {}
|
| 92 |
|
| 93 |
+
# CORRECTED: The benchmark function now takes the model_name as an argument
|
| 94 |
+
def benchmark_csv(self, csv_filepath: str, client_id: str, model_name: str) -> Tuple[str, Optional[str], Optional[str]]:
|
| 95 |
+
"""Benchmark a specific model on a CSV file."""
|
| 96 |
+
# Step 1: Explicitly load the model requested by the user for this benchmark run.
|
| 97 |
+
load_status = self.load_model(model_name)
|
| 98 |
+
# We allow the function to proceed if the model is "already loaded", but stop for any other error.
|
| 99 |
+
if "β" in load_status:
|
| 100 |
+
return f"β Model '{model_name}' could not be loaded for benchmarking.\n\nError: {load_status}", None, None
|
| 101 |
|
| 102 |
+
# Step 2: Proceed with the benchmark logic as before.
|
| 103 |
encodings_to_try = ['utf-8-sig', 'utf-8', 'cp1256', 'latin1']
|
| 104 |
df = None
|
| 105 |
for encoding in encodings_to_try:
|
|
|
|
| 172 |
|
| 173 |
return result, ""
|
| 174 |
|
| 175 |
+
# CORRECTED: The interface now accepts model_name
|
| 176 |
@spaces.GPU(duration=300)
|
| 177 |
+
def benchmark_interface(csv_file_obj, client_id: str, model_name: str):
|
| 178 |
if csv_file_obj is None:
|
| 179 |
return "Please upload a CSV file!", None, None
|
| 180 |
+
if not model_name.strip():
|
| 181 |
+
return "Please enter a model name for the benchmark!", None, None
|
| 182 |
try:
|
| 183 |
csv_filepath = csv_file_obj.name
|
| 184 |
+
# Pass the model name from the UI down to the classifier method
|
| 185 |
+
return classifier.benchmark_csv(csv_filepath, client_id, model_name.strip())
|
| 186 |
except Exception as e:
|
| 187 |
error_details = traceback.format_exc()
|
| 188 |
return f"β Error processing CSV file object: {str(e)}\n\nDetails:\n{error_details}", None, None
|
|
|
|
| 193 |
|
| 194 |
with gr.Tab("π Setup & Model"):
|
| 195 |
gr.Markdown("### Step 1: Load the Embedding Model (Optional)")
|
| 196 |
+
gr.Markdown("A default model (`Qwen/Qwen3-Embedding-0.6B`) will load automatically on first use. You can specify a different model here to use it in other tabs.")
|
| 197 |
with gr.Row():
|
| 198 |
+
# This input is now used by the benchmark tab as well
|
| 199 |
model_input = gr.Textbox(label="HuggingFace Model Name", value="Qwen/Qwen3-Embedding-0.6B")
|
| 200 |
load_btn = gr.Button("Load Model", variant="primary")
|
| 201 |
load_status = gr.Textbox(label="Status", interactive=False)
|
|
|
|
| 222 |
classify_btn.click(classify_interface, inputs=[text_input, client_select, confidence_slider], outputs=[classification_result, gr.Textbox(visible=False)])
|
| 223 |
|
| 224 |
with gr.Tab("π CSV Benchmarking"):
|
| 225 |
+
gr.Markdown("### Benchmark on Your Dataset\nUpload a CSV with `text` and `real_tag` columns. The model from the 'Setup & Model' tab will be loaded and used for the benchmark.")
|
| 226 |
with gr.Row():
|
| 227 |
with gr.Column():
|
| 228 |
csv_upload = gr.File(label="Upload CSV File", file_types=[".csv"])
|
|
|
|
| 233 |
with gr.Row():
|
| 234 |
results_csv = gr.File(label="Download Detailed Results", interactive=False)
|
| 235 |
visualization = gr.HTML(label="Visualization")
|
| 236 |
+
|
| 237 |
+
# CORRECTED: The button now sends the model_input value to the benchmark function
|
| 238 |
+
benchmark_btn.click(
|
| 239 |
+
benchmark_interface,
|
| 240 |
+
inputs=[csv_upload, benchmark_client, model_input],
|
| 241 |
+
outputs=[benchmark_results, results_csv, visualization]
|
| 242 |
+
)
|
| 243 |
|
| 244 |
# Launch the app
|
| 245 |
if __name__ == "__main__":
|