Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from transformers import DistilBertTokenizer, DistilBertModel | |
| import torch.nn as nn | |
| class NewsClassifier(nn.Module): | |
| def __init__(self): | |
| super(NewsClassifier, self).__init__() | |
| self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased") | |
| for param in self.bert.parameters(): | |
| param.requires_grad = False | |
| self.classifier = nn.Sequential( | |
| nn.Linear(self.bert.config.hidden_size, 256), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, 128), | |
| nn.BatchNorm1d(128), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(128, 64), | |
| nn.BatchNorm1d(64), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(64, 2) | |
| ) | |
| def forward(self, input_ids, attention_mask): | |
| bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| sentence_embeddings = bert_output.last_hidden_state[:, 0, :] | |
| return self.classifier(sentence_embeddings) | |
| def load_model_and_tokenizer(): | |
| tokenizer = DistilBertTokenizer.from_pretrained("tokenizer_distilbert") | |
| model = NewsClassifier() | |
| model.load_state_dict(torch.load("News_classifier.pt", map_location="cpu")) | |
| model.eval() | |
| return tokenizer, model | |
| tokenizer, model = load_model_and_tokenizer() | |
| class_names = ["True", "Fake"] | |
| st.title("Fake News Detection App") | |
| st.write("Paste a news article/text below to check if it is **Fake** or **True**.") | |
| news_text = st.text_area("Enter News Text", height=200) | |
| if st.button("Predict"): | |
| if news_text.strip(): | |
| encoding = tokenizer(news_text, padding="max_length", max_length=200, truncation=True, return_tensors="pt") | |
| input_ids = encoding["input_ids"] | |
| attention_mask = encoding["attention_mask"] | |
| with torch.no_grad(): | |
| outputs = model(input_ids, attention_mask) | |
| prediction = torch.argmax(outputs, dim=1).item() | |
| result = class_names[prediction] | |
| st.success(f"This news is **{result}**.") | |
| else: | |
| st.warning("Please enter some news text!") | |