| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | import cv2 |
| | import base64 |
| | import subprocess |
| |
|
| | import gradio as gr |
| | import numpy as np |
| | from PIL import Image |
| | import argparse |
| | from omegaconf import OmegaConf |
| |
|
| | import torch |
| | from lam.runners.infer.head_utils import prepare_motion_seqs, preprocess_image |
| | import moviepy.editor as mpy |
| | from lam.utils.ffmpeg_utils import images_to_video |
| | import sys |
| | from flame_tracking_single_image import FlameTrackingSingleImage |
| |
|
| | try: |
| | import spaces |
| | except: |
| | pass |
| |
|
| |
|
| | def launch_pretrained(): |
| | from huggingface_hub import snapshot_download, hf_hub_download |
| | hf_hub_download(repo_id='DyrusQZ/LHM_Runtime', |
| | repo_type='model', |
| | filename='assets.tar', |
| | local_dir='./') |
| | os.system('tar -xvf assets.tar && rm assets.tar') |
| | hf_hub_download(repo_id='DyrusQZ/LHM_Runtime', |
| | repo_type='model', |
| | filename='LHM-0.5B.tar', |
| | local_dir='./') |
| | os.system('tar -xvf LHM-0.5B.tar && rm LHM-0.5B.tar') |
| | hf_hub_download(repo_id='DyrusQZ/LHM_Runtime', |
| | repo_type='model', |
| | filename='LHM_prior_model.tar', |
| | local_dir='./') |
| | os.system('tar -xvf LHM_prior_model.tar && rm LHM_prior_model.tar') |
| |
|
| |
|
| | def launch_env_not_compile_with_cuda(): |
| | os.system('pip install chumpy') |
| | os.system('pip uninstall -y basicsr') |
| | os.system('pip install git+https://github.com/hitsz-zuoqi/BasicSR/') |
| | os.system('pip install numpy==1.23.0') |
| | os.system( |
| | 'pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/py310_cu121_pyt251/download.html' |
| | ) |
| |
|
| |
|
| | def assert_input_image(input_image): |
| | if input_image is None: |
| | raise gr.Error('No image selected or uploaded!') |
| |
|
| |
|
| | def prepare_working_dir(): |
| | import tempfile |
| | working_dir = tempfile.TemporaryDirectory() |
| | return working_dir |
| |
|
| |
|
| | def init_preprocessor(): |
| | from lam.utils.preprocess import Preprocessor |
| | global preprocessor |
| | preprocessor = Preprocessor() |
| |
|
| |
|
| | def preprocess_fn(image_in: np.ndarray, remove_bg: bool, recenter: bool, |
| | working_dir): |
| | image_raw = os.path.join(working_dir.name, 'raw.png') |
| | with Image.fromarray(image_in) as img: |
| | img.save(image_raw) |
| | image_out = os.path.join(working_dir.name, 'rembg.png') |
| | success = preprocessor.preprocess(image_path=image_raw, |
| | save_path=image_out, |
| | rmbg=remove_bg, |
| | recenter=recenter) |
| | assert success, f'Failed under preprocess_fn!' |
| | return image_out |
| |
|
| |
|
| | def get_image_base64(path): |
| | with open(path, 'rb') as image_file: |
| | encoded_string = base64.b64encode(image_file.read()).decode() |
| | return f'data:image/png;base64,{encoded_string}' |
| |
|
| |
|
| | def save_imgs_2_video(imgs, v_pth, fps): |
| | img_lst = [imgs[i] for i in range(imgs.shape[0])] |
| | |
| | clips = [mpy.ImageClip(img).set_duration(0.1) for img in img_lst] |
| |
|
| | |
| | video = mpy.concatenate_videoclips(clips, method="compose") |
| |
|
| | |
| | video.write_videofile(v_pth, fps=fps) |
| |
|
| |
|
| | def parse_configs(): |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--config", type=str) |
| | parser.add_argument("--infer", type=str) |
| | args, unknown = parser.parse_known_args() |
| |
|
| | cfg = OmegaConf.create() |
| | cli_cfg = OmegaConf.from_cli(unknown) |
| |
|
| | |
| | if os.environ.get("APP_INFER") is not None: |
| | args.infer = os.environ.get("APP_INFER") |
| | if os.environ.get("APP_MODEL_NAME") is not None: |
| | cli_cfg.model_name = os.environ.get("APP_MODEL_NAME") |
| |
|
| | args.config = args.infer if args.config is None else args.config |
| |
|
| | if args.config is not None: |
| | cfg_train = OmegaConf.load(args.config) |
| | cfg.source_size = cfg_train.dataset.source_image_res |
| | try: |
| | cfg.src_head_size = cfg_train.dataset.src_head_size |
| | except: |
| | cfg.src_head_size = 112 |
| | cfg.render_size = cfg_train.dataset.render_image.high |
| | _relative_path = os.path.join( |
| | cfg_train.experiment.parent, |
| | cfg_train.experiment.child, |
| | os.path.basename(cli_cfg.model_name).split("_")[-1], |
| | ) |
| |
|
| | cfg.save_tmp_dump = os.path.join("exps", "save_tmp", _relative_path) |
| | cfg.image_dump = os.path.join("exps", "images", _relative_path) |
| | cfg.video_dump = os.path.join("exps", "videos", _relative_path) |
| |
|
| | if args.infer is not None: |
| | cfg_infer = OmegaConf.load(args.infer) |
| | cfg.merge_with(cfg_infer) |
| | cfg.setdefault( |
| | "save_tmp_dump", os.path.join("exps", cli_cfg.model_name, "save_tmp") |
| | ) |
| | cfg.setdefault("image_dump", os.path.join("exps", cli_cfg.model_name, "images")) |
| | cfg.setdefault( |
| | "video_dump", os.path.join("dumps", cli_cfg.model_name, "videos") |
| | ) |
| | cfg.setdefault("mesh_dump", os.path.join("dumps", cli_cfg.model_name, "meshes")) |
| |
|
| | cfg.motion_video_read_fps = 6 |
| | cfg.merge_with(cli_cfg) |
| |
|
| | cfg.setdefault("logger", "INFO") |
| |
|
| | assert cfg.model_name is not None, "model_name is required" |
| |
|
| | return cfg, cfg_train |
| |
|
| |
|
| | def demo_lam(flametracking, lam, cfg): |
| |
|
| | |
| | def core_fn(image_path: str, video_params, working_dir): |
| | image_raw = os.path.join(working_dir.name, "raw.png") |
| | with Image.open(image_path).convert('RGB') as img: |
| | img.save(image_raw) |
| | |
| | base_vid = os.path.basename(video_params).split(".")[0] |
| | flame_params_dir = os.path.join("./assets/sample_motion/export", base_vid, "flame_param") |
| | base_iid = os.path.basename(image_path).split('.')[0] |
| | image_path = os.path.join("./assets/sample_input", base_iid, "images/00000_00.png") |
| |
|
| | dump_video_path = os.path.join(working_dir.name, "output.mp4") |
| | dump_image_path = os.path.join(working_dir.name, "output.png") |
| |
|
| | |
| | omit_prefix = os.path.dirname(image_raw) |
| | image_name = os.path.basename(image_raw) |
| | uid = image_name.split(".")[0] |
| | subdir_path = os.path.dirname(image_raw).replace(omit_prefix, "") |
| | subdir_path = ( |
| | subdir_path[1:] if subdir_path.startswith("/") else subdir_path |
| | ) |
| | print("subdir_path and uid:", subdir_path, uid) |
| |
|
| | motion_seqs_dir = flame_params_dir |
| |
|
| | dump_image_dir = os.path.dirname(dump_image_path) |
| | os.makedirs(dump_image_dir, exist_ok=True) |
| |
|
| | print(image_raw, motion_seqs_dir, dump_image_dir, dump_video_path) |
| |
|
| | dump_tmp_dir = dump_image_dir |
| |
|
| | if os.path.exists(dump_video_path): |
| | return dump_image_path, dump_video_path |
| |
|
| | motion_img_need_mask = cfg.get("motion_img_need_mask", False) |
| | vis_motion = cfg.get("vis_motion", False) |
| |
|
| | |
| | return_code = flametracking.preprocess(image_raw) |
| | assert (return_code == 0), "flametracking preprocess failed!" |
| | return_code = flametracking.optimize() |
| | assert (return_code == 0), "flametracking optimize failed!" |
| | return_code, output_dir = flametracking.export() |
| | assert (return_code == 0), "flametracking export failed!" |
| |
|
| | image_path = os.path.join(output_dir, "images/00000_00.png") |
| | mask_path = image_path.replace("/images/", "/fg_masks/").replace(".jpg", ".png") |
| | print(image_path, mask_path) |
| |
|
| | aspect_standard = 1.0/1.0 |
| | source_size = cfg.source_size |
| | render_size = cfg.render_size |
| | render_fps = 30 |
| | |
| | image, _, _, shape_param = preprocess_image(image_path, mask_path=mask_path, intr=None, pad_ratio=0, bg_color=1., |
| | max_tgt_size=None, aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1.0], |
| | render_tgt_size=source_size, multiply=14, need_mask=True, get_shape_param=True) |
| |
|
| | |
| | save_ref_img_path = os.path.join(dump_tmp_dir, "output.png") |
| | vis_ref_img = (image[0].permute(1, 2, 0).cpu().detach().numpy() * 255).astype(np.uint8) |
| | Image.fromarray(vis_ref_img).save(save_ref_img_path) |
| |
|
| | |
| | src = image_path.split('/')[-3] |
| | driven = motion_seqs_dir.split('/')[-2] |
| | src_driven = [src, driven] |
| | motion_seq = prepare_motion_seqs(motion_seqs_dir, None, save_root=dump_tmp_dir, fps=render_fps, |
| | bg_color=1., aspect_standard=aspect_standard, enlarge_ratio=[1.0, 1,0], |
| | render_image_res=render_size, multiply=16, |
| | need_mask=motion_img_need_mask, vis_motion=vis_motion, |
| | shape_param=shape_param, test_sample=False, cross_id=False, src_driven=src_driven) |
| |
|
| | |
| | motion_seq["flame_params"]["betas"] = shape_param.unsqueeze(0) |
| | device, dtype = "cuda", torch.float32 |
| | print("start to inference...................") |
| | with torch.no_grad(): |
| | |
| | res = lam.infer_single_view(image.unsqueeze(0).to(device, dtype), None, None, |
| | render_c2ws=motion_seq["render_c2ws"].to(device), |
| | render_intrs=motion_seq["render_intrs"].to(device), |
| | render_bg_colors=motion_seq["render_bg_colors"].to(device), |
| | flame_params={k:v.to(device) for k, v in motion_seq["flame_params"].items()}) |
| |
|
| | rgb = res["comp_rgb"].detach().cpu().numpy() |
| | mask = res["comp_mask"].detach().cpu().numpy() |
| | mask[mask < 0.5] = 0.0 |
| | rgb = rgb * mask + (1 - mask) * 1 |
| | rgb = (np.clip(rgb, 0, 1.0) * 255).astype(np.uint8) |
| | if vis_motion: |
| | vis_ref_img = np.tile( |
| | cv2.resize(vis_ref_img, (rgb[0].shape[1], rgb[0].shape[0]), interpolation=cv2.INTER_AREA)[None, :, :, :], |
| | (rgb.shape[0], 1, 1, 1), |
| | ) |
| | rgb = np.concatenate([vis_ref_img, rgb, motion_seq["vis_motion_render"]], axis=2) |
| |
|
| | os.makedirs(os.path.dirname(dump_video_path), exist_ok=True) |
| |
|
| | save_imgs_2_video(rgb, dump_video_path, render_fps) |
| | |
| |
|
| | return dump_image_path, dump_video_path |
| |
|
| | with gr.Blocks(analytics_enabled=False) as demo: |
| |
|
| | logo_url = './assets/images/logo.png' |
| | logo_base64 = get_image_base64(logo_url) |
| | gr.HTML(f""" |
| | <div style="display: flex; justify-content: center; align-items: center; text-align: center;"> |
| | <div> |
| | <h1> <img src="{logo_base64}" style='height:35px; display:inline-block;'/> LAM: Large Avatar Model for One-shot Animatable Gaussian Head</h1> |
| | </div> |
| | </div> |
| | """) |
| | gr.HTML( |
| | """<p><h4 style="color: red;"> Notes: Inputing front-face images or face orientation close to the driven signal gets better results.</h4></p>""" |
| | ) |
| |
|
| | |
| | with gr.Row(): |
| |
|
| | with gr.Column(variant='panel', scale=1): |
| | with gr.Tabs(elem_id='lam_input_image'): |
| | with gr.TabItem('Input Image'): |
| | with gr.Row(): |
| | input_image = gr.Image(label='Input Image', |
| | image_mode='RGB', |
| | height=480, |
| | width=270, |
| | sources='upload', |
| | type='filepath', |
| | elem_id='content_image') |
| | |
| | with gr.Row(): |
| | examples = [ |
| | ['assets/sample_input/2w01/images/2w01.png'], |
| | ['assets/sample_input/2w02/images/2w02.png'], |
| | ['assets/sample_input/2w03/images/2w03.png'], |
| | ['assets/sample_input/2w04/images/2w04.png'], |
| | ] |
| | gr.Examples( |
| | examples=examples, |
| | inputs=[input_image], |
| | examples_per_page=20, |
| | ) |
| |
|
| | with gr.Column(): |
| | with gr.Tabs(elem_id='lam_input_video'): |
| | with gr.TabItem('Input Video'): |
| | with gr.Row(): |
| | video_input = gr.Video(label='Input Video', |
| | height=480, |
| | width=270, |
| | interactive=False) |
| |
|
| | examples = [ |
| | './assets/sample_motion/export/clip1/clip1.mp4', |
| | './assets/sample_motion/export/clip2/clip2.mp4', |
| | './assets/sample_motion/export/clip3/clip3.mp4', |
| | ] |
| |
|
| | gr.Examples( |
| | examples=examples, |
| | inputs=[video_input], |
| | examples_per_page=20, |
| | ) |
| | with gr.Column(variant='panel', scale=1): |
| | with gr.Tabs(elem_id='lam_processed_image'): |
| | with gr.TabItem('Processed Image'): |
| | with gr.Row(): |
| | processed_image = gr.Image( |
| | label='Processed Image', |
| | image_mode='RGBA', |
| | type='filepath', |
| | elem_id='processed_image', |
| | height=480, |
| | width=270, |
| | interactive=False) |
| |
|
| | with gr.Column(variant='panel', scale=1): |
| | with gr.Tabs(elem_id='lam_render_video'): |
| | with gr.TabItem('Rendered Video'): |
| | with gr.Row(): |
| | output_video = gr.Video(label='Rendered Video', |
| | format='mp4', |
| | height=480, |
| | width=270, |
| | autoplay=True) |
| |
|
| | |
| | with gr.Row(): |
| | with gr.Column(variant='panel', scale=1): |
| | submit = gr.Button('Generate', |
| | elem_id='lam_generate', |
| | variant='primary') |
| |
|
| | working_dir = gr.State() |
| | submit.click( |
| | fn=assert_input_image, |
| | inputs=[input_image], |
| | queue=False, |
| | ).success( |
| | fn=prepare_working_dir, |
| | outputs=[working_dir], |
| | queue=False, |
| | ).success( |
| | fn=core_fn, |
| | inputs=[input_image, video_input, |
| | working_dir], |
| | outputs=[processed_image, output_video], |
| | ) |
| |
|
| | demo.queue() |
| | demo.launch() |
| |
|
| |
|
| | def _build_model(cfg): |
| | from lam.models import model_dict |
| | from lam.utils.hf_hub import wrap_model_hub |
| |
|
| | hf_model_cls = wrap_model_hub(model_dict["lam"]) |
| | model = hf_model_cls.from_pretrained(cfg.model_name) |
| |
|
| | return model |
| |
|
| | def launch_gradio_app(): |
| |
|
| | os.environ.update({ |
| | 'APP_ENABLED': '1', |
| | 'APP_MODEL_NAME': |
| | './exps/releases/lam/lam-20k/step_045500/', |
| | 'APP_INFER': './configs/inference/lam-20k-8gpu.yaml', |
| | 'APP_TYPE': 'infer.lam', |
| | 'NUMBA_THREADING_LAYER': 'omp', |
| | }) |
| |
|
| | cfg, _ = parse_configs() |
| | lam = _build_model(cfg) |
| | lam.to('cuda') |
| |
|
| | flametracking = FlameTrackingSingleImage(output_dir='tracking_output', |
| | alignment_model_path='./pretrain_model/68_keypoints_model.pkl', |
| | vgghead_model_path='./pretrain_model/vgghead/vgg_heads_l.trcd', |
| | human_matting_path='./pretrain_model/matting/stylematte_synth.pt', |
| | facebox_model_path='./pretrain_model/FaceBoxesV2.pth', |
| | detect_iris_landmarks=True) |
| |
|
| | demo_lam(flametracking, lam, cfg) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | |
| | |
| | launch_gradio_app() |
| |
|