kimlay1's picture
Upload 15 files
4b30156 verified
import streamlit as st
import torch
import dill
import json
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import time
from models.spam_model import SpamNaiveBayes
# use cpu
device = torch.device("cpu")
# for tokenizing text
def tokenize(text):
return re.findall(r"\w+|[!?.]", str(text).lower())
# for encoding text
def encode_text(text, vocab, max_len=40):
toks = tokenize(text)
ids = [vocab.get(t, 1) for t in toks[:max_len]]
return ids + [0]*(max_len - len(ids))
# function for loading all 4 models
@st.cache_resource # cache so it doesn't reload everytime
def load_models():
# load vocab for CNN and BiLSTM
with open('./models/vocab.json', 'r') as f:
vocab = json.load(f)
# load Naive Bayes
with open('./models/model_nb.pkl', 'rb') as f:
nb_model = dill.load(f)
# load CNN and BiLSTM
cnn_model = torch.jit.load("./models/model_cnn.pt", map_location=device)
lstm_model = torch.jit.load("./models/model_bilstm.pt", map_location=device)
# load distilBert
bert_tokenizer = AutoTokenizer.from_pretrained("./models/DistilBert")
bert_model = AutoModelForSequenceClassification.from_pretrained("./models/DistilBert")
return vocab, nb_model, cnn_model, lstm_model, bert_tokenizer, bert_model
# load everything
try:
vocab, nb_model, cnn, lstm, bert_tok, bert = load_models()
st.toast("System Ready!", icon="✅")
except Exception as e:
st.error(f"Failed to load models. Error: {e}")
st.stop()
## Streamlit Logic
st.title("Spam Message Classifier")
st.markdown("Compare 4 different AI architectures on the same message.")
# textbox
text = st.text_area("Enter Message:", "Congratulations! You've won a $1000 Walmart gift card. Click here to claim.")
# sidebar
with st.sidebar:
st.header("About Project")
st.write("The goal of this project is to compare Traditional Machine Learning vs. Deep Learning models for text classification.")
st.divider()
st.link_button("Dataset", "https://huggingface.co/datasets/mshenoda/spam-messages")
if st.button("Analyze Message", type="primary"):
col1, col2 = st.columns(2)
col3, col4 = st.columns(2)
# Naive Bayes
start = time.time()
nb_res = nb_model.predict(text)
end = time.time()
lbl = "SPAM" if nb_res == 1 else "HAM"
col1.metric("Naive Bayes", lbl, f"{(end-start)*1000:.1f} ms")
# Prepare for CNN and LSTM
input_ids = torch.tensor([encode_text(text, vocab)]).to(device)
# CNN
start = time.time()
with torch.no_grad():
cnn_res = cnn(input_ids).argmax(1).item()
end = time.time()
lbl = "SPAM" if cnn_res == 1 else "HAM"
col2.metric("CNN", lbl, f"{(end-start)*1000:.1f} ms")
# BiLSTM
start = time.time()
with torch.no_grad():
lstm_res = lstm(input_ids).argmax(1).item()
end = time.time()
lbl = "SPAM" if lstm_res == 1 else "HAM"
col3.metric("BiLSTM", lbl, f"{(end-start)*1000:.1f} ms")
# DistilBert
start = time.time()
inputs = bert_tok(text, return_tensors="pt", padding=True, truncation=True).to(device)
with torch.no_grad():
logits = bert(**inputs).logits
bert_res = logits.argmax().item()
end = time.time()
lbl = "SPAM" if bert_res == 1 else "HAM"
col4.metric("DistilBERT", lbl, f"{(end-start)*1000:.1f} ms")
with st.expander("View Model Details"):
st.markdown("""
* **Naive Bayes:** A traditional **Machine Learning** model that uses probability statistics (Bayes' Theorem) to predict spam based on simple word counts.
* **CNN:** A **Deep Learning** model that uses sliding "filters" to detect specific patterns of words (like "free prize"), similar to how it detects edges in images.
* **BiLSTM:** A **Recurrent Neural Network (RNN)** that reads the message forwards and backwards simultaneously to understand the context and sequence of words.
* **DistilBERT:** A **Transformer** model that uses "Self-Attention" to understand the complex meaning and relationship between every word in the sentence.
""")