Aswin92 commited on
Commit
2d905bf
Β·
verified Β·
1 Parent(s): 4fb6e5b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -117
app.py CHANGED
@@ -47,7 +47,16 @@ class LSTMClassifier(nn.Module):
47
  return logits
48
 
49
 
50
- # ------------------- Cached model loaders -------------------
 
 
 
 
 
 
 
 
 
51
  @st.cache_resource
52
  def load_deberta():
53
  model_name = "Aswin92/deberta-v3-disaster-tweets"
@@ -70,11 +79,8 @@ def load_distilbert():
70
 
71
  @st.cache_resource
72
  def load_bilstm():
73
- import zipfile
74
-
75
- # Reuse the DistilBERT tokenizer instead of local tokenizer.json
76
- distil_tokenizer, _ = load_distilbert()
77
- tokenizer = distil_tokenizer
78
 
79
  model = LSTMClassifier(
80
  vocab_size=tokenizer.vocab_size,
@@ -86,60 +92,20 @@ def load_bilstm():
86
  num_classes=2,
87
  )
88
 
89
- # Extract BiLSTM weights from zip file
90
- zip_filename = "saved_models.zip"
91
- state_filename = "bilstm_state_dict.pt"
92
-
93
- # Debug: Check current directory and files
94
- current_dir = os.getcwd()
95
- print(f"Current directory: {current_dir}")
96
- print(f"Files in current directory: {os.listdir(current_dir)}")
97
-
98
- # Try different possible paths
99
- possible_paths = [
100
- zip_filename,
101
- f"/app/{zip_filename}",
102
- os.path.join(current_dir, zip_filename),
103
- ]
104
 
105
- zip_path_found = None
106
- for path in possible_paths:
107
- print(f"Checking path: {path} - Exists: {os.path.exists(path)}")
108
- if os.path.exists(path):
109
- zip_path_found = path
110
- break
111
-
112
- if zip_path_found is None:
113
  files_in_dir = os.listdir(current_dir)
114
  raise FileNotFoundError(
115
- f"Zip file '{zip_filename}' not found in any expected location.\n"
116
  f"Current directory: {current_dir}\n"
117
- f"Files in directory: {files_in_dir}\n"
118
- f"Paths checked: {possible_paths}"
119
  )
120
 
121
- print(f"Found zip at: {zip_path_found}")
122
-
123
- with zipfile.ZipFile(zip_path_found, 'r') as zip_ref:
124
- # Check if file exists in zip
125
- zip_contents = zip_ref.namelist()
126
- print(f"Contents of zip: {zip_contents}")
127
-
128
- if state_filename not in zip_contents:
129
- raise FileNotFoundError(
130
- f"'{state_filename}' not found in {zip_path_found}. "
131
- f"Available files: {zip_contents}"
132
- )
133
-
134
- # Extract to temporary location and load
135
- zip_ref.extract(state_filename, path="/tmp")
136
- temp_state_path = f"/tmp/{state_filename}"
137
-
138
- state_dict = torch.load(temp_state_path, map_location=DEVICE, weights_only=True)
139
-
140
- # Clean up temporary file
141
- os.remove(temp_state_path)
142
-
143
  model.load_state_dict(state_dict)
144
  model.to(DEVICE)
145
  model.eval()
@@ -214,13 +180,21 @@ st.set_page_config(page_title="Disaster Tweet Classifier", layout="centered")
214
  st.title("πŸŒͺ️ Disaster Tweet Classifier")
215
  st.write(
216
  "NLP project on the Kaggle **Disaster Tweets** dataset.\n\n"
217
- "Type a tweet once and compare how **DeBERTa-v3**, **DistilBERT**, and a custom **BiLSTM (RNN)** "
218
- "decide whether it describes a real disaster."
219
  )
220
 
221
  # -------- Sidebar controls --------
222
  with st.sidebar:
223
- st.header("βš™οΈ Thresholds per model")
 
 
 
 
 
 
 
 
224
 
225
  thr_deberta = st.slider(
226
  "DeBERTa-v3 threshold",
@@ -228,6 +202,7 @@ with st.sidebar:
228
  max_value=0.95,
229
  value=0.60,
230
  step=0.05,
 
231
  )
232
  thr_distil = st.slider(
233
  "DistilBERT threshold",
@@ -235,6 +210,7 @@ with st.sidebar:
235
  max_value=0.95,
236
  value=0.80,
237
  step=0.05,
 
238
  )
239
  thr_bilstm = st.slider(
240
  "BiLSTM (RNN) threshold",
@@ -242,11 +218,12 @@ with st.sidebar:
242
  max_value=0.95,
243
  value=0.35,
244
  step=0.05,
 
245
  )
246
 
247
  st.caption(
248
  "Each model predicts `P(disaster)`. If that probability is "
249
- "β‰₯ its threshold, we classify it as **disaster (1)**."
250
  )
251
 
