fake-news / app.py
Sumit404's picture
Update app.py
cfc9d1b verified
import gradio as gr
import torch
import numpy as np
from transformers import BertTokenizer, BertModel
import torch.nn as nn
# Define the BERT architecture
class BERT_Arch(nn.Module):
def __init__(self, bert):
super(BERT_Arch, self).__init__()
self.bert = bert
self.dropout = nn.Dropout(0.1)
self.relu = nn.ReLU()
self.fc1 = nn.Linear(768, 512)
self.fc2 = nn.Linear(512, 2)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, sent_id, mask):
cls_hs = self.bert(sent_id, attention_mask=mask)['pooler_output']
x = self.fc1(cls_hs)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = self.softmax(x)
return x
# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert = BertModel.from_pretrained('bert-base-uncased')
model = BERT_Arch(bert)
# Load the trained model weights, ignoring unexpected keys
model.load_state_dict(torch.load('ak.pt', map_location=torch.device('cpu')), strict=False)
model.eval()
# Prediction function
def predict_news(text):
MAX_LENGTH = 15
# Tokenize input text
tokens = tokenizer.encode_plus(
text,
max_length=MAX_LENGTH,
pad_to_max_length=True,
truncation=True,
return_tensors='pt'
)
seq = tokens['input_ids']
mask = tokens['attention_mask']
with torch.no_grad():
preds = model(seq, mask)
preds = preds.detach().cpu().numpy()
pred_class = np.argmax(preds, axis=1)[0]
return "Real" if pred_class == 1 else "Fake"
# Gradio interface
iface = gr.Interface(
fn=predict_news,
inputs=gr.Textbox(lines=2, placeholder="Enter news headline here..."),
outputs="text",
title="Fake News Detector",
description="Enter a news headline to check if it's fake or real.",
examples=[
["Donald Trump Sends Out Embarrassing New Year’s Eve Message; This is Disturbing", "Fake"],
["Trump administration issues new rules on U.S. visa waivers", "Real"]
]
)
# Launch the interface
if __name__ == "__main__":
iface.launch()