import warnings warnings.filterwarnings("ignore", message=".*torch.distributed.reduce_op.*") warnings.filterwarnings("ignore", message=".*torch.cuda.amp.autocast.*") warnings.filterwarnings("ignore", message=".*Default grid_sample and affine_grid behavior.*") import gradio as gr from gradio_client import Client, handle_file import spaces from concurrent.futures import ThreadPoolExecutor import os os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1' os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["ATTN_BACKEND"] = "flash_attn_3" os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json') os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1' from datetime import datetime import shutil import cv2 from typing import * import torch import numpy as np from PIL import Image import base64 import io import tempfile from trellis2.modules.sparse import SparseTensor from trellis2.pipelines import Trellis2ImageTo3DPipeline from trellis2.renderers import EnvMap from trellis2.utils import render_utils import o_voxel # Patch postprocess module with local fix for cumesh.fill_holes() bug import importlib.util import sys _local_postprocess = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'o-voxel', 'o_voxel', 'postprocess.py') if os.path.exists(_local_postprocess): _spec = importlib.util.spec_from_file_location('o_voxel.postprocess', _local_postprocess) _mod = importlib.util.module_from_spec(_spec) _spec.loader.exec_module(_mod) o_voxel.postprocess = _mod sys.modules['o_voxel.postprocess'] = _mod MAX_SEED = np.iinfo(np.int32).max TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') MODES = [ {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"}, {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"}, {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"}, {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"}, {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"}, {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"}, ] STEPS = 8 DEFAULT_MODE = 3 DEFAULT_STEP = 3 css = """ /* Overwrite Gradio Default Style */ .stepper-wrapper { padding: 0; } .stepper-container { padding: 0; align-items: center; } .step-button { flex-direction: row; } .step-connector { transform: none; } .step-number { width: 16px; height: 16px; } .step-label { position: relative; bottom: 0; } .wrap.center.full { inset: 0; height: 100%; } .wrap.center.full.translucent { background: var(--block-background-fill); } .meta-text-center { display: block !important; position: absolute !important; top: unset !important; bottom: 0 !important; right: 0 !important; transform: unset !important; } /* Previewer */ .previewer-container { position: relative; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; width: 100%; height: 722px; margin: 0 auto; padding: 20px; display: flex; flex-direction: column; align-items: center; justify-content: center; } .previewer-container .tips-icon { position: absolute; right: 10px; top: 10px; z-index: 10; border-radius: 10px; color: #fff; background-color: var(--color-accent); padding: 3px 6px; user-select: none; } .previewer-container .tips-text { position: absolute; right: 10px; top: 50px; color: #fff; background-color: var(--color-accent); border-radius: 10px; padding: 6px; text-align: left; max-width: 300px; z-index: 10; transition: all 0.3s; opacity: 0%; user-select: none; } .previewer-container .tips-text p { font-size: 14px; line-height: 1.2; } .tips-icon:hover + .tips-text { display: block; opacity: 100%; } /* Row 1: Display Modes */ .previewer-container .mode-row { width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 20px; flex-wrap: wrap; } .previewer-container .mode-btn { width: 24px; height: 24px; border-radius: 50%; cursor: pointer; opacity: 0.5; transition: all 0.2s; border: 2px solid var(--neutral-600, #555); object-fit: cover; } .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); } .previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); } /* Row 2: Display Image */ .previewer-container .display-row { margin-bottom: 20px; min-height: 400px; width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; } .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; } .previewer-container .previewer-main-image.visible { display: block; } /* Row 3: Custom HTML Slider */ .previewer-container .slider-row { width: 100%; display: flex; flex-direction: column; align-items: center; gap: 10px; padding: 0 10px; } .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; } .previewer-container input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 8px; cursor: pointer; background: var(--neutral-700, #404040); border-radius: 5px; } .previewer-container input[type=range]::-webkit-slider-thumb { height: 20px; width: 20px; border-radius: 50%; background: var(--color-accent); cursor: pointer; -webkit-appearance: none; margin-top: -6px; box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: transform 0.1s; } .previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); } /* Overwrite Previewer Block Style */ .gradio-container .padded:has(.previewer-container) { padding: 0 !important; } .gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; } """ head = """ """ empty_html = f"""
""" def image_to_base64(image): buffered = io.BytesIO() image = image.convert("RGB") image.save(buffered, format="jpeg", quality=85) img_str = base64.b64encode(buffered.getvalue()).decode() return f"data:image/jpeg;base64,{img_str}" def start_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(user_dir, exist_ok=True) def end_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) if os.path.exists(user_dir): shutil.rmtree(user_dir) def remove_background(input: Image.Image) -> Image.Image: with tempfile.NamedTemporaryFile(suffix='.png') as f: input = input.convert('RGB') input.save(f.name) output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0] output = Image.open(output) return output def preprocess_image(input: Image.Image) -> Image.Image: """ Preprocess the input image. """ # if has alpha channel, use it directly; otherwise, remove background has_alpha = False if input.mode == 'RGBA': alpha = np.array(input)[:, :, 3] if not np.all(alpha == 255): has_alpha = True max_size = max(input.size) scale = min(1, 1024 / max_size) if scale < 1: input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) if has_alpha: output = input else: output = remove_background(input) output_np = np.array(output) alpha = output_np[:, :, 3] bbox = np.argwhere(alpha > 0.8 * 255) if bbox.size == 0: # No visible pixels, center the image in a square size = max(output.size) square = Image.new('RGB', (size, size), (0, 0, 0)) output_rgb = output.convert('RGB') if output.mode == 'RGBA' else output square.paste(output_rgb, ((size - output.width) // 2, (size - output.height) // 2)) return square bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) size = int(size * 1) bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 output = output.crop(bbox) # type: ignore output = np.array(output).astype(np.float32) / 255 output = output[:, :, :3] * output[:, :, 3:4] output = Image.fromarray((output * 255).astype(np.uint8)) return output def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]: """ Preprocess a list of input images for multi-image conditioning. Uses parallel processing for faster background removal. """ images = [image[0] for image in images] with ThreadPoolExecutor(max_workers=min(4, len(images))) as executor: processed_images = list(executor.map(preprocess_image, images)) return processed_images def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict: shape_slat, tex_slat, res = latents return { 'shape_slat_feats': shape_slat.feats.cpu().numpy(), 'tex_slat_feats': tex_slat.feats.cpu().numpy(), 'coords': shape_slat.coords.cpu().numpy(), 'res': res, } def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]: shape_slat = SparseTensor( feats=torch.from_numpy(state['shape_slat_feats']).cuda(), coords=torch.from_numpy(state['coords']).cuda(), ) tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda()) return shape_slat, tex_slat, state['res'] def get_seed(randomize_seed: bool, seed: int) -> int: """ Get the random seed. """ return np.random.randint(0, MAX_SEED) if randomize_seed else seed def prepare_multi_example() -> List[str]: """ Prepare multi-image examples. Returns list of image paths. Shows only the first view as representative thumbnail. """ multi_case = sorted(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")])) examples = [] for case in multi_case: first_img = f'assets/example_multi_image/{case}_1.png' if os.path.exists(first_img): examples.append(first_img) return examples def load_multi_example(image) -> List[Image.Image]: """Load all views for a multi-image case by matching the input image.""" if image is None: return [] # Convert to PIL Image if needed if isinstance(image, np.ndarray): image = Image.fromarray(image) # Convert to RGB for consistent comparison input_rgb = np.array(image.convert('RGB')) # Find matching case by comparing with first images example_dir = "assets/example_multi_image" case_names = sorted(set([f.rsplit('_', 1)[0] for f in os.listdir(example_dir) if f.endswith('.png')])) for case_name in case_names: first_img_path = f'{example_dir}/{case_name}_1.png' if os.path.exists(first_img_path): first_img = Image.open(first_img_path).convert('RGB') first_rgb = np.array(first_img) # Compare images (check if same shape and content) if input_rgb.shape == first_rgb.shape and np.array_equal(input_rgb, first_rgb): # Found match, load all views (without preprocessing - will be done on Generate) images = [] for i in range(1, 7): img_path = f'{example_dir}/{case_name}_{i}.png' if os.path.exists(img_path): img = Image.open(img_path).convert('RGBA') images.append(img) if images: return images # No match found, return the single image return [image.convert('RGBA') if image.mode != 'RGBA' else image] def split_image(image: Image.Image) -> List[Image.Image]: """ Split a concatenated image into multiple views. """ image = np.array(image) alpha = image[..., 3] alpha = np.any(alpha > 0, axis=0) start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist() end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist() images = [] for s, e in zip(start_pos, end_pos): images.append(Image.fromarray(image[:, s:e+1])) return [preprocess_image(image) for image in images] @spaces.GPU(duration=120) def image_to_3d( seed: int, resolution: str, ss_guidance_strength: float, ss_guidance_rescale: float, ss_sampling_steps: int, ss_rescale_t: float, shape_slat_guidance_strength: float, shape_slat_guidance_rescale: float, shape_slat_sampling_steps: int, shape_slat_rescale_t: float, tex_slat_guidance_strength: float, tex_slat_guidance_rescale: float, tex_slat_sampling_steps: int, tex_slat_rescale_t: float, multiimages: List[Tuple[Image.Image, str]], multiimage_algo: Literal["multidiffusion", "stochastic"], tex_multiimage_algo: Literal["multidiffusion", "stochastic"], req: gr.Request, progress=gr.Progress(track_tqdm=True), ) -> str: if not multiimages: raise gr.Error("Please upload images or select an example first.") # Preprocess images (background removal, cropping, etc.) images = [image[0] for image in multiimages] processed_images = [preprocess_image(img) for img in images] # Debug: save preprocessed images and log stats for i, img in enumerate(processed_images): arr = np.array(img) print(f"[DEBUG] Preprocessed image {i}: mode={img.mode}, size={img.size}, " f"dtype={arr.dtype}, min={arr.min()}, max={arr.max()}, mean={arr.mean():.1f}") img.save(os.path.join(TMP_DIR, f'debug_preprocessed_{i}.png')) print(f"[DEBUG] Pipeline params: mode={multiimage_algo}, tex_mode={tex_multiimage_algo}") # --- Sampling --- outputs, latents = pipeline.run_multi_image( processed_images, seed=seed, preprocess_image=False, sparse_structure_sampler_params={ "steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength, "guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t, }, shape_slat_sampler_params={ "steps": shape_slat_sampling_steps, "guidance_strength": shape_slat_guidance_strength, "guidance_rescale": shape_slat_guidance_rescale, "rescale_t": shape_slat_rescale_t, }, tex_slat_sampler_params={ "steps": tex_slat_sampling_steps, "guidance_strength": tex_slat_guidance_strength, "guidance_rescale": tex_slat_guidance_rescale, "rescale_t": tex_slat_rescale_t, }, pipeline_type={ "512": "512", "1024": "1024_cascade", "1536": "1536_cascade", }[resolution], return_latent=True, mode=multiimage_algo, tex_mode=tex_multiimage_algo, ) mesh = outputs[0] mesh.simplify(16777216) # nvdiffrast limit images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap) # Debug: save base_color render and log stats for all render modes for key in images: arr = images[key][0] # first view print(f"[DEBUG] Render '{key}': shape={arr.shape}, min={arr.min()}, max={arr.max()}, mean={arr.mean():.1f}") # Save base_color and shaded_forest for inspection Image.fromarray(images['base_color'][0]).save(os.path.join(TMP_DIR, 'debug_base_color.png')) Image.fromarray(images['shaded_forest'][0]).save(os.path.join(TMP_DIR, 'debug_shaded_forest.png')) state = pack_state(latents) torch.cuda.empty_cache() # --- HTML Construction --- # The Stack of 48 Images - encode in parallel for speed def encode_preview_image(args): m_idx, s_idx, render_key = args img_base64 = image_to_base64(Image.fromarray(images[render_key][s_idx])) return (m_idx, s_idx, img_base64) encode_tasks = [ (m_idx, s_idx, mode['render_key']) for m_idx, mode in enumerate(MODES) for s_idx in range(STEPS) ] with ThreadPoolExecutor(max_workers=8) as executor: encoded_results = list(executor.map(encode_preview_image, encode_tasks)) # Build HTML from encoded results encoded_map = {(m, s): b64 for m, s, b64 in encoded_results} images_html = "" for m_idx, mode in enumerate(MODES): for s_idx in range(STEPS): unique_id = f"view-m{m_idx}-s{s_idx}" is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP) vis_class = "visible" if is_visible else "" img_base64 = encoded_map[(m_idx, s_idx)] images_html += f""" """ # Button Row HTML btns_html = "" for idx, mode in enumerate(MODES): active_class = "active" if idx == DEFAULT_MODE else "" # Note: onclick calls the JS function defined in Head btns_html += f""" """ # Assemble the full component full_html = f"""
💡Tips

