WXM2000's picture
Create app.py
4ef25f5
import re
import warnings
from transformers import AutoModelForTokenClassification,AutoTokenizer,pipeline
import gradio as gr
import torch
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torch import LongTensor, FloatTensor
from torch.autograd import Function
from torchvision.transforms.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor
from transformers import BertModel, BertConfig, BertTokenizer
from torch.utils.data import Dataset, DataLoader
warnings.filterwarnings('ignore')
class Exp(Function):
@staticmethod
def forward(ctx, i):
result = i.exp()
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return grad_output * result
class ReverseLayerF(Function):
# @staticmethod
def forward(self, x, args):
self.lambd = args.lambd
return x.view_as(x)
# @staticmethod
def backward(self, grad_output):
return (grad_output * -self.lambd)
def grad_reverse(x):
return Exp.apply(x)
class Config():
def __init__(self):
self.batch_size = 16
self.epochs = 200
self.bert_path = "./fake-news-bert/"
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.event_num = 30
class FakeNewsDataset(Dataset):
def __init__(self, input_three, event, image, label):
self.event = LongTensor(list(event))
self.image = LongTensor([np.array(i) for i in image])
self.label = LongTensor(list(label))
self.input_three = self.input_three
self.input_three[0] = LongTensor(self.input_three[0])
self.input_three[1] = LongTensor(self.input_three[1])
self.input_three[2] = LongTensor(self.input_three[2])
def __len__(self):
return len(self.label)
def __getitem__(self, idx):
return self.input_three[0][idx], self.input_three[2][idx], self.input_three[2][idx], self.image[idx], \
self.event[idx], self.label[idx]
class Multi_Model(nn.Module):
def __init__(self, bert_path, event_num, classes=2, p=10):
super(Multi_Model, self).__init__()
self.config = BertConfig.from_pretrained("./fake-news-bert/config.json") # 导入模型超参数
self.bert = BertModel.from_pretrained(bert_path, config=self.config) # 加载预训练模型权重
self.fc = nn.Linear(self.config.hidden_size, p) # 直接分类
self.event_num = event_num
'''
vgg_19 = torchvision.models.vgg19(pretrained=True)
for param in vgg_19.parameters():
param.requires_grad = False
num_ftrs = vgg_19.classifier._modules['6'].out_features
self.vgg = vgg_19
'''
# self.image_fc1 = nn.Linear(num_ftrs, p)
# input 3*224*224
self.cnn = nn.Sequential(
nn.Conv2d(3, 1, kernel_size=5, stride=2, padding=2), # 1 * 112*112
nn.ReLU(),
nn.MaxPool2d(2), # 1*56*56
nn.Conv2d(1, 1, kernel_size=5, stride=2, padding=0), # 1*26*26
nn.ReLU(),
)
self.image_fc = nn.Sequential(
nn.Linear(1 * 26 * 26, 26),
nn.Linear(26, p),
)
# self.image_classifier = nn.Sequential(
# K.VisionTransformer(image_size=224, patch_size=16),
# K.ClassificationHead(num_classes=10)#adjust needed when p change
# )
self.softmax = nn.Softmax(dim=1)
self.class_classifier = nn.Sequential()
self.class_classifier.add_module(
'c_fc1', nn.Linear(2 * p, p))
self.class_classifier.add_module('c_fc2', nn.Linear(p, 2))
self.class_classifier.add_module('c_softmax', nn.Softmax(dim=1))
self.domain_classifier = nn.Sequential()
self.domain_classifier.add_module(
'd_fc1', nn.Linear(2 * p, p))
# self.domain_classifier.add_module('d_bn1', nn.BatchNorm2d(self.hidden_size))
self.domain_classifier.add_module('d_relu1', nn.LeakyReLU(True))
self.domain_classifier.add_module(
'd_fc2', nn.Linear(p, self.event_num))
self.domain_classifier.add_module('d_softmax', nn.Softmax(dim=1))
def forward(self, input_ids, attention_mask=None, token_type_ids=None, image=None):
outputs = self.bert(input_ids, attention_mask, token_type_ids)
out_pool = outputs[1] # 池化后的输出 [bs, config.hidden_size]
text = self.fc(out_pool) # [bs, classes]
# image = self.vgg(image) # [N, 512]
# image = F.leaky_relu(self.image_fc1(image))
image = self.cnn(image)
# image = self.image_classifier(image)
image = self.image_fc(image.view(image.size(0), -1))
text_image = torch.cat((text, image), 1)
class_output = self.class_classifier(text_image)
reverse_feature = grad_reverse(text_image)
domain_output = self.domain_classifier(reverse_feature)
return class_output, domain_output
def cleanSST(string):
string = re.sub(u"[,。 :,.;|-“”——_/nbsp+&;@、《》~()())#O!:【】]", "", string)
return string.strip().lower()
image_path = './1.jpg'
example_image = Image.open(image_path)
example_text = '2024年是世界末日,我们完蛋了,世界要毁灭了'
def predict(input_text,input_image):
data_transforms = Compose(transforms=[
Resize(256),
CenterCrop(224),
ToTensor(),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
text = ""
text = input_text
# text = '2024年是世界末日,我们完蛋了,世界要毁灭了'
# image_path = '1.jpg'
multi_model = Multi_Model("./fake-news-bert/", 30) # 这个30不用管
multi_model.eval()
multi_model.load_state_dict(torch.load('./fake-news-bert/best_multi_bert_model.pth'))
# im = Image.open(image_path).convert('RGB')
im = input_image.convert('RGB')
# im = Image.fromarray(input_image).convert('RGB')
im = data_transforms(im)
# 该文件夹下存放三个文件('vocab.txt', 'pytorch_model.bin', 'config.json')
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
input_ids, input_masks, input_types, = [], [], []
encode_dict = tokenizer.encode_plus(text=cleanSST(text), max_length=50,
padding='max_length', truncation=True)
multi_model.to(device)
input_ids.append(encode_dict['input_ids'])
input_types.append(encode_dict['token_type_ids'])
input_masks.append(encode_dict['attention_mask'])
label_pred, yyy = multi_model(LongTensor(input_ids).to(device), LongTensor(input_types).to(device),
LongTensor(input_masks).to(device), FloatTensor([np.array(im)]).to(device))
print(label_pred.shape)
print(label_pred)
y_pred = torch.argmax(label_pred, dim=1).detach().cpu().numpy().tolist()
print(y_pred)
print("fake news :", text)
# print("image path:", image_path)
if y_pred[0] == 0:
# print('Real News')
output_text = '真实新闻'
else:
# print('Fake News')
output_text = '虚假新闻'
return output_text
# examples=['2024年是世界末日,我们完蛋了,世界要毁灭了']
# demo = gr.Interface(predict,
# inputs=[gr.Textbox(lines=2,placeholer="在这里输入需要检测新闻的文本内容"),"image"],
# outputs="text")#,
# # examples=examples)
css = ".json {height: 527px; overflow: scroll;} .json-holder {height: 527px; overflow: scroll;}"
with gr.Blocks(css = css) as demo:
gr.Markdown("<h1><center>虚假新闻检测</center></h1>")
with gr.Row():
with gr.Column():
inp_txt = gr.Textbox(lines=2,placeholer="在这里输入需要检测新闻的文本内容")
inp_img = gr.Image(type='pil')
inp = [inp_txt,inp_img]
with gr.Column():
out = gr.Textbox(lines=2)
btn = gr.Button("检测")
btn.click(fn=predict,inputs=inp,outputs=out)
examples = [[example_text,image_path]]
gr.Examples(
examples = examples,
inputs = inp ,
)
demo.launch()