Update app.py
Browse files
app.py
CHANGED
|
@@ -37,7 +37,6 @@ stop_english = set(stopwords.words("english"))
|
|
| 37 |
# -----------------------------
|
| 38 |
st.write("Account Disruption")
|
| 39 |
st.write("""Dear Customer Support Team,
|
| 40 |
-
|
| 41 |
I am writing to report a significant problem with the centralized account management portal...
|
| 42 |
""")
|
| 43 |
|
|
@@ -58,7 +57,15 @@ with col2:
|
|
| 58 |
model_path = "model.h5"
|
| 59 |
model = load_model(model_path, compile=False) # <- works on HF
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
|
|
|
|
|
|
|
|
|
| 62 |
# -----------------------------
|
| 63 |
# Load Tokenizer
|
| 64 |
# -----------------------------
|
|
@@ -114,3 +121,10 @@ if st.button("Submit"):
|
|
| 114 |
|
| 115 |
preds = model.predict(seq)
|
| 116 |
st.write("Model Output:", preds)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
# -----------------------------
|
| 38 |
st.write("Account Disruption")
|
| 39 |
st.write("""Dear Customer Support Team,
|
|
|
|
| 40 |
I am writing to report a significant problem with the centralized account management portal...
|
| 41 |
""")
|
| 42 |
|
|
|
|
| 57 |
model_path = "model.h5"
|
| 58 |
model = load_model(model_path, compile=False) # <- works on HF
|
| 59 |
|
| 60 |
+
with open("le_type.pkl", "rb") as f:
|
| 61 |
+
le_type = pickle.load(f)
|
| 62 |
+
|
| 63 |
+
with open("le_queue.pkl", "rb") as f:
|
| 64 |
+
le_queue = pickle.load(f)
|
| 65 |
|
| 66 |
+
with open("mlb.pkl", "rb") as f:
|
| 67 |
+
mlb = pickle.load(f)
|
| 68 |
+
|
| 69 |
# -----------------------------
|
| 70 |
# Load Tokenizer
|
| 71 |
# -----------------------------
|
|
|
|
| 121 |
|
| 122 |
preds = model.predict(seq)
|
| 123 |
st.write("Model Output:", preds)
|
| 124 |
+
pred_type_probs, pred_queue_probs, pred_tags_probs = preds
|
| 125 |
+
pred_type_labels = le_type.inverse_transform(np.argmax(pred_type_probs, axis=1))
|
| 126 |
+
pred_queue_labels = le_queue.inverse_transform(np.argmax(pred_queue_probs, axis=1))
|
| 127 |
+
pred_tags_binary = (pred_tags_probs >= 0.5).astype(int)
|
| 128 |
+
pred_tags_lists = mlb.inverse_transform(pred_tags_binary)
|
| 129 |
+
|
| 130 |
+
st.write(pred_tags_lists)
|