import streamlit as st import torch import re import torch.nn as nn import json from tokenizers.models import WordLevel from tokenizers.pre_tokenizers import Whitespace from tokenizers import Tokenizer import torch.nn.functional as F st.title("Email Classification") ## mopdel vocab_size = 473 embedding_dim = 25 hidden_units = 25 num_classes = 2 max_len = 20 class RNNModel(nn.Module): def __init__(self, vocab_size, embedding_dim, hidden_units, num_classes): super(RNNModel, self).__init__() self.embedding = nn.Embedding(vocab_size, embedding_dim) self.rnn = nn.RNN(embedding_dim, hidden_units, batch_first=True, dropout=0.2) self.fc = nn.Linear(hidden_units, num_classes) def forward(self, x): x = self.embedding(x) output, _ = self.rnn(x) x = output[:, -1, :] x = self.fc(x) return F.softmax(x, dim=1) model = RNNModel(vocab_size, embedding_dim, hidden_units, num_classes) ## load the weights model.load_state_dict(torch.load( "email_classfication.pth", map_location=torch.device("cpu"))) model.eval() with open("vocab.json", "r") as f: vocab = json.load(f) tokenizer = Tokenizer(WordLevel(vocab, unk_token="")) tokenizer.pre_tokenizer = Whitespace() # tokenizer=joblib.load("tokenizer.pkl") def preprocess(words): normalized = [] for i in words: i = i.lower() # get rid of urlss i = re.sub('https?://\S+|www\.\S+', '', i) # get rid of non words and extra spaces i = re.sub('\\W', ' ', i) i = re.sub('\n', '', i) i = re.sub(' +', ' ', i) i = re.sub('^ ', '', i) i = re.sub(' $', '', i) normalized.append(i) text=[tokenizer.encode(text.lower()).ids for text in normalized] max_length = 20 flattened_text = [token for sublist in text for token in sublist] if len(flattened_text) > max_length: flattened_text = flattened_text[:max_length] else: flattened_text += [0] * (max_length - len(flattened_text)) text_tensor = torch.tensor(flattened_text, dtype=torch.long) text_tensor = text_tensor.unsqueeze(0) return text_tensor text=st.text_input("Enter the Email Text ",value="Happy holidays from our team! Wishing you joy and prosperity this season.") if st.button("submit"): words=text.split() v=preprocess(words) output=model(v) if output.argmax()==1: st.write("Its a Spam Mail") else: st.write("Its not a Spam Mail")