jinyin3's picture
Update app.py
e0fffdd verified
import gradio as gr
import os
import cv2
import random
import argparse
import datetime
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from network import MFF_MoE
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
parser = argparse.ArgumentParser()
parser.add_argument('--local_weight', type=str, default='weights/', help='trained weights path')
args = parser.parse_args()
class NetInference():
def __init__(self):
self.net = MFF_MoE(pretrained=False)
self.net.load(path=args.local_weight)
self.net = nn.DataParallel(self.net).cuda()
self.net.eval()
self.transform_val = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
transforms.Resize((512, 512), antialias=True),
])
def infer(self, input_path=''):
x = cv2.imread(input_path)[..., ::-1]
x = Image.fromarray(np.uint8(x))
x = self.transform_val(x).unsqueeze(0).cuda()
pred = self.net(x)
pred = pred.detach().cpu().numpy()
return pred
def process_image(image):
# 设置保存图片的路径,指向当前工程的 images 目录
save_path = os.path.join(os.getcwd(), "images")
# 确保目录存在
os.makedirs(save_path, exist_ok=True)
# 为每个上传的图片生成唯一的文件名
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
image_filename = f"uploaded_{timestamp}.png"
image_path = os.path.join(save_path, image_filename)
# 保存图片
image.save(image_path)
model.infer(image_path)
res = model.infer(image_path)
print('Prediction of [%s] being Deepfake: %10.9f' % (image_path, res))
return f"{res}" # 返回保存的位置信息
# 定义输入:一个图片选择框
image_input = gr.components.Image(label="Upload Image", type="pil")
# 定义输出:一个文本框,用来显示确认信息或处理结果
text_output = gr.components.Textbox()
# 创建 Interface 对象,设置 live=False 以添加提交按钮
demo = gr.Interface(fn=process_image, inputs=image_input, outputs=text_output, live=False)
model = NetInference()
# 启动界面
demo.launch()