| |
| |
| import shutil |
| import gradio as gr |
|
|
| |
|
|
| import insightface |
| import onnxruntime |
| from insightface.app import FaceAnalysis |
| import cv2 |
| import gfpgan |
| import tempfile |
| import time |
| import uuid |
| from typing import Any, Union |
| from loggers import logger, request_id as _request_id |
| import ssl |
| from datetime import datetime |
| import traceback |
| import torch |
| import os |
| import requests |
| import subprocess |
| import sys |
| from PIL import Image |
| import numpy as np |
|
|
| ssl._create_default_https_context = ssl._create_unverified_context |
|
|
| if sys.platform == 'darwin': |
| cache_file_dir = '/tmp/file' |
| else: |
| cache_file_dir = '/src/file' |
| |
|
|
|
|
| def img_url_to_local_path(img_url, file_path=None): |
| filename = img_url.split('/')[-1] |
| max_count = 3 |
| count = 0 |
| if file_path is None: |
| temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=filename) |
| temp_file_name = temp_file.name |
| else: |
| temp_file_name = file_path |
| while True: |
| count += 1 |
| try: |
| res = requests.get(img_url, timeout=60) |
| res.raise_for_status() |
| with open(temp_file_name, "wb") as f: |
| f.write(res.content) |
| return temp_file_name |
| except Exception as e: |
| logger.error(e) |
| if count >= max_count: |
| msg = f'request {max_count} time url: {img_url} failed, please check' |
| logger.error(msg) |
| raise Exception(msg) |
|
|
|
|
| def delete_files_day_ago(cache_days=10): |
| command = f"find {cache_file_dir} -type f -ctime +{cache_days} -exec rm {{}} \;" |
| result = subprocess.run(command, shell=True, capture_output=True, text=True) |
| logger.info(result.stdout) |
|
|
|
|
| def image_format_by_path(image_path): |
| image = Image.open(image_path) |
| image_format = image.format |
| if not image_format: |
| image_format = 'jpg' |
| elif image_format == "JPEG": |
| image_format = 'jpg' |
| else: |
| image_format = image_format.lower() |
| return image_format |
|
|
|
|
| def local_file_for_url(url, cache_days=10): |
| filename = url.split('/')[-1] |
| _, ext = filename.split('.') |
| file_path = f'{cache_file_dir}/{filename}' |
| if not os.path.exists(file_path): |
| img_url_to_local_path(url, file_path) |
| logger.info(f'download file to {file_path}') |
| delete_files_day_ago(cache_days) |
| else: |
| logger.info(f'cache file {file_path}') |
| return file_path |
|
|
|
|
| class Predictor: |
| def __init__(self): |
| self.det_thresh = 0.1 |
|
|
| def setup(self): |
| self.face_swapper = insightface.model_zoo.get_model('cache/inswapper_128.onnx', providers=onnxruntime.get_available_providers()) |
| self.face_enhancer = gfpgan.GFPGANer(model_path='cache/GFPGANv1.4.pth', upscale=1) |
| self.face_analyser = FaceAnalysis(name='buffalo_l') |
|
|
| def get_face(self, img_data, image_type='target'): |
| try: |
| logger.info(self.det_thresh) |
| self.face_analyser.prepare(ctx_id=0, det_thresh=0.5) |
| if image_type == 'source': |
| self.face_analyser.prepare(ctx_id=0, det_thresh=self.det_thresh) |
| analysed = self.face_analyser.get(img_data) |
| logger.info(f'face num: {len(analysed)}') |
| if len(analysed) == 0: |
| msg = 'no face' |
| logger.error(msg) |
| raise Exception(msg) |
| largest = max(analysed, key=lambda x: (x.bbox[2] - x.bbox[0]) * (x.bbox[3] - x.bbox[1])) |
| return largest |
| except Exception as e: |
| logger.error(str(e)) |
| raise Exception(str(e)) |
|
|
| def enhance_face(self, target_face, target_frame, weight=0.5): |
| start_x, start_y, end_x, end_y = map(int, target_face['bbox']) |
| padding_x = int((end_x - start_x) * 0.5) |
| padding_y = int((end_y - start_y) * 0.5) |
| start_x = max(0, start_x - padding_x) |
| start_y = max(0, start_y - padding_y) |
| end_x = max(0, end_x + padding_x) |
| end_y = max(0, end_y + padding_y) |
| temp_face = target_frame[start_y:end_y, start_x:end_x] |
| if temp_face.size: |
| _, _, temp_face = self.face_enhancer.enhance( |
| temp_face, |
| paste_back=True, |
| weight=weight |
| ) |
| target_frame[start_y:end_y, start_x:end_x] = temp_face |
| return target_frame |
|
|
| def predict( |
| self, |
| source_image_path, |
| target_image_path, |
| enhance_face, |
| |
| |
| |
| |
| |
| |
| |
| ) -> Any: |
| """Run a single prediction on the model""" |
| request_id = None |
| det_thresh = 0.1 |
| cache_days = 10 |
| weight = 0.5 |
|
|
| device = 'cuda' if torch.cuda.is_available() else 'mps' |
| logger.info(f'device: {device}, det_thresh:{det_thresh}') |
|
|
| try: |
| self.det_thresh = det_thresh |
| start_time = time.time() |
| if not request_id: |
| request_id = str(uuid.uuid4()) |
| _request_id.set(request_id) |
| frame = cv2.imread(str(target_image_path)) |
| source_frame = cv2.imread(str(source_image_path)) |
| source_face = self.get_face(source_frame, image_type='source') |
| target_face = self.get_face(frame) |
| try: |
| logger.info(f'{frame.shape}, {target_face.shape}, {source_face.shape}') |
| except Exception as e: |
| logger.error(f"printing shapes failed, error:{str(e)}") |
| raise Exception(str(e)) |
| ext = image_format_by_path(target_image_path) |
| size = os.path.getsize(target_image_path) |
| logger.info(f'origin {size/1024}k') |
| result = self.face_swapper.get(frame, target_face, source_face, paste_back=True) |
| if enhance_face: |
| result = self.enhance_face(target_face, result, weight) |
| |
| |
| |
| |
| out_path = f"{tempfile.mkdtemp()}/{uuid.uuid4()}.{ext}" |
| cv2.imwrite(str(out_path), result) |
| return Image.open(out_path) |
|
|
| size = os.path.getsize(out_path) |
| logger.info(f'result {size / 1024}k') |
| cost_time = time.time() - start_time |
| logger.info(f'total time: {cost_time * 1000} ms') |
| data = {'code': 200, 'msg': 'succeed', 'image': out_path, 'status': 'succeed'} |
| return data |
| except Exception as e: |
| logger.error(traceback.format_exc()) |
| data = {'code': 500, 'msg': str(e), 'image': '', 'status': 'failed'} |
| logger.error(f"{str(e)}") |
| return data |
|
|
| def swap_faces(source_image_path, target_image_path, enhance_face): |
| predictor = Predictor() |
| predictor.setup() |
| return predictor.predict( |
| source_image_path, |
| target_image_path, |
| enhance_face |
| ) |
|
|
| if __name__ == "__main__": |
| demo = gr.Interface( |
| fn=swap_faces, |
| inputs=[ |
| gr.Image(type="filepath"), |
| gr.Image(type="filepath"), |
| gr.Checkbox(label="Enhance Face", value=True), |
| |
| ], |
| outputs=[ |
| gr.Image( |
| type="pil", |
| show_download_button=True, |
| ) |
| ], |
| title="Swap Faces", |
| allow_flagging="never" |
| ) |
| demo.launch() |