kimlay1 commited on
Commit
4b30156
·
verified ·
1 Parent(s): ac6799a

Upload 15 files

Browse files
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import dill
4
+ import json
5
+ import re
6
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
+ import time
8
+ from models.spam_model import SpamNaiveBayes
9
+
10
+
11
+ # use cpu
12
+ device = torch.device("cpu")
13
+
14
+ # for tokenizing text
15
+ def tokenize(text):
16
+ return re.findall(r"\w+|[!?.]", str(text).lower())
17
+
18
+ # for encoding text
19
+ def encode_text(text, vocab, max_len=40):
20
+ toks = tokenize(text)
21
+ ids = [vocab.get(t, 1) for t in toks[:max_len]]
22
+ return ids + [0]*(max_len - len(ids))
23
+
24
+ # function for loading all 4 models
25
+ @st.cache_resource # cache so it doesn't reload everytime
26
+ def load_models():
27
+ # load vocab for CNN and BiLSTM
28
+ with open('./models/vocab.json', 'r') as f:
29
+ vocab = json.load(f)
30
+
31
+ # load Naive Bayes
32
+ with open('./models/model_nb.pkl', 'rb') as f:
33
+ nb_model = dill.load(f)
34
+
35
+ # load CNN and BiLSTM
36
+ cnn_model = torch.jit.load("./models/model_cnn.pt", map_location=device)
37
+ lstm_model = torch.jit.load("./models/model_bilstm.pt", map_location=device)
38
+
39
+ # load distilBert
40
+ bert_tokenizer = AutoTokenizer.from_pretrained("./models/DistilBert")
41
+ bert_model = AutoModelForSequenceClassification.from_pretrained("./models/DistilBert")
42
+
43
+ return vocab, nb_model, cnn_model, lstm_model, bert_tokenizer, bert_model
44
+
45
+ # load everything
46
+ try:
47
+ vocab, nb_model, cnn, lstm, bert_tok, bert = load_models()
48
+ st.toast("System Ready!", icon="✅")
49
+ except Exception as e:
50
+ st.error(f"Failed to load models. Error: {e}")
51
+ st.stop()
52
+
53
+ ## Streamlit Logic
54
+ st.title("Spam Message Classifier")
55
+ st.markdown("Compare 4 different AI architectures on the same message.")
56
+
57
+ # textbox
58
+ text = st.text_area("Enter Message:", "Congratulations! You've won a $1000 Walmart gift card. Click here to claim.")
59
+
60
+ # sidebar
61
+ with st.sidebar:
62
+ st.header("About Project")
63
+ st.write("The goal of this project is to compare Traditional Machine Learning vs. Deep Learning models for text classification.")
64
+ st.divider()
65
+ st.link_button("Dataset", "https://huggingface.co/datasets/mshenoda/spam-messages")
66
+
67
+ if st.button("Analyze Message", type="primary"):
68
+ col1, col2 = st.columns(2)
69
+ col3, col4 = st.columns(2)
70
+
71
+ # Naive Bayes
72
+ start = time.time()
73
+ nb_res = nb_model.predict(text)
74
+ end = time.time()
75
+ lbl = "SPAM" if nb_res == 1 else "HAM"
76
+ col1.metric("Naive Bayes", lbl, f"{(end-start)*1000:.1f} ms")
77
+
78
+ # Prepare for CNN and LSTM
79
+ input_ids = torch.tensor([encode_text(text, vocab)]).to(device)
80
+
81
+ # CNN
82
+ start = time.time()
83
+ with torch.no_grad():
84
+ cnn_res = cnn(input_ids).argmax(1).item()
85
+ end = time.time()
86
+ lbl = "SPAM" if cnn_res == 1 else "HAM"
87
+ col2.metric("CNN", lbl, f"{(end-start)*1000:.1f} ms")
88
+
89
+ # BiLSTM
90
+ start = time.time()
91
+ with torch.no_grad():
92
+ lstm_res = lstm(input_ids).argmax(1).item()
93
+ end = time.time()
94
+ lbl = "SPAM" if lstm_res == 1 else "HAM"
95
+ col3.metric("BiLSTM", lbl, f"{(end-start)*1000:.1f} ms")
96
+
97
+ # DistilBert
98
+ start = time.time()
99
+ inputs = bert_tok(text, return_tensors="pt", padding=True, truncation=True).to(device)
100
+ with torch.no_grad():
101
+ logits = bert(**inputs).logits
102
+ bert_res = logits.argmax().item()
103
+ end = time.time()
104
+ lbl = "SPAM" if bert_res == 1 else "HAM"
105
+ col4.metric("DistilBERT", lbl, f"{(end-start)*1000:.1f} ms")
106
+
107
+ with st.expander("View Model Details"):
108
+ st.markdown("""
109
+ * **Naive Bayes:** A traditional **Machine Learning** model that uses probability statistics (Bayes' Theorem) to predict spam based on simple word counts.
110
+ * **CNN:** A **Deep Learning** model that uses sliding "filters" to detect specific patterns of words (like "free prize"), similar to how it detects edges in images.
111
+ * **BiLSTM:** A **Recurrent Neural Network (RNN)** that reads the message forwards and backwards simultaneously to understand the context and sequence of words.
112
+ * **DistilBERT:** A **Transformer** model that uses "Self-Attention" to understand the complex meaning and relationship between every word in the sentence.
113
+ """)
models/DistilBert/config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation": "gelu",
3
+ "architectures": [
4
+ "DistilBertForSequenceClassification"
5
+ ],
6
+ "attention_dropout": 0.1,
7
+ "dim": 768,
8
+ "dropout": 0.1,
9
+ "hidden_dim": 3072,
10
+ "id2label": {
11
+ "0": "HAM",
12
+ "1": "SPAM"
13
+ },
14
+ "initializer_range": 0.02,
15
+ "label2id": {
16
+ "HAM": 0,
17
+ "SPAM": 1
18
+ },
19
+ "max_position_embeddings": 512,
20
+ "model_type": "distilbert",
21
+ "n_heads": 12,
22
+ "n_layers": 6,
23
+ "pad_token_id": 0,
24
+ "problem_type": "single_label_classification",
25
+ "qa_dropout": 0.1,
26
+ "seq_classif_dropout": 0.2,
27
+ "sinusoidal_pos_embds": false,
28
+ "tie_weights_": true,
29
+ "torch_dtype": "float32",
30
+ "transformers_version": "4.51.3",
31
+ "vocab_size": 30522
32
+ }
models/DistilBert/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:befff6c44b2d61855b90616e026f2b66d99210ae762b3bb52055b8bcfb047fba
3
+ size 267832560
models/DistilBert/special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
models/DistilBert/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
models/DistilBert/tokenizer_config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": false,
45
+ "cls_token": "[CLS]",
46
+ "do_lower_case": true,
47
+ "extra_special_tokens": {},
48
+ "mask_token": "[MASK]",
49
+ "model_max_length": 512,
50
+ "pad_token": "[PAD]",
51
+ "sep_token": "[SEP]",
52
+ "strip_accents": null,
53
+ "tokenize_chinese_chars": true,
54
+ "tokenizer_class": "DistilBertTokenizer",
55
+ "unk_token": "[UNK]"
56
+ }
models/DistilBert/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de5de394762643c9f549cfbd77db0ee77f88bc3c4b8567af754d79287de02681
3
+ size 5713
models/DistilBert/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
models/__pycache__/spam_model.cpython-311.pyc ADDED
Binary file (6.18 kB). View file
 
