Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -10,6 +10,7 @@ import plotly.graph_objects as go
|
|
| 10 |
from collections import defaultdict
|
| 11 |
import json
|
| 12 |
import traceback
|
|
|
|
| 13 |
|
| 14 |
class MultiClientThemeClassifier:
|
| 15 |
def __init__(self):
|
|
@@ -18,7 +19,7 @@ class MultiClientThemeClassifier:
|
|
| 18 |
self.model_loaded = False
|
| 19 |
|
| 20 |
def load_model(self, model_name: str = 'Qwen/Qwen3-Embedding-0.6B'):
|
| 21 |
-
"""Load the embedding model"""
|
| 22 |
try:
|
| 23 |
if self.model_loaded:
|
| 24 |
# If switching models, reset everything
|
|
@@ -26,10 +27,11 @@ class MultiClientThemeClassifier:
|
|
| 26 |
self.client_themes = {}
|
| 27 |
self.model_loaded = False
|
| 28 |
|
| 29 |
-
print(f"Loading model: {model_name}")
|
| 30 |
-
|
|
|
|
| 31 |
self.model_loaded = True
|
| 32 |
-
return f"✅ Model '{model_name}' loaded successfully!"
|
| 33 |
except Exception as e:
|
| 34 |
self.model_loaded = False
|
| 35 |
error_details = traceback.format_exc()
|
|
@@ -208,11 +210,13 @@ class MultiClientThemeClassifier:
|
|
| 208 |
# Initialize the classifier
|
| 209 |
classifier = MultiClientThemeClassifier()
|
| 210 |
|
|
|
|
| 211 |
def load_model_interface(model_name: str):
|
| 212 |
if not model_name.strip():
|
| 213 |
model_name = 'Qwen/Qwen3-Embedding-0.6B' # Default
|
| 214 |
return classifier.load_model(model_name.strip())
|
| 215 |
|
|
|
|
| 216 |
def add_themes_interface(client_id: str, themes_text: str):
|
| 217 |
if not themes_text.strip():
|
| 218 |
return "❌ Please enter themes!"
|
|
@@ -220,6 +224,7 @@ def add_themes_interface(client_id: str, themes_text: str):
|
|
| 220 |
themes = [theme.strip() for theme in themes_text.split('\n') if theme.strip()]
|
| 221 |
return classifier.add_client_themes(client_id, themes)
|
| 222 |
|
|
|
|
| 223 |
def classify_interface(text: str, client_id: str, confidence_threshold: float):
|
| 224 |
if not text.strip():
|
| 225 |
return "Please enter text to classify!", ""
|
|
@@ -241,6 +246,7 @@ def classify_interface(text: str, client_id: str, confidence_threshold: float):
|
|
| 241 |
|
| 242 |
return result, ""
|
| 243 |
|
|
|
|
| 244 |
def benchmark_interface(csv_file, client_id: str):
|
| 245 |
if csv_file is None:
|
| 246 |
return "Please upload a CSV file!", "", ""
|
|
|
|
| 10 |
from collections import defaultdict
|
| 11 |
import json
|
| 12 |
import traceback
|
| 13 |
+
import spaces # Import the spaces library
|
| 14 |
|
| 15 |
class MultiClientThemeClassifier:
|
| 16 |
def __init__(self):
|
|
|
|
| 19 |
self.model_loaded = False
|
| 20 |
|
| 21 |
def load_model(self, model_name: str = 'Qwen/Qwen3-Embedding-0.6B'):
|
| 22 |
+
"""Load the embedding model onto the GPU"""
|
| 23 |
try:
|
| 24 |
if self.model_loaded:
|
| 25 |
# If switching models, reset everything
|
|
|
|
| 27 |
self.client_themes = {}
|
| 28 |
self.model_loaded = False
|
| 29 |
|
| 30 |
+
print(f"Loading model: {model_name} onto CUDA device")
|
| 31 |
+
# Load the model directly onto the GPU
|
| 32 |
+
self.model = SentenceTransformer(model_name, device='cuda', trust_remote_code=True)
|
| 33 |
self.model_loaded = True
|
| 34 |
+
return f"✅ Model '{model_name}' loaded successfully onto GPU!"
|
| 35 |
except Exception as e:
|
| 36 |
self.model_loaded = False
|
| 37 |
error_details = traceback.format_exc()
|
|
|
|
| 210 |
# Initialize the classifier
|
| 211 |
classifier = MultiClientThemeClassifier()
|
| 212 |
|
| 213 |
+
@spaces.GPU
|
| 214 |
def load_model_interface(model_name: str):
|
| 215 |
if not model_name.strip():
|
| 216 |
model_name = 'Qwen/Qwen3-Embedding-0.6B' # Default
|
| 217 |
return classifier.load_model(model_name.strip())
|
| 218 |
|
| 219 |
+
@spaces.GPU
|
| 220 |
def add_themes_interface(client_id: str, themes_text: str):
|
| 221 |
if not themes_text.strip():
|
| 222 |
return "❌ Please enter themes!"
|
|
|
|
| 224 |
themes = [theme.strip() for theme in themes_text.split('\n') if theme.strip()]
|
| 225 |
return classifier.add_client_themes(client_id, themes)
|
| 226 |
|
| 227 |
+
@spaces.GPU
|
| 228 |
def classify_interface(text: str, client_id: str, confidence_threshold: float):
|
| 229 |
if not text.strip():
|
| 230 |
return "Please enter text to classify!", ""
|
|
|
|
| 246 |
|
| 247 |
return result, ""
|
| 248 |
|
| 249 |
+
@spaces.GPU(duration=300) # Request GPU for 300 seconds for longer benchmark jobs
|
| 250 |
def benchmark_interface(csv_file, client_id: str):
|
| 251 |
if csv_file is None:
|
| 252 |
return "Please upload a CSV file!", "", ""
|