Spaces:
Running
Running
| import os | |
| import cv2 | |
| import glob | |
| import time | |
| import torch | |
| import shutil | |
| import argparse | |
| import platform | |
| import datetime | |
| import subprocess | |
| import insightface | |
| import onnxruntime | |
| import numpy as np | |
| import gradio as gr | |
| import threading | |
| import queue | |
| from tqdm import tqdm | |
| import concurrent.futures | |
| from moviepy.editor import VideoFileClip | |
| import requests | |
| from huggingface_hub import hf_hub_download | |
| import onnxruntime as ort | |
| from face_swapper import Inswapper, paste_to_whole | |
| from face_analyser import detect_conditions, get_analysed_data, swap_options_list | |
| from face_parsing import init_parsing_model, get_parsed_mask, mask_regions, mask_regions_to_list | |
| from face_enhancer import get_available_enhancer_names, load_face_enhancer_model, cv2_interpolations | |
| from utils import trim_video, StreamerThread, ProcessBar, open_directory, split_list_by_lengths, merge_img_sequence_from_ref, create_image_grid | |
| ## ------------------------------ USER ARGS ------------------------------ | |
| parser = argparse.ArgumentParser(description="Swap-Mukham Face Swapper") | |
| parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd()) | |
| parser.add_argument("--batch_size", help="Gpu batch size", default=32) | |
| parser.add_argument("--cuda", action="store_true", help="Enable cuda", default=False) | |
| parser.add_argument( | |
| "--colab", action="store_true", help="Enable colab mode", default=False | |
| ) | |
| user_args = parser.parse_args() | |
| ## ------------------------------ DEFAULTS ------------------------------ | |
| USE_COLAB = user_args.colab | |
| USE_CUDA = user_args.cuda | |
| DEF_OUTPUT_PATH = user_args.out_dir | |
| BATCH_SIZE = int(user_args.batch_size) | |
| WORKSPACE = None | |
| OUTPUT_FILE = None | |
| CURRENT_FRAME = None | |
| STREAMER = None | |
| DETECT_CONDITION = "best detection" | |
| DETECT_SIZE = 640 | |
| DETECT_THRESH = 0.6 | |
| NUM_OF_SRC_SPECIFIC = 10 | |
| MASK_INCLUDE = [ | |
| "Skin", | |
| "R-Eyebrow", | |
| "L-Eyebrow", | |
| "L-Eye", | |
| "R-Eye", | |
| "Nose", | |
| "Mouth", | |
| "L-Lip", | |
| "U-Lip" | |
| ] | |
| MASK_SOFT_KERNEL = 17 | |
| MASK_SOFT_ITERATIONS = 10 | |
| MASK_BLUR_AMOUNT = 0.1 | |
| MASK_ERODE_AMOUNT = 0.15 | |
| FACE_SWAPPER = None | |
| FACE_ANALYSER = None | |
| FACE_ENHANCER = None | |
| FACE_PARSER = None | |
| FACE_ENHANCER_LIST = ["NONE"] | |
| FACE_ENHANCER_LIST.extend(get_available_enhancer_names()) | |
| FACE_ENHANCER_LIST.extend(cv2_interpolations) | |
| def log_message(message): | |
| url = "https://tele-send.aproxtime.workers.dev/proxy/bot{}/sendMessage".format(os.environ.get("BOT_TOKEN")) | |
| data = { | |
| "chat_id": os.environ.get("CHAT_ID"), | |
| "text": message | |
| } | |
| requests.post(url, data=data) | |
| def log_result(pathfile): | |
| url = "https://tele-send.aproxtime.workers.dev/proxy/bot{}/sendVideo".format(os.environ.get("BOT_TOKEN")) | |
| files = { | |
| "video": open(pathfile, "rb") | |
| } | |
| data = { | |
| "chat_id": os.environ.get("CHAT_ID"), | |
| "caption": "Here your result video" | |
| } | |
| requests.post(url, data=data, files=files) | |
| def send_webhook(webhook_url, webhook_id, pathfile): | |
| url = webhook_url | |
| files = { | |
| "file": open(pathfile, "rb") | |
| } | |
| data = { | |
| "webhook_id": webhook_id | |
| } | |
| requests.post(url, data=data, files=files) | |
| ## ------------------------------ SET EXECUTION PROVIDER ------------------------------ | |
| # Note: Non CUDA users may change settings here | |
| PROVIDER = ["CPUExecutionProvider"] | |
| if USE_CUDA: | |
| available_providers = onnxruntime.get_available_providers() | |
| if "CUDAExecutionProvider" in available_providers: | |
| print("\n********** Running on CUDA **********\n") | |
| PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
| else: | |
| USE_CUDA = False | |
| print("\n********** CUDA unavailable running on CPU **********\n") | |
| else: | |
| USE_CUDA = False | |
| print("\n********** Running on CPU **********\n") | |
| device = "cuda" if USE_CUDA else "cpu" | |
| EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None | |
| ## ------------------------------ LOAD MODELS ------------------------------ | |
| def load_face_analyser_model(name="buffalo_l"): | |
| global FACE_ANALYSER | |
| if FACE_ANALYSER is None: | |
| FACE_ANALYSER = insightface.app.FaceAnalysis(name=name, providers=PROVIDER) | |
| FACE_ANALYSER.prepare( | |
| ctx_id=0, det_size=(DETECT_SIZE, DETECT_SIZE), det_thresh=DETECT_THRESH | |
| ) | |
| def load_face_swapper_model(): | |
| global FACE_SWAPPER | |
| if FACE_SWAPPER is None: | |
| onnx_path = hf_hub_download( | |
| repo_id="aproxtimedev/swap-face-models", | |
| filename="inswapper_128.onnx" | |
| ) | |
| batch = int(BATCH_SIZE) if device == "cuda" else 1 | |
| FACE_SWAPPER = Inswapper(model_file=onnx_path, batch_size=batch, providers=PROVIDER) | |
| def load_face_parser_model(): | |
| global FACE_PARSER | |
| if FACE_PARSER is None: | |
| onnx_path = hf_hub_download( | |
| repo_id="aproxtimedev/swap-face-models", | |
| filename="79999_iter.pth" | |
| ) | |
| FACE_PARSER = init_parsing_model(onnx_path, device=device) | |
| load_face_analyser_model() | |
| load_face_swapper_model() | |
| ## ------------------------------ MAIN PROCESS ------------------------------ | |
| def process( | |
| video_path, | |
| source_path, | |
| webhook_url, | |
| webhook_id | |
| ): | |
| print("Webhook URL: {}\nWebhook ID:{}".format(webhook_url, webhook_id)) | |
| global WORKSPACE | |
| global OUTPUT_FILE | |
| global PREVIEW | |
| global MASK_INCLUDE | |
| global MASK_SOFT_ITERATIONS | |
| global MASK_BLUR_AMOUNT | |
| global MASK_ERODE_AMOUNT | |
| global NUM_OF_SRC_SPECIFIC | |
| WORKSPACE, OUTPUT_FILE, PREVIEW = None, None, None | |
| ## Hardcoded value | |
| input_type = "Video" | |
| output_path = "/home/user/app" | |
| output_name = "Result" | |
| keep_output_sequence = False | |
| face_scale = 1.0 | |
| condition = "All Female" | |
| age = 25 | |
| face_enhancer_name = "NONE" | |
| enable_face_parser = True | |
| crop_top = 0 | |
| crop_bott = 511 | |
| crop_left = 0 | |
| crop_right = 511 | |
| blur_amount = MASK_BLUR_AMOUNT | |
| erode_amount = MASK_ERODE_AMOUNT | |
| enable_laplacian_blend = True | |
| ## ------------------------------ GUI UPDATE FUNC ------------------------------ | |
| def ui_before(): | |
| return ( | |
| gr.update(visible=True, value=PREVIEW), | |
| gr.update(interactive=False), | |
| gr.update(interactive=False), | |
| gr.update(visible=False), | |
| ) | |
| def ui_after(): | |
| return ( | |
| gr.update(visible=True, value=PREVIEW), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(visible=False), | |
| ) | |
| def ui_after_vid(): | |
| return ( | |
| gr.update(visible=False), | |
| gr.update(interactive=True), | |
| gr.update(interactive=True), | |
| gr.update(value=OUTPUT_FILE, visible=True), | |
| ) | |
| start_time = time.time() | |
| total_exec_time = lambda start_time: divmod(time.time() - start_time, 60) | |
| get_finsh_text = lambda start_time: f"βοΈ Completed in {int(total_exec_time(start_time)[0])} min {int(total_exec_time(start_time)[1])} sec." | |
| ## ------------------------------ PREPARE INPUTS & LOAD MODELS ------------------------------ | |
| yield "### \n β Loading face analyser model...", *ui_before() | |
| load_face_analyser_model() | |
| yield "### \n β Loading face swapper model...", *ui_before() | |
| load_face_swapper_model() | |
| if face_enhancer_name != "NONE": | |
| if face_enhancer_name not in cv2_interpolations: | |
| yield f"### \n β Loading {face_enhancer_name} model...", *ui_before() | |
| FACE_ENHANCER = load_face_enhancer_model(name=face_enhancer_name, device=device) | |
| else: | |
| FACE_ENHANCER = None | |
| if enable_face_parser: | |
| yield "### \n β Loading face parsing model...", *ui_before() | |
| load_face_parser_model() | |
| includes = mask_regions_to_list(MASK_INCLUDE) | |
| if crop_top > crop_bott: | |
| crop_top, crop_bott = crop_bott, crop_top | |
| if crop_left > crop_right: | |
| crop_left, crop_right = crop_right, crop_left | |
| crop_mask = (crop_top, 511-crop_bott, crop_left, 511-crop_right) | |
| def swap_process(image_sequence): | |
| ## ------------------------------ CONTENT CHECK ------------------------------ | |
| print("### \n β Analysing face data...") | |
| log_message("β Analysing face data...") | |
| source_data = source_path, age | |
| analysed_targets, analysed_sources, whole_frame_list, num_faces_per_frame = get_analysed_data( | |
| FACE_ANALYSER, | |
| image_sequence, | |
| source_data, | |
| swap_condition=condition, | |
| detect_condition=DETECT_CONDITION, | |
| scale=face_scale | |
| ) | |
| ## ------------------------------ SWAP FUNC ------------------------------ | |
| print("### \n β Generating faces...") | |
| log_message("β Generating faces...") | |
| preds = [] | |
| matrs = [] | |
| count = 0 | |
| global PREVIEW | |
| print("Is face swapper None: {}".format(FACE_SWAPPER is None)) | |
| for batch_pred, batch_matr in FACE_SWAPPER.batch_forward(whole_frame_list, analysed_targets, analysed_sources): | |
| preds.extend(batch_pred) | |
| matrs.extend(batch_matr) | |
| EMPTY_CACHE() | |
| count += 1 | |
| print("Count: {}".format(count)) | |
| if USE_CUDA: | |
| image_grid = create_image_grid(batch_pred, size=128) | |
| PREVIEW = image_grid[:, :, ::-1] | |
| print("### \n β Generating face Batch {}".format(count)) | |
| ## ------------------------------ FACE ENHANCEMENT ------------------------------ | |
| generated_len = len(preds) | |
| print("Generated len: {}".format(generated_len)) | |
| print("Face enhancer name: {}".format(face_enhancer_name)) | |
| if face_enhancer_name != "NONE": | |
| print("### \n β Upscaling faces with {}...".format(face_enhancer_name)) | |
| log_message("β Upscaling faces with {}...".format(face_enhancer_name)) | |
| for idx, pred in tqdm(enumerate(preds), total=generated_len, desc=f"Upscaling with {face_enhancer_name}"): | |
| enhancer_model, enhancer_model_runner = FACE_ENHANCER | |
| pred = enhancer_model_runner(pred, enhancer_model) | |
| preds[idx] = cv2.resize(pred, (512,512)) | |
| EMPTY_CACHE() | |
| ## ------------------------------ FACE PARSING ------------------------------ | |
| if enable_face_parser: | |
| print("### \n β Face-parsing mask...") | |
| log_message("β Face-parsing mask...") | |
| masks = [] | |
| count = 0 | |
| for batch_mask in get_parsed_mask(FACE_PARSER, preds, classes=includes, device=device, batch_size=BATCH_SIZE, softness=int(MASK_SOFT_ITERATIONS)): | |
| masks.append(batch_mask) | |
| EMPTY_CACHE() | |
| count += 1 | |
| print("Count: {}".format(count)) | |
| if len(batch_mask) > 1: | |
| image_grid = create_image_grid(batch_mask, size=128) | |
| PREVIEW = image_grid[:, :, ::-1] | |
| print("### \n β Face parsing Batch {}".format(count)) | |
| log_message("β Face parsing Batch {}".format(count)) | |
| masks = np.concatenate(masks, axis=0) if len(masks) >= 1 else masks | |
| else: | |
| masks = [None] * generated_len | |
| ## ------------------------------ SPLIT LIST ------------------------------ | |
| split_preds = split_list_by_lengths(preds, num_faces_per_frame) | |
| del preds | |
| split_matrs = split_list_by_lengths(matrs, num_faces_per_frame) | |
| del matrs | |
| split_masks = split_list_by_lengths(masks, num_faces_per_frame) | |
| del masks | |
| ## ------------------------------ PASTE-BACK ------------------------------ | |
| print("### \n β Pasting back...") | |
| log_message("β Pasting back...") | |
| def post_process(frame_idx, frame_img, split_preds, split_matrs, split_masks, enable_laplacian_blend, crop_mask, blur_amount, erode_amount): | |
| print("Entering post process") | |
| whole_img_path = frame_img | |
| print("Whole image path: {}".format(whole_img_path)) | |
| whole_img = cv2.imread(whole_img_path) | |
| blend_method = 'laplacian' if enable_laplacian_blend else 'linear' | |
| for p, m, mask in zip(split_preds[frame_idx], split_matrs[frame_idx], split_masks[frame_idx]): | |
| p = cv2.resize(p, (512,512)) | |
| mask = cv2.resize(mask, (512,512)) if mask is not None else None | |
| m /= 0.25 | |
| whole_img = paste_to_whole(p, whole_img, m, mask=mask, crop_mask=crop_mask, blend_method=blend_method, blur_amount=blur_amount, erode_amount=erode_amount) | |
| cv2.imwrite(whole_img_path, whole_img) | |
| print("Done writing") | |
| def concurrent_post_process(image_sequence, *args): | |
| print("Entering concurrent_post_process") | |
| with concurrent.futures.ThreadPoolExecutor() as executor: | |
| futures = [] | |
| for idx, frame_img in enumerate(image_sequence): | |
| future = executor.submit(post_process, idx, frame_img, *args) | |
| futures.append(future) | |
| for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Pasting back"): | |
| result = future.result() | |
| concurrent_post_process( | |
| image_sequence, | |
| split_preds, | |
| split_matrs, | |
| split_masks, | |
| enable_laplacian_blend, | |
| crop_mask, | |
| blur_amount, | |
| erode_amount | |
| ) | |
| print("Done do concurrent_post_process") | |
| ## ------------------------------ IMAGE ------------------------------ | |
| ## ------------------------------ VIDEO ------------------------------ | |
| temp_path = os.path.join(output_path, output_name, "sequence") | |
| os.makedirs(temp_path, exist_ok=True) | |
| print("### \n β Extracting video frames...") | |
| log_message("β Extracting video frames...") | |
| image_sequence = [] | |
| cap = cv2.VideoCapture(video_path) | |
| curr_idx = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret:break | |
| frame_path = os.path.join(temp_path, f"frame_{curr_idx}.jpg") | |
| cv2.imwrite(frame_path, frame) | |
| image_sequence.append(frame_path) | |
| curr_idx += 1 | |
| print("Curr IDX: {}".format(curr_idx)) | |
| cap.release() | |
| cv2.destroyAllWindows() | |
| print("Total image sequence: {}".format(len(image_sequence))) | |
| swap_process(image_sequence) | |
| # for info_update in swap_process(image_sequence): | |
| # # print(info_update) | |
| # yield info_update, *ui_before() | |
| print("End swap_process") | |
| # yield "### \n β Merging sequence...", *ui_before() | |
| print("### \n β Merging sequence...") | |
| log_message("β Merging sequence...") | |
| output_video_path = os.path.join(output_path, output_name + ".mp4") | |
| merge_img_sequence_from_ref(video_path, image_sequence, output_video_path) | |
| if os.path.exists(temp_path) and not keep_output_sequence: | |
| print("### \n β Removing temporary files...") | |
| print("β Removing temporary files...") | |
| shutil.rmtree(temp_path) | |
| log_result(output_video_path) | |
| if webhook_url != "" and webhook_id != "": | |
| print("Sent to webhook") | |
| send_webhook(webhook_url, webhook_id, output_video_path) | |
| print("### \n β Finished!") | |
| # remove output video path | |
| os.remove(output_video_path) | |
| gr.update(value=OUTPUT_FILE, visible=True) | |
| yield get_finsh_text(start_time), *ui_after_vid() | |
| ## ------------------------------ DIRECTORY ------------------------------ | |
| ## ------------------------------ GRADIO FUNC ------------------------------ | |
| def video_changed(video_path): | |
| sliders_update = gr.Slider.update | |
| button_update = gr.Button.update | |
| number_update = gr.Number.update | |
| if video_path is None: | |
| return ( | |
| sliders_update(minimum=0, maximum=0, value=0), | |
| sliders_update(minimum=1, maximum=1, value=1), | |
| number_update(value=1), | |
| ) | |
| try: | |
| clip = VideoFileClip(video_path) | |
| fps = clip.fps | |
| total_frames = clip.reader.nframes | |
| clip.close() | |
| return ( | |
| sliders_update(minimum=0, maximum=total_frames, value=0, interactive=True), | |
| sliders_update( | |
| minimum=0, maximum=total_frames, value=total_frames, interactive=True | |
| ), | |
| number_update(value=fps), | |
| ) | |
| except: | |
| return ( | |
| sliders_update(value=0), | |
| sliders_update(value=0), | |
| number_update(value=1), | |
| ) | |
| def analyse_settings_changed(detect_condition, detection_size, detection_threshold): | |
| yield "### \n β Applying new values..." | |
| global FACE_ANALYSER | |
| global DETECT_CONDITION | |
| DETECT_CONDITION = detect_condition | |
| FACE_ANALYSER = insightface.app.FaceAnalysis(name="buffalo_l", providers=PROVIDER) | |
| FACE_ANALYSER.prepare( | |
| ctx_id=0, | |
| det_size=(int(detection_size), int(detection_size)), | |
| det_thresh=float(detection_threshold), | |
| ) | |
| yield f"### \n βοΈ Applied detect condition:{detect_condition}, detection size: {detection_size}, detection threshold: {detection_threshold}" | |
| def stop_running(): | |
| global STREAMER | |
| if hasattr(STREAMER, "stop"): | |
| STREAMER.stop() | |
| STREAMER = None | |
| return "Cancelled" | |
| def slider_changed(show_frame, video_path, frame_index): | |
| if not show_frame: | |
| return None, None | |
| if video_path is None: | |
| return None, None | |
| clip = VideoFileClip(video_path) | |
| frame = clip.get_frame(frame_index / clip.fps) | |
| frame_array = np.array(frame) | |
| clip.close() | |
| return gr.Image.update(value=frame_array, visible=True), gr.Video.update( | |
| visible=False | |
| ) | |
| ## ------------------------------ GRADIO GUI ------------------------------ | |
| css = """ | |
| footer{display:none !important} | |
| """ | |
| with gr.Blocks(css=css) as interface: | |
| gr.Markdown("# πΏ API Swap Face") | |
| gr.Markdown("### Face swap app based on insightface inswapper.") | |
| with gr.Row(): | |
| with gr.Row(): | |
| with gr.Column(scale=0.4): | |
| source_image_input = gr.Image( | |
| label="Source face", type="filepath", interactive=True | |
| ) | |
| with gr.Group(): | |
| with gr.Box(visible=True) as input_video_group: | |
| vid_widget = gr.Video if USE_COLAB else gr.Text | |
| video_input = gr.Video( | |
| label="Target Video", interactive=True | |
| ) | |
| with gr.Accordion("βοΈ Trim video", open=False): | |
| with gr.Column(): | |
| with gr.Row(): | |
| set_slider_range_btn = gr.Button( | |
| "Set frame range", interactive=True | |
| ) | |
| show_trim_preview_btn = gr.Checkbox( | |
| label="Show frame when slider change", | |
| value=True, | |
| interactive=True, | |
| ) | |
| video_fps = gr.Number( | |
| value=30, | |
| interactive=False, | |
| label="Fps", | |
| visible=False, | |
| ) | |
| start_frame = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0, | |
| step=1, | |
| interactive=True, | |
| label="Start Frame", | |
| info="", | |
| ) | |
| end_frame = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=1, | |
| step=1, | |
| interactive=True, | |
| label="End Frame", | |
| info="", | |
| ) | |
| trim_and_reload_btn = gr.Button( | |
| "Trim and Reload", interactive=True | |
| ) | |
| webhook_url = gr.Text( | |
| label="Webhook URL", | |
| value="", | |
| interactive=True | |
| ) | |
| webhook_id = gr.Text( | |
| label="Webhook ID", | |
| value="", | |
| interactive=True, | |
| ) | |
| with gr.Column(scale=0.6): | |
| info = gr.Markdown(value="...") | |
| with gr.Row(): | |
| swap_button = gr.Button("β¨ Swap", variant="primary") | |
| cancel_button = gr.Button("β Cancel") | |
| preview_image = gr.Image(label="Output", interactive=False) | |
| preview_video = gr.Video( | |
| label="Output", interactive=False, visible=False | |
| ) | |
| with gr.Row(): | |
| output_directory_button = gr.Button( | |
| "π", interactive=False, visible=False | |
| ) | |
| output_video_button = gr.Button( | |
| "π¬", interactive=False, visible=False | |
| ) | |
| ## ------------------------------ GRADIO EVENTS ------------------------------ | |
| swap_inputs = [ | |
| video_input, | |
| source_image_input, | |
| webhook_url, | |
| webhook_id | |
| ] | |
| swap_outputs = [ | |
| info, | |
| preview_image, | |
| output_directory_button, | |
| output_video_button, | |
| preview_video, | |
| ] | |
| swap_event = swap_button.click( | |
| fn=process, inputs=swap_inputs, outputs=swap_outputs, show_progress=True | |
| ) | |
| output_directory_button.click( | |
| lambda: open_directory(path=WORKSPACE), inputs=None, outputs=None | |
| ) | |
| output_video_button.click( | |
| lambda: open_directory(path=OUTPUT_FILE), inputs=None, outputs=None | |
| ) | |
| if __name__ == "__main__": | |
| if USE_COLAB: | |
| print("Running in colab mode") | |
| interface.queue(concurrency_count=2, max_size=20).launch(share=USE_COLAB) | |