Spaces:
Build error
Build error
| import re | |
| import warnings | |
| from transformers import AutoModelForTokenClassification,AutoTokenizer,pipeline | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from torch import LongTensor, FloatTensor | |
| from torch.autograd import Function | |
| from torchvision.transforms.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor | |
| from transformers import BertModel, BertConfig, BertTokenizer | |
| from torch.utils.data import Dataset, DataLoader | |
| warnings.filterwarnings('ignore') | |
| class Exp(Function): | |
| def forward(ctx, i): | |
| result = i.exp() | |
| ctx.save_for_backward(result) | |
| return result | |
| def backward(ctx, grad_output): | |
| result, = ctx.saved_tensors | |
| return grad_output * result | |
| class ReverseLayerF(Function): | |
| # @staticmethod | |
| def forward(self, x, args): | |
| self.lambd = args.lambd | |
| return x.view_as(x) | |
| # @staticmethod | |
| def backward(self, grad_output): | |
| return (grad_output * -self.lambd) | |
| def grad_reverse(x): | |
| return Exp.apply(x) | |
| class Config(): | |
| def __init__(self): | |
| self.batch_size = 16 | |
| self.epochs = 200 | |
| self.bert_path = "./fake-news-bert/" | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.event_num = 30 | |
| class FakeNewsDataset(Dataset): | |
| def __init__(self, input_three, event, image, label): | |
| self.event = LongTensor(list(event)) | |
| self.image = LongTensor([np.array(i) for i in image]) | |
| self.label = LongTensor(list(label)) | |
| self.input_three = self.input_three | |
| self.input_three[0] = LongTensor(self.input_three[0]) | |
| self.input_three[1] = LongTensor(self.input_three[1]) | |
| self.input_three[2] = LongTensor(self.input_three[2]) | |
| def __len__(self): | |
| return len(self.label) | |
| def __getitem__(self, idx): | |
| return self.input_three[0][idx], self.input_three[2][idx], self.input_three[2][idx], self.image[idx], \ | |
| self.event[idx], self.label[idx] | |
| class Multi_Model(nn.Module): | |
| def __init__(self, bert_path, event_num, classes=2, p=10): | |
| super(Multi_Model, self).__init__() | |
| self.config = BertConfig.from_pretrained("./fake-news-bert/config.json") # 导入模型超参数 | |
| self.bert = BertModel.from_pretrained(bert_path, config=self.config) # 加载预训练模型权重 | |
| self.fc = nn.Linear(self.config.hidden_size, p) # 直接分类 | |
| self.event_num = event_num | |
| ''' | |
| vgg_19 = torchvision.models.vgg19(pretrained=True) | |
| for param in vgg_19.parameters(): | |
| param.requires_grad = False | |
| num_ftrs = vgg_19.classifier._modules['6'].out_features | |
| self.vgg = vgg_19 | |
| ''' | |
| # self.image_fc1 = nn.Linear(num_ftrs, p) | |
| # input 3*224*224 | |
| self.cnn = nn.Sequential( | |
| nn.Conv2d(3, 1, kernel_size=5, stride=2, padding=2), # 1 * 112*112 | |
| nn.ReLU(), | |
| nn.MaxPool2d(2), # 1*56*56 | |
| nn.Conv2d(1, 1, kernel_size=5, stride=2, padding=0), # 1*26*26 | |
| nn.ReLU(), | |
| ) | |
| self.image_fc = nn.Sequential( | |
| nn.Linear(1 * 26 * 26, 26), | |
| nn.Linear(26, p), | |
| ) | |
| # self.image_classifier = nn.Sequential( | |
| # K.VisionTransformer(image_size=224, patch_size=16), | |
| # K.ClassificationHead(num_classes=10)#adjust needed when p change | |
| # ) | |
| self.softmax = nn.Softmax(dim=1) | |
| self.class_classifier = nn.Sequential() | |
| self.class_classifier.add_module( | |
| 'c_fc1', nn.Linear(2 * p, p)) | |
| self.class_classifier.add_module('c_fc2', nn.Linear(p, 2)) | |
| self.class_classifier.add_module('c_softmax', nn.Softmax(dim=1)) | |
| self.domain_classifier = nn.Sequential() | |
| self.domain_classifier.add_module( | |
| 'd_fc1', nn.Linear(2 * p, p)) | |
| # self.domain_classifier.add_module('d_bn1', nn.BatchNorm2d(self.hidden_size)) | |
| self.domain_classifier.add_module('d_relu1', nn.LeakyReLU(True)) | |
| self.domain_classifier.add_module( | |
| 'd_fc2', nn.Linear(p, self.event_num)) | |
| self.domain_classifier.add_module('d_softmax', nn.Softmax(dim=1)) | |
| def forward(self, input_ids, attention_mask=None, token_type_ids=None, image=None): | |
| outputs = self.bert(input_ids, attention_mask, token_type_ids) | |
| out_pool = outputs[1] # 池化后的输出 [bs, config.hidden_size] | |
| text = self.fc(out_pool) # [bs, classes] | |
| # image = self.vgg(image) # [N, 512] | |
| # image = F.leaky_relu(self.image_fc1(image)) | |
| image = self.cnn(image) | |
| # image = self.image_classifier(image) | |
| image = self.image_fc(image.view(image.size(0), -1)) | |
| text_image = torch.cat((text, image), 1) | |
| class_output = self.class_classifier(text_image) | |
| reverse_feature = grad_reverse(text_image) | |
| domain_output = self.domain_classifier(reverse_feature) | |
| return class_output, domain_output | |
| def cleanSST(string): | |
| string = re.sub(u"[,。 :,.;|-“”——_/nbsp+&;@、《》~()())#O!:【】]", "", string) | |
| return string.strip().lower() | |
| image_path = './1.jpg' | |
| example_image = Image.open(image_path) | |
| example_text = '2024年是世界末日,我们完蛋了,世界要毁灭了' | |
| def predict(input_text,input_image): | |
| data_transforms = Compose(transforms=[ | |
| Resize(256), | |
| CenterCrop(224), | |
| ToTensor(), | |
| Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| text = "" | |
| text = input_text | |
| # text = '2024年是世界末日,我们完蛋了,世界要毁灭了' | |
| # image_path = '1.jpg' | |
| multi_model = Multi_Model("./fake-news-bert/", 30) # 这个30不用管 | |
| multi_model.eval() | |
| multi_model.load_state_dict(torch.load('./fake-news-bert/best_multi_bert_model.pth')) | |
| # im = Image.open(image_path).convert('RGB') | |
| im = input_image.convert('RGB') | |
| # im = Image.fromarray(input_image).convert('RGB') | |
| im = data_transforms(im) | |
| # 该文件夹下存放三个文件('vocab.txt', 'pytorch_model.bin', 'config.json') | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-chinese') | |
| input_ids, input_masks, input_types, = [], [], [] | |
| encode_dict = tokenizer.encode_plus(text=cleanSST(text), max_length=50, | |
| padding='max_length', truncation=True) | |
| multi_model.to(device) | |
| input_ids.append(encode_dict['input_ids']) | |
| input_types.append(encode_dict['token_type_ids']) | |
| input_masks.append(encode_dict['attention_mask']) | |
| label_pred, yyy = multi_model(LongTensor(input_ids).to(device), LongTensor(input_types).to(device), | |
| LongTensor(input_masks).to(device), FloatTensor([np.array(im)]).to(device)) | |
| print(label_pred.shape) | |
| print(label_pred) | |
| y_pred = torch.argmax(label_pred, dim=1).detach().cpu().numpy().tolist() | |
| print(y_pred) | |
| print("fake news :", text) | |
| # print("image path:", image_path) | |
| if y_pred[0] == 0: | |
| # print('Real News') | |
| output_text = '真实新闻' | |
| else: | |
| # print('Fake News') | |
| output_text = '虚假新闻' | |
| return output_text | |
| # examples=['2024年是世界末日,我们完蛋了,世界要毁灭了'] | |
| # demo = gr.Interface(predict, | |
| # inputs=[gr.Textbox(lines=2,placeholer="在这里输入需要检测新闻的文本内容"),"image"], | |
| # outputs="text")#, | |
| # # examples=examples) | |
| css = ".json {height: 527px; overflow: scroll;} .json-holder {height: 527px; overflow: scroll;}" | |
| with gr.Blocks(css = css) as demo: | |
| gr.Markdown("<h1><center>虚假新闻检测</center></h1>") | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp_txt = gr.Textbox(lines=2,placeholer="在这里输入需要检测新闻的文本内容") | |
| inp_img = gr.Image(type='pil') | |
| inp = [inp_txt,inp_img] | |
| with gr.Column(): | |
| out = gr.Textbox(lines=2) | |
| btn = gr.Button("检测") | |
| btn.click(fn=predict,inputs=inp,outputs=out) | |
| examples = [[example_text,image_path]] | |
| gr.Examples( | |
| examples = examples, | |
| inputs = inp , | |
| ) | |
| demo.launch() |