| import argparse |
| import gradio as gr |
| import os |
| import shutil |
| from glob import glob |
| from PIL import Image |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from torchvision.utils import make_grid, save_image |
| from torchvision.io import read_image |
| import torchvision.transforms.functional as F |
| from functools import partial |
| from datetime import datetime |
|
|
|
|
| plt.rcParams["savefig.bbox"] = 'tight' |
|
|
| def show(imgs): |
| if not isinstance(imgs, list): |
| imgs = [imgs] |
| fig, axs = plt.subplots(ncols=len(imgs), squeeze=False) |
| for i, img in enumerate(imgs): |
| img = F.to_pil_image(img.detach()) |
| axs[0, i].imshow(np.asarray(img)) |
| axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) |
|
|
| class Intermediate: |
| def __init__(self): |
| self.input_img = None |
| self.input_img_time = 0 |
|
|
|
|
| model_ckpts = {"elf": "ffhq-elf.pkl", |
| "greek_statue": "ffhq-greek_statue.pkl", |
| "hobbit": "ffhq-hobbit.pkl", |
| "lego": "ffhq-lego.pkl", |
| "masquerade": "ffhq-masquerade.pkl", |
| "neanderthal": "ffhq-neanderthal.pkl", |
| "orc": "ffhq-orc.pkl", |
| "pixar": "ffhq-pixar.pkl", |
| "skeleton": "ffhq-skeleton.pkl", |
| "stone_golem": "ffhq-stone_golem.pkl", |
| "super_mario": "ffhq-super_mario.pkl", |
| "tekken": "ffhq-tekken.pkl", |
| "yoda": "ffhq-yoda.pkl", |
| "zombie": "ffhq-zombie.pkl", |
| "cat_in_Zootopia": "cat-cat_in_Zootopia.pkl", |
| "fox_in_Zootopia": "cat-fox_in_Zootopia.pkl", |
| "golden_aluminum_animal": "cat-golden_aluminum_animal.pkl", |
| } |
|
|
| manip_model_ckpts = {"super_mario": "ffhq-super_mario.pkl", |
| "lego": "ffhq-lego.pkl", |
| "neanderthal": "ffhq-neanderthal.pkl", |
| "orc": "ffhq-orc.pkl", |
| "pixar": "ffhq-pixar.pkl", |
| "skeleton": "ffhq-skeleton.pkl", |
| "stone_golem": "ffhq-stone_golem.pkl", |
| "tekken": "ffhq-tekken.pkl", |
| "greek_statue": "ffhq-greek_statue.pkl", |
| "yoda": "ffhq-yoda.pkl", |
| "zombie": "ffhq-zombie.pkl", |
| "elf": "ffhq-elf.pkl", |
| } |
|
|
|
|
| def TextGuidedImageTo3D(intermediate, img, model_name, num_inversion_steps, truncation): |
| if img != intermediate.input_img: |
| if os.path.exists('input_imgs_gradio'): |
| shutil.rmtree('input_imgs_gradio') |
| os.makedirs('input_imgs_gradio', exist_ok=True) |
| img.save('input_imgs_gradio/input.png') |
| intermediate.input_img = img |
| now = datetime.now() |
| intermediate.input_img_time = now.strftime('%Y-%m-%d_%H:%M:%S') |
|
|
| all_model_names = manip_model_ckpts.keys() |
| generator_type = 'ffhq' |
|
|
| if model_name == 'all': |
| _no_video_models = [] |
| for _model_name in all_model_names: |
| if not os.path.exists(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/finetuned___{model_ckpts[_model_name]}__input_inv.mp4'): |
| print() |
| _no_video_models.append(_model_name) |
|
|
| model_names_command = '' |
| for _model_name in _no_video_models: |
| if not os.path.exists(f'finetuned/{model_ckpts[_model_name]}'): |
| command = f"""wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/{model_ckpts[_model_name]} -O finetuned/{model_ckpts[_model_name]} |
| """ |
| os.system(command) |
|
|
| model_names_command += f"finetuned/{model_ckpts[_model_name]} " |
|
|
| w_pths = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/3_inversion_result/*.pt')) |
| if len(w_pths) == 0: |
| mode = 'manip' |
| else: |
| mode = 'manip_from_inv' |
|
|
| if len(_no_video_models) > 0: |
| command = f"""python datid3d_test.py --mode {mode} \ |
| --indir='input_imgs_gradio' \ |
| --generator_type={generator_type} \ |
| --outdir='test_runs' \ |
| --trunc={truncation} \ |
| --network {model_names_command} \ |
| --num_inv_steps={num_inversion_steps} \ |
| --down_src_eg3d_from_nvidia=False \ |
| --name_tag='_gradio_{intermediate.input_img_time}' \ |
| --shape=False \ |
| --w_frames 60 |
| """ |
| print(command) |
| os.system(command) |
|
|
| aligned_img_pth = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/2_pose_result/*.png'))[0] |
| aligned_img = Image.open(aligned_img_pth) |
|
|
| result_imgs = [] |
| for _model_name in all_model_names: |
| img_pth = f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/finetuned___{model_ckpts[_model_name]}__input_inv.png' |
| result_imgs.append(read_image(img_pth)) |
|
|
| result_grid_pt = make_grid(result_imgs, nrow=1) |
| result_img = F.to_pil_image(result_grid_pt) |
| else: |
| if not os.path.exists(f'finetuned/{model_ckpts[model_name]}'): |
| command = f"""wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/{model_ckpts[model_name]} -O finetuned/{model_ckpts[model_name]} |
| """ |
| os.system(command) |
|
|
| if not os.path.exists(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/finetuned___{model_ckpts[model_name]}__input_inv.mp4'): |
| w_pths = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/3_inversion_result/*.pt')) |
| if len(w_pths) == 0: |
| mode = 'manip' |
| else: |
| mode = 'manip_from_inv' |
|
|
| command = f"""python datid3d_test.py --mode {mode} \ |
| --indir='input_imgs_gradio' \ |
| --generator_type={generator_type} \ |
| --outdir='test_runs' \ |
| --trunc={truncation} \ |
| --network finetuned/{model_ckpts[model_name]} \ |
| --num_inv_steps={num_inversion_steps} \ |
| --down_src_eg3d_from_nvidia=0 \ |
| --name_tag='_gradio_{intermediate.input_img_time}' \ |
| --shape=False |
| --w_frames 60""" |
| print(command) |
| os.system(command) |
|
|
| aligned_img_pth = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/2_pose_result/*.png'))[0] |
| aligned_img = Image.open(aligned_img_pth) |
|
|
| result_img_pth = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/*{model_ckpts[model_name]}*.png'))[0] |
| result_img = Image.open(result_img_pth) |
|
|
|
|
|
|
|
|
| if model_name=='all': |
| result_video_pth = f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/finetuned___ffhq-all__input_inv.mp4' |
| if os.path.exists(result_video_pth): |
| os.remove(result_video_pth) |
| command = 'ffmpeg ' |
| for _model_name in all_model_names: |
| command += f'-i test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/finetuned___ffhq-{_model_name}.pkl__input_inv.mp4 ' |
| |
| command += '-filter_complex "[v0][v1][v2][v3][v4][v5][v6][v7][v8][v9][v10][v11]concat=n=12:v=1:a=0[output]"' |
| command += f" -vcodec libx264 {result_video_pth}" |
| print() |
| print(command) |
| os.system(command) |
|
|
| else: |
| result_video_pth = sorted(glob(f'test_runs/manip_3D_recon_gradio_{intermediate.input_img_time}/4_manip_result/*{model_ckpts[model_name]}*.mp4'))[0] |
|
|
| return aligned_img, result_img, result_video_pth |
|
|
|
|
| def SampleImage(model_name, num_samples, truncation, seed): |
| seed_list = np.random.RandomState(seed).choice(np.arange(10000), num_samples).tolist() |
| seeds = '' |
| for seed in seed_list: |
| seeds += f'{seed},' |
| seeds = seeds[:-1] |
|
|
| if model_name in ["fox_in_Zootopia", "cat_in_Zootopia", "golden_aluminum_animal"]: |
| generator_type = 'cat' |
| else: |
| generator_type = 'ffhq' |
|
|
| if not os.path.exists(f'finetuned/{model_ckpts[model_name]}'): |
| command = f"""wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/{model_ckpts[model_name]} -O finetuned/{model_ckpts[model_name]} |
| """ |
| os.system(command) |
|
|
| command = f"""python datid3d_test.py --mode image \ |
| --generator_type={generator_type} \ |
| --outdir='test_runs' \ |
| --seeds={seeds} \ |
| --trunc={truncation} \ |
| --network=finetuned/{model_ckpts[model_name]} \ |
| --shape=False""" |
| print(command) |
| os.system(command) |
|
|
| result_img_pths = sorted(glob(f'test_runs/image/*{model_ckpts[model_name]}*.png')) |
| result_imgs = [] |
| for img_pth in result_img_pths: |
| result_imgs.append(read_image(img_pth)) |
|
|
| result_grid_pt = make_grid(result_imgs, nrow=1) |
| result_grid_pil = F.to_pil_image(result_grid_pt) |
| return result_grid_pil |
|
|
|
|
|
|
|
|
| def SampleVideo(model_name, grid_height, truncation, seed): |
| seed_list = np.random.RandomState(seed).choice(np.arange(10000), grid_height**2).tolist() |
| seeds = '' |
| for seed in seed_list: |
| seeds += f'{seed},' |
| seeds = seeds[:-1] |
|
|
| if model_name in ["fox_in_Zootopia", "cat_in_Zootopia", "golden_aluminum_animal"]: |
| generator_type = 'cat' |
| else: |
| generator_type = 'ffhq' |
|
|
| if not os.path.exists(f'finetuned/{model_ckpts[model_name]}'): |
| command = f"""wget https://huggingface.co/gwang-kim/datid3d-finetuned-eg3d-models/resolve/main/finetuned_models/{model_ckpts[model_name]} -O finetuned/{model_ckpts[model_name]} |
| """ |
| os.system(command) |
|
|
| command = f"""python datid3d_test.py --mode video \ |
| --generator_type={generator_type} \ |
| --outdir='test_runs' \ |
| --seeds={seeds} \ |
| --trunc={truncation} \ |
| --grid={grid_height}x{grid_height} \ |
| --network=finetuned/{model_ckpts[model_name]} \ |
| --shape=False""" |
| print(command) |
| os.system(command) |
|
|
| result_video_pth = sorted(glob(f'test_runs/video/*{model_ckpts[model_name]}*.mp4'))[0] |
|
|
| return result_video_pth |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--share', action='store_true', help="public url") |
| args = parser.parse_args() |
|
|
| demo = gr.Blocks(title="DATID-3D Interactive Demo") |
| os.makedirs('finetuned', exist_ok=True) |
| intermediate = Intermediate() |
| with demo: |
| gr.Markdown("# DATID-3D Interactive Demo") |
| gr.Markdown( |
| "### Demo of the CVPR 2023 paper \"DATID-3D: Diversity-Preserved Domain Adaptation Using Text-to-Image Diffusion for 3D Generative Model\"") |
|
|
| with gr.Tab("Text-guided Manipulated 3D reconstruction"): |
| gr.Markdown("Text-guided Image-to-3D Translation") |
| with gr.Row(): |
| with gr.Column(scale=1, variant='panel'): |
| t_image_input = gr.Image(source='upload', type="pil", interactive=True) |
|
|
| t_model_name = gr.Radio(["super_mario", "lego", "neanderthal", "orc", |
| "pixar", "skeleton", "stone_golem","tekken", |
| "greek_statue", "yoda", "zombie", "elf", "all"], |
| label="Model fine-tuned through DATID-3D", |
| value="super_mario", interactive=True) |
| with gr.Accordion("Advanced Options", open=False): |
| t_truncation = gr.Slider(label="Truncation psi", minimum=0, maximum=1.0, step=0.01, randomize=False, value=0.8) |
| t_num_inversion_steps = gr.Slider(200, 1000, value=200, step=1, label='Number of steps for the invresion') |
| with gr.Row(): |
| t_button_gen_result = gr.Button("Generate Result", variant='primary') |
| |
| |
| with gr.Row(): |
| t_align_image_result = gr.Image(label="Alignment result", interactive=False) |
| with gr.Column(scale=1, variant='panel'): |
| with gr.Row(): |
| t_video_result = gr.Video(label="Video result", interactive=False) |
|
|
| with gr.Row(): |
| t_image_result = gr.Image(label="Image result", interactive=False) |
|
|
|
|
| with gr.Tab("Sample Images"): |
| with gr.Row(): |
| with gr.Column(scale=1, variant='panel'): |
| i_model_name = gr.Radio( |
| ["elf", "greek_statue", "hobbit", "lego", "masquerade", "neanderthal", "orc", "pixar", |
| "skeleton", "stone_golem", "super_mario", "tekken", "yoda", "zombie", "fox_in_Zootopia", |
| "cat_in_Zootopia", "golden_aluminum_animal"], |
| label="Model fine-tuned through DATID-3D", |
| value="super_mario", interactive=True) |
| i_num_samples = gr.Slider(0, 20, value=4, step=1, label='Number of samples') |
| i_seed = gr.Slider(label="Seed", minimum=0, maximum=1000000000, step=1, value=1235) |
| with gr.Accordion("Advanced Options", open=False): |
| i_truncation = gr.Slider(label="Truncation psi", minimum=0, maximum=1.0, step=0.01, randomize=False, value=0.8) |
| with gr.Row(): |
| i_button_gen_image = gr.Button("Generate Image", variant='primary') |
| with gr.Column(scale=1, variant='panel'): |
| with gr.Row(): |
| i_image_result = gr.Image(label="Image result", interactive=False) |
|
|
|
|
| with gr.Tab("Sample Videos"): |
| with gr.Row(): |
| with gr.Column(scale=1, variant='panel'): |
| v_model_name = gr.Radio( |
| ["elf", "greek_statue", "hobbit", "lego", "masquerade", "neanderthal", "orc", "pixar", |
| "skeleton", "stone_golem", "super_mario", "tekken", "yoda", "zombie", "fox_in_Zootopia", |
| "cat_in_Zootopia", "golden_aluminum_animal"], |
| label="Model fine-tuned through DATID-3D", |
| value="super_mario", interactive=True) |
| v_grid_height = gr.Slider(0, 5, value=2, step=1,label='Height of the grid') |
| v_seed = gr.Slider(label="Seed", minimum=0, maximum=1000000000, step=1, value=1235) |
| with gr.Accordion("Advanced Options", open=False): |
| v_truncation = gr.Slider(label="Truncation psi", minimum=0, maximum=1.0, step=0.01, randomize=False, |
| value=0.8) |
|
|
| with gr.Row(): |
| v_button_gen_video = gr.Button("Generate Video", variant='primary') |
|
|
| with gr.Column(scale=1, variant='panel'): |
|
|
| with gr.Row(): |
| v_video_result = gr.Video(label="Video result", interactive=False) |
|
|
|
|
|
|
|
|
|
|
| |
| t_button_gen_result.click(fn=partial(TextGuidedImageTo3D, intermediate), |
| inputs=[t_image_input, t_model_name, t_num_inversion_steps, t_truncation], |
| outputs=[t_align_image_result, t_image_result, t_video_result]) |
| i_button_gen_image.click(fn=SampleImage, |
| inputs=[i_model_name, i_num_samples, i_truncation, i_seed], |
| outputs=[i_image_result]) |
| v_button_gen_video.click(fn=SampleVideo, |
| inputs=[i_model_name, v_grid_height, v_truncation, v_seed], |
| outputs=[v_video_result]) |
|
|
| demo.queue(concurrency_count=1) |
| demo.launch(share=args.share) |
|
|
|
|