Spaces:
Build error
Build error
| import sys | |
| import torch | |
| import pickle | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| from collections import defaultdict | |
| from glob import glob | |
| from matplotlib import pyplot as plt | |
| from matplotlib import animation | |
| from easydict import EasyDict as edict | |
| from huggingface_hub import hf_hub_download | |
| sys.path.append("./rome/") | |
| sys.path.append('./DECA') | |
| from rome.infer import Infer | |
| from rome.src.utils.processing import process_black_shape, tensor2image | |
| from rome.src.utils.visuals import mask_errosion | |
| # loading models ---- create model repo | |
| default_modnet_path = hf_hub_download('Pie31415/rome', 'modnet_photographic_portrait_matting.ckpt') | |
| default_model_path = hf_hub_download('Pie31415/rome', 'rome.pth') | |
| # parser configurations | |
| args = edict({ | |
| "save_dir": ".", | |
| "save_render": True, | |
| "model_checkpoint": default_model_path, | |
| "modnet_path": default_modnet_path, | |
| "random_seed": 0, | |
| "debug": False, | |
| "verbose": False, | |
| "model_image_size": 256, | |
| "align_source": True, | |
| "align_target": False, | |
| "align_scale": 1.25, | |
| "use_mesh_deformations": False, | |
| "subdivide_mesh": False, | |
| "renderer_sigma": 1e-08, | |
| "renderer_zfar": 100.0, | |
| "renderer_type": "soft_mesh", | |
| "renderer_texture_type": "texture_uv", | |
| "renderer_normalized_alphas": False, | |
| "deca_path": "DECA", | |
| "rome_data_dir": "rome/data", | |
| "autoenc_cat_alphas": False, | |
| "autoenc_align_inputs": False, | |
| "autoenc_use_warp": False, | |
| "autoenc_num_channels": 64, | |
| "autoenc_max_channels": 512, | |
| "autoenc_num_groups": 4, | |
| "autoenc_num_bottleneck_groups": 0, | |
| "autoenc_num_blocks": 2, | |
| "autoenc_num_layers": 4, | |
| "autoenc_block_type": "bottleneck", | |
| "neural_texture_channels": 8, | |
| "num_harmonic_encoding_funcs": 6, | |
| "unet_num_channels": 64, | |
| "unet_max_channels": 512, | |
| "unet_num_groups": 4, | |
| "unet_num_blocks": 1, | |
| "unet_num_layers": 2, | |
| "unet_block_type": "conv", | |
| "unet_skip_connection_type": "cat", | |
| "unet_use_normals_cond": True, | |
| "unet_use_vertex_cond": False, | |
| "unet_use_uvs_cond": False, | |
| "unet_pred_mask": False, | |
| "use_separate_seg_unet": True, | |
| "norm_layer_type": "gn", | |
| "activation_type": "relu", | |
| "conv_layer_type": "ws_conv", | |
| "deform_norm_layer_type": "gn", | |
| "deform_activation_type": "relu", | |
| "deform_conv_layer_type": "ws_conv", | |
| "unet_seg_weight": 0.0, | |
| "unet_seg_type": "bce_with_logits", | |
| "deform_face_tightness": 0.0001, | |
| "use_whole_segmentation": False, | |
| "mask_hair_for_neck": False, | |
| "use_hair_from_avatar": False, | |
| "use_scalp_deforms": True, | |
| "use_neck_deforms": True, | |
| "use_basis_deformer": False, | |
| "use_unet_deformer": True, | |
| "pretrained_encoder_basis_path": "", | |
| "pretrained_vertex_basis_path": "", | |
| "num_basis": 50, | |
| "basis_init": "pca", | |
| "num_vertex": 5023, | |
| "train_basis": True, | |
| "path_to_deca": "DECA", | |
| "path_to_linear_hair_model": "data/linear_hair.pth", # N/A | |
| "path_to_mobile_model": "data/disp_model.pth", # N/A | |
| "n_scalp": 60, | |
| "use_distill": False, | |
| "use_mobile_version": False, | |
| "deformer_path": "data/rome.pth", | |
| "output_unet_deformer_feats": 32, | |
| "use_deca_details": False, | |
| "use_flametex": False, | |
| "upsample_type": "nearest", | |
| "num_frequencies": 6, | |
| "deform_face_scale_coef": 0.0, | |
| "device": "cuda" | |
| }) | |
| # download FLAME and DECA pretrained | |
| generic_model_path = hf_hub_download('Pie31415/rome', 'generic_model.pkl') | |
| deca_model_path = hf_hub_download('Pie31415/rome', 'deca_model.tar') | |
| with open(generic_model_path, 'rb') as f: | |
| ss = pickle.load(f, encoding='latin1') | |
| with open('./DECA/data/generic_model.pkl', 'wb') as out: | |
| pickle.dump(ss, out) | |
| with open(deca_model_path, "rb") as input: | |
| with open('./DECA/data/deca_model.tar', "wb") as out: | |
| for line in input: | |
| out.write(line) | |
| # load ROME inference model | |
| infer = Infer(args) | |
| def image_inference( | |
| source_img: gr.inputs.Image = None, | |
| driver_img: gr.inputs.Image = None | |
| ): | |
| out = infer.evaluate(source_img, driver_img, crop_center=False) | |
| res = tensor2image(torch.cat([out['source_information']['data_dict']['source_img'][0].cpu(), | |
| out['source_information']['data_dict']['target_img'][0].cpu(), | |
| out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2)) | |
| return res[..., ::-1] | |
| def extract_frames( | |
| driver_vid: gr.inputs.Video = None | |
| ): | |
| image_frames = [] | |
| vid = cv2.VideoCapture(driver_vid) # path to mp4 | |
| while True: | |
| success, img = vid.read() | |
| if not success: break | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| pil_img = Image.fromarray(img) | |
| image_frames.append(pil_img) | |
| return image_frames | |
| def video_inference( | |
| source_img: gr.inputs.Image = None, | |
| driver_vid: gr.inputs.Video = None | |
| ): | |
| image_frames = extract_frames(driver_vid) | |
| resulted_imgs = defaultdict(list) | |
| mask_hard_threshold = 0.5 | |
| N = len(image_frames) | |
| for i in range(0, N, 4): # frame limits | |
| new_out = infer.evaluate(source_img, image_frames[i]) | |
| mask_pred = (new_out['pred_target_unet_mask'].cpu() > mask_hard_threshold).float() | |
| mask_pred = mask_errosion(mask_pred[0].float().numpy() * 255) | |
| render = new_out['pred_target_img'].cpu() * (mask_pred) + (1 - mask_pred) | |
| normals = process_black_shape(((new_out['pred_target_normal'][0].cpu() + 1) / 2 * mask_pred + (1 - mask_pred) ) ) | |
| normals[normals==0.5]=1. | |
| resulted_imgs['res_normal'].append(tensor2image(normals)) | |
| resulted_imgs['res_mesh_images'].append(tensor2image(new_out['pred_target_shape_img'][0])) | |
| resulted_imgs['res_renders'].append(tensor2image(render[0])) | |
| video = np.array(resulted_imgs['res_renders']) | |
| fig = plt.figure() | |
| im = plt.imshow(video[0,:,:,::-1]) | |
| plt.axis('off') | |
| plt.close() # this is required to not display the generated image | |
| def init(): | |
| im.set_data(video[0,:,:,::-1]) | |
| def animate(i): | |
| im.set_data(video[i,:,:,::-1]) | |
| return im | |
| anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0], interval=30) | |
| anim.save("avatar.gif", dpi=300, writer = animation.PillowWriter(fps=24)) | |
| return "avatar.gif" | |
| description = """<p style='text-align: center'> Create a personal avatar from just a single image using ROME. <br> <a href='https://arxiv.org/abs/2206.08343' target='_blank'>Paper</a> | <a href='https://samsunglabs.github.io/rome' target='_blank'>Project Page</a> | <a href='https://github.com/SamsungLabs/rome' target='_blank'>Github</a> </p>""" | |
| quote = """ | |
| > <p style='text-align: center'> [The] system creates realistic mesh-based avatars from a single <strong>source</strong> photo. These avatars are rigged, i.e., they can be driven by the animation parameters from a different <strong>driving</strong> frame. </p>""" | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# **<p align='center'>ROME: Realistic one-shot mesh-based head avatars</p>**") | |
| gr.HTML(value="<img src='file/media/tease.gif' alt='Teaser' style='display: block; margin: auto;'>") | |
| gr.Markdown(description) | |
| gr.Markdown(quote) | |
| with gr.Tab("Image Inference"): | |
| with gr.Row(): | |
| source_img = gr.Image(type="pil", label="Source image", show_label=True) | |
| driver_img = gr.Image(type="pil", label="Driver image", show_label=True) | |
| image_output = gr.Image(label="Rendered avatar") | |
| image_button = gr.Button("Predict") | |
| with gr.Tab("Video Inference"): | |
| with gr.Row(): | |
| source_img2 = gr.Image(type="pil", label="Source image", show_label=True) | |
| driver_vid = gr.Video(label="Driver video", source="upload") | |
| video_output = gr.Image(label="Rendered GIF avatar") | |
| video_button = gr.Button("Predict") | |
| with gr.Tab("Webcam Inference"): | |
| with gr.Row(): | |
| source_img3 = gr.Image(type="pil", label="Source image", show_label=True) | |
| driver_cam = gr.Video(label="Driver video", source="webcam") | |
| cam_output = gr.Image(label="Rendered GIF avatar") | |
| cam_button = gr.Button("Predict") | |
| gr.Examples( | |
| examples=[ | |
| ["./examples/lincoln.jpg", "./examples/taras2.jpg"], | |
| ["./examples/lincoln.jpg", "./examples/taras1.jpg"] | |
| ], | |
| inputs=[source_img, driver_img], | |
| outputs=[image_output], | |
| fn=image_inference, | |
| cache_examples=True | |
| ) | |
| image_button.click(image_inference, inputs=[source_img, driver_img], outputs=image_output) | |
| video_button.click(video_inference, inputs=[source_img2, driver_vid], outputs=video_output) | |
| cam_button.click(video_inference, inputs=[source_img3, driver_cam], outputs=cam_output) | |
| demo.launch() |