Render Mode - Click on the circular buttons to switch between different render modes.

View Angle - Drag the slider to change the view angle.

{images_html}
{btns_html}
""" return state, full_html @spaces.GPU(duration=120) def extract_glb( state: dict, decimation_target: int, texture_size: int, req: gr.Request, progress=gr.Progress(track_tqdm=True), ) -> Tuple[str, str]: """ Extract a GLB file from the 3D model. Args: state (dict): The state of the generated 3D model. decimation_target (int): The target face count for decimation. texture_size (int): The texture resolution. Returns: Tuple[str, str]: The path to the extracted GLB file (for Model3D and DownloadButton). """ user_dir = os.path.join(TMP_DIR, str(req.session_hash)) shape_slat, tex_slat, res = unpack_state(state) mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0] mesh.simplify(16777216) # nvdiffrast limit glb = o_voxel.postprocess.to_glb( vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs, coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout, grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], decimation_target=decimation_target, texture_size=texture_size, remesh=True, remesh_band=1, remesh_project=0, use_tqdm=True, ) now = datetime.now() timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}" os.makedirs(user_dir, exist_ok=True) glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb') glb.export(glb_path, extension_webp=True) torch.cuda.empty_cache() return glb_path, glb_path with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate")) as demo: gr.HTML("""
OpsiClear

