import streamlit as st import torch import dill import json import re from transformers import AutoTokenizer, AutoModelForSequenceClassification import time from models.spam_model import SpamNaiveBayes # use cpu device = torch.device("cpu") # for tokenizing text def tokenize(text): return re.findall(r"\w+|[!?.]", str(text).lower()) # for encoding text def encode_text(text, vocab, max_len=40): toks = tokenize(text) ids = [vocab.get(t, 1) for t in toks[:max_len]] return ids + [0]*(max_len - len(ids)) # function for loading all 4 models @st.cache_resource # cache so it doesn't reload everytime def load_models(): # load vocab for CNN and BiLSTM with open('./models/vocab.json', 'r') as f: vocab = json.load(f) # load Naive Bayes with open('./models/model_nb.pkl', 'rb') as f: nb_model = dill.load(f) # load CNN and BiLSTM cnn_model = torch.jit.load("./models/model_cnn.pt", map_location=device) lstm_model = torch.jit.load("./models/model_bilstm.pt", map_location=device) # load distilBert bert_tokenizer = AutoTokenizer.from_pretrained("./models/DistilBert") bert_model = AutoModelForSequenceClassification.from_pretrained("./models/DistilBert") return vocab, nb_model, cnn_model, lstm_model, bert_tokenizer, bert_model # load everything try: vocab, nb_model, cnn, lstm, bert_tok, bert = load_models() st.toast("System Ready!", icon="✅") except Exception as e: st.error(f"Failed to load models. Error: {e}") st.stop() ## Streamlit Logic st.title("Spam Message Classifier") st.markdown("Compare 4 different AI architectures on the same message.") # textbox text = st.text_area("Enter Message:", "Congratulations! You've won a $1000 Walmart gift card. Click here to claim.") # sidebar with st.sidebar: st.header("About Project") st.write("The goal of this project is to compare Traditional Machine Learning vs. Deep Learning models for text classification.") st.divider() st.link_button("Dataset", "https://huggingface.co/datasets/mshenoda/spam-messages") if st.button("Analyze Message", type="primary"): col1, col2 = st.columns(2) col3, col4 = st.columns(2) # Naive Bayes start = time.time() nb_res = nb_model.predict(text) end = time.time() lbl = "SPAM" if nb_res == 1 else "HAM" col1.metric("Naive Bayes", lbl, f"{(end-start)*1000:.1f} ms") # Prepare for CNN and LSTM input_ids = torch.tensor([encode_text(text, vocab)]).to(device) # CNN start = time.time() with torch.no_grad(): cnn_res = cnn(input_ids).argmax(1).item() end = time.time() lbl = "SPAM" if cnn_res == 1 else "HAM" col2.metric("CNN", lbl, f"{(end-start)*1000:.1f} ms") # BiLSTM start = time.time() with torch.no_grad(): lstm_res = lstm(input_ids).argmax(1).item() end = time.time() lbl = "SPAM" if lstm_res == 1 else "HAM" col3.metric("BiLSTM", lbl, f"{(end-start)*1000:.1f} ms") # DistilBert start = time.time() inputs = bert_tok(text, return_tensors="pt", padding=True, truncation=True).to(device) with torch.no_grad(): logits = bert(**inputs).logits bert_res = logits.argmax().item() end = time.time() lbl = "SPAM" if bert_res == 1 else "HAM" col4.metric("DistilBERT", lbl, f"{(end-start)*1000:.1f} ms") with st.expander("View Model Details"): st.markdown(""" * **Naive Bayes:** A traditional **Machine Learning** model that uses probability statistics (Bayes' Theorem) to predict spam based on simple word counts. * **CNN:** A **Deep Learning** model that uses sliding "filters" to detect specific patterns of words (like "free prize"), similar to how it detects edges in images. * **BiLSTM:** A **Recurrent Neural Network (RNN)** that reads the message forwards and backwards simultaneously to understand the context and sequence of words. * **DistilBERT:** A **Transformer** model that uses "Self-Attention" to understand the complex meaning and relationship between every word in the sentence. """)