import gradio as gr from gradio_client import Client, handle_file import spaces import os os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1' os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["ATTN_BACKEND"] = "flash_attn_3" os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json') os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1' from datetime import datetime import shutil import cv2 from typing import * import torch import numpy as np from PIL import Image import base64 import io import tempfile from trellis2.modules.sparse import SparseTensor from trellis2.pipelines import Trellis2ImageTo3DPipeline from trellis2.renderers import EnvMap from trellis2.utils import render_utils import o_voxel import trimesh MAX_SEED = np.iinfo(np.int32).max TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') MODES = [ {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"}, {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"}, {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"}, {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"}, {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"}, {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"}, ] STEPS = 8 DEFAULT_MODE = 3 DEFAULT_STEP = 3 css = """ /* Overwrite Gradio Default Style */ .stepper-wrapper { padding: 0; } .stepper-container { padding: 0; align-items: center; } .step-button { flex-direction: row; } .step-connector { transform: none; } .step-number { width: 16px; height: 16px; } .step-label { position: relative; bottom: 0; } .wrap.center.full { inset: 0; height: 100%; } .wrap.center.full.translucent { background: var(--block-background-fill); } .meta-text-center { display: block !important; position: absolute !important; top: unset !important; bottom: 0 !important; right: 0 !important; transform: unset !important; } /* Previewer */ .previewer-container { position: relative; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; width: 100%; height: 722px; margin: 0 auto; padding: 20px; display: flex; flex-direction: column; align-items: center; justify-content: center; } .previewer-container .tips-icon { position: absolute; right: 10px; top: 10px; z-index: 10; border-radius: 10px; color: #fff; background-color: var(--color-accent); padding: 3px 6px; user-select: none; } .previewer-container .tips-text { position: absolute; right: 10px; top: 50px; color: #fff; background-color: var(--color-accent); border-radius: 10px; padding: 6px; text-align: left; max-width: 300px; z-index: 10; transition: all 0.3s; opacity: 0%; user-select: none; } .previewer-container .tips-text p { font-size: 14px; line-height: 1.2; } .tips-icon:hover + .tips-text { display: block; opacity: 100%; } /* Row 1: Display Modes */ .previewer-container .mode-row { width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 20px; flex-wrap: wrap; } .previewer-container .mode-btn { width: 24px; height: 24px; border-radius: 50%; cursor: pointer; opacity: 0.5; transition: all 0.2s; border: 2px solid #ddd; object-fit: cover; } .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); } .previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); } /* Row 2: Display Image */ .previewer-container .display-row { margin-bottom: 20px; min-height: 400px; width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; } .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; } .previewer-container .previewer-main-image.visible { display: block; } /* Row 3: Custom HTML Slider */ .previewer-container .slider-row { width: 100%; display: flex; flex-direction: column; align-items: center; gap: 10px; padding: 0 10px; } .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; } .previewer-container input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 8px; cursor: pointer; background: #ddd; border-radius: 5px; } .previewer-container input[type=range]::-webkit-slider-thumb { height: 20px; width: 20px; border-radius: 50%; background: var(--color-accent); cursor: pointer; -webkit-appearance: none; margin-top: -6px; box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: transform 0.1s; } .previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); } /* Overwrite Previewer Block Style */ .gradio-container .padded:has(.previewer-container) { padding: 0 !important; } .gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; } /* Viewer toolbar */ .viewer-toolbar { display: flex; gap: 10px; align-items: center; padding: 8px 0; flex-wrap: wrap; } .viewer-toolbar label { font-size: 13px; color: var(--body-text-color); } """ head = """ """ empty_html = f"""
""" def image_to_base64(image): buffered = io.BytesIO() image = image.convert("RGB") image.save(buffered, format="jpeg", quality=85) img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/jpeg;base64,{img_str}" def start_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(user_dir, exist_ok=True) def end_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) shutil.rmtree(user_dir) def remove_background(input: Image.Image) -> Image.Image: with tempfile.NamedTemporaryFile(suffix='.png') as f: input = input.convert('RGB') input.save(f.name) output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0] output = Image.open(output) return output def preprocess_image(input: Image.Image) -> Image.Image: has_alpha = False if input.mode == 'RGBA': alpha = np.array(input)[:, :, 3] if not np.all(alpha == 255): has_alpha = True max_size = max(input.size) scale = min(1, 1024 / max_size) if scale < 1: input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) if has_alpha: output = input else: output = remove_background(input) output_np = np.array(output) alpha = output_np[:, :, 3] bbox = np.argwhere(alpha > 0.8 * 255) bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) size = int(size * 1) bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 output = output.crop(bbox) output = np.array(output).astype(np.float32) / 255 output = output[:, :, :3] * output[:, :, 3:4] output = Image.fromarray((output * 255).astype(np.uint8)) return output def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict: shape_slat, tex_slat, res = latents return { 'shape_slat_feats': shape_slat.feats.cpu().numpy(), 'tex_slat_feats': tex_slat.feats.cpu().numpy(), 'coords': shape_slat.coords.cpu().numpy(), 'res': res, } def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]: shape_slat = SparseTensor( feats=torch.from_numpy(state['shape_slat_feats']).cuda(), coords=torch.from_numpy(state['coords']).cuda(), ) tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda()) return shape_slat, tex_slat, state['res'] def get_seed(randomize_seed: bool, seed: int) -> int: return np.random.randint(0, MAX_SEED) if randomize_seed else seed @spaces.GPU(duration=120) def image_to_3d( image: Image.Image, seed: int, resolution: str, ss_guidance_strength: float, ss_guidance_rescale: float, ss_sampling_steps: int, ss_rescale_t: float, shape_slat_guidance_strength: float, shape_slat_guidance_rescale: float, shape_slat_sampling_steps: int, shape_slat_rescale_t: float, tex_slat_guidance_strength: float, tex_slat_guidance_rescale: float, tex_slat_sampling_steps: int, tex_slat_rescale_t: float, req: gr.Request, progress=gr.Progress(track_tqdm=True), ) -> str: outputs, latents = pipeline.run( image, seed=seed, preprocess_image=False, sparse_structure_sampler_params={ "steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength, "guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t, }, shape_slat_sampler_params={ "steps": shape_slat_sampling_steps, "guidance_strength": shape_slat_guidance_strength, "guidance_rescale": shape_slat_guidance_rescale, "rescale_t": shape_slat_rescale_t, }, tex_slat_sampler_params={ "steps": tex_slat_sampling_steps, "guidance_strength": tex_slat_guidance_strength, "guidance_rescale": tex_slat_guidance_rescale, "rescale_t": tex_slat_rescale_t, }, pipeline_type={ "512": "512", "1024": "1024_cascade", "1536": "1536_cascade", }[resolution], return_latent=True, ) mesh = outputs[0] mesh.simplify(16777216) # nvdiffrast limit β€” preview only images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap) state = pack_state(latents) torch.cuda.empty_cache() # --- HTML Construction --- images_html = "" for m_idx, mode in enumerate(MODES): for s_idx in range(STEPS): unique_id = f"view-m{m_idx}-s{s_idx}" is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP) vis_class = "visible" if is_visible else "" img_base64 = image_to_base64(Image.fromarray(images[mode['render_key']][s_idx])) images_html += f""" """ btns_html = "" for idx, mode in enumerate(MODES): active_class = "active" if idx == DEFAULT_MODE else "" btns_html += f""" """ full_html = f"""
πŸ’‘Tips

