import gradio as gr import os import cv2 import random import argparse import datetime import numpy as np import torch import torch.nn as nn from torchvision import transforms from PIL import Image from network import MFF_MoE os.environ['CUDA_VISIBLE_DEVICES'] = '0' parser = argparse.ArgumentParser() parser.add_argument('--local_weight', type=str, default='weights/', help='trained weights path') args = parser.parse_args() class NetInference(): def __init__(self): self.net = MFF_MoE(pretrained=False) self.net.load(path=args.local_weight) self.net = nn.DataParallel(self.net).cuda() self.net.eval() self.transform_val = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), transforms.Resize((512, 512), antialias=True), ]) def infer(self, input_path=''): x = cv2.imread(input_path)[..., ::-1] x = Image.fromarray(np.uint8(x)) x = self.transform_val(x).unsqueeze(0).cuda() pred = self.net(x) pred = pred.detach().cpu().numpy() return pred def process_image(image): # 设置保存图片的路径,指向当前工程的 images 目录 save_path = os.path.join(os.getcwd(), "images") # 确保目录存在 os.makedirs(save_path, exist_ok=True) # 为每个上传的图片生成唯一的文件名 from datetime import datetime timestamp = datetime.now().strftime("%Y%m%d%H%M%S") image_filename = f"uploaded_{timestamp}.png" image_path = os.path.join(save_path, image_filename) # 保存图片 image.save(image_path) model.infer(image_path) res = model.infer(image_path) print('Prediction of [%s] being Deepfake: %10.9f' % (image_path, res)) return f"{res}" # 返回保存的位置信息 # 定义输入:一个图片选择框 image_input = gr.components.Image(label="Upload Image", type="pil") # 定义输出:一个文本框,用来显示确认信息或处理结果 text_output = gr.components.Textbox() # 创建 Interface 对象,设置 live=False 以添加提交按钮 demo = gr.Interface(fn=process_image, inputs=image_input, outputs=text_output, live=False) model = NetInference() # 启动界面 demo.launch()