| import spaces |
| import gradio as gr |
| import glob |
| import hashlib |
| from PIL import Image |
| import os |
| import shlex |
| import subprocess |
|
|
| os.makedirs("./ckpt", exist_ok=True) |
| |
| subprocess.call(["wget", "-q", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "-O", "./ckpt/sam_vit_h_4b8939.pth"]) |
|
|
| subprocess.run( |
| shlex.split( |
| "pip install pip==24.0" |
| ) |
| ) |
| subprocess.run( |
| shlex.split( |
| "pip install package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps" |
| ) |
| ) |
|
|
| from infer_api import InferAPI |
|
|
| config_canocalize = { |
| 'config_path': './configs/canonicalization-infer.yaml', |
| } |
| config_multiview = {} |
| config_slrm = { |
| 'config_path': './configs/mesh-slrm-infer.yaml' |
| } |
| config_refine = {} |
|
|
| EXAMPLE_IMAGES = glob.glob("./input_cases/*") |
| EXAMPLE_APOSE_IMAGES = glob.glob("./input_cases_apose/*") |
|
|
| infer_api = InferAPI(config_canocalize, config_multiview, config_slrm, config_refine) |
|
|
| _HEADER_ = ''' |
| <h2><b>[CVPR 2025] StdGEN 🤗 Gradio Demo</b></h2> |
| This is official demo for our CVPR 2025 paper <a href="">StdGEN: Semantic-Decomposed 3D Character Generation from Single Images</a>. |
| |
| Code: <a href='https://github.com/hyz317/StdGEN' target='_blank'>GitHub</a>. Paper: <a href='https://arxiv.org/abs/2411.05738' target='_blank'>ArXiv</a>. |
| |
| ❗️❗️❗️**Important Notes:** This is only a **PREVIEW** version with **coarse precision geometry and texture** due to limited online resource. We skip some refinement process and perform only color back-projection to clothes and hair. Please refer to GitHub repo for complete version. |
| 1. Refinement stage takes about ~2.5min, and the mesh result may possibly delayed due to the server load, please wait patiently. |
| |
| 2. You can upload any reference image (with or without background), A-pose images are also supported (white bkg required). If the image has an alpha channel (transparency), background segmentation will be automatically performed. Alternatively, you can pre-segment the background using other tools and upload the result directly. |
| |
| 3. Real person images generally work well, but note that normals may appear smoother than expected. You can try to use other monocular normal estimation models. |
| |
| 4. The base human model in the output is uncolored due to potential NSFW concerns. If you need colored results, please refer to the official GitHub repository for instructions. |
| ''' |
|
|
| _CITE_ = r""" |
| If StdGEN is helpful, please help to ⭐ the <a href='https://github.com/hyz317/StdGEN' target='_blank'>GitHub Repo</a>. Thanks! [](https://github.com/hyz317/StdGEN) |
| --- |
| 📝 **Citation** |
| If you find our work useful for your research or applications, please cite using this bibtex: |
| ```bibtex |
| @article{he2024stdgen, |
| title={StdGEN: Semantic-Decomposed 3D Character Generation from Single Images}, |
| author={He, Yuze and Zhou, Yanning and Zhao, Wang and Wu, Zhongkai and Xiao, Kaiwen and Yang, Wei and Liu, Yong-Jin and Han, Xiao}, |
| journal={arXiv preprint arXiv:2411.05738}, |
| year={2024} |
| } |
| ``` |
| 📧 **Contact** |
| If you have any questions, feel free to open a discussion or contact us at <b>hyz22@mails.tsinghua.edu.cn</b>. |
| """ |
|
|
| cache_arbitrary = {} |
| cache_multiview = [ {}, {}, {} ] |
| cache_slrm = {} |
| cache_refine = {} |
|
|
| tmp_path = '/tmp' |
|
|
| |
| def arbitrary_to_apose(image, seed): |
| |
| image = Image.fromarray(image) |
| image_hash = str(hashlib.md5(image.tobytes()).hexdigest()) + '_' + str(seed) |
| if image_hash not in cache_arbitrary: |
| apose_img = infer_api.genStage1(image, seed) |
| apose_img.save(f'{tmp_path}/{image_hash}.png') |
| cache_arbitrary[image_hash] = f'{tmp_path}/{image_hash}.png' |
| print(f'cached apose image: {image_hash}') |
| return apose_img |
| else: |
| apose_img = Image.open(cache_arbitrary[image_hash]) |
| print(f'loaded cached apose image: {image_hash}') |
| return apose_img |
|
|
| def apose_to_multiview(apose_img, seed): |
| |
| apose_img = Image.fromarray(apose_img) |
| image_hash = str(hashlib.md5(apose_img.tobytes()).hexdigest()) + '_' + str(seed) |
| if image_hash not in cache_multiview[0]: |
| results = infer_api.genStage2(apose_img, seed, num_levels=1) |
| for idx, img in enumerate(results[0]["images"]): |
| img.save(f'{tmp_path}/{image_hash}_images_{idx}.png') |
| for idx, img in enumerate(results[0]["normals"]): |
| img.save(f'{tmp_path}/{image_hash}_normals_{idx}.png') |
| cache_multiview[0][image_hash] = { |
| "images": [f'{tmp_path}/{image_hash}_images_{idx}.png' for idx in range(len(results[0]["images"]))], |
| "normals": [f'{tmp_path}/{image_hash}_normals_{idx}.png' for idx in range(len(results[0]["normals"]))] |
| } |
| print(f'cached multiview images: {image_hash}') |
| return results[0]["images"], image_hash |
| else: |
| print(f'loaded cached multiview images: {image_hash}') |
| return [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["images"]], image_hash |
|
|
| def multiview_to_mesh(images, image_hash): |
| if image_hash not in cache_slrm: |
| mesh_files = infer_api.genStage3(images) |
| cache_slrm[image_hash] = mesh_files |
| print(f'cached slrm files: {image_hash}') |
| else: |
| mesh_files = cache_slrm[image_hash] |
| print(f'loaded cached slrm files: {image_hash}') |
| return *mesh_files, image_hash |
|
|
| def refine_mesh(mesh1, mesh2, mesh3, seed, image_hash): |
| apose_img = Image.open(cache_multiview[0][image_hash]["images"][0]) |
| if image_hash not in cache_refine: |
| results = infer_api.genStage2(apose_img, seed, num_levels=2) |
| results[0] = {} |
| results[0]["images"] = [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["images"]] |
| results[0]["normals"] = [Image.open(img_path) for img_path in cache_multiview[0][image_hash]["normals"]] |
| refined = infer_api.genStage4([mesh1, mesh2, mesh3], results) |
| cache_refine[image_hash] = refined |
| print(f'cached refined mesh: {image_hash}') |
| else: |
| refined = cache_refine[image_hash] |
| print(f'loaded cached refined mesh: {image_hash}') |
| |
| return refined |
|
|
| with gr.Blocks(title="StdGEN: Semantically Decomposed 3D Character Generation from Single Images") as demo: |
| gr.Markdown(_HEADER_) |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("## 1. Reference Image to A-pose Image") |
| input_image = gr.Image(label="Input Reference Image", type="numpy", width=384, height=384) |
| gr.Examples( |
| examples=EXAMPLE_IMAGES, |
| inputs=input_image, |
| label="Click to use sample images", |
| ) |
| seed_input = gr.Number( |
| label="Seed", |
| value=52, |
| precision=0, |
| interactive=True |
| ) |
| pose_btn = gr.Button("Convert") |
| with gr.Column(): |
| gr.Markdown("## 2. Multi-view Generation") |
| a_pose_image = gr.Image(label="A-pose Result", type="numpy", width=384, height=384) |
| gr.Examples( |
| examples=EXAMPLE_APOSE_IMAGES, |
| inputs=a_pose_image, |
| label="Click to use sample A-pose images", |
| ) |
| seed_input2 = gr.Number( |
| label="Seed", |
| value=50, |
| precision=0, |
| interactive=True |
| ) |
| state2 = gr.State(value="") |
| view_btn = gr.Button("Generate Multi-view Images") |
|
|
| with gr.Column(): |
| gr.Markdown("## 3. Semantic-aware Reconstruction") |
| multiview_gallery = gr.Gallery( |
| label="Multi-view results", |
| columns=2, |
| interactive=False, |
| height="None" |
| ) |
| state3 = gr.State(value="") |
| mesh_btn = gr.Button("Reconstruct") |
|
|
| with gr.Row(): |
| mesh_cols = [gr.Model3D(label=f"Mesh {i+1}", interactive=False, height=384) for i in range(3)] |
| full_mesh = gr.Model3D(label="Whole Mesh", height=384) |
| refine_btn = gr.Button("Refine") |
|
|
| gr.Markdown("## 4. Mesh refinement") |
| with gr.Row(): |
| refined_meshes = [gr.Model3D(label=f"refined mesh {i+1}", height=384) for i in range(3)] |
| refined_full_mesh = gr.Model3D(label="refined whole mesh", height=384) |
|
|
| gr.Markdown(_CITE_) |
|
|
| |
| pose_btn.click( |
| arbitrary_to_apose, |
| inputs=[input_image, seed_input], |
| outputs=a_pose_image |
| ) |
|
|
| view_btn.click( |
| apose_to_multiview, |
| inputs=[a_pose_image, seed_input2], |
| outputs=[multiview_gallery, state2] |
| ) |
|
|
| mesh_btn.click( |
| multiview_to_mesh, |
| inputs=[multiview_gallery, state2], |
| outputs=[*mesh_cols, full_mesh, state3] |
| ) |
|
|
| refine_btn.click( |
| refine_mesh, |
| inputs=[*mesh_cols, seed_input2, state3], |
| outputs=[refined_meshes[2], refined_meshes[0], refined_meshes[1], refined_full_mesh] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(ssr_mode=False) |