| |
| import json |
| import sys |
| import os |
| import io |
| import argparse |
| import uuid |
| import base64 |
| import logging |
| import time |
| import copy |
| import cv2 |
| import insightface |
| import numpy as np |
| from typing import List, Union |
| from PIL import Image |
| from restoration import * |
| from flask import Flask, request, jsonify, make_response |
| from waitress import serve |
|
|
| LOG_LEVEL = logging.DEBUG |
| TMP_PATH = '/tmp/inswapper' |
| script_dir = os.path.dirname(os.path.abspath(__file__)) |
| log_path = '' |
|
|
| |
| if sys.platform == 'linux': |
| log_path = '/var/log/' |
|
|
| logging.basicConfig( |
| filename=f'{log_path}inswapper.log', |
| format='%(asctime)s : %(levelname)s : %(message)s', |
| level=LOG_LEVEL |
| ) |
|
|
| logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) |
|
|
|
|
| def process_request(request_obj): |
| try: |
| logging.debug('Swapping face') |
| face_swap_timer = Timer() |
| result_image = face_swap(request_obj['source_image'], request_obj['target_image']) |
| face_swap_time = face_swap_timer.get_elapsed_time() |
| logging.info(f'Time taken to swap face: {face_swap_time} seconds') |
|
|
| response = { |
| 'status': 'ok', |
| 'image': result_image |
| } |
| except Exception as e: |
| logging.error(e) |
| response = { |
| 'status': 'error', |
| 'msg': 'Face swap failed', |
| 'detail': str(e) |
| } |
|
|
| return response |
|
|
|
|
| class Timer: |
| def __init__(self): |
| self.start = time.time() |
|
|
| def restart(self): |
| self.start = time.time() |
|
|
| def get_elapsed_time(self): |
| end = time.time() |
| return round(end - self.start, 1) |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser( |
| description='Inswapper REST API' |
| ) |
|
|
| parser.add_argument( |
| '-p', '--port', |
| help='Port to listen on', |
| type=int, |
| default=8090 |
| ) |
|
|
| parser.add_argument( |
| '-H', '--host', |
| help='Host to bind to', |
| default='0.0.0.0' |
| ) |
|
|
| return parser.parse_args() |
|
|
|
|
| def determine_file_extension(image_data): |
| try: |
| if image_data.startswith('/9j/'): |
| image_extension = '.jpg' |
| elif image_data.startswith('iVBORw0Kg'): |
| image_extension = '.png' |
| else: |
| |
| image_extension = '.png' |
| except Exception as e: |
| image_extension = '.png' |
|
|
| return image_extension |
|
|
|
|
| def write_base64_to_disk(file_b64: str, file_path: str): |
| with open(file_path, 'wb') as file: |
| file.write(base64.b64decode(file_b64)) |
|
|
|
|
| def get_face_swap_model(model_path: str): |
| model = insightface.model_zoo.get_model(model_path) |
| return model |
|
|
|
|
| def get_face_analyser(model_path: str, |
| det_size=(320, 320)): |
| face_analyser = insightface.app.FaceAnalysis(name="buffalo_l", root="./checkpoints") |
| face_analyser.prepare(ctx_id=0, det_size=det_size) |
| return face_analyser |
|
|
|
|
| def get_one_face(face_analyser, |
| frame:np.ndarray): |
| face = face_analyser.get(frame) |
| try: |
| return min(face, key=lambda x: x.bbox[0]) |
| except ValueError: |
| return None |
|
|
|
|
| def get_many_faces(face_analyser, |
| frame:np.ndarray): |
| """ |
| get faces from left to right by order |
| """ |
| try: |
| face = face_analyser.get(frame) |
| return sorted(face, key=lambda x: x.bbox[0]) |
| except IndexError: |
| return None |
|
|
|
|
| def swap_face(face_swapper, |
| source_faces, |
| target_faces, |
| source_index, |
| target_index, |
| temp_frame): |
| """ |
| paste source_face on target image |
| """ |
| source_face = source_faces[source_index] |
| target_face = target_faces[target_index] |
|
|
| return face_swapper.get(temp_frame, target_face, source_face, paste_back=True) |
|
|
|
|
| def process(source_img: Union[Image.Image, List], |
| target_img: Image.Image, |
| source_indexes: str, |
| target_indexes: str, |
| model: str): |
|
|
| |
| face_analyser = get_face_analyser(model) |
|
|
| |
| model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model) |
| face_swapper = get_face_swap_model(model_path) |
|
|
| |
| target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) |
|
|
| |
| target_faces = get_many_faces(face_analyser, target_img) |
| num_target_faces = len(target_faces) |
| num_source_images = len(source_img) |
|
|
| if target_faces is not None: |
| temp_frame = copy.deepcopy(target_img) |
| if isinstance(source_img, list) and num_source_images == num_target_faces: |
| logging.debug('Replacing the faces in the target image from left to right by order') |
| for i in range(num_target_faces): |
| source_faces = get_many_faces(face_analyser, cv2.cvtColor(np.array(source_img[i]), cv2.COLOR_RGB2BGR)) |
| source_index = i |
| target_index = i |
|
|
| if source_faces is None: |
| raise Exception('No source faces found!') |
|
|
| temp_frame = swap_face( |
| face_swapper, |
| source_faces, |
| target_faces, |
| source_index, |
| target_index, |
| temp_frame |
| ) |
| elif num_source_images == 1: |
| |
| source_faces = get_many_faces(face_analyser, cv2.cvtColor(np.array(source_img[0]), cv2.COLOR_RGB2BGR)) |
| num_source_faces = len(source_faces) |
| logging.debug(f'Source faces: {num_source_faces}') |
| logging.debug(f'Target faces: {num_target_faces}') |
|
|
| if source_faces is None: |
| raise Exception('No source faces found!') |
|
|
| if target_indexes == "-1": |
| if num_source_faces == 1: |
| logging.debug('Replacing all faces in target image with the same face from the source image') |
| num_iterations = num_target_faces |
| elif num_source_faces < num_target_faces: |
| logging.debug('There are less faces in the source image than the target image, replacing as many as we can') |
| num_iterations = num_source_faces |
| elif num_target_faces < num_source_faces: |
| logging.debug('There are less faces in the target image than the source image, replacing as many as we can') |
| num_iterations = num_target_faces |
| else: |
| logging.debug('Replacing all faces in the target image with the faces from the source image') |
| num_iterations = num_target_faces |
|
|
| for i in range(num_iterations): |
| source_index = 0 if num_source_faces == 1 else i |
| target_index = i |
|
|
| temp_frame = swap_face( |
| face_swapper, |
| source_faces, |
| target_faces, |
| source_index, |
| target_index, |
| temp_frame |
| ) |
| elif source_indexes == '-1' and target_indexes == '-1': |
| logging.debug('Replacing specific face(s) in the target image with the face from the source image') |
| target_indexes = target_indexes.split(',') |
| source_index = 0 |
|
|
| for target_index in target_indexes: |
| target_index = int(target_index) |
|
|
| temp_frame = swap_face( |
| face_swapper, |
| source_faces, |
| target_faces, |
| source_index, |
| target_index, |
| temp_frame |
| ) |
| else: |
| logging.debug('Replacing specific face(s) in the target image with specific face(s) from the source image') |
|
|
| if source_indexes == "-1": |
| source_indexes = ','.join(map(lambda x: str(x), range(num_source_faces))) |
|
|
| if target_indexes == "-1": |
| target_indexes = ','.join(map(lambda x: str(x), range(num_target_faces))) |
|
|
| source_indexes = source_indexes.split(',') |
| target_indexes = target_indexes.split(',') |
| num_source_faces_to_swap = len(source_indexes) |
| num_target_faces_to_swap = len(target_indexes) |
|
|
| if num_source_faces_to_swap > num_source_faces: |
| raise Exception('Number of source indexes is greater than the number of faces in the source image') |
|
|
| if num_target_faces_to_swap > num_target_faces: |
| raise Exception('Number of target indexes is greater than the number of faces in the target image') |
|
|
| if num_source_faces_to_swap > num_target_faces_to_swap: |
| num_iterations = num_source_faces_to_swap |
| else: |
| num_iterations = num_target_faces_to_swap |
|
|
| if num_source_faces_to_swap == num_target_faces_to_swap: |
| for index in range(num_iterations): |
| source_index = int(source_indexes[index]) |
| target_index = int(target_indexes[index]) |
|
|
| if source_index > num_source_faces-1: |
| raise ValueError(f'Source index {source_index} is higher than the number of faces in the source image') |
|
|
| if target_index > num_target_faces-1: |
| raise ValueError(f'Target index {target_index} is higher than the number of faces in the target image') |
|
|
| temp_frame = swap_face( |
| face_swapper, |
| source_faces, |
| target_faces, |
| source_index, |
| target_index, |
| temp_frame |
| ) |
| else: |
| logging.error('Unsupported face configuration') |
| raise Exception('Unsupported face configuration') |
| result = temp_frame |
| else: |
| logging.error('No target faces found') |
| raise Exception('No target faces found!') |
|
|
| result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) |
| return result_image |
|
|
|
|
| def face_swap(src_img_path, |
| target_img_path, |
| source_indexes, |
| target_indexes, |
| background_enhance, |
| face_restore, |
| face_upsample, |
| upscale, |
| codeformer_fidelity, |
| output_format): |
|
|
| source_img_paths = src_img_path.split(';') |
| source_img = [Image.open(img_path) for img_path in source_img_paths] |
| target_img = Image.open(target_img_path) |
|
|
| |
| model = os.path.join(script_dir, 'checkpoints/inswapper_128.onnx') |
| logging.debug(f'Face swap model: {model}') |
|
|
| try: |
| logging.debug('Performing face swap') |
| result_image = process( |
| source_img, |
| target_img, |
| source_indexes, |
| target_indexes, |
| model |
| ) |
| logging.debug('Face swap complete') |
| except Exception as e: |
| raise |
|
|
| |
| check_ckpts() |
|
|
| if face_restore: |
| |
| logging.debug('Setting upsampler to RealESRGAN_x2plus') |
| upsampler = set_realesrgan() |
|
|
| if torch.cuda.is_available(): |
| torch_device = 'cuda' |
| else: |
| torch_device = 'cpu' |
|
|
| logging.debug(f'Torch device: {torch_device.upper()}') |
| device = torch.device(torch_device) |
|
|
| codeformer_net = ARCH_REGISTRY.get('CodeFormer')( |
| dim_embd=512, |
| codebook_size=1024, |
| n_head=8, |
| n_layers=9, |
| connect_list=['32', '64', '128', '256'], |
| ).to(device) |
|
|
| ckpt_path = os.path.join(script_dir, 'CodeFormer/CodeFormer/weights/CodeFormer/codeformer.pth') |
| logging.debug(f'Loading CodeFormer model: {ckpt_path}') |
| checkpoint = torch.load(ckpt_path)['params_ema'] |
| codeformer_net.load_state_dict(checkpoint) |
| codeformer_net.eval() |
| result_image = cv2.cvtColor(np.array(result_image), cv2.COLOR_RGB2BGR) |
| logging.debug('Performing face restoration using CodeFormer') |
|
|
| try: |
| result_image = face_restoration( |
| result_image, |
| background_enhance, |
| face_upsample, |
| upscale, |
| codeformer_fidelity, |
| upsampler, |
| codeformer_net, |
| device |
| ) |
| except Exception as e: |
| raise |
|
|
| logging.debug('CodeFormer face restoration completed successfully') |
| result_image = Image.fromarray(result_image) |
|
|
| output_buffer = io.BytesIO() |
| result_image.save(output_buffer, format=output_format) |
| image_data = output_buffer.getvalue() |
|
|
| return base64.b64encode(image_data).decode('utf-8') |
|
|
|
|
| app = Flask(__name__) |
|
|
|
|
| @app.errorhandler(400) |
| def not_found(error): |
| return make_response(jsonify( |
| { |
| 'status': 'error', |
| 'msg': f'Bad Request', |
| 'detail': str(error) |
| } |
| ), 400) |
|
|
|
|
| @app.errorhandler(404) |
| def not_found(error): |
| return make_response(jsonify( |
| { |
| 'status': 'error', |
| 'msg': f'{request.url} not found', |
| 'detail': str(error) |
| } |
| ), 404) |
|
|
|
|
| @app.errorhandler(500) |
| def internal_server_error(error): |
| return make_response(jsonify( |
| { |
| 'status': 'error', |
| 'msg': 'Internal Server Error', |
| 'detail': str(error) |
| } |
| ), 500) |
|
|
|
|
| @app.route('/', methods=['GET']) |
| def ping(): |
| return make_response(jsonify( |
| { |
| 'status': 'ok' |
| } |
| ), 200) |
|
|
|
|
| @app.route('/faceswap', methods=['POST']) |
| def face_swap_api(): |
| total_timer = Timer() |
| logging.debug('Received face swap API request') |
| payload = request.get_json() |
|
|
| if not os.path.exists(TMP_PATH): |
| logging.debug(f'Creating temporary directory: {TMP_PATH}') |
| os.makedirs(TMP_PATH) |
|
|
| unique_id = uuid.uuid4() |
| source_image_data = payload['source_image'] |
| target_image_data = payload['target_image'] |
|
|
| |
| source_image = base64.b64decode(source_image_data) |
| source_file_extension = determine_file_extension(source_image_data) |
| source_image_path = f'{TMP_PATH}/source_{unique_id}{source_file_extension}' |
|
|
| |
| with open(source_image_path, 'wb') as source_file: |
| source_file.write(source_image) |
|
|
| |
| target_image = base64.b64decode(target_image_data) |
| target_file_extension = determine_file_extension(target_image_data) |
| target_image_path = f'{TMP_PATH}/target_{unique_id}{target_file_extension}' |
|
|
| |
| with open(target_image_path, 'wb') as target_file: |
| target_file.write(target_image) |
|
|
| |
| if 'source_indexes' not in payload: |
| payload['source_indexes'] = '-1' |
|
|
| if 'target_indexes' not in payload: |
| payload['target_indexes'] = '-1' |
|
|
| if 'background_enhance' not in payload: |
| payload['background_enhance'] = True |
|
|
| if 'face_restore' not in payload: |
| payload['face_restore'] = True |
|
|
| if 'face_upsample' not in payload: |
| payload['face_upsample'] = True |
|
|
| if 'upscale' not in payload: |
| payload['upscale'] = 1 |
|
|
| if 'codeformer_fidelity' not in payload: |
| payload['codeformer_fidelity'] = 0.5 |
|
|
| if 'output_format' not in payload: |
| payload['output_format'] = 'JPEG' |
|
|
| try: |
| logging.debug(f'Source indexes: {payload["source_indexes"]}') |
| logging.debug(f'Target indexes: {payload["target_indexes"]}') |
| logging.debug(f'Background enhance: {payload["background_enhance"]}') |
| logging.debug(f'Face Restoration: {payload["face_restore"]}') |
| logging.debug(f'Face Upsampling: {payload["face_upsample"]}') |
| logging.debug(f'Upscale: {payload["upscale"]}') |
| logging.debug(f'Codeformer Fidelity: {payload["codeformer_fidelity"]}') |
| logging.debug(f'Output Format: {payload["output_format"]}') |
|
|
| result_image = face_swap( |
| source_image_path, |
| target_image_path, |
| payload['source_indexes'], |
| payload['target_indexes'], |
| payload['background_enhance'], |
| payload['face_restore'], |
| payload['face_upsample'], |
| payload['upscale'], |
| payload['codeformer_fidelity'], |
| payload['output_format'] |
| ) |
|
|
| status_code = 200 |
|
|
| response = { |
| 'status': 'ok', |
| 'image': result_image |
| } |
| except Exception as e: |
| logging.error(e) |
|
|
| response = { |
| 'status': 'error', |
| 'msg': 'Face swap failed', |
| 'detail': str(e) |
| } |
|
|
| status_code = 500 |
|
|
| |
| os.remove(source_image_path) |
| os.remove(target_image_path) |
|
|
| total_time = total_timer.get_elapsed_time() |
| logging.info(f'Total time taken for face swap API call {total_time} seconds') |
|
|
| return make_response(jsonify(response), status_code) |
|
|
|
|
| if __name__ == '__main__': |
| args = get_args() |
|
|
| serve( |
| app, |
| host=args.host, |
| port=args.port |
| ) |
|
|