import torch import torch.nn as nn from transformers import DistilBertTokenizer import gradio as gr import re import nltk from nltk.corpus import stopwords from nltk.tokenize import word_tokenize from nltk.stem import WordNetLemmatizer # Download NLTK resources nltk.download("stopwords") nltk.download("punkt_tab") nltk.download("wordnet") # Preprocessing setup stop_words = set(stopwords.words("english")) lemmatizer = WordNetLemmatizer() def preprocess_text(text): text = re.sub(r'[^A-Za-z\s]', '', text) text = re.sub(r'https?://\S+|www\.\S+', '', text) text = text.lower() tokens = word_tokenize(text) tokens = [word for word in tokens if word not in stop_words] tokens = [lemmatizer.lemmatize(word) for word in tokens] return ' '.join(tokens) # GRU Classifier class GRUClassifier(nn.Module): def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes): super(GRUClassifier, self).__init__() self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0) self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True) self.fc = nn.Linear(hidden_dim, num_classes) def forward(self, input_ids): x = self.embedding(input_ids) out, _ = self.gru(x) out = out[:, -1, :] return self.fc(out) # Load tokenizer and model tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = GRUClassifier( vocab_size=tokenizer.vocab_size, embed_dim=128, hidden_dim=64, num_classes=2 ).to(device) model.load_state_dict(torch.load("best_gru_model.pth", map_location=device)) model.eval() # Prediction function def predict_clickbait(title): preprocessed = preprocess_text(title) encoding = tokenizer(preprocessed, truncation=True, padding='max_length', max_length=32, return_tensors='pt') input_ids = encoding['input_ids'].to(device) with torch.no_grad(): output = model(input_ids) pred = torch.argmax(output, dim=1).item() confidence = torch.softmax(output, dim=1).squeeze()[pred].item() label = "📢 Spam (Clickbait)" if pred == 1 else "✅ Ham (Non-Clickbait)" return f"{label} (Confidence: {confidence:.2f})" # Gradio Interface interface = gr.Interface( fn=predict_clickbait, inputs=gr.Textbox(lines=2, placeholder="Enter a news title or headline..."), outputs="text", title="📰 Clickbait Detector (Ham vs Spam)", description="Enter a headline to detect whether it's ham (non-clickbait) or spam (clickbait) using a GRU-based model." ) if __name__ == "__main__": interface.launch()