Sazzz02 commited on
Commit
d26c8ed
·
verified ·
1 Parent(s): 3acb623

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -44
app.py CHANGED
@@ -3,7 +3,6 @@ import os
3
  import re
4
  import joblib
5
  import numpy as np
6
- import pandas as pd
7
  import gradio as gr
8
 
9
  # -------------------------
@@ -52,7 +51,6 @@ tfidf, tfidf_path = try_load(tfidf_candidates)
52
  logreg, logreg_path = try_load(logreg_candidates)
53
  lgbm, lgbm_path = try_load(lgbm_candidates)
54
 
55
- # Fallback label order (common mapping)
56
  DEFAULT_LABELS = ['negative', 'neutral', 'positive']
57
 
58
  # -------------------------
@@ -61,8 +59,7 @@ DEFAULT_LABELS = ['negative', 'neutral', 'positive']
61
  def clean_text(t):
62
  if t is None:
63
  return ""
64
- s = str(t)
65
- s = s.lower()
66
  s = re.sub(r"\s+", " ", s)
67
  s = re.sub(r"[^a-z0-9\s']", " ", s)
68
  return s.strip()
@@ -74,10 +71,8 @@ import warnings
74
  warnings.filterwarnings("ignore")
75
 
76
  def get_model_classes(model):
77
- # some models expose .classes_, some don't
78
  if hasattr(model, "classes_"):
79
  return list(model.classes_)
80
- # LightGBM might store classes_ as np.array in classifier
81
  if hasattr(model, "classes"):
82
  return list(model.classes)
83
  return DEFAULT_LABELS
@@ -102,14 +97,12 @@ def predict_one(text, model_choice="Logistic Regression"):
102
  probs = logreg.predict_proba(X)[0]
103
  classes = get_model_classes(logreg)
104
  elif model_choice == "LightGBM" and lgbm is not None:
105
- # LightGBM may want dense arrays in some configs; try both
106
  try:
107
  probs = lgbm.predict_proba(X)[0]
108
  except Exception:
109
  probs = lgbm.predict_proba(X.toarray())[0]
110
  classes = get_model_classes(lgbm)
111
  else:
112
- # fallback to whichever model exists
113
  if logreg is not None:
114
  probs = logreg.predict_proba(X)[0]; classes = get_model_classes(logreg)
115
  elif lgbm is not None:
@@ -121,17 +114,14 @@ def predict_one(text, model_choice="Logistic Regression"):
121
  except Exception as e:
122
  return {"label": None, "confidence": 0.0, "html": "", "error": f"Prediction error: {e}"}
123
 
124
- # Ensure classes + probs align
125
- # If classes are not sorted in expected order, we will display them as the model provides.
126
  idx = int(np.argmax(probs))
127
  label = classes[idx]
128
  confidence = float(probs[idx])
129
 
130
- # Build colored HTML bars for probabilities
131
  colors = {
132
- 'positive': '#16a34a', # green
133
- 'neutral': '#f59e0b', # amber
134
- 'negative': '#ef4444' # red
135
  }
136
  bars_html = ""
137
  for c, p in zip(classes, probs):
@@ -139,8 +129,8 @@ def predict_one(text, model_choice="Logistic Regression"):
139
  pct = float(p) * 100.0
140
  bars_html += f"""
141
  <div style="display:flex;align-items:center;margin-bottom:8px;">
142
- <div style="width:95px;font-weight:600;color:#111;">{c}</div>
143
- <div style="flex:1;margin-left:10px;background:#f1f5f9;border-radius:999px;padding:3px;">
144
  <div style="width:{pct:.2f}%;background:{col};padding:6px 10px;border-radius:999px;color:white;font-weight:700;text-align:right;">
145
  {pct:.1f}%
146
  </div>
@@ -150,7 +140,7 @@ def predict_one(text, model_choice="Logistic Regression"):
150
 
151
  header_html = f"""
152
  <div style="display:flex;align-items:center;gap:12px;">
