Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import re | |
| import torch.nn as nn | |
| import json | |
| from tokenizers.models import WordLevel | |
| from tokenizers.pre_tokenizers import Whitespace | |
| from tokenizers import Tokenizer | |
| import torch.nn.functional as F | |
| st.title("Email Classification") | |
| ## mopdel | |
| vocab_size = 473 | |
| embedding_dim = 25 | |
| hidden_units = 25 | |
| num_classes = 2 | |
| max_len = 20 | |
| class RNNModel(nn.Module): | |
| def __init__(self, vocab_size, embedding_dim, hidden_units, num_classes): | |
| super(RNNModel, self).__init__() | |
| self.embedding = nn.Embedding(vocab_size, embedding_dim) | |
| self.rnn = nn.RNN(embedding_dim, hidden_units, batch_first=True, dropout=0.2) | |
| self.fc = nn.Linear(hidden_units, num_classes) | |
| def forward(self, x): | |
| x = self.embedding(x) | |
| output, _ = self.rnn(x) | |
| x = output[:, -1, :] | |
| x = self.fc(x) | |
| return F.softmax(x, dim=1) | |
| model = RNNModel(vocab_size, embedding_dim, hidden_units, num_classes) | |
| ## load the weights | |
| model.load_state_dict(torch.load( "email_classfication.pth", map_location=torch.device("cpu"))) | |
| model.eval() | |
| with open("vocab.json", "r") as f: | |
| vocab = json.load(f) | |
| tokenizer = Tokenizer(WordLevel(vocab, unk_token="<unk>")) | |
| tokenizer.pre_tokenizer = Whitespace() | |
| # tokenizer=joblib.load("tokenizer.pkl") | |
| def preprocess(words): | |
| normalized = [] | |
| for i in words: | |
| i = i.lower() | |
| # get rid of urlss | |
| i = re.sub('https?://\S+|www\.\S+', '', i) | |
| # get rid of non words and extra spaces | |
| i = re.sub('\\W', ' ', i) | |
| i = re.sub('\n', '', i) | |
| i = re.sub(' +', ' ', i) | |
| i = re.sub('^ ', '', i) | |
| i = re.sub(' $', '', i) | |
| normalized.append(i) | |
| text=[tokenizer.encode(text.lower()).ids for text in normalized] | |
| max_length = 20 | |
| flattened_text = [token for sublist in text for token in sublist] | |
| if len(flattened_text) > max_length: | |
| flattened_text = flattened_text[:max_length] | |
| else: | |
| flattened_text += [0] * (max_length - len(flattened_text)) | |
| text_tensor = torch.tensor(flattened_text, dtype=torch.long) | |
| text_tensor = text_tensor.unsqueeze(0) | |
| return text_tensor | |
| text=st.text_input("Enter the Email Text ",value="Happy holidays from our team! Wishing you joy and prosperity this season.") | |
| if st.button("submit"): | |
| words=text.split() | |
| v=preprocess(words) | |
| output=model(v) | |
| if output.argmax()==1: | |
| st.write("Its a Spam Mail") | |
| else: | |
| st.write("Its not a Spam Mail") | |