ma4389's picture
Update app.py
0a285b8 verified
import torch
import torch.nn as nn
from transformers import DistilBertTokenizer
import gradio as gr
import re
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
# Download NLTK resources
nltk.download("stopwords")
nltk.download("punkt_tab")
nltk.download("wordnet")
# Preprocessing setup
stop_words = set(stopwords.words("english"))
lemmatizer = WordNetLemmatizer()
def preprocess_text(text):
text = re.sub(r'[^A-Za-z\s]', '', text)
text = re.sub(r'https?://\S+|www\.\S+', '', text)
text = text.lower()
tokens = word_tokenize(text)
tokens = [word for word in tokens if word not in stop_words]
tokens = [lemmatizer.lemmatize(word) for word in tokens]
return ' '.join(tokens)
# GRU Classifier
class GRUClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
super(GRUClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, num_classes)
def forward(self, input_ids):
x = self.embedding(input_ids)
out, _ = self.gru(x)
out = out[:, -1, :]
return self.fc(out)
# Load tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GRUClassifier(
vocab_size=tokenizer.vocab_size,
embed_dim=128,
hidden_dim=64,
num_classes=2
).to(device)
model.load_state_dict(torch.load("best_gru_model.pth", map_location=device))
model.eval()
# Prediction function
def predict_clickbait(title):
preprocessed = preprocess_text(title)
encoding = tokenizer(preprocessed, truncation=True, padding='max_length', max_length=32, return_tensors='pt')
input_ids = encoding['input_ids'].to(device)
with torch.no_grad():
output = model(input_ids)
pred = torch.argmax(output, dim=1).item()
confidence = torch.softmax(output, dim=1).squeeze()[pred].item()
label = "📢 Spam (Clickbait)" if pred == 1 else "✅ Ham (Non-Clickbait)"
return f"{label} (Confidence: {confidence:.2f})"
# Gradio Interface
interface = gr.Interface(
fn=predict_clickbait,
inputs=gr.Textbox(lines=2, placeholder="Enter a news title or headline..."),
outputs="text",
title="📰 Clickbait Detector (Ham vs Spam)",
description="Enter a headline to detect whether it's ham (non-clickbait) or spam (clickbait) using a GRU-based model."
)
if __name__ == "__main__":
interface.launch()