Mohaddz commited on
Commit
30f9702
Β·
verified Β·
1 Parent(s): febe156

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -21
app.py CHANGED
@@ -16,17 +16,17 @@ import tempfile
16
  class MultiClientThemeClassifier:
17
  def __init__(self):
18
  self.model = None
19
- self.client_themes = {} # {client_id: {theme: prototype_embedding}}
20
  self.model_loaded = False
21
  self.default_model = 'Qwen/Qwen3-Embedding-0.6B'
 
 
22
 
23
- def load_model(self, model_name: str = None):
24
- """Load the embedding model onto the GPU"""
25
- if model_name is None:
26
- model_name = self.default_model
27
-
28
  try:
29
- if self.model_loaded and hasattr(self.model, 'tokenizer') and self.model.tokenizer.name_or_path == model_name:
 
30
  return f"βœ… Model '{model_name}' is already loaded."
31
 
32
  self.model = None
@@ -36,6 +36,8 @@ class MultiClientThemeClassifier:
36
  print(f"Loading model: {model_name} onto CUDA device")
37
  self.model = SentenceTransformer(model_name, device='cuda', trust_remote_code=True)
38
  self.model_loaded = True
 
 
39
  return f"βœ… Model '{model_name}' loaded successfully onto GPU!"
40
  except Exception as e:
41
  self.model_loaded = False
@@ -43,15 +45,16 @@ class MultiClientThemeClassifier:
43
  return f"❌ Error loading model '{model_name}': {str(e)}\n\nDetails:\n{error_details}"
44
 
45
  def _ensure_model_is_loaded(self) -> Optional[str]:
46
- """Internal helper to load model if it's not already loaded."""
47
  if not self.model_loaded:
48
- print("Model not loaded. Automatically loading default model...")
49
- status = self.load_model()
 
50
  if "Error" in status:
51
  return status
52
  return None
53
 
54
- def add_client_themes(self, client_id: str, themes: List[str], examples_per_theme: Dict[str, List[str]] = None):
55
  """Add themes for a specific client"""
56
  error_status = self._ensure_model_is_loaded()
57
  if error_status: return error_status
@@ -95,22 +98,18 @@ class MultiClientThemeClassifier:
95
  error_status = self._ensure_model_is_loaded()
96
  if error_status: return f"❌ Model could not be loaded: {error_status}", None, None
97
 
98
- # FINAL FIX: Try a list of common encodings to handle different file types.
99
- encodings_to_try = ['utf-8-sig', 'utf-8', 'cp1256', 'latin1', 'cp1252']
100
  df = None
101
-
102
  for encoding in encodings_to_try:
103
  try:
104
  df = pd.read_csv(csv_filepath, encoding=encoding)
105
  print(f"Successfully read CSV with encoding: {encoding}")
106
- break # Exit loop if successful
107
  except (UnicodeDecodeError, pd.errors.ParserError):
108
- print(f"Failed to read with encoding: {encoding}, trying next...")
109
  continue
110
 
111
  if df is None:
112
- error_message = "❌ Could not decode the CSV file. Please save it in a common format like 'UTF-8' and try again."
113
- return error_message, None, None
114
 
115
  try:
116
  if 'text' not in df.columns or 'real_tag' not in df.columns:
@@ -123,8 +122,8 @@ class MultiClientThemeClassifier:
123
  unique_themes = df['real_tag'].unique().tolist()
124
  self.add_client_themes(client_id, unique_themes)
125
 
126
- texts_to_classify = df['text'].str.slice(0, 500).tolist()
127
- results = [self.classify_text(text, client_id) for text in texts_to_classify]
128
 
129
  df['predicted_tag'] = [res[0] for res in results]
130
  df['confidence'] = [res[1] for res in results]
@@ -133,7 +132,7 @@ class MultiClientThemeClassifier:
133
  total = len(df)
134
  accuracy = correct / total if total > 0 else 0
135
 
