Spaces:
Build error
Build error
| # -*- coding: utf-8 -*- | |
| import os | |
| import time | |
| from collections import OrderedDict | |
| from PIL import Image | |
| import torch | |
| import trimesh | |
| from typing import Optional, List | |
| from einops import repeat, rearrange | |
| import numpy as np | |
| from michelangelo.models.tsal.tsal_base import Latent2MeshOutput | |
| from michelangelo.utils.misc import get_config_from_file, instantiate_from_config | |
| from michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer | |
| from michelangelo.utils.visualizers import html_util | |
| import gradio as gr | |
| from omegaconf import OmegaConf | |
| from huggingface_hub import snapshot_download | |
| gradio_cached_dir = "./gradio_cached_dir" | |
| os.makedirs(gradio_cached_dir, exist_ok=True) | |
| save_mesh = False | |
| state = "" | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| box_v = 1.1 | |
| viewer = PyThreeJSViewer(settings={}, render_mode="WEBSITE") | |
| image_model_config_dict = OrderedDict({ | |
| "ASLDM-256-obj": { | |
| # "config": "./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml", | |
| # "ckpt_path": "./checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt", | |
| "config": "./configs/image_cond_diffuser_asl/image-ASLDM-256.yaml", | |
| "ckpt_path": "checkpoints/image_cond_diffuser_asl/image-ASLDM-256.ckpt", | |
| }, | |
| }) | |
| text_model_config_dict = OrderedDict({ | |
| "ASLDM-256": { | |
| # "config": "./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml", | |
| # "ckpt_path": "./checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt", | |
| "config": "./configs/text_cond_diffuser_asl/text-ASLDM-256.yaml", | |
| "ckpt_path": "checkpoints/text_cond_diffuser_asl/text-ASLDM-256.ckpt", | |
| }, | |
| }) | |
| model_path = snapshot_download(repo_id="Maikou/Michelangelo") | |
| class InferenceModel(object): | |
| model = None | |
| name = "" | |
| text2mesh_model = InferenceModel() | |
| image2mesh_model = InferenceModel() | |
| def set_state(s): | |
| global state | |
| state = s | |
| print(s) | |
| def output_to_html_frame(mesh_outputs: List[Latent2MeshOutput], bbox_size: float, | |
| image: Optional[np.ndarray] = None, | |
| html_frame: bool = False): | |
| global viewer | |
| for i in range(len(mesh_outputs)): | |
| mesh = mesh_outputs[i] | |
| if mesh is None: | |
| continue | |
| mesh_v = mesh.mesh_v.copy() | |
| mesh_v[:, 0] += i * np.max(bbox_size) | |
| mesh_v[:, 2] += np.max(bbox_size) | |
| viewer.add_mesh(mesh_v, mesh.mesh_f) | |
| mesh_tag = viewer.to_html(html_frame=False) | |
| if image is not None: | |
| image_tag = html_util.to_image_embed_tag(image) | |
| frame = f""" | |
| <table border = "1"> | |
| <tr> | |
| <td>{image_tag}</td> | |
| <td>{mesh_tag}</td> | |
| </tr> | |
| </table> | |
| """ | |
| else: | |
| frame = mesh_tag | |
| if html_frame: | |
| frame = html_util.to_html_frame(frame) | |
| viewer.reset() | |
| return frame | |
| def load_model(model_name: str, model_config_dict: dict, inference_model: InferenceModel): | |
| global device | |
| if inference_model.name == model_name: | |
| model = inference_model.model | |
| else: | |
| assert model_name in model_config_dict | |
| if inference_model.model is not None: | |
| del inference_model.model | |
| config_ckpt_path = model_config_dict[model_name] | |
| # raw_config_file = config_ckpt_path["config"] | |
| # raw_config = OmegaConf.load(raw_config_file) | |
| # raw_clip_ckpt_path = raw_config['model']['params']['first_stage_config']['params']['aligned_module_cfg']['params']['clip_model_version'] | |
| # clip_ckpt_path = os.path.join(model_path, raw_clip_ckpt_path) | |
| # raw_config['model']['params']['first_stage_config']['params']['aligned_module_cfg']['params']['clip_model_version'] = clip_ckpt_path | |
| # raw_config['model']['params']['cond_stage_config']['params']['version'] = clip_ckpt_path | |
| # OmegaConf.save(raw_config, 'current_config.yaml') | |
| # model_config = get_config_from_file('current_config.yaml') | |
| model_config = get_config_from_file(config_ckpt_path["config"]) | |
| if hasattr(model_config, "model"): | |
| model_config = model_config.model | |
| ckpt_path = os.path.join(model_path, config_ckpt_path["ckpt_path"]) | |
| model = instantiate_from_config(model_config, ckpt_path=ckpt_path) | |
| model = model.to(device) | |
| model = model.eval() | |
| inference_model.model = model | |
| inference_model.name = model_name | |
| return model | |
| def prepare_img(image: np.ndarray): | |
| image_pt = torch.tensor(image).float() | |
| image_pt = image_pt / 255 * 2 - 1 | |
| image_pt = rearrange(image_pt, "h w c -> c h w") | |
| return image_pt | |
| def prepare_model_viewer(fp): | |
| content = f""" | |
| <head> | |
| <script | |
| type="module" src="https://ajax.googleapis.com/ajax/libs/model-viewer/3.1.1/model-viewer.min.js"> | |
| </script> | |
| </head> | |
| <body> | |
| <model-viewer | |
| style="height: 150px; width: 150px;" | |
| rotation-per-second="10deg" | |
| id="t1" | |
| src="file/gradio_cached_dir/{fp}" | |
| environment-image="neutral" | |
| camera-target="0m 0m 0m" | |
| orientation="0deg 90deg 170deg" | |
| shadow-intensity="1" | |
| ar:true | |
| auto-rotate | |
| camera-controls> | |
| </model-viewer> | |
| </body> | |
| """ | |
| return content | |
| def prepare_html_frame(content): | |
| frame = f""" | |
| <html> | |
| <body> | |
| {content} | |
| </body> | |
| </html> | |
| """ | |
| return frame | |
| def prepare_html_body(content): | |
| frame = f""" | |
| <body> | |
| {content} | |
| </body> | |
| """ | |
| return frame | |
| def post_process_mesh_outputs(mesh_outputs): | |
| # html_frame = output_to_html_frame(mesh_outputs, 2 * box_v, image=None, html_frame=True) | |
| html_content = output_to_html_frame(mesh_outputs, 2 * box_v, image=None, html_frame=False) | |
| html_frame = prepare_html_frame(html_content) | |
| # filename = f"{time.time()}.html" | |
| filename = f"four-in-one-{time.time()}.html" | |
| html_filepath = os.path.join(gradio_cached_dir, filename) | |
| with open(html_filepath, "w") as writer: | |
| writer.write(html_frame) | |
| ''' | |
| Bug: The iframe tag does not work in Gradio. | |
| The chrome returns "No resource with given URL found" | |
| Solutions: | |
| https://github.com/gradio-app/gradio/issues/884 | |
| Due to the security bitches, the server can only find files parallel to the gradio_app.py. | |
| The path has format "file/TARGET_FILE_PATH" | |
| ''' | |
| iframe_tag = f'<iframe src="file/gradio_cached_dir/{filename}" width="600%" height="400" frameborder="0"></iframe>' | |
| filelist = [] | |
| filenames = [] | |
| for i, mesh in enumerate(mesh_outputs): | |
| mesh.mesh_f = mesh.mesh_f[:, ::-1] | |
| mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f) | |
| name = str(i) + "_out_mesh.obj" | |
| filepath = gradio_cached_dir + "/" + name | |
| mesh_output.export(filepath, include_normals=True) | |
| filelist.append(filepath) | |
| filenames.append(name) | |
| filelist.append(html_filepath) | |
| return iframe_tag, filelist | |
| def image2mesh(image: np.ndarray, | |
| model_name: str = "subsp+pk_asl_perceiver=01_01_udt=03", | |
| num_samples: int = 4, | |
| guidance_scale: int = 7.5, | |
| octree_depth: int = 7): | |
| global device, gradio_cached_dir, image_model_config_dict, box_v | |
| # load model | |
| model = load_model(model_name, image_model_config_dict, image2mesh_model) | |
| # prepare image inputs | |
| image_pt = prepare_img(image) | |
| image_pt = repeat(image_pt, "c h w -> b c h w", b=num_samples) | |
| sample_inputs = { | |
| "image": image_pt | |
| } | |
| mesh_outputs = model.sample( | |
| sample_inputs, | |
| sample_times=1, | |
| guidance_scale=guidance_scale, | |
| return_intermediates=False, | |
| bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v], | |
| octree_depth=octree_depth, | |
| )[0] | |
| iframe_tag, filelist = post_process_mesh_outputs(mesh_outputs) | |
| return iframe_tag, gr.update(value=filelist, visible=True) | |
| def text2mesh(text: str, | |
| model_name: str = "subsp+pk_asl_perceiver=01_01_udt=03", | |
| num_samples: int = 4, | |
| guidance_scale: int = 7.5, | |
| octree_depth: int = 7): | |
| global device, gradio_cached_dir, text_model_config_dict, text2mesh_model, box_v | |
| # load model | |
| model = load_model(model_name, text_model_config_dict, text2mesh_model) | |
| # prepare text inputs | |
| sample_inputs = { | |
| "text": [text] * num_samples | |
| } | |
| mesh_outputs = model.sample( | |
| sample_inputs, | |
| sample_times=1, | |
| guidance_scale=guidance_scale, | |
| return_intermediates=False, | |
| bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v], | |
| octree_depth=octree_depth, | |
| )[0] | |
| iframe_tag, filelist = post_process_mesh_outputs(mesh_outputs) | |
| return iframe_tag, gr.update(value=filelist, visible=True) | |
| example_dir = './gradio_cached_dir/example/img_example' | |
| first_page_items = [ | |
| 'alita.jpg', | |
| 'burger.jpg' | |
| 'loopy.jpg' | |
| 'building.jpg', | |
| 'mario.jpg', | |
| 'car.jpg', | |
| 'airplane.jpg', | |
| 'bag.jpg', | |
| 'bench.jpg', | |
| 'ship.jpg' | |
| ] | |
| raw_example_items = [ | |
| # (os.path.join(example_dir, x), x) | |
| os.path.join(example_dir, x) | |
| for x in os.listdir(example_dir) | |
| if x.endswith(('.jpg', '.png')) | |
| ] | |
| example_items = [x for x in raw_example_items if os.path.basename(x) in first_page_items] + [x for x in raw_example_items if os.path.basename(x) not in first_page_items] | |
| example_text = [ | |
| ["A 3D model of a car; Audi A6."], | |
| ["A 3D model of police car; Highway Patrol Charger"] | |
| ], | |
| def set_cache(data: gr.SelectData): | |
| img_name = os.path.basename(example_items[data.index]) | |
| return os.path.join(example_dir, img_name), os.path.join(img_name) | |
| def disable_cache(): | |
| return "" | |
| with gr.Blocks() as app: | |
| gr.Markdown("# Michelangelo") | |
| gr.Markdown("## [Github](https://github.com/NeuralCarver/Michelangelo) | [Arxiv](https://arxiv.org/abs/2306.17115) | [Project Page](https://neuralcarver.github.io/michelangelo/)") | |
| gr.Markdown("Michelangelo is a conditional 3D shape generation system that trains based on the shape-image-text aligned latent representation.") | |
| gr.Markdown("### Hint:") | |
| gr.Markdown("1. We provide two APIs: Image-conditioned generation and Text-conditioned generation") | |
| gr.Markdown("2. Note that the Image-conditioned model is trained on multiple 3D datasets like ShapeNet and Objaverse") | |
| gr.Markdown("3. We provide some examples for you to try. You can also upload images or text as input.") | |
| gr.Markdown("4. To make it convenient to take favor results home, we provide download buttons for each OBJ file and a combined HTML file.") | |
| gr.Markdown("5. Welcome to share suggestions or amazing results with us, and thanks for your interest in our work!") | |
| gr.Markdown("6. Please note that the model may require some time to download the weights and set up during the first launch; we are working to fix this issue.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Tab("Image to 3D"): | |
| img = gr.Image(label="Image") | |
| gr.Markdown("For the best results, we suggest that the images uploaded meet the following three criteria: 1. The object is positioned at the center of the image, 2. The image size is square, and 3. The background is relatively clean.") | |
| btn_generate_img2obj = gr.Button(value="Generate") | |
| with gr.Accordion("Advanced settings", open=False): | |
| image_dropdown_models = gr.Dropdown(label="Model", value="ASLDM-256-obj",choices=list(image_model_config_dict.keys())) | |
| num_samples = gr.Slider(label="samples", value=4, minimum=1, maximum=8, step=1) | |
| guidance_scale = gr.Slider(label="Guidance scale", value=7.5, minimum=3.0, maximum=10.0, step=0.1) | |
| octree_depth = gr.Slider(label="Octree Depth (for 3D model)", value=7, minimum=4, maximum=8, step=1) | |
| cache_dir = gr.Textbox(value="", visible=False) | |
| examples = gr.Gallery(label='Examples', value=example_items, elem_id="gallery", allow_preview=False, columns=[4], object_fit="contain") | |
| with gr.Tab("Text to 3D"): | |
| prompt = gr.Textbox(label="Prompt", placeholder="A 3D model of motorcar; Porche Cayenne Turbo.") | |
| gr.Markdown("For the best results, we suggest that the prompt follows 'A 3D model of CATEGORY; DESCRIPTION'. For example, A 3D model of motorcar; Porche Cayenne Turbo.") | |
| btn_generate_txt2obj = gr.Button(value="Generate") | |
| with gr.Accordion("Advanced settings", open=False): | |
| text_dropdown_models = gr.Dropdown(label="Model", value="ASLDM-256",choices=list(text_model_config_dict.keys())) | |
| num_samples = gr.Slider(label="samples", value=4, minimum=1, maximum=8, step=1) | |
| guidance_scale = gr.Slider(label="Guidance scale", value=7.5, minimum=3.0, maximum=10.0, step=0.1) | |
| octree_depth = gr.Slider(label="Octree Depth (for 3D model)", value=7, minimum=4, maximum=8, step=1) | |
| gr.Markdown("#### Examples:") | |
| gr.Markdown("1. A 3D model of an airplane; Airbus.") | |
| gr.Markdown("2. A 3D model of a fighter aircraft; Attack Fighter.") | |
| gr.Markdown("3. A 3D model of a chair; Simple Wooden Chair.") | |
| gr.Markdown("4. A 3D model of a laptop computer; Dell Laptop.") | |
| gr.Markdown("5. A 3D model of a coupe; Audi A6.") | |
| gr.Markdown("6. A 3D model of a motorcar; Hummer H2 SUT.") | |
| gr.Markdown("7. A 3D model of a lamp; Light Post.") | |
| gr.Markdown("8. A 3D model of a rifle; AK47.") | |
| gr.Markdown("9. A 3D model of a knife; Sword.") | |
| gr.Markdown("10. A 3D model of a vase; Plant in pot.") | |
| with gr.Column(): | |
| model_3d = gr.HTML() | |
| file_out = gr.File(label="Files", visible=False) | |
| outputs = [model_3d, file_out] | |
| img.upload(disable_cache, outputs=cache_dir) | |
| examples.select(set_cache, outputs=[img, cache_dir]) | |
| print(os.path.abspath(os.path.dirname(__file__)), flush=True) | |
| print(model_path, flush=True) | |
| fps = os.listdir(model_path) | |
| print(fps) | |
| print(f'line:404: {cache_dir}', flush=True) | |
| btn_generate_img2obj.click(image2mesh, inputs=[img, image_dropdown_models, num_samples, | |
| guidance_scale, octree_depth], | |
| outputs=outputs, api_name="generate_img2obj") | |
| btn_generate_txt2obj.click(text2mesh, inputs=[prompt, text_dropdown_models, num_samples, | |
| guidance_scale, octree_depth], | |
| outputs=outputs, api_name="generate_txt2obj") | |
| app.launch() |