252
  # -------- Main input area --------
@@ -262,70 +239,77 @@ tweet_text = st.text_area(
262
  height=120,
263
  )
264
 
265
- if st.button("Classify with all models"):
266
  text = tweet_text.strip()
267
  if not text:
268
  st.warning("Please type a tweet first.")
269
  else:
270
- try:
271
- with st.spinner("Running all three models..."):
272
- configs = [
273
- ("DeBERTa-v3", thr_deberta),
274
- ("DistilBERT", thr_distil),
275
- ("BiLSTM (RNN)", thr_bilstm),
276
- ]
277
-
278
- rows = []
279
- for name, thr in configs:
280
- pred_label, prob_not, prob_dis = predict_text(text, name, thr)
281
- rows.append(
282
- {
283
- "Model": name,
284
- "Threshold": thr,
285
- "P_not_disaster": prob_not,
286
- "P_disaster": prob_dis,
287
- "Predicted_label": pred_label,
288
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  )
290
-
291
- # ---- Table view ----
292
- st.subheader("πŸ“‹ Model outputs")
293
- df = pd.DataFrame(rows)
294
- # Nice formatting for display
295
- df_display = df.copy()
296
- df_display["P_not_disaster"] = df_display["P_not_disaster"].map(lambda x: f"{x:.3f}")
297
- df_display["P_disaster"] = df_display["P_disaster"].map(lambda x: f"{x:.3f}")
298
- df_display["Threshold"] = df_display["Threshold"].map(lambda x: f"{x:.2f}")
299
- st.dataframe(df_display, use_container_width=True)
300
-
301
- # ---- Interactive bar chart comparing P(disaster) ----
302
- st.subheader("πŸ“Š P(disaster) comparison")
303
- chart_df = df[["Model", "P_disaster"]].set_index("Model")
304
- st.bar_chart(chart_df)
305
-
306
- # ---- Per-model summary text ----
307
- st.subheader("πŸ”Ž Per-model decisions")
308
- for row in rows:
309
- name = row["Model"]
310
- thr = row["Threshold"]
311
- p_dis = row["P_disaster"]
312
- p_not = row["P_not_disaster"]
313
- label = row["Predicted_label"]
314
-
315
- st.markdown(f"**{name}**")
316
- st.write(
317
- f"- P(disaster = 1): `{p_dis:.3f}`\n"
318
- f"- P(not disaster = 0): `{p_not:.3f}`\n"
319
- f"- Threshold: `{thr:.2f}` β†’ prediction = `{label}`"
320
- )
321
- st.markdown("---")
322
-
323
- except FileNotFoundError as e:
324
- st.error(f"❌ {str(e)}")
325
- st.info("Please upload `bilstm_state_dict.pt` to the root of your Space repository.")
326
- except Exception as e:
327
- st.error(f"❌ An error occurred: {str(e)}")
328
- st.exception(e)
329
 
330
  st.markdown("---")
331
  st.caption(
 
47
  return logits
48
 
49
 
50
+ # ------------------- Shared tokenizer for BiLSTM -------------------
51
+ @st.cache_resource
52
+ def load_shared_tokenizer():
53
+ """Load tokenizer once for BiLSTM (uses DistilBERT tokenizer)"""
54
+ model_name = "Aswin92/distilbert-disaster-tweets"
55
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
56
+ return tokenizer
57
+
58
+
59
+ # ------------------- Individual model loaders -------------------
60
  @st.cache_resource
61
  def load_deberta():
62
  model_name = "Aswin92/deberta-v3-disaster-tweets"
 
79
 
80
  @st.cache_resource
81
  def load_bilstm():
82
+ # Use shared tokenizer instead of loading DistilBERT model
83
+ tokenizer = load_shared_tokenizer()
 
 
 
84
 
85
  model = LSTMClassifier(
86
  vocab_size=tokenizer.vocab_size,
 
92
  num_classes=2,
93
  )
94
 
95
+ # Load BiLSTM weights directly
96
+ state_path = "bilstm_state_dict.pt"
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
+ if not os.path.exists(state_path):
99
+ current_dir = os.getcwd()
 
 
 
 
 
 
100
  files_in_dir = os.listdir(current_dir)
101
  raise FileNotFoundError(
102
+ f"BiLSTM weights file '{state_path}' not found.\n"
103
  f"Current directory: {current_dir}\n"
104
+ f"Files available: {files_in_dir}\n"
105
+ f"Please upload 'bilstm_state_dict.pt' directly to your Space root (not in a zip)."
106
  )
107
 
108
+ state_dict = torch.load(state_path, map_location=DEVICE, weights_only=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  model.load_state_dict(state_dict)
110
  model.to(DEVICE)
111
  model.eval()
 
180
  st.title("πŸŒͺ️ Disaster Tweet Classifier")
181
  st.write(
182
  "NLP project on the Kaggle **Disaster Tweets** dataset.\n\n"
183
+ "Compare **DeBERTa-v3**, **DistilBERT**, and a custom **BiLSTM (RNN)** "
184
+ "to decide whether a tweet describes a real disaster."
185
  )
186
 
187
  # -------- Sidebar controls --------
188
  with st.sidebar:
189
+ st.header("βš™οΈ Model Selection")
190
+
191
+ # Let user select which models to run
192
+ run_deberta = st.checkbox("Run DeBERTa-v3", value=True)
193
+ run_distilbert = st.checkbox("Run DistilBERT", value=True)
194
+ run_bilstm = st.checkbox("Run BiLSTM (RNN)", value=True)
195
+
196
+ st.markdown("---")
197
+ st.header("βš™οΈ Thresholds")
198
 
199
  thr_deberta = st.slider(
200
  "DeBERTa-v3 threshold",
 
202
  max_value=0.95,
203
  value=0.60,
204
  step=0.05,
205
+ disabled=not run_deberta,
206
  )
207
  thr_distil = st.slider(
208
  "DistilBERT threshold",
 
210
  max_value=0.95,
211
  value=0.80,
212
  step=0.05,
213
+ disabled=not run_distilbert,
214
  )
215
  thr_bilstm = st.slider(
216
  "BiLSTM (RNN) threshold",
 
218
  max_value=0.95,
219
  value=0.35,
220
  step=0.05,
221
+ disabled=not run_bilstm,
222
  )
223
 
224
  st.caption(
225
  "Each model predicts `P(disaster)`. If that probability is "
226
+ "β‰₯ its threshold, we classify it as **disaster**."
227
  )
228
 
229
  # -------- Main input area --------
 
239
  height=120,
240
  )
241
 
242
+ if st.button("Classify Tweet"):
243
  text = tweet_text.strip()
244
  if not text:
245
  st.warning("Please type a tweet first.")
246
  else:
247
+ # Build list of models to run based on checkboxes
248
+ configs = []
249
+ if run_deberta:
250
+ configs.append(("DeBERTa-v3", thr_deberta))
251
+ if run_distilbert:
252
+ configs.append(("DistilBERT", thr_distil))
253
+ if run_bilstm:
254
+ configs.append(("BiLSTM (RNN)", thr_bilstm))
255
+
256
+ if not configs:
257
+ st.warning("Please select at least one model to run.")
258
+ else:
259
+ try:
260
+ with st.spinner(f"Running {len(configs)} model(s)..."):
261
+ rows = []
262
+ for name, thr in configs:
263
+ pred_label, prob_not, prob_dis = predict_text(text, name, thr)
264
+ rows.append(
265
+ {
266
+ "Model": name,
267
+ "Threshold": thr,
268
+ "P_not_disaster": prob_not,
269
+ "P_disaster": prob_dis,
270
+ "Predicted_label": pred_label,
271
+ }
272
+ )
273
+
274
+ # ---- Table view ----
275
+ st.subheader("πŸ“‹ Model outputs")
276
+ df = pd.DataFrame(rows)
277
+ # Nice formatting for display
278
+ df_display = df.copy()
279
+ df_display["P_not_disaster"] = df_display["P_not_disaster"].map(lambda x: f"{x:.3f}")
280
+ df_display["P_disaster"] = df_display["P_disaster"].map(lambda x: f"{x:.3f}")
281
+ df_display["Threshold"] = df_display["Threshold"].map(lambda x: f"{x:.2f}")
282
+ st.dataframe(df_display, use_container_width=True)
283
+
284
+ # ---- Interactive bar chart comparing P(disaster) ----
285
+ if len(rows) > 1:
286
+ st.subheader("πŸ“Š P(disaster) comparison")
287
+ chart_df = df[["Model", "P_disaster"]].set_index("Model")
288
+ st.bar_chart(chart_df)
289
+
290
+ # ---- Per-model summary text ----
291
+ st.subheader("πŸ”Ž Per-model decisions")
292
+ for row in rows:
293
+ name = row["Model"]
294
+ thr = row["Threshold"]
295
+ p_dis = row["P_disaster"]
296
+ p_not = row["P_not_disaster"]
297
+ label = row["Predicted_label"]
298
+
299
+ st.markdown(f"**{name}**")
300
+ st.write(
301
+ f"- P(disaster = 1): `{p_dis:.3f}`\n"
302
+ f"- P(not disaster = 0): `{p_not:.3f}`\n"
303
+ f"- Threshold: `{thr:.2f}` β†’ prediction = `{label}`"
304
  )
305
+ st.markdown("---")
306
+
307
+ except FileNotFoundError as e:
308
+ st.error(f"❌ {str(e)}")
309
+ st.info("Please upload `bilstm_state_dict.pt` to the root of your Space repository.")
310
+ except Exception as e:
311
+ st.error(f"❌ An error occurred: {str(e)}")
312
+ st.exception(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  st.markdown("---")
315
  st.caption(