Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import cv2 | |
| import random | |
| import argparse | |
| import datetime | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import transforms | |
| from PIL import Image | |
| from network import MFF_MoE | |
| os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--local_weight', type=str, default='weights/', help='trained weights path') | |
| args = parser.parse_args() | |
| class NetInference(): | |
| def __init__(self): | |
| self.net = MFF_MoE(pretrained=False) | |
| self.net.load(path=args.local_weight) | |
| self.net = nn.DataParallel(self.net).cuda() | |
| self.net.eval() | |
| self.transform_val = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| transforms.Resize((512, 512), antialias=True), | |
| ]) | |
| def infer(self, input_path=''): | |
| x = cv2.imread(input_path)[..., ::-1] | |
| x = Image.fromarray(np.uint8(x)) | |
| x = self.transform_val(x).unsqueeze(0).cuda() | |
| pred = self.net(x) | |
| pred = pred.detach().cpu().numpy() | |
| return pred | |
| def process_image(image): | |
| # 设置保存图片的路径,指向当前工程的 images 目录 | |
| save_path = os.path.join(os.getcwd(), "images") | |
| # 确保目录存在 | |
| os.makedirs(save_path, exist_ok=True) | |
| # 为每个上传的图片生成唯一的文件名 | |
| from datetime import datetime | |
| timestamp = datetime.now().strftime("%Y%m%d%H%M%S") | |
| image_filename = f"uploaded_{timestamp}.png" | |
| image_path = os.path.join(save_path, image_filename) | |
| # 保存图片 | |
| image.save(image_path) | |
| model.infer(image_path) | |
| res = model.infer(image_path) | |
| print('Prediction of [%s] being Deepfake: %10.9f' % (image_path, res)) | |
| return f"{res}" # 返回保存的位置信息 | |
| # 定义输入:一个图片选择框 | |
| image_input = gr.components.Image(label="Upload Image", type="pil") | |
| # 定义输出:一个文本框,用来显示确认信息或处理结果 | |
| text_output = gr.components.Textbox() | |
| # 创建 Interface 对象,设置 live=False 以添加提交按钮 | |
| demo = gr.Interface(fn=process_image, inputs=image_input, outputs=text_output, live=False) | |
| model = NetInference() | |
| # 启动界面 | |
| demo.launch() |