Spaces:
Build error
Build error
| import os | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| import torch | |
| import json | |
| import glob | |
| import numpy as np | |
| from PIL import Image | |
| import time | |
| import copy | |
| import sys | |
| # Mesh imports | |
| from pytorch3d.io import load_objs_as_meshes | |
| from pytorch3d.vis.plotly_vis import AxisArgs, plot_scene | |
| from pytorch3d.transforms import RotateAxisAngle, Translate | |
| from sampling_for_demo import load_and_return_model_and_data, sample, load_base_model | |
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
| def transform_mesh(mesh, transform, scale=1.0): | |
| mesh = mesh.clone() | |
| verts = mesh.verts_packed() * scale | |
| verts = transform.transform_points(verts) | |
| mesh.offset_verts_(verts - mesh.verts_packed()) | |
| return mesh | |
| def get_input_pose_fig(category=None): | |
| global curr_camera_dict | |
| global obj_filename | |
| global plane_trans | |
| plane_filename = 'assets/plane.obj' | |
| mesh_scale = 0.75 | |
| mesh = load_objs_as_meshes([obj_filename], device=device) | |
| mesh.scale_verts_(mesh_scale) | |
| plane = load_objs_as_meshes([plane_filename], device=device) | |
| ### plane | |
| rotate_x = RotateAxisAngle(angle=90.0, axis='X', device=device) | |
| plane = transform_mesh(plane, rotate_x) | |
| if category == "teddybear": | |
| rotate_teddy = RotateAxisAngle(angle=15.0, axis='X', device=device) | |
| plane = transform_mesh(plane, rotate_teddy) | |
| translate_y = Translate(0, plane_trans * mesh_scale, 0, device=device) | |
| plane = transform_mesh(plane, translate_y) | |
| fig = plot_scene({ | |
| "plot": { | |
| "object": mesh, | |
| }, | |
| }, | |
| axis_args=AxisArgs(showgrid=True, backgroundcolor='#cccde0'), | |
| xaxis=dict(range=[-1, 1]), | |
| yaxis=dict(range=[-1, 1]), | |
| zaxis=dict(range=[-1, 1]) | |
| ) | |
| plane = plane.detach().cpu() | |
| verts = plane.verts_packed() | |
| faces = plane.faces_packed() | |
| fig.add_trace( | |
| go.Mesh3d( | |
| x=verts[:, 0], | |
| y=verts[:, 1], | |
| z=verts[:, 2], | |
| i=faces[:, 0], | |
| j=faces[:, 1], | |
| k=faces[:, 2], | |
| opacity=0.7, | |
| color='gray', | |
| hoverinfo='skip', | |
| ), | |
| ) | |
| print("fig: curr camera dict") | |
| print(curr_camera_dict) | |
| camera_dict = curr_camera_dict | |
| fig.update_layout(scene=dict( | |
| xaxis=dict(showticklabels=True, visible=True), | |
| yaxis=dict(showticklabels=True, visible=True), | |
| zaxis=dict(showticklabels=True, visible=True), | |
| )) | |
| # show grid | |
| fig.update_layout(scene=dict( | |
| xaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'), | |
| yaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'), | |
| zaxis=dict(showgrid=True, gridwidth=1, gridcolor='black'), | |
| bgcolor='#dedede', | |
| )) | |
| fig.update_layout( | |
| camera_dict, | |
| width=512, height=512, | |
| ) | |
| return fig | |
| def run_inference(cam_pose_json, prompt, scale_im, scale, steps, seed): | |
| print("prompt is ", prompt) | |
| global current_data, current_model | |
| # run model | |
| images = sample( | |
| current_model, current_data, | |
| num_images=1, | |
| prompt=prompt, | |
| appendpath="", | |
| camera_json=cam_pose_json, | |
| train=False, | |
| scale=scale, | |
| scale_im=scale_im, | |
| beta=1.0, | |
| num_ref=8, | |
| skipreflater=False, | |
| num_steps=steps, | |
| valid=False, | |
| max_images=20, | |
| seed=seed | |
| ) | |
| result = images[0] | |
| print(result.shape) | |
| result = Image.fromarray((np.clip(((result+1.0)/2.0).permute(1, 2, 0).cpu().numpy(), 0., 1.)*255).astype(np.uint8)) | |
| print('result obtained') | |
| return result | |
| def update_curr_camera_dict(camera_json): | |
| # TODO: this does not always update the figure, also there's always flashes | |
| global curr_camera_dict | |
| global prev_camera_dict | |
| if camera_json is None: | |
| camera_json = json.dumps(prev_camera_dict) | |
| camera_json = camera_json.replace("'", "\"") | |
| curr_camera_dict = json.loads(camera_json) # ["scene.camera"] | |
| print("update curr camera dict") | |
| print(curr_camera_dict) | |
| return camera_json | |
| MODELS_DIR = "pretrained-models/" | |
| def select_and_load_model(category, category_single_id): | |
| global current_data, current_model, base_model | |
| del current_model | |
| del current_data | |
| torch.cuda.empty_cache() | |
| current_model = copy.deepcopy(base_model) | |
| ### choose model checkpoint and config | |
| delta_ckpt = glob.glob(f"{MODELS_DIR}/*{category}{category_single_id}*/checkpoints/step=*.ckpt")[0] | |
| print(f"Loading model from {delta_ckpt}") | |
| logdir = delta_ckpt.split('/checkpoints')[0] | |
| config = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))[-1] | |
| start_time = time.time() | |
| current_model, current_data = load_and_return_model_and_data(config, current_model, | |
| delta_ckpt=delta_ckpt | |
| ) | |
| print(f"Time taken to load delta model: {time.time() - start_time:.2f}s") | |
| print("!!! model loaded") | |
| if category == "car": | |
| input_prompt = "A <new1> car parked by a snowy mountain range" | |
| elif category == "chair": | |
| input_prompt = "A <new1> chair in a garden surrounded by flowers" | |
| elif category == "motorcycle": | |
| input_prompt = "A <new1> motorcycle beside a calm lake" | |
| elif category == "teddybear": | |
| input_prompt = "A <new1> teddy bear on the sand at the beach" | |
| return "### Model loaded!", input_prompt | |
| global current_data | |
| global current_model | |
| current_data = None | |
| current_model = None | |
| global base_model | |
| BASE_CONFIG = "custom-diffusion360/configs/train_co3d_concept.yaml" | |
| BASE_CKPT = "pretrained-models/sd_xl_base_1.0.safetensors" | |
| base_model = None | |
| ORIGINAL_SPACE_ID = "customdiffusion360/customdiffusion360" | |
| SPACE_ID = os.getenv("SPACE_ID") | |
| if SPACE_ID != ORIGINAL_SPACE_ID: | |
| start_time = time.time() | |
| base_model = load_base_model(BASE_CONFIG, ckpt=BASE_CKPT, verbose=False) | |
| print(f"Time taken to load base model: {time.time() - start_time:.2f}s") | |
| global curr_camera_dict | |
| curr_camera_dict = { | |
| "scene.camera": { | |
| "up": {"x": -0.13227683305740356, | |
| "y": -0.9911391735076904, | |
| "z": -0.013464212417602539}, | |
| "center": {"x": -0.005292057991027832, | |
| "y": 0.020704858005046844, | |
| "z": 0.0873757004737854}, | |
| "eye": {"x": 0.8585731983184814, | |
| "y": -0.08790968358516693, | |
| "z": -0.40458938479423523}, | |
| }, | |
| "scene.aspectratio": {"x": 1.974, "y": 1.974, "z": 1.974}, | |
| "scene.aspectmode": "manual" | |
| } | |
| global prev_camera_dict | |
| prev_camera_dict = copy.deepcopy(curr_camera_dict) | |
| global obj_filename | |
| obj_filename = "assets/car0_mesh_centered_flipped.obj" | |
| global plane_trans | |
| plane_trans = 0.16 | |
| my_fig = get_input_pose_fig() | |
| scripts = open("scripts.js", "r").read() | |
| def update_category_single_id(category): | |
| global curr_camera_dict | |
| global prev_camera_dict | |
| global obj_filename | |
| global plane_trans | |
| choices = None | |
| if category == "car": | |
| choices = ["0"] | |
| curr_camera_dict = { | |
| "scene.camera": { | |
| "up": {"x": -0.13227683305740356, | |
| "y": -0.9911391735076904, | |
| "z": -0.013464212417602539}, | |
| "center": {"x": -0.005292057991027832, | |
| "y": 0.020704858005046844, | |
| "z": 0.0873757004737854}, | |
| "eye": {"x": 0.8585731983184814, | |
| "y": -0.08790968358516693, | |
| "z": -0.40458938479423523}, | |
| }, | |
| "scene.aspectratio": {"x": 1.974, "y": 1.974, "z": 1.974}, | |
| "scene.aspectmode": "manual" | |
| } | |
| plane_trans = 0.16 | |
| elif category == "chair": | |
| choices = ["191"] | |
| curr_camera_dict = { | |
| "scene.camera": { | |
| "up": {"x": 1.0477e-04, | |
| "y": -9.9995e-01, | |
| "z": 1.0288e-02}, | |
| "center": {"x": 0.0539, | |
| "y": 0.0015, | |
| "z": 0.0007}, | |
| "eye": {"x": 0.0410, | |
| "y": -0.0091, | |
| "z": -0.9991}, | |
| }, | |
| "scene.aspectratio": {"x": 0.9084, "y": 0.9084, "z": 0.9084}, | |
| "scene.aspectmode": "manual" | |
| } | |
| plane_trans = 0.38 | |
| elif category == "motorcycle": | |
| choices = ["12"] | |
| curr_camera_dict = { | |
| "scene.camera": { | |
| "up": {"x": 0.0308, | |
| "y": -0.9994, | |
| "z": -0.0147}, | |
| "center": {"x": 0.0240, | |
| "y": -0.0310, | |
| "z": -0.0016}, | |
| "eye": {"x": -0.0580, | |
| "y": -0.0188, | |
| "z": -0.9981}, | |
| }, | |
| "scene.aspectratio": {"x": 1.5786, "y": 1.5786, "z": 1.5786}, | |
| "scene.aspectmode": "manual" | |
| } | |
| plane_trans = 0.2 | |
| elif category == "teddybear": | |
| choices = ["31"] | |
| curr_camera_dict = { | |
| "scene.camera": { | |
| "up": {"x": 0.4304, | |
| "y": -0.9023, | |
| "z": -0.0221}, | |
| "center": {"x": -0.0658, | |
| "y": 0.2081, | |
| "z": 0.0175}, | |
| "eye": {"x": -0.4456, | |
| "y": 0.0493, | |
| "z": -0.8939}, | |
| }, | |
| "scene.aspectratio": {"x": 1.8052, "y": 1.8052, "z": 1.8052}, | |
| "scene.aspectmode": "manual", | |
| } | |
| plane_trans = 0.3 | |
| obj_filename = f"assets/{category}{choices[0]}_mesh_centered_flipped.obj" | |
| prev_camera_dict = copy.deepcopy(curr_camera_dict) | |
| return gr.Dropdown(choices=choices, label="Object ID", value=choices[0]) | |
| head = """ | |
| <script src="https://cdn.plot.ly/plotly-2.30.0.min.js" charset="utf-8"></script> | |
| """ | |
| with gr.Blocks(head=head, | |
| css="style.css", | |
| js=scripts, | |
| title="Customizing Text-to-Image Diffusion with Camera Viewpoint Control") as demo: | |
| gr.HTML(""" | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
| <div> | |
| <h2><a href='https://customdiffusion360.github.io/index.html'>Customizing Text-to-Image Diffusion with Camera Viewpoint Control</a></h2> | |
| </div> | |
| </div> | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
| <a href='https://customdiffusion360.github.io/index.html' style="padding: 10px;"> | |
| <img src='https://img.shields.io/badge/Project%20Page-8A2BE2'> | |
| </a> | |
| <a href='https://arxiv.org/abs/2404.12333'> | |
| <img src="https://img.shields.io/badge/arXiv-2404.12333-red"> | |
| </a> | |
| <a class="link" href='https://github.com/customdiffusion360/custom-diffusion360' style="padding: 10px;"> | |
| <img src='https://img.shields.io/badge/Github-%23121011.svg'> | |
| </a> | |
| </div> | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
| <p> | |
| This is a demo for <a href='https://github.com/customdiffusion360/custom-diffusion360'>Custom Diffusion 360</a>. | |
| Please duplicate this space and upgrade the GPU to A10G Large in Settings to run the demo. | |
| </p> | |
| </div> | |
| <div style="display: flex; justify-content: center; align-items: center; text-align: center;"> | |
| <a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/customdiffusion360/customdiffusion360?duplicate=true"><img style="margin-top:0;margin-bottom:0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a> | |
| </div> | |
| <hr></hr> | |
| """, | |
| visible=True | |
| ) | |
| with gr.Row(): | |
| with gr.Column(min_width=150): | |
| gr.Markdown("## 1. SELECT CUSTOMIZED MODEL") | |
| category = gr.Dropdown(choices=["car", "chair", "motorcycle", "teddybear"], label="Category", value="car") | |
| category_single_id = gr.Dropdown(label="Object ID", choices=["0"], type="value", value="0", visible=False) | |
| category.change(update_category_single_id, [category], [category_single_id]) | |
| load_model_btn = gr.Button(value="Load Model", elem_id="load_model_button") | |
| load_model_status = gr.Markdown(elem_id="load_model_status", value="### Please select and load a model.") | |
| with gr.Column(min_width=512): | |
| gr.Markdown("## 2. CAMERA POSE VISUALIZATION") | |
| # TODO ? don't use gradio plotly element so we can remove menu buttons | |
| map = gr.Plot(value=my_fig, min_width=512, elem_id="map") | |
| ### hidden elements | |
| update_pose_btn = gr.Button(value="Update Camera Pose", visible=False, elem_id="update_pose_button") | |
| input_pose = gr.TextArea(value=curr_camera_dict, label="Input Camera Pose", visible=False, elem_id="input_pose", interactive=False) | |
| check_pose_btn = gr.Button(value="Check Camera Pose", visible=False, elem_id="check_pose_button") | |
| ## TODO: track init_camera_dict and with js? | |
| ### visible elements | |
| input_prompt = gr.Textbox(value="A <new1> car parked by a snowy mountain range", label="Prompt", interactive=True) | |
| scale_im = gr.Slider(value=3.5, label="Image guidance scale", minimum=0, maximum=20.0, step=0.1) | |
| scale = gr.Slider(value=7.5, label="Text guidance scale", minimum=0, maximum=20.0, step=0.1) | |
| steps = gr.Slider(value=10, label="Inference steps", minimum=1, maximum=50, step=1) | |
| seed = gr.Textbox(value=42, label="Seed") | |
| with gr.Column(min_width=50, elem_id="column_process", scale=0.3): | |
| run_btn = gr.Button(value="Run", elem_id="run_button", min_width=50) | |
| with gr.Column(min_width=512): | |
| gr.Markdown("## 3. OUR OUTPUT") | |
| result = gr.Image(show_label=False, show_download_button=True, width=512, height=512, elem_id="result") | |
| gr.Markdown("### Camera Pose Controls:") | |
| gr.Markdown("* Orbital rotation: Left-click and drag.") | |
| gr.Markdown("* Zoom: Mouse wheel scroll.") | |
| gr.Markdown("* Pan (translate the camera): Right-click and drag.") | |
| gr.Markdown("* Tilt camera: Tilt mouse wheel left/right.") | |
| gr.Markdown("* Reset to initial camera pose: Hover over the top right corner of the plot and click the camera icon.") | |
| gr.Markdown("### Note:") | |
| gr.Markdown("The models only work within a range of elevation angles and distances near the initial camera pose.") | |
| load_model_btn.click(select_and_load_model, [category, category_single_id], [load_model_status, input_prompt]) | |
| load_model_btn.click(get_input_pose_fig, [category], [map]) | |
| update_pose_btn.click(update_curr_camera_dict, [input_pose], [input_pose],) # js=send_js_camera_to_gradio) | |
| # check_pose_btn.click(check_curr_camera_dict, [], [input_pose]) | |
| run_btn.click(run_inference, [input_pose, input_prompt, scale_im, scale, steps, seed], result) | |
| demo.load(js=scripts) | |
| if __name__ == "__main__": | |
| demo.queue().launch(debug=True) | |