153
- <div style="font-size:16px;font-weight:700;">Prediction:</div>
154
  <div style="padding:6px 12px;border-radius:999px;background:{colors.get(str(label).lower(),'#3b82f6')};color:white;font-weight:800;">
155
  {label.upper()} ({confidence:.2f})
156
  </div>
@@ -161,36 +151,33 @@ def predict_one(text, model_choice="Logistic Regression"):
161
  return {"label": label, "confidence": float(confidence), "html": header_html, "error": None}
162
 
163
  # -------------------------
164
- # Gradio UI
165
  # -------------------------
166
  css = """
167
- /* page background */
168
- body { background: linear-gradient(135deg,#fdfbfb 0%,#ebf8ff 100%); }
169
-
170
- /* central card */
171
  .app-card {
172
  border-radius: 12px;
173
  padding: 18px;
174
- box-shadow: 0 10px 25px rgba(11, 20, 41, 0.06);
175
- background: linear-gradient(180deg, rgba(255,255,255,0.9), rgba(255,255,255,0.82));
 
176
  }
177
-
178
- /* title */
179
  .title {
180
  font-weight: 800;
181
  font-size: 22px;
182
  margin-bottom: 6px;
 
183
  }
184
  .subtitle {
185
- color: #374151;
186
  margin-bottom: 12px;
187
  }
188
-
189
- /* button */
190
  .gr-button {
191
  border-radius: 10px;
192
  padding: 10px 16px;
193
  font-weight:700;
 
 
194
  }
195
  """
196
 
