Sazzz02 commited on
Commit
3acb623
·
verified ·
1 Parent(s): c7d7de6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +246 -0
app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import re
4
+ import joblib
5
+ import numpy as np
6
+ import pandas as pd
7
+ import gradio as gr
8
+
9
+ # -------------------------
10
+ # Helper: safe-loading
11
+ # -------------------------
12
+ def try_load(path_options):
13
+ for p in path_options:
14
+ if p is None:
15
+ continue
16
+ if os.path.exists(p):
17
+ try:
18
+ model = joblib.load(p)
19
+ print(f"Loaded: {p}")
20
+ return model, p
21
+ except Exception as e:
22
+ print(f"Failed to load {p}: {e}")
23
+ return None, None
24
+
25
+ ROOT = os.path.dirname(__file__) if "__file__" in globals() else os.getcwd()
26
+ MODEL_DIR = os.path.join(ROOT, "models")
27
+
28
+ # try multiple plausible names/locations
29
+ tfidf_candidates = [
30
+ os.path.join(MODEL_DIR, "tfidf_vectorizer.pkl"),
31
+ os.path.join(MODEL_DIR, "tfidf.pkl"),
32
+ os.path.join(ROOT, "tfidf_vectorizer.pkl"),
33
+ os.path.join(ROOT, "tfidf.joblib"),
34
+ os.path.join(MODEL_DIR, "tfidf_vectorizer.joblib"),
35
+ os.path.join(MODEL_DIR, "tfidf.joblib"),
36
+ ]
37
+ logreg_candidates = [
38
+ os.path.join(MODEL_DIR, "logreg_model.pkl"),
39
+ os.path.join(MODEL_DIR, "logreg.pkl"),
40
+ os.path.join(ROOT, "logreg_model.pkl"),
41
+ os.path.join(ROOT, "logreg.pkl"),
42
+ os.path.join(MODEL_DIR, "logreg.joblib"),
43
+ ]
44
+ lgbm_candidates = [
45
+ os.path.join(MODEL_DIR, "lgbm_model.pkl"),
46
+ os.path.join(MODEL_DIR, "lgbm.pkl"),
47
+ os.path.join(ROOT, "lgbm_model.pkl"),
48
+ os.path.join(ROOT, "lgbm.pkl"),
49
+ ]
50
+
51
+ tfidf, tfidf_path = try_load(tfidf_candidates)
52
+ logreg, logreg_path = try_load(logreg_candidates)
53
+ lgbm, lgbm_path = try_load(lgbm_candidates)
54
+
55
+ # Fallback label order (common mapping)
56
+ DEFAULT_LABELS = ['negative', 'neutral', 'positive']
57
+
58
+ # -------------------------
59
+ # Text preprocessing
60
+ # -------------------------
61
+ def clean_text(t):
62
+ if t is None:
63
+ return ""
64
+ s = str(t)
65
+ s = s.lower()
66
+ s = re.sub(r"\s+", " ", s)
67
+ s = re.sub(r"[^a-z0-9\s']", " ", s)
68
+ return s.strip()
69
+
70
+ # -------------------------
71
+ # Prediction logic
72
+ # -------------------------
73
+ import warnings
74
+ warnings.filterwarnings("ignore")
75
+
76
+ def get_model_classes(model):
77
+ # some models expose .classes_, some don't
78
+ if hasattr(model, "classes_"):
79
+ return list(model.classes_)
80
+ # LightGBM might store classes_ as np.array in classifier
81
+ if hasattr(model, "classes"):
82
+ return list(model.classes)
83
+ return DEFAULT_LABELS
84
+
85
+ def predict_one(text, model_choice="Logistic Regression"):
86
+ text_clean = clean_text(text)
87
+ if not text_clean:
88
+ return {
89
+ "label": "neutral",
90
+ "confidence": 0.0,
91
+ "html": "<i>No text provided</i>",
92
+ "error": None
93
+ }
94
+
95
+ if tfidf is None:
96
+ return {"label": None, "confidence": 0.0, "html": "", "error": "Vectorizer (tfidf) not found. Upload tfidf_vectorizer.pkl to models/."}
97
+
98
+ X = tfidf.transform([text_clean])
99
+
100
+ try:
101
+ if model_choice == "Logistic Regression" and logreg is not None:
102
+ probs = logreg.predict_proba(X)[0]
103
+ classes = get_model_classes(logreg)
104
+ elif model_choice == "LightGBM" and lgbm is not None:
105
+ # LightGBM may want dense arrays in some configs; try both
106
+ try:
107
+ probs = lgbm.predict_proba(X)[0]
108
+ except Exception:
109
+ probs = lgbm.predict_proba(X.toarray())[0]
110
+ classes = get_model_classes(lgbm)
111
+ else:
112
+ # fallback to whichever model exists
113
+ if logreg is not None:
114
+ probs = logreg.predict_proba(X)[0]; classes = get_model_classes(logreg)
115
+ elif lgbm is not None:
116
+ try: probs = lgbm.predict_proba(X)[0]
117
+ except: probs = lgbm.predict_proba(X.toarray())[0]
118
+ classes = get_model_classes(lgbm)
119
+ else:
120
+ return {"label": None, "confidence": 0.0, "html": "", "error": "No model found. Upload logreg_model.pkl or lgbm_model.pkl to models/."}
121
+ except Exception as e:
122
+ return {"label": None, "confidence": 0.0, "html": "", "error": f"Prediction error: {e}"}
123
+
124
+ # Ensure classes + probs align
125
+ # If classes are not sorted in expected order, we will display them as the model provides.
126
+ idx = int(np.argmax(probs))
127
+ label = classes[idx]
128
+ confidence = float(probs[idx])
129
+
130
+ # Build colored HTML bars for probabilities
131
+ colors = {
132
+ 'positive': '#16a34a', # green
133
+ 'neutral': '#f59e0b', # amber
134
+ 'negative': '#ef4444' # red
135
+ }
136
+ bars_html = ""
137
+ for c, p in zip(classes, probs):
138
+ col = colors.get(str(c).lower(), "#3b82f6")
139
+ pct = float(p) * 100.0
140
+ bars_html += f"""
141
+ <div style="display:flex;align-items:center;margin-bottom:8px;">
142
+ <div style="width:95px;font-weight:600;color:#111;">{c}</div>
143
+ <div style="flex:1;margin-left:10px;background:#f1f5f9;border-radius:999px;padding:3px;">
144
+ <div style="width:{pct:.2f}%;background:{col};padding:6px 10px;border-radius:999px;color:white;font-weight:700;text-align:right;">
145
+ {pct:.1f}%
146
+ </div>
147
+ </div>
148
+ </div>
149
+ """
150
+
151
+ header_html = f"""
152
+ <div style="display:flex;align-items:center;gap:12px;">
153
+ <div style="font-size:16px;font-weight:700;">Prediction:</div>
154
+ <div style="padding:6px 12px;border-radius:999px;background:{colors.get(str(label).lower(),'#3b82f6')};color:white;font-weight:800;">
155
+ {label.upper()} ({confidence:.2f})
156
+ </div>
157
+ </div>
158
+ <div style="margin-top:12px;">{bars_html}</div>
159
+ """
160
+
161
+ return {"label": label, "confidence": float(confidence), "html": header_html, "error": None}
162
+
163
+ # -------------------------
164
+ # Gradio UI
165
+ # -------------------------
166
+ css = """
167
+ /* page background */
168
+ body { background: linear-gradient(135deg,#fdfbfb 0%,#ebf8ff 100%); }
169
+
170
+ /* central card */
171
+ .app-card {
172
+ border-radius: 12px;
173
+ padding: 18px;
174
+ box-shadow: 0 10px 25px rgba(11, 20, 41, 0.06);
175
+ background: linear-gradient(180deg, rgba(255,255,255,0.9), rgba(255,255,255,0.82));
176
+ }
177
+
178
+ /* title */
179
+ .title {
180
+ font-weight: 800;
181
+ font-size: 22px;
182
+ margin-bottom: 6px;
183
+ }
184
+ .subtitle {
185
+ color: #374151;
186
+ margin-bottom: 12px;
187
+ }
188
+
189
+ /* button */
190
+ .gr-button {
191
+ border-radius: 10px;
192
+ padding: 10px 16px;
193
+ font-weight:700;
194
+ }
195
+ """
196
+
197
+ examples = [
198
+ ["Looking forward to our demo next week! Confirm time please.", "Logistic Regression"],
199
+ ["Not interested at this time, thanks.", "LightGBM"],
200
+ ["Can you share pricing and features?", "Logistic Regression"],
201
+ ]
202
+
203
+ with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
204
+ with gr.Row():
205
+ with gr.Column(scale=2):
206
+ gr.HTML("<div class='app-card'><div class='title'>SvaraAI — Reply Classifier</div>"
207
+ "<div class='subtitle'>Paste an email reply below to classify it as <b>positive</b> / <b>neutral</b> / <b>negative</b>.</div>"
208
+ "</div>")
209
+ inp = gr.Textbox(lines=5, placeholder="e.g. Thanks — let's schedule a demo next Tuesday at 10am.", label="Reply text")
210
+ model_choice = gr.Dropdown(choices=["Logistic Regression", "LightGBM"], value="Logistic Regression", label="Model (choose one)")
211
+ with gr.Row():
212
+ btn = gr.Button("Classify", variant="primary")
213
+ clear = gr.Button("Clear")
214
+ output_label = gr.Markdown(value="**Prediction:** _waiting for input_", label="Result")
215
+ output_html = gr.HTML("<i>Probabilities will appear here</i>")
216
+ error_box = gr.Textbox(interactive=False, visible=False)
217
+ gr.Examples(examples=examples, inputs=[inp, model_choice], label="Try these examples")
218
+ with gr.Column(scale=1):
219
+ gr.HTML("<div class='app-card'><div style='font-weight:800;margin-bottom:8px'>About</div>"
220
+ "<div style='font-size:13px;color:#374151'>This demo uses a TF-IDF vectorizer and a saved classifier (Logistic Regression / LightGBM). "
221
+ "Upload your saved pickles to <code>models/</code> as described in README.md.</div></div>")
222
+ # small quick test panel
223
+ stats_md = gr.Markdown("**Model files detected:**<br>"
224
+ f"- TF-IDF: `{tfidf_path or 'NOT FOUND'}` \n"
225
+ f"- LogReg: `{logreg_path or 'NOT FOUND'}` \n"
226
+ f"- LGBM: `{lgbm_path or 'NOT FOUND'}` \n")
227
+ download_note = gr.Markdown("<small>If a model is missing upload it to <code>models/</code> or rename files appropriately.</small>")
228
+
229
+ def run_and_format(text, model_choice):
230
+ res = predict_one(text, model_choice)
231
+ if res.get("error"):
232
+ return f"**Error:** {res['error']}", "", gr.update(value=f"<div style='color:#b91c1c;font-weight:700'>{res['error']}</div>")
233
+ label = res["label"]
234
+ conf = res["confidence"]
235
+ html = res["html"]
236
+ md = f"**Prediction:** **{label.upper()}** — confidence **{conf:.2f}**"
237
+ return md, str(round(conf, 3)), gr.update(value=html)
238
+
239
+ btn.click(run_and_format, inputs=[inp, model_choice], outputs=[output_label, error_box, output_html])
240
+ clear.click(lambda: ("**Prediction:** _waiting for input_", "", gr.update(value="<i>Probabilities will appear here</i>")), [], [output_label, error_box, output_html])
241
+
242
+ # footer
243
+ gr.HTML("<div style='margin-top:18px;color:#6b7280;font-size:13px'>Built for the SvaraAI assignment • Upload your model pickles into <code>models/</code></div>")
244
+
245
+ if __name__ == "__main__":
246
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))