Spaces:
Sleeping
Sleeping
| import insightface | |
| import os | |
| import onnxruntime | |
| import cv2 | |
| import gfpgan | |
| import tempfile | |
| import time | |
| import gradio as gr | |
| class Predictor: | |
| def __init__(self): | |
| self.setup() | |
| def setup(self): | |
| os.makedirs('models', exist_ok=True) | |
| os.chdir('models') | |
| if not os.path.exists('GFPGANv1.4.pth'): | |
| os.system( | |
| 'wget https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth' | |
| ) | |
| if not os.path.exists('inswapper_128.onnx'): | |
| os.system( | |
| 'wget https://huggingface.co/ashleykleynhans/inswapper/resolve/main/inswapper_128.onnx' | |
| ) | |
| os.chdir('..') | |
| """Load the model into memory to make running multiple predictions efficient""" | |
| self.face_swapper = insightface.model_zoo.get_model('models/inswapper_128.onnx', | |
| providers=onnxruntime.get_available_providers()) | |
| self.face_enhancer = gfpgan.GFPGANer(model_path='models/GFPGANv1.4.pth', upscale=1) | |
| self.face_analyser = insightface.app.FaceAnalysis(name='buffalo_l') | |
| self.face_analyser.prepare(ctx_id=0, det_size=(640, 640)) | |
| def get_face(self, img_data): | |
| analysed = self.face_analyser.get(img_data) | |
| try: | |
| largest = max(analysed, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])) | |
| return largest | |
| except: | |
| print("No face found") | |
| return None | |
| # 这是修正后的predict函数,请用它替换掉您源代码中的旧版本 | |
| def predict(self, input_image, swap_image): | |
| """Run a single prediction on the model""" | |
| # 【新增】检查输入是否为空 | |
| if input_image is None or swap_image is None: | |
| raise gr.Error("请确保同时上传了目标图片和源图片。") # 使用Gradio的方式优雅地报错 | |
| try: | |
| # 读取图片 | |
| target_img = cv2.imread(input_image.name) | |
| swap_img = cv2.imread(swap_image.name) | |
| # 分析人脸 | |
| target_face = self.get_face(target_img) | |
| source_face = self.get_face(swap_img) | |
| # 【关键修正】在执行换脸前,检查是否成功找到了人脸 | |
| if target_face is None: | |
| raise gr.Error("在目标图片中未能检测到人脸,请更换图片后重试。") | |
| if source_face is None: | |
| raise gr.Error("在源图片中未能检测到人脸,请更换图片后重试。") | |
| # 如果人脸都找到了,才执行核心换脸操作 | |
| result = self.face_swapper.get(target_img, target_face, source_face, paste_back=True) | |
| # 增强画质 | |
| _, _, result = self.face_enhancer.enhance( | |
| result, | |
| paste_back=True | |
| ) | |
| # 保存并返回结果 | |
| out_path = tempfile.mkdtemp() + f"/{str(int(time.time()))}.jpg" | |
| cv2.imwrite(out_path, result) | |
| return out_path | |
| except Exception as e: | |
| # 如果发生其他未知错误,也通过Gradio报错 | |
| print(f"An unexpected error occurred: {e}") | |
| raise gr.Error(f"发生未知错误: {e}") | |
| # Instantiate the Predictor class | |
| predictor = Predictor() | |
| title = "Swap Faces Using Our Model!!!" | |
| # Create Gradio Interface | |
| iface = gr.Interface( | |
| fn=predictor.predict, | |
| inputs=[ | |
| gr.inputs.Image(type="file", label="Target Image"), | |
| gr.inputs.Image(type="file", label="Swap Image") | |
| ], | |
| outputs=gr.outputs.Image(type="file", label="Result"), | |
| title=title, | |
| examples=[["input.jpg", "swap img.jpg"]]) | |
| # Launch the Gradio Interface | |
| iface.launch() | |