Spaces:
Sleeping
Sleeping
File size: 2,672 Bytes
0a285b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import torch
import torch.nn as nn
from transformers import DistilBertTokenizer
import gradio as gr
import re
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
# Download NLTK resources
nltk.download("stopwords")
nltk.download("punkt_tab")
nltk.download("wordnet")
# Preprocessing setup
stop_words = set(stopwords.words("english"))
lemmatizer = WordNetLemmatizer()
def preprocess_text(text):
text = re.sub(r'[^A-Za-z\s]', '', text)
text = re.sub(r'https?://\S+|www\.\S+', '', text)
text = text.lower()
tokens = word_tokenize(text)
tokens = [word for word in tokens if word not in stop_words]
tokens = [lemmatizer.lemmatize(word) for word in tokens]
return ' '.join(tokens)
# GRU Classifier
class GRUClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
super(GRUClassifier, self).__init__()
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)
self.fc = nn.Linear(hidden_dim, num_classes)
def forward(self, input_ids):
x = self.embedding(input_ids)
out, _ = self.gru(x)
out = out[:, -1, :]
return self.fc(out)
# Load tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GRUClassifier(
vocab_size=tokenizer.vocab_size,
embed_dim=128,
hidden_dim=64,
num_classes=2
).to(device)
model.load_state_dict(torch.load("best_gru_model.pth", map_location=device))
model.eval()
# Prediction function
def predict_clickbait(title):
preprocessed = preprocess_text(title)
encoding = tokenizer(preprocessed, truncation=True, padding='max_length', max_length=32, return_tensors='pt')
input_ids = encoding['input_ids'].to(device)
with torch.no_grad():
output = model(input_ids)
pred = torch.argmax(output, dim=1).item()
confidence = torch.softmax(output, dim=1).squeeze()[pred].item()
label = "📢 Spam (Clickbait)" if pred == 1 else "✅ Ham (Non-Clickbait)"
return f"{label} (Confidence: {confidence:.2f})"
# Gradio Interface
interface = gr.Interface(
fn=predict_clickbait,
inputs=gr.Textbox(lines=2, placeholder="Enter a news title or headline..."),
outputs="text",
title="📰 Clickbait Detector (Ham vs Spam)",
description="Enter a headline to detect whether it's ham (non-clickbait) or spam (clickbait) using a GRU-based model."
)
if __name__ == "__main__":
interface.launch()
|