Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from transformers import DistilBertTokenizer | |
| import gradio as gr | |
| import re | |
| import nltk | |
| from nltk.corpus import stopwords | |
| from nltk.tokenize import word_tokenize | |
| from nltk.stem import WordNetLemmatizer | |
| # Download necessary NLTK data | |
| nltk.download('punkt_tab') | |
| nltk.download('stopwords') | |
| nltk.download('wordnet') | |
| # Preprocessing | |
| stop_words = set(stopwords.words("english")) | |
| lemmatizer = WordNetLemmatizer() | |
| def preprocess_text(text): | |
| text = re.sub(r'[^A-Za-z\s]', '', text) | |
| text = re.sub(r'https?://\S+|www\.\S+', '', text) | |
| text = text.lower() | |
| tokens = word_tokenize(text) | |
| tokens = [word for word in tokens if word not in stop_words] | |
| tokens = [lemmatizer.lemmatize(word) for word in tokens] | |
| return ' '.join(tokens) | |
| # Tokenizer | |
| tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") | |
| max_len = 32 | |
| vocab_size = tokenizer.vocab_size | |
| # Model definition | |
| class BiLSTMClassifier(nn.Module): | |
| def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes): | |
| super(BiLSTMClassifier, self).__init__() | |
| self.embedding = nn.Embedding(vocab_size, embed_dim) | |
| self.bilstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True) | |
| self.fc = nn.Linear(hidden_dim * 2, num_classes) | |
| def forward(self, x): | |
| x = self.embedding(x) | |
| out, _ = self.bilstm(x) | |
| out = out[:, -1, :] | |
| return self.fc(out) | |
| # Load model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = BiLSTMClassifier(vocab_size, embed_dim=128, hidden_dim=64, num_classes=2) | |
| model.load_state_dict(torch.load("best_bi_model.pth", map_location=device)) | |
| model.to(device) | |
| model.eval() | |
| # Inference function | |
| def predict_spam(text): | |
| cleaned = preprocess_text(text) | |
| encoded = tokenizer(cleaned, truncation=True, padding='max_length', max_length=max_len, return_tensors='pt') | |
| input_ids = encoded['input_ids'].to(device) | |
| with torch.no_grad(): | |
| output = model(input_ids) | |
| prediction = torch.argmax(output, dim=1).item() | |
| return "Spam 🚫" if prediction == 1 else "Ham ✅" | |
| # Gradio Interface | |
| interface = gr.Interface( | |
| fn=predict_spam, | |
| inputs=gr.Textbox(lines=5, label="Enter Email Text"), | |
| outputs=gr.Label(num_top_classes=2, label="Prediction"), | |
| title="📧 Spam or Ham Classifier (BiLSTM)", | |
| description="Enter an email message to predict whether it is Spam or Ham using a trained BiLSTM model." | |
| ) | |
| interface.launch() | |