● Render Mode - Click on the circular buttons to switch between different render modes.

● View Angle - Drag the slider to change the view angle.

{images_html}
{btns_html}
""" return state, full_html @spaces.GPU(duration=120) def extract_glb( state: dict, decimation_target: int, texture_size: int, req: gr.Request, progress=gr.Progress(track_tqdm=True), ) -> Tuple[str, str, str, str]: """ Extract GLB (with texture) + Raw STL (4M faces, no decimation/texture) in one pass. """ user_dir = os.path.join(TMP_DIR, str(req.session_hash)) shape_slat, tex_slat, res = unpack_state(state) mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0] now = datetime.now() timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}" os.makedirs(user_dir, exist_ok=True) # --- Raw STL first (before simplify destroys detail) --- mesh.simplify(4000000) # 4M faces β€” preserves detail, manageable file size raw_stl_path = os.path.join(user_dir, f'raw_{timestamp}.stl') raw_mesh = trimesh.Trimesh( vertices=mesh.vertices.cpu().numpy(), faces=mesh.faces.cpu().numpy(), process=False, ) raw_mesh.export(raw_stl_path) # --- GLB with texture (standard extraction) --- glb = o_voxel.postprocess.to_glb( vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs, coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout, grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], decimation_target=decimation_target, texture_size=texture_size, remesh=True, remesh_band=1, remesh_project=0, use_tqdm=True, ) glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb') glb.export(glb_path, extension_webp=True) # --- Decimated STL from GLB --- stl_path = os.path.join(user_dir, f'sample_{timestamp}.stl') try: tm = trimesh.load(glb_path) if isinstance(tm, trimesh.Scene): geometries = list(tm.geometry.values()) if geometries: tm = trimesh.util.concatenate(geometries) else: tm = trimesh.Trimesh() tm.export(stl_path) except Exception: stl_mesh = trimesh.Trimesh( vertices=mesh.vertices.cpu().numpy(), faces=mesh.faces.cpu().numpy(), process=False, ) stl_mesh.export(stl_path) torch.cuda.empty_cache() return glb_path, glb_path, stl_path, raw_stl_path @spaces.GPU(duration=120) def extract_raw( state: dict, req: gr.Request, progress=gr.Progress(track_tqdm=True), ) -> str: """Extract raw mesh data as compressed .npz β€” much smaller than STL for transfer.""" user_dir = os.path.join(TMP_DIR, str(req.session_hash)) shape_slat, tex_slat, res = unpack_state(state) mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0] mesh.simplify(4000000) # 4M faces max now = datetime.now() timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}" os.makedirs(user_dir, exist_ok=True) # Save as compressed numpy β€” ~30-40MB vs ~200MB STL npz_path = os.path.join(user_dir, f'raw_{timestamp}.npz') verts = mesh.vertices.cpu().numpy().astype(np.float32) faces = mesh.faces.cpu().numpy().astype(np.int32) np.savez_compressed(npz_path, vertices=verts, faces=faces) torch.cuda.empty_cache() return npz_path def update_viewer_mode(display_mode: str) -> dict: return gr.Model3D(display_mode=display_mode) with gr.Blocks(delete_cache=(600, 600)) as demo: gr.Markdown(""" ## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2) * Upload an image (preferably with an alpha-masked foreground object) and click Generate to create a 3D asset. * Click Extract GLB to export and download the generated GLB file if you're satisfied with the result. Otherwise, try another time. """) with gr.Row(): with gr.Column(scale=1, min_width=360): image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400) resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024") seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) # Reverted to original Microsoft range/default β€” fits within 120s ZeroGPU limit decimation_target = gr.Slider(100000, 1000000, label="Decimation Target", value=500000, step=10000) texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024) generate_btn = gr.Button("Generate") with gr.Accordion(label="Advanced Settings", open=False): gr.Markdown("Stage 1: Sparse Structure Generation") with gr.Row(): ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01) ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1) gr.Markdown("Stage 2: Shape Generation") with gr.Row(): shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01) shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) gr.Markdown("Stage 3: Material Generation") with gr.Row(): tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1) tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01) tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) with gr.Column(scale=10): with gr.Walkthrough(selected=0) as walkthrough: with gr.Step("Preview", id=0): preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True) extract_btn = gr.Button("Extract GLB") with gr.Step("Extract", id=1): # Viewer display mode toggle with gr.Row(elem_classes="viewer-toolbar"): viewer_mode = gr.Radio( choices=["solid", "wireframe"], value="solid", label="Viewer Display", info="Toggle texture / wireframe view", ) glb_output = gr.Model3D( label="Extracted GLB", height=620, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0), ) with gr.Row(): download_btn = gr.DownloadButton(label="⬇️ Download GLB (with textures)") download_stl_btn = gr.DownloadButton(label="⬇️ Download STL (no texture)") extract_raw_btn = gr.Button("Extract Raw STL (No Decimation β€” Max Detail)") download_raw_btn = gr.DownloadButton(label="⬇️ Download Raw STL") gr.Markdown("*We are actively working on improving the speed of GLB extraction. Currently, it may take half a minute or more and face count is limited.*") with gr.Column(scale=1, min_width=172): examples = gr.Examples( examples=[ f'assets/example_image/{image}' for image in os.listdir("assets/example_image") ], inputs=[image_prompt], fn=preprocess_image, outputs=[image_prompt], run_on_click=True, examples_per_page=18, ) output_buf = gr.State() # Handlers demo.load(start_session) demo.unload(end_session) image_prompt.upload( preprocess_image, inputs=[image_prompt], outputs=[image_prompt], ) generate_btn.click( get_seed, inputs=[randomize_seed, seed], outputs=[seed], ).then( lambda: gr.Walkthrough(selected=0), outputs=walkthrough ).then( image_to_3d, inputs=[ image_prompt, seed, resolution, ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t, shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t, tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t, ], outputs=[output_buf, preview_output], ) extract_btn.click( lambda: gr.Walkthrough(selected=1), outputs=walkthrough ).then( extract_glb, inputs=[output_buf, decimation_target, texture_size], outputs=[glb_output, download_btn, download_stl_btn, download_raw_btn], ) extract_raw_btn.click( extract_raw, inputs=[output_buf], outputs=[download_raw_btn], ) viewer_mode.change( fn=update_viewer_mode, inputs=[viewer_mode], outputs=[glb_output], ) # Launch the Gradio app if __name__ == "__main__": os.makedirs(TMP_DIR, exist_ok=True) btn_img_base64_strs = {} for i in range(len(MODES)): icon = Image.open(MODES[i]['icon']) MODES[i]['icon_base64'] = image_to_base64(icon) rmbg_client = Client("briaai/BRIA-RMBG-2.0") pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B') pipeline.rembg_model = None pipeline.low_vram = False pipeline.cuda() envmap = { 'forest': EnvMap(torch.tensor( cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda' )), 'sunset': EnvMap(torch.tensor( cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda' )), 'courtyard': EnvMap(torch.tensor( cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda' )), } demo.launch(css=css, head=head)