| import os |
| import wget |
| import subprocess |
| import sys |
| import torch |
|
|
| if os.getenv('SYSTEM') == 'spaces': |
| pyt_version_str=torch.__version__.split("+")[0].replace(".", "") |
| version_str="".join([ |
| f"py3{sys.version_info.minor}_cu", |
| torch.version.cuda.replace(".",""), |
| f"_pyt{pyt_version_str}" |
| ]) |
| |
| |
| |
| |
| subprocess.run( |
| f'pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html'.split()) |
|
|
|
|
| import argparse |
| import gradio as gr |
| from functools import partial |
| from my.config import BaseConf, dispatch_gradio |
| from run_3DFuse import SJC_3DFuse |
| import numpy as np |
| from PIL import Image |
| from pc_project import point_e |
| from diffusers import UnCLIPPipeline, DiffusionPipeline |
| from pc_project import point_e_gradio |
| import numpy as np |
| import plotly.graph_objs as go |
| from my.utils.seed import seed_everything |
|
|
| SHARED_UI_WARNING = f'''### [NOTE] Training may be very slow in this shared UI. |
| You can duplicate and use it with a paid private GPU. |
| <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/jyseo/3DFuse?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-xl-dark.svg" alt="Duplicate Space"></a> |
| Alternatively, you can also use the Colab demo on our project page. |
| <a style="display:inline-block" href="https://ku-cvlab.github.io/3DFuse/"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/Project%20Page-online-brightgreen"></a> |
| ''' |
|
|
| class Intermediate: |
| def __init__(self): |
| self.images = None |
| self.points = None |
| self.is_generating = False |
|
|
|
|
| def gen_3d(model, intermediate, prompt, keyword, seed, ti_step, pt_step) : |
| intermediate.is_generating = True |
| images, points = intermediate.images, intermediate.points |
| if images is None or points is None : |
| raise gr.Error("Please generate point cloud first") |
| del model |
| |
| seed_everything(seed) |
| model = dispatch_gradio(SJC_3DFuse, prompt, keyword, ti_step, pt_step, seed) |
| setting = model.dict() |
| |
| |
| |
| |
| |
| |
| |
| yield from model.run_gradio(points, images) |
| |
| intermediate.is_generating = False |
| |
|
|
|
|
| def gen_pc_from_prompt(intermediate, num_initial_image, prompt, keyword, type, bg_preprocess, seed) : |
| |
| seed_everything(seed=seed) |
| if keyword not in prompt: |
| raise gr.Error("Prompt should contain keyword!") |
| elif " " in keyword: |
| raise gr.Error("Keyword should be one word!") |
| |
| images = gen_init(num_initial_image=num_initial_image, prompt=prompt,seed=seed, type=type, bg_preprocess=bg_preprocess) |
| points = point_e_gradio(images[0],'cuda') |
| |
| intermediate.images = images |
| intermediate.points = points |
| |
| coords = np.array(points.coords) |
| trace = go.Scatter3d(x=coords[:,0], y=coords[:,1], z=coords[:,2], mode='markers', marker=dict(size=2)) |
|
|
| layout = go.Layout( |
| scene=dict( |
| xaxis=dict( |
| title="", |
| showgrid=False, |
| zeroline=False, |
| showline=False, |
| ticks='', |
| showticklabels=False |
| ), |
| yaxis=dict( |
| title="", |
| showgrid=False, |
| zeroline=False, |
| showline=False, |
| ticks='', |
| showticklabels=False |
| ), |
| zaxis=dict( |
| title="", |
| showgrid=False, |
| zeroline=False, |
| showline=False, |
| ticks='', |
| showticklabels=False |
| ), |
| ), |
| margin=dict(l=0, r=0, b=0, t=0), |
| showlegend=False |
| ) |
|
|
| fig = go.Figure(data=[trace], layout=layout) |
| |
| return images[0], fig, gr.update(interactive=True) |
|
|
|
|
| def gen_pc_from_image(intermediate, image, prompt, keyword, bg_preprocess, seed) : |
| |
| seed_everything(seed=seed) |
| if keyword not in prompt: |
| raise gr.Error("Prompt should contain keyword!") |
| elif " " in keyword: |
| raise gr.Error("Keyword should be one word!") |
| |
| if bg_preprocess: |
| import cv2 |
| from carvekit.api.high import HiInterface |
| interface = HiInterface(object_type="object", |
| batch_size_seg=5, |
| batch_size_matting=1, |
| device='cuda' if torch.cuda.is_available() else 'cpu', |
| seg_mask_size=640, |
| matting_mask_size=2048, |
| trimap_prob_threshold=231, |
| trimap_dilation=30, |
| trimap_erosion_iters=5, |
| fp16=False) |
| |
| img_without_background = interface([image]) |
| mask = np.array(img_without_background[0]) > 127 |
| image = np.array(image) |
| image[~mask] = [255., 255., 255.] |
| image = Image.fromarray(np.array(image)) |
| |
| |
| points = point_e_gradio(image,'cuda') |
| |
| intermediate.images = [image] |
| intermediate.points = points |
| |
| coords = np.array(points.coords) |
| trace = go.Scatter3d(x=coords[:,0], y=coords[:,1], z=coords[:,2], mode='markers', marker=dict(size=2)) |
|
|
| layout = go.Layout( |
| scene=dict( |
| xaxis=dict( |
| title="", |
| showgrid=False, |
| zeroline=False, |
| showline=False, |
| ticks='', |
| showticklabels=False |
| ), |
| yaxis=dict( |
| title="", |
| showgrid=False, |
| zeroline=False, |
| showline=False, |
| ticks='', |
| showticklabels=False |
| ), |
| zaxis=dict( |
| title="", |
| showgrid=False, |
| zeroline=False, |
| showline=False, |
| ticks='', |
| showticklabels=False |
| ), |
| ), |
| margin=dict(l=0, r=0, b=0, t=0), |
| showlegend=False |
| ) |
|
|
| fig = go.Figure(data=[trace], layout=layout) |
|
|
| return image, fig, gr.update(interactive=True) |
|
|
| def gen_init(num_initial_image, prompt,seed,type="Karlo", bg_preprocess=False): |
| pipe = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha", torch_dtype=torch.float16) if type=="Karlo (Recommended)" \ |
| else DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") |
| pipe = pipe.to('cuda') |
| |
| view_prompt=["front view of ","overhead view of ","side view of ", "back view of "] |
| |
| if bg_preprocess: |
| import cv2 |
| from carvekit.api.high import HiInterface |
| interface = HiInterface(object_type="object", |
| batch_size_seg=5, |
| batch_size_matting=1, |
| device='cuda' if torch.cuda.is_available() else 'cpu', |
| seg_mask_size=640, |
| matting_mask_size=2048, |
| trimap_prob_threshold=231, |
| trimap_dilation=30, |
| trimap_erosion_iters=5, |
| fp16=False) |
|
|
| images = [] |
| generator = torch.Generator(device='cuda').manual_seed(seed) |
| for i in range(num_initial_image): |
| t=", white background" if bg_preprocess else ", white background" |
| if i==0: |
| prompt_ = f"{view_prompt[i%4]}{prompt}{t}" |
| else: |
| prompt_ = f"{view_prompt[i%4]}{prompt}" |
|
|
| image = pipe(prompt_, generator=generator).images[0] |
| |
| if bg_preprocess: |
| |
| |
| |
| img_without_background = interface([image]) |
| mask = np.array(img_without_background[0]) > 127 |
| image = np.array(image) |
| image[~mask] = [255., 255., 255.] |
| image = Image.fromarray(np.array(image)) |
| images.append(image) |
| |
| return images |
| |
| |
|
|
| if __name__ == '__main__': |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument('--share', action='store_true', help="public url") |
| args = parser.parse_args() |
|
|
| weights_dir = './weights' |
| if not os.path.exists(weights_dir): |
| os.makedirs(weights_dir) |
| weights_path = os.path.join(weights_dir, '3DFuse_sparse_depth_injector.ckpt') |
|
|
| |
| if not os.path.isfile(weights_path): |
| url = 'https://huggingface.co/jyseo/3DFuse_weights/resolve/main/models/3DFuse_sparse_depth_injector.ckpt' |
| wget.download(url, weights_path) |
| print(f'{weights_path} downloaded.') |
| else: |
| print(f'{weights_path} already exists.') |
|
|
| |
| model = None |
| intermediate = Intermediate() |
| demo = gr.Blocks(title="3DFuse Interactive Demo") |
| |
| with demo: |
| with gr.Box(): |
| gr.Markdown(SHARED_UI_WARNING) |
| |
| gr.Markdown("# 3DFuse Interactive Demo") |
| gr.Markdown("### Official Implementation of the paper \"Let 2D Diffusion Model Know 3D-Consistency for Robust Text-to-3D Generation\"") |
| gr.Markdown("Enter your own prompt and enjoy! With this demo, you can preview the point cloud before 3D generation and determine the desired shape.") |
| |
| with gr.Row(): |
| with gr.Column(scale=1., variant='panel'): |
| |
| with gr.Tab("Text to 3D"): |
| prompt_input = gr.Textbox(label="Prompt", value="a comfortable bed", interactive=True) |
| word_input = gr.Textbox(label="Keyword for initialization (should be a noun included in the prompt)", value="bed", interactive=True) |
| semantic_model_choice = gr.Radio(["Karlo (Recommended)","Stable Diffusion"], label="Backbone for initial image generation", value="Karlo (Recommended)", interactive=True) |
| seed = gr.Slider(label="Seed", minimum=0, maximum=2100000000, step=1, randomize=True) |
| preprocess_choice = gr.Checkbox(True, label="Preprocess intially-generated image by removing background", interactive=True) |
| with gr.Accordion("Advanced Options", open=False): |
| opt_step = gr.Slider(0, 1000, value=500, step=1, label='Number of text embedding optimization step') |
| pivot_step = gr.Slider(0, 1000, value=500, step=1, label='Number of pivotal tuning step for LoRA') |
| with gr.Row(): |
| button_gen_pc = gr.Button("1. Generate Point Cloud", interactive=True, variant='secondary') |
| button_gen_3d = gr.Button("2. Generate 3D", interactive=False, variant='primary') |
| |
| with gr.Tab("Image to 3D"): |
| image_input = gr.Image(source='upload', type="pil", interactive=True) |
| prompt_input_2 = gr.Textbox(label="Prompt", value="a dog", interactive=True) |
| word_input_2 = gr.Textbox(label="Keyword for initialization (should be a noun included in the prompt)", value="dog", interactive=True) |
| seed_2 = gr.Slider(label="Seed", minimum=0, maximum=2100000000, step=1, randomize=True) |
| preprocess_choice_2 = gr.Checkbox(True, label="Preprocess intially-generated image by removing background", interactive=False) |
| with gr.Accordion("Advanced Options", open=False): |
| opt_step_2 = gr.Slider(0, 1000, value=500, step=1, label='Number of text embedding optimization step') |
| pivot_step_2 = gr.Slider(0, 1000, value=500, step=1, label='Number of pivotal tuning step for LoRA') |
| with gr.Row(): |
| button_gen_pc_2 = gr.Button("1. Generate Point Cloud", interactive=True, variant='secondary') |
| button_gen_3d_2 = gr.Button("2. Generate 3D", interactive=False, variant='primary') |
| gr.Markdown("Note: A photo showing the entire object in a front view is recommended. Also, our framework is not designed with the goal of single shot reconstruction, so it may be difficult to reflect the details of the given image.") |
| |
| |
| with gr.Row(scale=1.): |
| with gr.Column(scale=1.): |
| pc_plot = gr.Plot(label="Inferred point cloud") |
| with gr.Column(scale=1.): |
| init_output = gr.Image(label='Generated initial image', interactive=False) |
| |
| |
| |
| with gr.Column(scale=1., variant='panel'): |
| with gr.Row(): |
| with gr.Column(scale=1.): |
| intermediate_output = gr.Image(label="Intermediate Output", interactive=False) |
| with gr.Column(scale=1.): |
| logs = gr.Textbox(label="logs", lines=8, max_lines=8, interactive=False) |
| with gr.Row(): |
| video_result = gr.Video(label="Video result for generated 3D", interactive=False) |
| |
| gr.Markdown("Note: Keyword is used for Textual Inversion. Please choose one important noun included in the prompt. This demo may be slower than directly running run_3DFuse.py.") |
| |
| |
| |
| button_gen_pc.click(fn=partial(gen_pc_from_prompt,intermediate,4), inputs=[prompt_input, word_input, semantic_model_choice, \ |
| preprocess_choice, seed], outputs=[init_output, pc_plot, button_gen_3d]) |
| button_gen_3d.click(fn=partial(gen_3d,model,intermediate), inputs=[prompt_input, word_input, seed, opt_step, pivot_step], \ |
| outputs=[intermediate_output,logs,video_result]) |
| |
| button_gen_pc_2.click(fn=partial(gen_pc_from_image,intermediate), inputs=[image_input, prompt_input_2, word_input_2, \ |
| preprocess_choice_2, seed_2], outputs=[init_output, pc_plot, button_gen_3d_2]) |
| button_gen_3d_2.click(fn=partial(gen_3d,model,intermediate), inputs=[prompt_input_2, word_input_2, seed_2, opt_step_2, pivot_step_2], \ |
| outputs=[intermediate_output,logs,video_result]) |
| |
| |
| demo.queue(concurrency_count=1) |
| demo.launch(share=args.share) |
| |
| |
| |
| |
| |
|
|