Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import dill | |
| import json | |
| import re | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| import time | |
| from models.spam_model import SpamNaiveBayes | |
| # use cpu | |
| device = torch.device("cpu") | |
| # for tokenizing text | |
| def tokenize(text): | |
| return re.findall(r"\w+|[!?.]", str(text).lower()) | |
| # for encoding text | |
| def encode_text(text, vocab, max_len=40): | |
| toks = tokenize(text) | |
| ids = [vocab.get(t, 1) for t in toks[:max_len]] | |
| return ids + [0]*(max_len - len(ids)) | |
| # function for loading all 4 models | |
| # cache so it doesn't reload everytime | |
| def load_models(): | |
| # load vocab for CNN and BiLSTM | |
| with open('./models/vocab.json', 'r') as f: | |
| vocab = json.load(f) | |
| # load Naive Bayes | |
| with open('./models/model_nb.pkl', 'rb') as f: | |
| nb_model = dill.load(f) | |
| # load CNN and BiLSTM | |
| cnn_model = torch.jit.load("./models/model_cnn.pt", map_location=device) | |
| lstm_model = torch.jit.load("./models/model_bilstm.pt", map_location=device) | |
| # load distilBert | |
| bert_tokenizer = AutoTokenizer.from_pretrained("./models/DistilBert") | |
| bert_model = AutoModelForSequenceClassification.from_pretrained("./models/DistilBert") | |
| return vocab, nb_model, cnn_model, lstm_model, bert_tokenizer, bert_model | |
| # load everything | |
| try: | |
| vocab, nb_model, cnn, lstm, bert_tok, bert = load_models() | |
| st.toast("System Ready!", icon="✅") | |
| except Exception as e: | |
| st.error(f"Failed to load models. Error: {e}") | |
| st.stop() | |
| ## Streamlit Logic | |
| st.title("Spam Message Classifier") | |
| st.markdown("Compare 4 different AI architectures on the same message.") | |
| # textbox | |
| text = st.text_area("Enter Message:", "Congratulations! You've won a $1000 Walmart gift card. Click here to claim.") | |
| # sidebar | |
| with st.sidebar: | |
| st.header("About Project") | |
| st.write("The goal of this project is to compare Traditional Machine Learning vs. Deep Learning models for text classification.") | |
| st.divider() | |
| st.link_button("Dataset", "https://huggingface.co/datasets/mshenoda/spam-messages") | |
| if st.button("Analyze Message", type="primary"): | |
| col1, col2 = st.columns(2) | |
| col3, col4 = st.columns(2) | |
| # Naive Bayes | |
| start = time.time() | |
| nb_res = nb_model.predict(text) | |
| end = time.time() | |
| lbl = "SPAM" if nb_res == 1 else "HAM" | |
| col1.metric("Naive Bayes", lbl, f"{(end-start)*1000:.1f} ms") | |
| # Prepare for CNN and LSTM | |
| input_ids = torch.tensor([encode_text(text, vocab)]).to(device) | |
| # CNN | |
| start = time.time() | |
| with torch.no_grad(): | |
| cnn_res = cnn(input_ids).argmax(1).item() | |
| end = time.time() | |
| lbl = "SPAM" if cnn_res == 1 else "HAM" | |
| col2.metric("CNN", lbl, f"{(end-start)*1000:.1f} ms") | |
| # BiLSTM | |
| start = time.time() | |
| with torch.no_grad(): | |
| lstm_res = lstm(input_ids).argmax(1).item() | |
| end = time.time() | |
| lbl = "SPAM" if lstm_res == 1 else "HAM" | |
| col3.metric("BiLSTM", lbl, f"{(end-start)*1000:.1f} ms") | |
| # DistilBert | |
| start = time.time() | |
| inputs = bert_tok(text, return_tensors="pt", padding=True, truncation=True).to(device) | |
| with torch.no_grad(): | |
| logits = bert(**inputs).logits | |
| bert_res = logits.argmax().item() | |
| end = time.time() | |
| lbl = "SPAM" if bert_res == 1 else "HAM" | |
| col4.metric("DistilBERT", lbl, f"{(end-start)*1000:.1f} ms") | |
| with st.expander("View Model Details"): | |
| st.markdown(""" | |
| * **Naive Bayes:** A traditional **Machine Learning** model that uses probability statistics (Bayes' Theorem) to predict spam based on simple word counts. | |
| * **CNN:** A **Deep Learning** model that uses sliding "filters" to detect specific patterns of words (like "free prize"), similar to how it detects edges in images. | |
| * **BiLSTM:** A **Recurrent Neural Network (RNN)** that reads the message forwards and backwards simultaneously to understand the context and sequence of words. | |
| * **DistilBERT:** A **Transformer** model that uses "Self-Attention" to understand the complex meaning and relationship between every word in the sentence. | |
| """) |