Mohaddz commited on
Commit
b60d459
·
verified ·
1 Parent(s): d86cb9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -10,6 +10,7 @@ import plotly.graph_objects as go
10
  from collections import defaultdict
11
  import json
12
  import traceback
 
13
 
14
  class MultiClientThemeClassifier:
15
  def __init__(self):
@@ -18,7 +19,7 @@ class MultiClientThemeClassifier:
18
  self.model_loaded = False
19
 
20
  def load_model(self, model_name: str = 'Qwen/Qwen3-Embedding-0.6B'):
21
- """Load the embedding model"""
22
  try:
23
  if self.model_loaded:
24
  # If switching models, reset everything
@@ -26,10 +27,11 @@ class MultiClientThemeClassifier:
26
  self.client_themes = {}
27
  self.model_loaded = False
28
 
29
- print(f"Loading model: {model_name}")
30
- self.model = SentenceTransformer(model_name,trust_remote_code=True)
 
31
  self.model_loaded = True
32
- return f"✅ Model '{model_name}' loaded successfully!"
33
  except Exception as e:
34
  self.model_loaded = False
35
  error_details = traceback.format_exc()
@@ -208,11 +210,13 @@ class MultiClientThemeClassifier:
208
  # Initialize the classifier
209
  classifier = MultiClientThemeClassifier()
210
 
 
211
  def load_model_interface(model_name: str):
212
  if not model_name.strip():
213
  model_name = 'Qwen/Qwen3-Embedding-0.6B' # Default
214
  return classifier.load_model(model_name.strip())
215
 
 
216
  def add_themes_interface(client_id: str, themes_text: str):
217
  if not themes_text.strip():
218
  return "❌ Please enter themes!"
@@ -220,6 +224,7 @@ def add_themes_interface(client_id: str, themes_text: str):
220
  themes = [theme.strip() for theme in themes_text.split('\n') if theme.strip()]
221
  return classifier.add_client_themes(client_id, themes)
222
 
 
223
  def classify_interface(text: str, client_id: str, confidence_threshold: float):
224
  if not text.strip():
225
  return "Please enter text to classify!", ""
@@ -241,6 +246,7 @@ def classify_interface(text: str, client_id: str, confidence_threshold: float):
241
 
242
  return result, ""
243
 
 
244
  def benchmark_interface(csv_file, client_id: str):
245
  if csv_file is None:
246
  return "Please upload a CSV file!", "", ""
 
10
  from collections import defaultdict
11
  import json
12
  import traceback
13
+ import spaces # Import the spaces library
14
 
15
  class MultiClientThemeClassifier:
16
  def __init__(self):
 
19
  self.model_loaded = False
20
 
21
  def load_model(self, model_name: str = 'Qwen/Qwen3-Embedding-0.6B'):
22
+ """Load the embedding model onto the GPU"""
23
  try:
24
  if self.model_loaded:
25
  # If switching models, reset everything
 
27
  self.client_themes = {}
28
  self.model_loaded = False
29
 
30
+ print(f"Loading model: {model_name} onto CUDA device")
31
+ # Load the model directly onto the GPU
32
+ self.model = SentenceTransformer(model_name, device='cuda', trust_remote_code=True)
33
  self.model_loaded = True
34
+ return f"✅ Model '{model_name}' loaded successfully onto GPU!"
35
  except Exception as e:
36
  self.model_loaded = False
37
  error_details = traceback.format_exc()
 
210
  # Initialize the classifier
211
  classifier = MultiClientThemeClassifier()
212
 
213
+ @spaces.GPU
214
  def load_model_interface(model_name: str):
215
  if not model_name.strip():
216
  model_name = 'Qwen/Qwen3-Embedding-0.6B' # Default
217
  return classifier.load_model(model_name.strip())
218
 
219
+ @spaces.GPU
220
  def add_themes_interface(client_id: str, themes_text: str):
221
  if not themes_text.strip():
222
  return "❌ Please enter themes!"
 
224
  themes = [theme.strip() for theme in themes_text.split('\n') if theme.strip()]
225
  return classifier.add_client_themes(client_id, themes)
226
 
227
+ @spaces.GPU
228
  def classify_interface(text: str, client_id: str, confidence_threshold: float):
229
  if not text.strip():
230
  return "Please enter text to classify!", ""
 
246
 
247
  return result, ""
248
 
249
+ @spaces.GPU(duration=300) # Request GPU for 300 seconds for longer benchmark jobs
250
  def benchmark_interface(csv_file, client_id: str):
251
  if csv_file is None:
252
  return "Please upload a CSV file!", "", ""