| import gradio as gr |
| import torch |
| import numpy as np |
| from transformers import BertTokenizer, BertModel |
| import torch.nn as nn |
|
|
| |
| class BERT_Arch(nn.Module): |
| def __init__(self, bert): |
| super(BERT_Arch, self).__init__() |
| self.bert = bert |
| self.dropout = nn.Dropout(0.1) |
| self.relu = nn.ReLU() |
| self.fc1 = nn.Linear(768, 512) |
| self.fc2 = nn.Linear(512, 2) |
| self.softmax = nn.LogSoftmax(dim=1) |
|
|
| def forward(self, sent_id, mask): |
| cls_hs = self.bert(sent_id, attention_mask=mask)['pooler_output'] |
| x = self.fc1(cls_hs) |
| x = self.relu(x) |
| x = self.dropout(x) |
| x = self.fc2(x) |
| x = self.softmax(x) |
| return x |
|
|
| |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
| bert = BertModel.from_pretrained('bert-base-uncased') |
| model = BERT_Arch(bert) |
|
|
| |
| model.load_state_dict(torch.load('ak.pt', map_location=torch.device('cpu')), strict=False) |
| model.eval() |
|
|
| |
| def predict_news(text): |
| MAX_LENGTH = 15 |
| |
| tokens = tokenizer.encode_plus( |
| text, |
| max_length=MAX_LENGTH, |
| pad_to_max_length=True, |
| truncation=True, |
| return_tensors='pt' |
| ) |
| |
| seq = tokens['input_ids'] |
| mask = tokens['attention_mask'] |
| |
| with torch.no_grad(): |
| preds = model(seq, mask) |
| preds = preds.detach().cpu().numpy() |
| |
| pred_class = np.argmax(preds, axis=1)[0] |
| return "Real" if pred_class == 1 else "Fake" |
|
|
| |
| iface = gr.Interface( |
| fn=predict_news, |
| inputs=gr.Textbox(lines=2, placeholder="Enter news headline here..."), |
| outputs="text", |
| title="Fake News Detector", |
| description="Enter a news headline to check if it's fake or real.", |
| examples=[ |
| ["Donald Trump Sends Out Embarrassing New Year’s Eve Message; This is Disturbing", "Fake"], |
| ["Trump administration issues new rules on U.S. visa waivers", "Real"] |
| ] |
| ) |
|
|
| |
| if __name__ == "__main__": |
| iface.launch() |