Multi-View to 3D with TRELLIS.2

""") with gr.Row(): with gr.Column(scale=1, min_width=360): multiimage_prompt = gr.Gallery(label="Multi-View Images", format="png", type="pil", height=400, columns=3) resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024") seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000) texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024) with gr.Accordion(label="Advanced Settings", open=False): gr.Markdown("Stage 1: Sparse Structure Generation") with gr.Row(): ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01) ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1) gr.Markdown("Stage 2: Shape Generation") with gr.Row(): shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01) shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) gr.Markdown("Stage 3: Material Generation") with gr.Row(): tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1) tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01) tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Structure Algorithm", value="stochastic") tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="stochastic") with gr.Column(scale=10): with gr.Row(): generate_btn = gr.Button("Generate", variant="primary") extract_btn = gr.Button("Extract GLB") preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True) glb_output = gr.Model3D(label="Extracted GLB", height=400, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0)) download_btn = gr.DownloadButton(label="Download GLB") example_image = gr.Image(visible=False) # Hidden component for examples examples_multi = gr.Examples( examples=prepare_multi_example(), inputs=[example_image], fn=load_multi_example, outputs=[multiimage_prompt], run_on_click=True, cache_examples=False, examples_per_page=50, ) output_buf = gr.State() # Handlers demo.load(start_session) demo.unload(end_session) multiimage_prompt.upload( preprocess_images, inputs=[multiimage_prompt], outputs=[multiimage_prompt], ) generate_btn.click( get_seed, inputs=[randomize_seed, seed], outputs=[seed], ).then( image_to_3d, inputs=[ seed, resolution, ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t, shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t, tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t, multiimage_prompt, multiimage_algo, tex_multiimage_algo ], outputs=[output_buf, preview_output], ) extract_btn.click( extract_glb, inputs=[output_buf, decimation_target, texture_size], outputs=[glb_output, download_btn], ) # Launch the Gradio app if __name__ == "__main__": os.makedirs(TMP_DIR, exist_ok=True) # Construct ui components btn_img_base64_strs = {} for i in range(len(MODES)): icon = Image.open(MODES[i]['icon']) MODES[i]['icon_base64'] = image_to_base64(icon) rmbg_client = Client("briaai/BRIA-RMBG-2.0") pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B') pipeline.rembg_model = None pipeline.low_vram = False pipeline.cuda() envmap = { 'forest': EnvMap(torch.tensor( cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda' )), 'sunset': EnvMap(torch.tensor( cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda' )), 'courtyard': EnvMap(torch.tensor( cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda' )), } demo.launch(css=css, head=head)