YAMITEK's picture
Upload 10 files
abbe13f verified
import streamlit as st
import torch
import re
import torch.nn as nn
import json
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Whitespace
from tokenizers import Tokenizer
import torch.nn.functional as F
st.title("Email Classification")
## mopdel
vocab_size = 473
embedding_dim = 25
hidden_units = 25
num_classes = 2
max_len = 20
class RNNModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_units, num_classes):
super(RNNModel, self).__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.rnn = nn.RNN(embedding_dim, hidden_units, batch_first=True, dropout=0.2)
self.fc = nn.Linear(hidden_units, num_classes)
def forward(self, x):
x = self.embedding(x)
output, _ = self.rnn(x)
x = output[:, -1, :]
x = self.fc(x)
return F.softmax(x, dim=1)
model = RNNModel(vocab_size, embedding_dim, hidden_units, num_classes)
## load the weights
model.load_state_dict(torch.load( "email_classfication.pth", map_location=torch.device("cpu")))
model.eval()
with open("vocab.json", "r") as f:
vocab = json.load(f)
tokenizer = Tokenizer(WordLevel(vocab, unk_token="<unk>"))
tokenizer.pre_tokenizer = Whitespace()
# tokenizer=joblib.load("tokenizer.pkl")
def preprocess(words):
normalized = []
for i in words:
i = i.lower()
# get rid of urlss
i = re.sub('https?://\S+|www\.\S+', '', i)
# get rid of non words and extra spaces
i = re.sub('\\W', ' ', i)
i = re.sub('\n', '', i)
i = re.sub(' +', ' ', i)
i = re.sub('^ ', '', i)
i = re.sub(' $', '', i)
normalized.append(i)
text=[tokenizer.encode(text.lower()).ids for text in normalized]
max_length = 20
flattened_text = [token for sublist in text for token in sublist]
if len(flattened_text) > max_length:
flattened_text = flattened_text[:max_length]
else:
flattened_text += [0] * (max_length - len(flattened_text))
text_tensor = torch.tensor(flattened_text, dtype=torch.long)
text_tensor = text_tensor.unsqueeze(0)
return text_tensor
text=st.text_input("Enter the Email Text ",value="Happy holidays from our team! Wishing you joy and prosperity this season.")
if st.button("submit"):
words=text.split()
v=preprocess(words)
output=model(v)
if output.argmax()==1:
st.write("Its a Spam Mail")
else:
st.write("Its not a Spam Mail")