| | |
| | import json |
| | import sys |
| | import os |
| | import io |
| | import argparse |
| | import uuid |
| | import base64 |
| | import logging |
| | import time |
| | import copy |
| | import cv2 |
| | import insightface |
| | import numpy as np |
| | from typing import List, Union |
| | from PIL import Image |
| | from restoration import * |
| | from flask import Flask, request, jsonify, make_response |
| | from waitress import serve |
| |
|
| | LOG_LEVEL = logging.DEBUG |
| | TMP_PATH = '/tmp/inswapper' |
| | script_dir = os.path.dirname(os.path.abspath(__file__)) |
| | log_path = '' |
| |
|
| | |
| | if sys.platform == 'linux': |
| | log_path = '/var/log/' |
| |
|
| | logging.basicConfig( |
| | filename=f'{log_path}inswapper.log', |
| | format='%(asctime)s : %(levelname)s : %(message)s', |
| | level=LOG_LEVEL |
| | ) |
| |
|
| | logging.getLogger().addHandler(logging.StreamHandler(sys.stdout)) |
| |
|
| |
|
| | def process_request(request_obj): |
| | try: |
| | logging.debug('Swapping face') |
| | face_swap_timer = Timer() |
| | result_image = face_swap(request_obj['source_image'], request_obj['target_image']) |
| | face_swap_time = face_swap_timer.get_elapsed_time() |
| | logging.info(f'Time taken to swap face: {face_swap_time} seconds') |
| |
|
| | response = { |
| | 'status': 'ok', |
| | 'image': result_image |
| | } |
| | except Exception as e: |
| | logging.error(e) |
| | response = { |
| | 'status': 'error', |
| | 'msg': 'Face swap failed', |
| | 'detail': str(e) |
| | } |
| |
|
| | return response |
| |
|
| |
|
| | class Timer: |
| | def __init__(self): |
| | self.start = time.time() |
| |
|
| | def restart(self): |
| | self.start = time.time() |
| |
|
| | def get_elapsed_time(self): |
| | end = time.time() |
| | return round(end - self.start, 1) |
| |
|
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser( |
| | description='Inswapper REST API' |
| | ) |
| |
|
| | parser.add_argument( |
| | '-p', '--port', |
| | help='Port to listen on', |
| | type=int, |
| | default=80 |
| | ) |
| |
|
| | parser.add_argument( |
| | '-H', '--host', |
| | help='Host to bind to', |
| | default='0.0.0.0' |
| | ) |
| |
|
| | return parser.parse_args() |
| |
|
| |
|
| | def determine_file_extension(image_data): |
| | try: |
| | if image_data.startswith('/9j/'): |
| | image_extension = '.jpg' |
| | elif image_data.startswith('iVBORw0Kg'): |
| | image_extension = '.png' |
| | else: |
| | |
| | image_extension = '.png' |
| | except Exception as e: |
| | image_extension = '.png' |
| |
|
| | return image_extension |
| |
|
| |
|
| | def write_base64_to_disk(file_b64: str, file_path: str): |
| | with open(file_path, 'wb') as file: |
| | file.write(base64.b64decode(file_b64)) |
| |
|
| |
|
| | def get_face_swap_model(model_path: str): |
| | model = insightface.model_zoo.get_model(model_path) |
| | return model |
| |
|
| |
|
| | def get_face_analyser(model_path: str, |
| | det_size=(320, 320)): |
| | face_analyser = insightface.app.FaceAnalysis(name="buffalo_l", root="./checkpoints") |
| | face_analyser.prepare(ctx_id=0, det_size=det_size) |
| | return face_analyser |
| |
|
| |
|
| | def get_one_face(face_analyser, |
| | frame:np.ndarray): |
| | face = face_analyser.get(frame) |
| | try: |
| | return min(face, key=lambda x: x.bbox[0]) |
| | except ValueError: |
| | return None |
| |
|
| |
|
| | def get_many_faces(face_analyser, |
| | frame:np.ndarray): |
| | """ |
| | get faces from left to right by order |
| | """ |
| | try: |
| | face = face_analyser.get(frame) |
| | return sorted(face, key=lambda x: x.bbox[0]) |
| | except IndexError: |
| | return None |
| |
|
| |
|
| | def swap_face(face_swapper, |
| | source_faces, |
| | target_faces, |
| | source_index, |
| | target_index, |
| | temp_frame): |
| | """ |
| | paste source_face on target image |
| | """ |
| | source_face = source_faces[source_index] |
| | target_face = target_faces[target_index] |
| |
|
| | return face_swapper.get(temp_frame, target_face, source_face, paste_back=True) |
| |
|
| |
|
| | def process(source_img: Union[Image.Image, List], |
| | target_img: Image.Image, |
| | source_indexes: str, |
| | target_indexes: str, |
| | model: str): |
| |
|
| | |
| | face_analyser = get_face_analyser(model) |
| |
|
| | |
| | model_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), model) |
| | face_swapper = get_face_swap_model(model_path) |
| |
|
| | |
| | target_img = cv2.cvtColor(np.array(target_img), cv2.COLOR_RGB2BGR) |
| |
|
| | |
| | target_faces = get_many_faces(face_analyser, target_img) |
| | num_target_faces = len(target_faces) |
| | num_source_images = len(source_img) |
| |
|
| | if target_faces is not None: |
| | temp_frame = copy.deepcopy(target_img) |
| | if isinstance(source_img, list) and num_source_images == num_target_faces: |
| | logging.debug('Replacing the faces in the target image from left to right by order') |
| | for i in range(num_target_faces): |
| | source_faces = get_many_faces(face_analyser, cv2.cvtColor(np.array(source_img[i]), cv2.COLOR_RGB2BGR)) |
| | source_index = i |
| | target_index = i |
| |
|
| | if source_faces is None: |
| | raise Exception('No source faces found!') |
| |
|
| | temp_frame = swap_face( |
| | face_swapper, |
| | source_faces, |
| | target_faces, |
| | source_index, |
| | target_index, |
| | temp_frame |
| | ) |
| | elif num_source_images == 1: |
| | |
| | source_faces = get_many_faces(face_analyser, cv2.cvtColor(np.array(source_img[0]), cv2.COLOR_RGB2BGR)) |
| | num_source_faces = len(source_faces) |
| | logging.debug(f'Source faces: {num_source_faces}') |
| | logging.debug(f'Target faces: {num_target_faces}') |
| |
|
| | if source_faces is None: |
| | raise Exception('No source faces found!') |
| |
|
| | if target_indexes == "-1": |
| | if num_source_faces == 1: |
| | logging.debug('Replacing all faces in target image with the same face from the source image') |
| | num_iterations = num_target_faces |
| | elif num_source_faces < num_target_faces: |
| | logging.debug('There are less faces in the source image than the target image, replacing as many as we can') |
| | num_iterations = num_source_faces |
| | elif num_target_faces < num_source_faces: |
| | logging.debug('There are less faces in the target image than the source image, replacing as many as we can') |
| | num_iterations = num_target_faces |
| | else: |
| | logging.debug('Replacing all faces in the target image with the faces from the source image') |
| | num_iterations = num_target_faces |
| |
|
| | for i in range(num_iterations): |
| | source_index = 0 if num_source_faces == 1 else i |
| | target_index = i |
| |
|
| | temp_frame = swap_face( |
| | face_swapper, |
| | source_faces, |
| | target_faces, |
| | source_index, |
| | target_index, |
| | temp_frame |
| | ) |
| | elif source_indexes == '-1' and target_indexes == '-1': |
| | logging.debug('Replacing specific face(s) in the target image with the face from the source image') |
| | target_indexes = target_indexes.split(',') |
| | source_index = 0 |
| |
|
| | for target_index in target_indexes: |
| | target_index = int(target_index) |
| |
|
| | temp_frame = swap_face( |
| | face_swapper, |
| | source_faces, |
| | target_faces, |
| | source_index, |
| | target_index, |
| | temp_frame |
| | ) |
| | else: |
| | logging.debug('Replacing specific face(s) in the target image with specific face(s) from the source image') |
| |
|
| | if source_indexes == "-1": |
| | source_indexes = ','.join(map(lambda x: str(x), range(num_source_faces))) |
| |
|
| | if target_indexes == "-1": |
| | target_indexes = ','.join(map(lambda x: str(x), range(num_target_faces))) |
| |
|
| | source_indexes = source_indexes.split(',') |
| | target_indexes = target_indexes.split(',') |
| | num_source_faces_to_swap = len(source_indexes) |
| | num_target_faces_to_swap = len(target_indexes) |
| |
|
| | if num_source_faces_to_swap > num_source_faces: |
| | raise Exception('Number of source indexes is greater than the number of faces in the source image') |
| |
|
| | if num_target_faces_to_swap > num_target_faces: |
| | raise Exception('Number of target indexes is greater than the number of faces in the target image') |
| |
|
| | if num_source_faces_to_swap > num_target_faces_to_swap: |
| | num_iterations = num_source_faces_to_swap |
| | else: |
| | num_iterations = num_target_faces_to_swap |
| |
|
| | if num_source_faces_to_swap == num_target_faces_to_swap: |
| | for index in range(num_iterations): |
| | source_index = int(source_indexes[index]) |
| | target_index = int(target_indexes[index]) |
| |
|
| | if source_index > num_source_faces-1: |
| | raise ValueError(f'Source index {source_index} is higher than the number of faces in the source image') |
| |
|
| | if target_index > num_target_faces-1: |
| | raise ValueError(f'Target index {target_index} is higher than the number of faces in the target image') |
| |
|
| | temp_frame = swap_face( |
| | face_swapper, |
| | source_faces, |
| | target_faces, |
| | source_index, |
| | target_index, |
| | temp_frame |
| | ) |
| | else: |
| | logging.error('Unsupported face configuration') |
| | raise Exception('Unsupported face configuration') |
| | result = temp_frame |
| | else: |
| | logging.error('No target faces found') |
| | raise Exception('No target faces found!') |
| |
|
| | result_image = Image.fromarray(cv2.cvtColor(result, cv2.COLOR_BGR2RGB)) |
| | return result_image |
| |
|
| |
|
| | def face_swap(src_img_path, |
| | target_img_path, |
| | source_indexes, |
| | target_indexes, |
| | background_enhance, |
| | face_restore, |
| | face_upsample, |
| | upscale, |
| | codeformer_fidelity, |
| | output_format): |
| |
|
| | source_img_paths = src_img_path.split(';') |
| | source_img = [Image.open(img_path) for img_path in source_img_paths] |
| | target_img = Image.open(target_img_path) |
| |
|
| | |
| | model = os.path.join(script_dir, 'checkpoints/inswapper_128.onnx') |
| | logging.debug(f'Face swap model: {model}') |
| |
|
| | try: |
| | logging.debug('Performing face swap') |
| | result_image = process( |
| | source_img, |
| | target_img, |
| | source_indexes, |
| | target_indexes, |
| | model |
| | ) |
| | logging.debug('Face swap complete') |
| | except Exception as e: |
| | raise |
| |
|
| | |
| | check_ckpts() |
| |
|
| | if face_restore: |
| | |
| | logging.debug('Setting upsampler to RealESRGAN_x2plus') |
| | upsampler = set_realesrgan() |
| |
|
| | if torch.cuda.is_available(): |
| | torch_device = 'cuda' |
| | else: |
| | torch_device = 'cpu' |
| |
|
| | logging.debug(f'Torch device: {torch_device.upper()}') |
| | device = torch.device(torch_device) |
| |
|
| | codeformer_net = ARCH_REGISTRY.get('CodeFormer')( |
| | dim_embd=512, |
| | codebook_size=1024, |
| | n_head=8, |
| | n_layers=9, |
| | connect_list=['32', '64', '128', '256'], |
| | ).to(device) |
| |
|
| | ckpt_path = os.path.join(script_dir, 'CodeFormer/CodeFormer/weights/CodeFormer/codeformer.pth') |
| | logging.debug(f'Loading CodeFormer model: {ckpt_path}') |
| | checkpoint = torch.load(ckpt_path)['params_ema'] |
| | codeformer_net.load_state_dict(checkpoint) |
| | codeformer_net.eval() |
| | result_image = cv2.cvtColor(np.array(result_image), cv2.COLOR_RGB2BGR) |
| | logging.debug('Performing face restoration using CodeFormer') |
| |
|
| | try: |
| | result_image = face_restoration( |
| | result_image, |
| | background_enhance, |
| | face_upsample, |
| | upscale, |
| | codeformer_fidelity, |
| | upsampler, |
| | codeformer_net, |
| | device |
| | ) |
| | except Exception as e: |
| | raise |
| |
|
| | logging.debug('CodeFormer face restoration completed successfully') |
| | result_image = Image.fromarray(result_image) |
| |
|
| | output_buffer = io.BytesIO() |
| | result_image.save(output_buffer, format=output_format) |
| | image_data = output_buffer.getvalue() |
| |
|
| | return base64.b64encode(image_data).decode('utf-8') |
| |
|
| |
|
| | app = Flask(__name__) |
| |
|
| |
|
| | @app.errorhandler(400) |
| | def not_found(error): |
| | return make_response(jsonify( |
| | { |
| | 'status': 'error', |
| | 'msg': f'Bad Request', |
| | 'detail': str(error) |
| | } |
| | ), 400) |
| |
|
| |
|
| | @app.errorhandler(404) |
| | def not_found(error): |
| | return make_response(jsonify( |
| | { |
| | 'status': 'error', |
| | 'msg': f'{request.url} not found', |
| | 'detail': str(error) |
| | } |
| | ), 404) |
| |
|
| |
|
| | @app.errorhandler(500) |
| | def internal_server_error(error): |
| | return make_response(jsonify( |
| | { |
| | 'status': 'error', |
| | 'msg': 'Internal Server Error', |
| | 'detail': str(error) |
| | } |
| | ), 500) |
| |
|
| |
|
| | @app.route('/', methods=['GET']) |
| | def ping(): |
| | return make_response(jsonify( |
| | { |
| | 'status': 'ok' |
| | } |
| | ), 200) |
| |
|
| |
|
| | @app.route('/faceswap', methods=['POST']) |
| | def face_swap_api(): |
| | total_timer = Timer() |
| | logging.debug('Received face swap API request') |
| | payload = request.get_json() |
| |
|
| | if not os.path.exists(TMP_PATH): |
| | logging.debug(f'Creating temporary directory: {TMP_PATH}') |
| | os.makedirs(TMP_PATH) |
| |
|
| | unique_id = uuid.uuid4() |
| | source_image_data = payload['source_image'] |
| | target_image_data = payload['target_image'] |
| |
|
| | |
| | source_image = base64.b64decode(source_image_data) |
| | source_file_extension = determine_file_extension(source_image_data) |
| | source_image_path = f'{TMP_PATH}/source_{unique_id}{source_file_extension}' |
| |
|
| | |
| | with open(source_image_path, 'wb') as source_file: |
| | source_file.write(source_image) |
| |
|
| | |
| | target_image = base64.b64decode(target_image_data) |
| | target_file_extension = determine_file_extension(target_image_data) |
| | target_image_path = f'{TMP_PATH}/target_{unique_id}{target_file_extension}' |
| |
|
| | |
| | with open(target_image_path, 'wb') as target_file: |
| | target_file.write(target_image) |
| |
|
| | |
| | if 'source_indexes' not in payload: |
| | payload['source_indexes'] = '-1' |
| |
|
| | if 'target_indexes' not in payload: |
| | payload['target_indexes'] = '-1' |
| |
|
| | if 'background_enhance' not in payload: |
| | payload['background_enhance'] = True |
| |
|
| | if 'face_restore' not in payload: |
| | payload['face_restore'] = True |
| |
|
| | if 'face_upsample' not in payload: |
| | payload['face_upsample'] = True |
| |
|
| | if 'upscale' not in payload: |
| | payload['upscale'] = 1 |
| |
|
| | if 'codeformer_fidelity' not in payload: |
| | payload['codeformer_fidelity'] = 0.5 |
| |
|
| | if 'output_format' not in payload: |
| | payload['output_format'] = 'JPEG' |
| |
|
| | try: |
| | logging.debug(f'Source indexes: {payload["source_indexes"]}') |
| | logging.debug(f'Target indexes: {payload["target_indexes"]}') |
| | logging.debug(f'Background enhance: {payload["background_enhance"]}') |
| | logging.debug(f'Face Restoration: {payload["face_restore"]}') |
| | logging.debug(f'Face Upsampling: {payload["face_upsample"]}') |
| | logging.debug(f'Upscale: {payload["upscale"]}') |
| | logging.debug(f'Codeformer Fidelity: {payload["codeformer_fidelity"]}') |
| | logging.debug(f'Output Format: {payload["output_format"]}') |
| |
|
| | result_image = face_swap( |
| | source_image_path, |
| | target_image_path, |
| | payload['source_indexes'], |
| | payload['target_indexes'], |
| | payload['background_enhance'], |
| | payload['face_restore'], |
| | payload['face_upsample'], |
| | payload['upscale'], |
| | payload['codeformer_fidelity'], |
| | payload['output_format'] |
| | ) |
| |
|
| | status_code = 200 |
| |
|
| | response = { |
| | 'status': 'ok', |
| | 'image': result_image |
| | } |
| | except Exception as e: |
| | logging.error(e) |
| |
|
| | response = { |
| | 'status': 'error', |
| | 'msg': 'Face swap failed', |
| | 'detail': str(e) |
| | } |
| |
|
| | status_code = 500 |
| |
|
| | |
| | os.remove(source_image_path) |
| | os.remove(target_image_path) |
| |
|
| | total_time = total_timer.get_elapsed_time() |
| | logging.info(f'Total time taken for face swap API call {total_time} seconds') |
| |
|
| | return make_response(jsonify(response), status_code) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | args = get_args() |
| |
|
| | serve( |
| | app, |
| | host=args.host, |
| | port=args.port |
| | ) |