136
- results_summary = f"πŸ“Š **Benchmarking Results**\n\n**Accuracy: {accuracy:.2%}** ({correct}/{total})"
137
 
138
  fig = px.bar(df['real_tag'].value_counts(), title="Theme Distribution", labels={'index': 'Theme', 'value': 'Count'})
139
  visualization_html = fig.to_html()
 
16
  class MultiClientThemeClassifier:
17
  def __init__(self):
18
  self.model = None
19
+ self.client_themes = {}
20
  self.model_loaded = False
21
  self.default_model = 'Qwen/Qwen3-Embedding-0.6B'
22
+ # CORRECTED: Add attribute to remember the last loaded model's name
23
+ self.current_model_name = self.default_model
24
 
25
+ def load_model(self, model_name: str):
26
+ """Load the embedding model onto the GPU, remembering the choice."""
 
 
 
27
  try:
28
+ # Prevent reloading the same model
29
+ if self.model_loaded and self.current_model_name == model_name:
30
  return f"βœ… Model '{model_name}' is already loaded."
31
 
32
  self.model = None
 
36
  print(f"Loading model: {model_name} onto CUDA device")
37
  self.model = SentenceTransformer(model_name, device='cuda', trust_remote_code=True)
38
  self.model_loaded = True
39
+ # CORRECTED: Remember the name of the successfully loaded model
40
+ self.current_model_name = model_name
41
  return f"βœ… Model '{model_name}' loaded successfully onto GPU!"
42
  except Exception as e:
43
  self.model_loaded = False
 
45
  return f"❌ Error loading model '{model_name}': {str(e)}\n\nDetails:\n{error_details}"
46
 
47
  def _ensure_model_is_loaded(self) -> Optional[str]:
48
+ """Internal helper to load the correct model if it's not already loaded."""
49
  if not self.model_loaded:
50
+ print(f"Model not loaded. Automatically loading last selected model: {self.current_model_name}...")
51
+ # CORRECTED: Load the last selected model, not the default one
52
+ status = self.load_model(self.current_model_name)
53
  if "Error" in status:
54
  return status
55
  return None
56
 
57
+ def add_client_themes(self, client_id: str, themes: List[str]):
58
  """Add themes for a specific client"""
59
  error_status = self._ensure_model_is_loaded()
60
  if error_status: return error_status
 
98
  error_status = self._ensure_model_is_loaded()
99
  if error_status: return f"❌ Model could not be loaded: {error_status}", None, None
100
 
101
+ encodings_to_try = ['utf-8-sig', 'utf-8', 'cp1256', 'latin1']
 
102
  df = None
 
103
  for encoding in encodings_to_try:
104
  try:
105
  df = pd.read_csv(csv_filepath, encoding=encoding)
106
  print(f"Successfully read CSV with encoding: {encoding}")
107
+ break
108
  except (UnicodeDecodeError, pd.errors.ParserError):
 
109
  continue
110
 
111
  if df is None:
112
+ return "❌ Could not decode the CSV. Please save it as 'UTF-8' and try again.", None, None
 
113
 
114
  try:
115
  if 'text' not in df.columns or 'real_tag' not in df.columns:
 
122
  unique_themes = df['real_tag'].unique().tolist()
123
  self.add_client_themes(client_id, unique_themes)
124
 
125
+ texts = df['text'].str.slice(0, 500).tolist()
126
+ results = [self.classify_text(text, client_id) for text in texts]
127
 
128
  df['predicted_tag'] = [res[0] for res in results]
129
  df['confidence'] = [res[1] for res in results]
 
132
  total = len(df)
133
  accuracy = correct / total if total > 0 else 0
134
 
135
+ results_summary = f"πŸ“Š **Benchmarking Results for `{self.current_model_name}`**\n\n**Accuracy: {accuracy:.2%}** ({correct}/{total})"
136
 
137
  fig = px.bar(df['real_tag'].value_counts(), title="Theme Distribution", labels={'index': 'Theme', 'value': 'Count'})
138
  visualization_html = fig.to_html()