Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import os | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| from datetime import datetime | |
| import shutil | |
| from typing import * | |
| import torch | |
| import numpy as np | |
| import trimesh | |
| from PIL import Image | |
| from trellis2.pipelines import Trellis2TexturingPipeline | |
| MAX_SEED = np.iinfo(np.int32).max | |
| TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') | |
| def start_session(req: gr.Request): | |
| user_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
| os.makedirs(user_dir, exist_ok=True) | |
| def end_session(req: gr.Request): | |
| user_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
| shutil.rmtree(user_dir) | |
| def preprocess_image(image: Image.Image) -> Image.Image: | |
| """ | |
| Preprocess the input image. | |
| Args: | |
| image (Image.Image): The input image. | |
| Returns: | |
| Image.Image: The preprocessed image. | |
| """ | |
| processed_image = pipeline.preprocess_image(image) | |
| return processed_image | |
| def get_seed(randomize_seed: bool, seed: int) -> int: | |
| """ | |
| Get the random seed. | |
| """ | |
| return np.random.randint(0, MAX_SEED) if randomize_seed else seed | |
| def shapeimage_to_tex( | |
| mesh_file: str, | |
| image: Image.Image, | |
| seed: int, | |
| resolution: str, | |
| texture_size: int, | |
| tex_slat_guidance_strength: float, | |
| tex_slat_guidance_rescale: float, | |
| tex_slat_sampling_steps: int, | |
| tex_slat_rescale_t: float, | |
| req: gr.Request, | |
| progress=gr.Progress(track_tqdm=True), | |
| ) -> str: | |
| mesh = trimesh.load(mesh_file) | |
| if isinstance(mesh, trimesh.Scene): | |
| mesh = mesh.to_mesh() | |
| output = pipeline.run( | |
| mesh, | |
| image, | |
| seed=seed, | |
| preprocess_image=False, | |
| tex_slat_sampler_params={ | |
| "steps": tex_slat_sampling_steps, | |
| "guidance_strength": tex_slat_guidance_strength, | |
| "guidance_rescale": tex_slat_guidance_rescale, | |
| "rescale_t": tex_slat_rescale_t, | |
| }, | |
| resolution=int(resolution), | |
| texture_size=texture_size, | |
| ) | |
| now = datetime.now() | |
| timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}" | |
| user_dir = os.path.join(TMP_DIR, str(req.session_hash)) | |
| os.makedirs(user_dir, exist_ok=True) | |
| glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb') | |
| output.export(glb_path, extension_webp=True) | |
| torch.cuda.empty_cache() | |
| return glb_path, glb_path | |
| with gr.Blocks(delete_cache=(600, 600)) as demo: | |
| gr.Markdown(""" | |
| ## Texturing a mesh with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2) | |
| * Upload a mesh and corresponding reference image (preferably with an alpha-masked foreground object) and click Generate to create a textured 3D asset. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=360): | |
| mesh_file = gr.File(label="Upload Mesh", file_types=[".ply", ".obj", ".glb", ".gltf"], file_count="single") | |
| image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400) | |
| resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024") | |
| seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) | |
| randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
| texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024) | |
| generate_btn = gr.Button("Generate") | |
| with gr.Accordion(label="Advanced Settings", open=False): | |
| with gr.Row(): | |
| tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1) | |
| tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01) | |
| tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) | |
| tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) | |
| with gr.Column(scale=10): | |
| glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0)) | |
| download_btn = gr.DownloadButton(label="Download GLB") | |
| # Handlers | |
| demo.load(start_session) | |
| demo.unload(end_session) | |
| image_prompt.upload( | |
| preprocess_image, | |
| inputs=[image_prompt], | |
| outputs=[image_prompt], | |
| ) | |
| generate_btn.click( | |
| get_seed, | |
| inputs=[randomize_seed, seed], | |
| outputs=[seed], | |
| ).then( | |
| shapeimage_to_tex, | |
| inputs=[ | |
| mesh_file, image_prompt, seed, resolution, texture_size, | |
| tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t, | |
| ], | |
| outputs=[glb_output, download_btn], | |
| ) | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| os.makedirs(TMP_DIR, exist_ok=True) | |
| pipeline = Trellis2TexturingPipeline.from_pretrained('microsoft/TRELLIS.2-4B', config_file="texturing_pipeline.json") | |
| pipeline.cuda() | |
| demo.launch() | |