@@ -203,33 +190,32 @@ examples = [
203
  with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
204
  with gr.Row():
205
  with gr.Column(scale=2):
206
- gr.HTML("<div class='app-card'><div class='title'>SvaraAI — Reply Classifier</div>"
207
- "<div class='subtitle'>Paste an email reply below to classify it as <b>positive</b> / <b>neutral</b> / <b>negative</b>.</div>"
208
  "</div>")
209
- inp = gr.Textbox(lines=5, placeholder="e.g. Thanks let's schedule a demo next Tuesday at 10am.", label="Reply text")
210
- model_choice = gr.Dropdown(choices=["Logistic Regression", "LightGBM"], value="Logistic Regression", label="Model (choose one)")
211
  with gr.Row():
212
- btn = gr.Button("Classify", variant="primary")
213
- clear = gr.Button("Clear")
214
  output_label = gr.Markdown(value="**Prediction:** _waiting for input_", label="Result")
215
- output_html = gr.HTML("<i>Probabilities will appear here</i>")
216
  error_box = gr.Textbox(interactive=False, visible=False)
217
  gr.Examples(examples=examples, inputs=[inp, model_choice], label="Try these examples")
218
  with gr.Column(scale=1):
219
- gr.HTML("<div class='app-card'><div style='font-weight:800;margin-bottom:8px'>About</div>"
220
- "<div style='font-size:13px;color:#374151'>This demo uses a TF-IDF vectorizer and a saved classifier (Logistic Regression / LightGBM). "
221
  "Upload your saved pickles to <code>models/</code> as described in README.md.</div></div>")
222
- # small quick test panel
223
  stats_md = gr.Markdown("**Model files detected:**<br>"
224
  f"- TF-IDF: `{tfidf_path or 'NOT FOUND'}` \n"
225
  f"- LogReg: `{logreg_path or 'NOT FOUND'}` \n"
226
  f"- LGBM: `{lgbm_path or 'NOT FOUND'}` \n")
227
- download_note = gr.Markdown("<small>If a model is missing upload it to <code>models/</code> or rename files appropriately.</small>")
228
 
229
  def run_and_format(text, model_choice):
230
  res = predict_one(text, model_choice)
231
  if res.get("error"):
232
- return f"**Error:** {res['error']}", "", gr.update(value=f"<div style='color:#b91c1c;font-weight:700'>{res['error']}</div>")
233
  label = res["label"]
234
  conf = res["confidence"]
235
  html = res["html"]
@@ -237,10 +223,9 @@ with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
237
  return md, str(round(conf, 3)), gr.update(value=html)
238
 
239
  btn.click(run_and_format, inputs=[inp, model_choice], outputs=[output_label, error_box, output_html])
240
- clear.click(lambda: ("**Prediction:** _waiting for input_", "", gr.update(value="<i>Probabilities will appear here</i>")), [], [output_label, error_box, output_html])
241
 
242
- # footer
243
- gr.HTML("<div style='margin-top:18px;color:#6b7280;font-size:13px'>Built for the SvaraAI assignment • Upload your model pickles into <code>models/</code></div>")
244
 
245
  if __name__ == "__main__":
246
  demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
3
  import re
4
  import joblib
5
  import numpy as np
 
6
  import gradio as gr
7
 
8
  # -------------------------
 
51
  logreg, logreg_path = try_load(logreg_candidates)
52
  lgbm, lgbm_path = try_load(lgbm_candidates)
53
 
 
54
  DEFAULT_LABELS = ['negative', 'neutral', 'positive']
55
 
56
  # -------------------------
 
59
  def clean_text(t):
60
  if t is None:
61
  return ""
62
+ s = str(t).lower()
 
63
  s = re.sub(r"\s+", " ", s)
64
  s = re.sub(r"[^a-z0-9\s']", " ", s)
65
  return s.strip()
 
71
  warnings.filterwarnings("ignore")
72
 
73
  def get_model_classes(model):
 
74
  if hasattr(model, "classes_"):
75
  return list(model.classes_)
 
76
  if hasattr(model, "classes"):
77
  return list(model.classes)
78
  return DEFAULT_LABELS
 
97
  probs = logreg.predict_proba(X)[0]
98
  classes = get_model_classes(logreg)
99
  elif model_choice == "LightGBM" and lgbm is not None:
 
100
  try:
101
  probs = lgbm.predict_proba(X)[0]
102
  except Exception:
103
  probs = lgbm.predict_proba(X.toarray())[0]
104
  classes = get_model_classes(lgbm)
105
  else:
 
106
  if logreg is not None:
107
  probs = logreg.predict_proba(X)[0]; classes = get_model_classes(logreg)
108
  elif lgbm is not None:
 
114
  except Exception as e:
115
  return {"label": None, "confidence": 0.0, "html": "", "error": f"Prediction error: {e}"}
116
 
 
 
117
  idx = int(np.argmax(probs))
118
  label = classes[idx]
119
  confidence = float(probs[idx])
120
 
 
121
  colors = {
122
+ 'positive': '#16a34a',
123
+ 'neutral': '#f59e0b',
124
+ 'negative': '#ef4444'
125
  }
126
  bars_html = ""
127
  for c, p in zip(classes, probs):
 
129
  pct = float(p) * 100.0
130
  bars_html += f"""
131
  <div style="display:flex;align-items:center;margin-bottom:8px;">
132
+ <div style="width:95px;font-weight:600;color:#e5e7eb;">{c}</div>
133
+ <div style="flex:1;margin-left:10px;background:#1f2937;border-radius:999px;padding:3px;">
134
  <div style="width:{pct:.2f}%;background:{col};padding:6px 10px;border-radius:999px;color:white;font-weight:700;text-align:right;">
135
  {pct:.1f}%
136
  </div>
 
140
 
141
  header_html = f"""
142
  <div style="display:flex;align-items:center;gap:12px;">
143
+ <div style="font-size:16px;font-weight:700;color:#f3f4f6;">Prediction:</div>
144
  <div style="padding:6px 12px;border-radius:999px;background:{colors.get(str(label).lower(),'#3b82f6')};color:white;font-weight:800;">
145
  {label.upper()} ({confidence:.2f})
146
  </div>
 
151
  return {"label": label, "confidence": float(confidence), "html": header_html, "error": None}
152
 
153
  # -------------------------
154
+ # Dark theme CSS
155
  # -------------------------
156
  css = """
157
+ body { background: linear-gradient(135deg,#0f172a 0%,#1e293b 100%); color:#e5e7eb; }
 
 
 
158
  .app-card {
159
  border-radius: 12px;
160
  padding: 18px;
161
+ box-shadow: 0 10px 25px rgba(0, 0, 0, 0.6);
162
+ background: rgba(31,41,55,0.9);
163
+ color: #f9fafb;
164
  }
 
 
165
  .title {
166
  font-weight: 800;
167
  font-size: 22px;
168
  margin-bottom: 6px;
169
+ color: #f9fafb;
170
  }
171
  .subtitle {
172
+ color: #9ca3af;
173
  margin-bottom: 12px;
174
  }
 
 
175
  .gr-button {
176
  border-radius: 10px;
177
  padding: 10px 16px;
178
  font-weight:700;
179
+ background: #3b82f6 !important;
180
+ color: white !important;
181
  }
182
  """
183
 
 
190
  with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
191
  with gr.Row():
192
  with gr.Column(scale=2):
193
+ gr.HTML("<div class='app-card'><div class='title'>🌙 SvaraAI — Reply Classifier</div>"
194
+ "<div class='subtitle'>Classify replies as <b style='color:#16a34a'>positive</b> / <b style='color:#f59e0b'>neutral</b> / <b style='color:#ef4444'>negative</b>.</div>"
195
  "</div>")
196
+ inp = gr.Textbox(lines=5, placeholder="Type your reply here...", label="Reply text")
197
+ model_choice = gr.Dropdown(choices=["Logistic Regression", "LightGBM"], value="Logistic Regression", label="Model")
198
  with gr.Row():
199
+ btn = gr.Button("🚀 Classify", variant="primary")
200
+ clear = gr.Button("🧹 Clear")
201
  output_label = gr.Markdown(value="**Prediction:** _waiting for input_", label="Result")
202
+ output_html = gr.HTML("<i style='color:#9ca3af;'>Probabilities will appear here</i>")
203
  error_box = gr.Textbox(interactive=False, visible=False)
204
  gr.Examples(examples=examples, inputs=[inp, model_choice], label="Try these examples")
205
  with gr.Column(scale=1):
206
+ gr.HTML("<div class='app-card'><div style='font-weight:800;margin-bottom:8px'>ℹ️ About</div>"
207
+ "<div style='font-size:13px;color:#d1d5db'>This demo uses a TF-IDF vectorizer and a saved classifier (Logistic Regression / LightGBM). "
208
  "Upload your saved pickles to <code>models/</code> as described in README.md.</div></div>")
 
209
  stats_md = gr.Markdown("**Model files detected:**<br>"
210
  f"- TF-IDF: `{tfidf_path or 'NOT FOUND'}` \n"
211
  f"- LogReg: `{logreg_path or 'NOT FOUND'}` \n"
212
  f"- LGBM: `{lgbm_path or 'NOT FOUND'}` \n")
213
+ download_note = gr.Markdown("<small style='color:#9ca3af;'>If a model is missing upload it to <code>models/</code> or rename files appropriately.</small>")
214
 
215
  def run_and_format(text, model_choice):
216
  res = predict_one(text, model_choice)
217
  if res.get("error"):
218
+ return f"**Error:** {res['error']}", "", gr.update(value=f"<div style='color:#ef4444;font-weight:700'>{res['error']}</div>")
219
  label = res["label"]
220
  conf = res["confidence"]
221
  html = res["html"]
 
223
  return md, str(round(conf, 3)), gr.update(value=html)
224
 
225
  btn.click(run_and_format, inputs=[inp, model_choice], outputs=[output_label, error_box, output_html])
226
+ clear.click(lambda: ("**Prediction:** _waiting for input_", "", gr.update(value="<i style='color:#9ca3af;'>Probabilities will appear here</i>")), [], [output_label, error_box, output_html])
227
 
228
+ gr.HTML("<div style='margin-top:18px;color:#9ca3af;font-size:13px'>🌌 Built for the SvaraAI assignment • Upload your model pickles into <code>models/</code></div>")
 
229
 
230
  if __name__ == "__main__":
231
  demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))