ma4389 commited on
Commit
0a285b8
·
verified ·
1 Parent(s): 4d5ff45

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +81 -81
app.py CHANGED
@@ -1,81 +1,81 @@
1
- import torch
2
- import torch.nn as nn
3
- from transformers import DistilBertTokenizer
4
- import gradio as gr
5
- import re
6
- import nltk
7
- from nltk.corpus import stopwords
8
- from nltk.tokenize import word_tokenize
9
- from nltk.stem import WordNetLemmatizer
10
-
11
- # Download NLTK resources
12
- nltk.download("stopwords")
13
- nltk.download("punkt_tab")
14
- nltk.download("wordnet")
15
-
16
- # Preprocessing setup
17
- stop_words = set(stopwords.words("english"))
18
- lemmatizer = WordNetLemmatizer()
19
-
20
- def preprocess_text(text):
21
- text = re.sub(r'[^A-Za-z\s]', '', text)
22
- text = re.sub(r'https?://\S+|www\.\S+', '', text)
23
- text = text.lower()
24
- tokens = word_tokenize(text)
25
- tokens = [word for word in tokens if word not in stop_words]
26
- tokens = [lemmatizer.lemmatize(word) for word in tokens]
27
- return ' '.join(tokens)
28
-
29
- # GRU Classifier
30
- class GRUClassifier(nn.Module):
31
- def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
32
- super(GRUClassifier, self).__init__()
33
- self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
34
- self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)
35
- self.fc = nn.Linear(hidden_dim, num_classes)
36
-
37
- def forward(self, input_ids):
38
- x = self.embedding(input_ids)
39
- out, _ = self.gru(x)
40
- out = out[:, -1, :]
41
- return self.fc(out)
42
-
43
- # Load tokenizer and model
44
- tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
45
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
-
47
- model = GRUClassifier(
48
- vocab_size=tokenizer.vocab_size,
49
- embed_dim=128,
50
- hidden_dim=64,
51
- num_classes=2
52
- ).to(device)
53
-
54
- model.load_state_dict(torch.load("best_gru_model.pth", map_location=device))
55
- model.eval()
56
-
57
- # Prediction function
58
- def predict_clickbait(title):
59
- preprocessed = preprocess_text(title)
60
- encoding = tokenizer(preprocessed, truncation=True, padding='max_length', max_length=32, return_tensors='pt')
61
- input_ids = encoding['input_ids'].to(device)
62
-
63
- with torch.no_grad():
64
- output = model(input_ids)
65
- pred = torch.argmax(output, dim=1).item()
66
- confidence = torch.softmax(output, dim=1).squeeze()[pred].item()
67
-
68
- label = "Clickbait" if pred == 1 else "Not Clickbait"
69
- return f"{label} (Confidence: {confidence:.2f})"
70
-
71
- # Gradio Interface
72
- interface = gr.Interface(
73
- fn=predict_clickbait,
74
- inputs=gr.Textbox(lines=2, placeholder="Enter a headline..."),
75
- outputs="text",
76
- title="Clickbait Title Classifier",
77
- description="Detect whether a news headline is clickbait using a GRU classifier."
78
- )
79
-
80
- if __name__ == "__main__":
81
- interface.launch()
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import DistilBertTokenizer
4
+ import gradio as gr
5
+ import re
6
+ import nltk
7
+ from nltk.corpus import stopwords
8
+ from nltk.tokenize import word_tokenize
9
+ from nltk.stem import WordNetLemmatizer
10
+
11
+ # Download NLTK resources
12
+ nltk.download("stopwords")
13
+ nltk.download("punkt_tab")
14
+ nltk.download("wordnet")
15
+
16
+ # Preprocessing setup
17
+ stop_words = set(stopwords.words("english"))
18
+ lemmatizer = WordNetLemmatizer()
19
+
20
+ def preprocess_text(text):
21
+ text = re.sub(r'[^A-Za-z\s]', '', text)
22
+ text = re.sub(r'https?://\S+|www\.\S+', '', text)
23
+ text = text.lower()
24
+ tokens = word_tokenize(text)
25
+ tokens = [word for word in tokens if word not in stop_words]
26
+ tokens = [lemmatizer.lemmatize(word) for word in tokens]
27
+ return ' '.join(tokens)
28
+
29
+ # GRU Classifier
30
+ class GRUClassifier(nn.Module):
31
+ def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes):
32
+ super(GRUClassifier, self).__init__()
33
+ self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
34
+ self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True)
35
+ self.fc = nn.Linear(hidden_dim, num_classes)
36
+
37
+ def forward(self, input_ids):
38
+ x = self.embedding(input_ids)
39
+ out, _ = self.gru(x)
40
+ out = out[:, -1, :]
41
+ return self.fc(out)
42
+
43
+ # Load tokenizer and model
44
+ tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
45
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
46
+
47
+ model = GRUClassifier(
48
+ vocab_size=tokenizer.vocab_size,
49
+ embed_dim=128,
50
+ hidden_dim=64,
51
+ num_classes=2
52
+ ).to(device)
53
+
54
+ model.load_state_dict(torch.load("best_gru_model.pth", map_location=device))
55
+ model.eval()
56
+
57
+ # Prediction function
58
+ def predict_clickbait(title):
59
+ preprocessed = preprocess_text(title)
60
+ encoding = tokenizer(preprocessed, truncation=True, padding='max_length', max_length=32, return_tensors='pt')
61
+ input_ids = encoding['input_ids'].to(device)
62
+
63
+ with torch.no_grad():
64
+ output = model(input_ids)
65
+ pred = torch.argmax(output, dim=1).item()
66
+ confidence = torch.softmax(output, dim=1).squeeze()[pred].item()
67
+
68
+ label = "📢 Spam (Clickbait)" if pred == 1 else " Ham (Non-Clickbait)"
69
+ return f"{label} (Confidence: {confidence:.2f})"
70
+
71
+ # Gradio Interface
72
+ interface = gr.Interface(
73
+ fn=predict_clickbait,
74
+ inputs=gr.Textbox(lines=2, placeholder="Enter a news title or headline..."),
75
+ outputs="text",
76
+ title="📰 Clickbait Detector (Ham vs Spam)",
77
+ description="Enter a headline to detect whether it's ham (non-clickbait) or spam (clickbait) using a GRU-based model."
78
+ )
79
+
80
+ if __name__ == "__main__":
81
+ interface.launch()