File size: 4,207 Bytes
4b30156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import streamlit as st
import torch
import dill
import json
import re
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import time
from models.spam_model import SpamNaiveBayes


# use cpu
device = torch.device("cpu")

# for tokenizing text
def tokenize(text):
    return re.findall(r"\w+|[!?.]", str(text).lower())

# for encoding text
def encode_text(text, vocab, max_len=40):
    toks = tokenize(text)
    ids = [vocab.get(t, 1) for t in toks[:max_len]]
    return ids + [0]*(max_len - len(ids))

# function for loading all 4 models
@st.cache_resource # cache so it doesn't reload everytime
def load_models():
    # load vocab for CNN and BiLSTM
    with open('./models/vocab.json', 'r') as f:
        vocab = json.load(f)

    # load Naive Bayes
    with open('./models/model_nb.pkl', 'rb') as f:
        nb_model = dill.load(f)

    # load CNN and BiLSTM
    cnn_model = torch.jit.load("./models/model_cnn.pt", map_location=device)
    lstm_model = torch.jit.load("./models/model_bilstm.pt", map_location=device)

    # load distilBert
    bert_tokenizer = AutoTokenizer.from_pretrained("./models/DistilBert")
    bert_model = AutoModelForSequenceClassification.from_pretrained("./models/DistilBert")

    return vocab, nb_model, cnn_model, lstm_model, bert_tokenizer, bert_model
    
# load everything
try:
    vocab, nb_model, cnn, lstm, bert_tok, bert = load_models()
    st.toast("System Ready!", icon="✅")
except Exception as e:
    st.error(f"Failed to load models. Error: {e}")
    st.stop()

## Streamlit Logic
st.title("Spam Message Classifier")
st.markdown("Compare 4 different AI architectures on the same message.")

# textbox
text = st.text_area("Enter Message:", "Congratulations! You've won a $1000 Walmart gift card. Click here to claim.")

# sidebar
with st.sidebar:
    st.header("About Project")
    st.write("The goal of this project is to compare Traditional Machine Learning vs. Deep Learning models for text classification.")
    st.divider()
    st.link_button("Dataset", "https://huggingface.co/datasets/mshenoda/spam-messages")

if st.button("Analyze Message", type="primary"):
    col1, col2 = st.columns(2)
    col3, col4 = st.columns(2)
    
    # Naive Bayes
    start = time.time()
    nb_res = nb_model.predict(text)
    end = time.time()
    lbl = "SPAM" if nb_res == 1 else "HAM"
    col1.metric("Naive Bayes", lbl, f"{(end-start)*1000:.1f} ms")
    
    # Prepare for CNN and LSTM
    input_ids = torch.tensor([encode_text(text, vocab)]).to(device)
    
    # CNN 
    start = time.time()
    with torch.no_grad():
        cnn_res = cnn(input_ids).argmax(1).item()
    end = time.time()
    lbl = "SPAM" if cnn_res == 1 else "HAM"
    col2.metric("CNN", lbl, f"{(end-start)*1000:.1f} ms")

    # BiLSTM
    start = time.time()
    with torch.no_grad():
        lstm_res = lstm(input_ids).argmax(1).item()
    end = time.time()
    lbl = "SPAM" if lstm_res == 1 else "HAM"
    col3.metric("BiLSTM", lbl, f"{(end-start)*1000:.1f} ms")

    # DistilBert
    start = time.time()
    inputs = bert_tok(text, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        logits = bert(**inputs).logits
        bert_res = logits.argmax().item()
    end = time.time()
    lbl = "SPAM" if bert_res == 1 else "HAM"
    col4.metric("DistilBERT", lbl, f"{(end-start)*1000:.1f} ms")

with st.expander("View Model Details"):
    st.markdown("""

    * **Naive Bayes:** A traditional **Machine Learning** model that uses probability statistics (Bayes' Theorem) to predict spam based on simple word counts.

    * **CNN:** A **Deep Learning** model that uses sliding "filters" to detect specific patterns of words (like "free prize"), similar to how it detects edges in images.

    * **BiLSTM:** A **Recurrent Neural Network (RNN)** that reads the message forwards and backwards simultaneously to understand the context and sequence of words.

    * **DistilBERT:** A **Transformer** model that uses "Self-Attention" to understand the complex meaning and relationship between every word in the sentence.

    """)