models/model_bilstm.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b72e135a2ad091eda1c9502b5f9e8bdcdf5e57b4da3bd67fb89a0ceafe0c1596
3
+ size 66132208
models/model_cnn.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eaaf0fd1e07cb80ad08ebba4040533a979c2c4689c680bc4e174aeb2e4a1a2e2
3
+ size 65668371
models/model_nb.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bfcf52d24f845ec808ddfe0c2cdbcef512508ed3990ad6c900024eaaa8603cd
3
+ size 6348039
models/spam_model.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import re
3
+ import dill
4
+ from collections import Counter
5
+
6
+ class SpamNaiveBayes:
7
+ def __init__(self, alpha=1):
8
+ self.alpha = alpha
9
+ self.vocab = set()
10
+ self.log_spam = {}
11
+ self.log_ham = {}
12
+ self.P_spam = 0
13
+ self.P_ham = 0
14
+ self.unk_spam = 0
15
+ self.unk_ham = 0
16
+
17
+ def tokenize(self, text):
18
+ return re.findall(r"\w+|[!?.]", str(text).lower())
19
+
20
+ def train(self, texts, labels):
21
+ # Build Vocab
22
+ for t in texts:
23
+ self.vocab.update(self.tokenize(t))
24
+ self.vocab = sorted(self.vocab)
25
+
26
+ # Counts
27
+ wc_spam = Counter()
28
+ wc_ham = Counter()
29
+ spam_docs = sum(1 for l in labels if l == 1)
30
+ ham_docs = len(labels) - spam_docs
31
+ total_docs = len(labels)
32
+
33
+ for txt, lab in zip(texts, labels):
34
+ toks = self.tokenize(txt)
35
+ if lab == 1:
36
+ wc_spam.update(toks)
37
+ else:
38
+ wc_ham.update(toks)
39
+
40
+ # Calculate Probabilities
41
+ self.P_spam = spam_docs / total_docs
42
+ self.P_ham = ham_docs / total_docs
43
+
44
+ V = len(self.vocab)
45
+ total_spam = sum(wc_spam.values()) + self.alpha * V
46
+ total_ham = sum(wc_ham.values()) + self.alpha * V
47
+
48
+ self.log_spam = {w: math.log((wc_spam[w] + self.alpha) / total_spam) for w in self.vocab}
49
+ self.log_ham = {w: math.log((wc_ham[w] + self.alpha) / total_ham) for w in self.vocab}
50
+
51
+ self.unk_spam = math.log(self.alpha / total_spam)
52
+ self.unk_ham = math.log(self.alpha / total_ham)
53
+ print("Training Complete.")
54
+
55
+ def predict(self, text):
56
+ toks = self.tokenize(text)
57
+ s_spam = math.log(self.P_spam + 1e-12)
58
+ s_ham = math.log(self.P_ham + 1e-12)
59
+
60
+ for t in toks:
61
+ s_spam += self.log_spam.get(t, self.unk_spam)
62
+ s_ham += self.log_ham.get(t, self.unk_ham)
63
+
64
+ return 1 if s_spam > s_ham else 0
65
+
66
+ if __name__ == "__main__":
67
+ from datasets import load_dataset
68
+
69
+ print("Loading data...")
70
+ ds = load_dataset("mshenoda/spam-messages")
71
+ texts = [x['text'] for x in ds['train']]
72
+
73
+ labels = []
74
+ for x in ds['train']:
75
+ lab = x['label']
76
+ if isinstance(lab, str):
77
+ labels.append(1 if lab.lower() in ['spam', '1'] else 0)
78
+ else:
79
+ labels.append(int(lab))
80
+
81
+ print("Training clean model...")
82
+ model = SpamNaiveBayes()
83
+ model.train(texts, labels)
84
+
85
+ with open("model_nb_clean.pkl", "wb") as f:
86
+ dill.dump(model, f)
87
+ print("✅ Success! 'model_nb_clean.pkl' created. Upload this file to Hugging Face.")
models/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
- altair
2
- pandas
3
- streamlit
 
 
1
+ streamlit
2
+ torch
3
+ transformers
4
+ dill