Spaces:
Sleeping
Sleeping
File size: 2,265 Bytes
08b7be7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import streamlit as st
import torch
from transformers import DistilBertTokenizer, DistilBertModel
import torch.nn as nn
class NewsClassifier(nn.Module):
def __init__(self):
super(NewsClassifier, self).__init__()
self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
for param in self.bert.parameters():
param.requires_grad = False
self.classifier = nn.Sequential(
nn.Linear(self.bert.config.hidden_size, 256),
nn.BatchNorm1d(256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, 64),
nn.BatchNorm1d(64),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(64, 2)
)
def forward(self, input_ids, attention_mask):
bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
sentence_embeddings = bert_output.last_hidden_state[:, 0, :]
return self.classifier(sentence_embeddings)
@st.cache_resource
def load_model_and_tokenizer():
tokenizer = DistilBertTokenizer.from_pretrained("tokenizer_distilbert")
model = NewsClassifier()
model.load_state_dict(torch.load("News_classifier.pt", map_location="cpu"))
model.eval()
return tokenizer, model
tokenizer, model = load_model_and_tokenizer()
class_names = ["True", "Fake"]
st.title("Fake News Detection App")
st.write("Paste a news article/text below to check if it is **Fake** or **True**.")
news_text = st.text_area("Enter News Text", height=200)
if st.button("Predict"):
if news_text.strip():
encoding = tokenizer(news_text, padding="max_length", max_length=200, truncation=True, return_tensors="pt")
input_ids = encoding["input_ids"]
attention_mask = encoding["attention_mask"]
with torch.no_grad():
outputs = model(input_ids, attention_mask)
prediction = torch.argmax(outputs, dim=1).item()
result = class_names[prediction]
st.success(f"This news is **{result}**.")
else:
st.warning("Please enter some news text!")
|