ma4389 commited on
Commit
6eddda0
·
verified ·
1 Parent(s): 4886174

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +76 -0
  2. best_bi_model.pth +3 -0
  3. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import DistilBertTokenizer
4
+ import gradio as gr
5
+ import re
6
+ import nltk
7
+ from nltk.corpus import stopwords
8
+ from nltk.tokenize import word_tokenize
9
+ from nltk.stem import WordNetLemmatizer
10
+
11
+ # Download necessary NLTK data
12
+ nltk.download('punkt_tab')
13
+ nltk.download('stopwords')
14
+ nltk.download('wordnet')
15
+
16
+ # Preprocessing
17
+ stop_words = set(stopwords.words("english"))
18
+ lemmatizer = WordNetLemmatizer()
19
+
20
+ def preprocess_text(text):
21
+ text = re.sub(r'[^A-Za-z\s]', '', text)
22
+ text = re.sub(r'https?://\S+|www\.\S+', '', text)
23
+ text = text.lower()
24
+ tokens = word_tokenize(text)
25
+ tokens = [word for word in tokens if word not in stop_words]
26
+ tokens = [lemmatizer.lemmatize(word) for word in tokens]
27
+ return ' '.join(tokens)
28
+
29
+ # Tokenizer
30
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
31
+ max_len = 32
32
+ vocab_size = tokenizer.vocab_size
33
+
34
+ # Model definition
35
+ class BiLSTMClassifier(nn.Module):
36
+ def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
37
+ super(BiLSTMClassifier, self).__init__()
38
+ self.embedding = nn.Embedding(vocab_size, embed_dim)
39
+ self.bilstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
40
+ self.fc = nn.Linear(hidden_dim * 2, num_classes)
41
+
42
+ def forward(self, x):
43
+ x = self.embedding(x)
44
+ out, _ = self.bilstm(x)
45
+ out = out[:, -1, :]
46
+ return self.fc(out)
47
+
48
+ # Load model
49
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
50
+ model = BiLSTMClassifier(vocab_size, embed_dim=128, hidden_dim=64, num_classes=2)
51
+ model.load_state_dict(torch.load("best_bi_model.pth", map_location=device))
52
+ model.to(device)
53
+ model.eval()
54
+
55
+ # Inference function
56
+ def predict_spam(text):
57
+ cleaned = preprocess_text(text)
58
+ encoded = tokenizer(cleaned, truncation=True, padding='max_length', max_length=max_len, return_tensors='pt')
59
+ input_ids = encoded['input_ids'].to(device)
60
+
61
+ with torch.no_grad():
62
+ output = model(input_ids)
63
+ prediction = torch.argmax(output, dim=1).item()
64
+
65
+ return "Spam 🚫" if prediction == 1 else "Ham ✅"
66
+
67
+ # Gradio Interface
68
+ interface = gr.Interface(
69
+ fn=predict_spam,
70
+ inputs=gr.Textbox(lines=5, label="Enter Email Text"),
71
+ outputs=gr.Label(num_top_classes=2, label="Prediction"),
72
+ title="📧 Spam or Ham Classifier (BiLSTM)",
73
+ description="Enter an email message to predict whether it is Spam or Ham using a trained BiLSTM model."
74
+ )
75
+
76
+ interface.launch()
best_bi_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:56c69c29aacc53fd27d8378ac966e373204a2b770c02f5d3304b70e2ddaff1b2
3
+ size 16028416
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.38.0
3
+ gradio>=4.0.0
4
+ nltk>=3.8.1
5
+ pandas>=2.0.0
6
+ scikit-learn>=1.2.0
7
+ tqdm>=4.65.0