import gradio as gr from gradio_client import Client, handle_file import os # Conditional spaces import for RunPod/Local support # If on RunPod, WE MUST MOCK SPACES to avoid ZeroGPU quota/crashes if os.environ.get('SPACE_ID') and not os.environ.get('RUNPOD_POD_ID'): import spaces else: # Mock spaces if not on HF (RunPod/Local) class spaces: @staticmethod def GPU(duration=120): def decorator(func): return func return decorator import os os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1' os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" os.environ["ATTN_BACKEND"] = "flash_attn_3" os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json') os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1' from datetime import datetime import shutil from typing import * import torch import numpy as np from PIL import Image import tempfile import imageio import cv2 from trellis2.modules.sparse import SparseTensor from trellis2.pipelines import Trellis2ImageTo3DPipeline from trellis2.renderers import EnvMap from trellis2.utils import render_utils import o_voxel MAX_SEED = np.iinfo(np.int32).max TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') def start_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(user_dir, exist_ok=True) def end_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) if os.path.exists(user_dir): shutil.rmtree(user_dir) # BiRefNet Global Instance (Lazy Loaded) from utils_birefnet import BiRefNet birefnet = None @spaces.GPU(duration=20) def remove_background(input: Image.Image) -> Image.Image: global birefnet if birefnet is None: print("[LAZY LOAD] Initializing BiRefNet...", flush=True) birefnet = BiRefNet() birefnet.cuda() # Use local BiRefNet return birefnet(input) def preprocess_image(input: Image.Image) -> Image.Image: """ Preprocess the input image. """ # if has alpha channel, use it directly; otherwise, remove background has_alpha = False if input.mode == 'RGBA': alpha = np.array(input)[:, :, 3] if not np.all(alpha == 255): has_alpha = True max_size = max(input.size) scale = min(1, 1024 / max_size) if scale < 1: input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) if has_alpha: output = input else: output = remove_background(input) output_np = np.array(output) alpha = output_np[:, :, 3] bbox = np.argwhere(alpha > 0.8 * 255) bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) size = int(size * 1) bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 output = output.crop(bbox) # type: ignore output = np.array(output).astype(np.float32) / 255 output = output[:, :, :3] * output[:, :, 3:4] output = Image.fromarray((output * 255).astype(np.uint8)) return output def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict: shape_slat, tex_slat, res = latents return { 'shape_slat_feats': shape_slat.feats.cpu().numpy(), 'tex_slat_feats': tex_slat.feats.cpu().numpy(), 'coords': shape_slat.coords.cpu().numpy(), 'res': res, } def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]: shape_slat = SparseTensor( feats=torch.from_numpy(state['shape_slat_feats']).cuda(), coords=torch.from_numpy(state['coords']).cuda(), ) tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda()) return shape_slat, tex_slat, state['res'] def get_seed(randomize_seed: bool, seed: int) -> int: """ Get the random seed. """ return np.random.randint(0, MAX_SEED) if randomize_seed else seed import threading GPU_LOCK = threading.Lock() @spaces.GPU(duration=120) def image_to_3d( image: Image.Image, seed: int, resolution: str, ss_guidance_strength: float, ss_guidance_rescale: float, ss_sampling_steps: int, ss_rescale_t: float, shape_slat_guidance_strength: float, shape_slat_guidance_rescale: float, shape_slat_sampling_steps: int, shape_slat_rescale_t: float, tex_slat_guidance_strength: float, tex_slat_guidance_rescale: float, tex_slat_sampling_steps: int, tex_slat_rescale_t: float, decimation_target: int, texture_size: int, req: gr.Request, progress=gr.Progress(track_tqdm=True), ) -> Tuple[dict, str]: """ Convert an image to a 3D model. Returns: dict: The state of the generated 3D model. str: The path to the video of the 3D model. """ # Start timing and cleanup torch.cuda.empty_cache() user_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(user_dir, exist_ok=True) # Use Global Lock to prevent concurrent runs on RunPod with GPU_LOCK: try: # --- Sampling --- outputs, latents = pipeline.run( image, seed=seed, preprocess_image=False, sparse_structure_sampler_params={ "steps": ss_sampling_steps, "guidance_strength": ss_guidance_strength, "guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t, }, shape_slat_sampler_params={ "steps": shape_slat_sampling_steps, "guidance_strength": shape_slat_guidance_strength, "guidance_rescale": shape_slat_guidance_rescale, "rescale_t": shape_slat_rescale_t, }, tex_slat_sampler_params={ "steps": tex_slat_sampling_steps, "guidance_strength": tex_slat_guidance_strength, "guidance_rescale": tex_slat_guidance_rescale, "rescale_t": tex_slat_rescale_t, }, pipeline_type={ "512": "512", "1024": "1024_cascade", "1536": "1536_cascade", }[resolution], return_latent=True, ) mesh = outputs[0] mesh.simplify(16777216) # nvdiffrast limit # --- Render Video Preview (8 seconds at 15fps = 120 frames) --- render_result = render_utils.render_video(mesh, num_frames=120, resolution=720, envmap=envmap, verbose=False) video = render_result.get('shaded_forest', render_result.get('shaded')) download_video = os.path.join(user_dir, 'sample.mp4') imageio.mimsave(download_video, video, fps=15, pixelformat='yuv420p') state = pack_state(latents) return state, download_video except Exception as e: print(f"Error during generation: {e}") raise e finally: # Cleanup latent tensors and mesh to free VRAM for mesh processing # Use try-del pattern to avoid NameError if crash happens before variable creation for var in ['latents', 'mesh', 'outputs', 'render_result', 'video']: if var in locals(): del locals()[var] import gc gc.collect() torch.cuda.empty_cache() @spaces.GPU(duration=120) def extract_glb( state: dict, decimation_target: int, texture_size: int, req: gr.Request, progress=gr.Progress(track_tqdm=True), ) -> Tuple[str, str]: """ Extract a GLB file from the 3D model. Args: state (dict): The state of the generated 3D model. decimation_target (int): The target face count for decimation. texture_size (int): The texture resolution. Returns: str: The path to the extracted GLB file. """ user_dir = os.path.join(TMP_DIR, str(req.session_hash)) with GPU_LOCK: try: shape_slat, tex_slat, res = unpack_state(state) mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0] # Cleanup latent tensors immediately after decode if possible, but finally block handles it too # del shape_slat, tex_slat # torch.cuda.empty_cache() mesh.simplify(16777216) glb = o_voxel.postprocess.to_glb( vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs, coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout, grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], decimation_target=decimation_target, texture_size=texture_size, remesh=True, remesh_band=1, remesh_project=0, use_tqdm=True, ) now = datetime.now() timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}" os.makedirs(user_dir, exist_ok=True) glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb') glb.export(glb_path, extension_webp=True) return glb_path, glb_path except Exception as e: print(f"Error during extraction: {e}") raise e finally: # Cleanup for var in ['shape_slat', 'tex_slat', 'mesh', 'glb']: if var in locals(): del locals()[var] import gc gc.collect() torch.cuda.empty_cache() with gr.Blocks(delete_cache=(600, 600)) as demo: gr.Markdown(""" ## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2) * Upload an image (preferably with an alpha-masked foreground object) and click Generate to create a 3D asset. * Click Extract GLB to export and download the generated GLB file if you're satisfied with the result. Otherwise, try another time. """) with gr.Row(): with gr.Column(): image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=300) resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024") seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) generate_btn = gr.Button("Generate") with gr.Accordion(label="Advanced Settings", open=False): gr.Markdown("Stage 1: Sparse Structure Generation") with gr.Row(): ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01) ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1) gr.Markdown("Stage 2: Shape Generation") with gr.Row(): shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01) shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) gr.Markdown("Stage 3: Material Generation") with gr.Row(): tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=8.5, step=0.1) tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01) tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) with gr.Accordion(label="GLB Extraction Settings", open=False): decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000) texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024) extract_glb_btn = gr.Button("Extract GLB", interactive=False) with gr.Column(): video_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) model_output = gr.Model3D(label="Extracted GLB", height=300, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0)) download_glb = gr.DownloadButton(label="Download GLB", interactive=False) with gr.Row(): examples = gr.Examples( examples=[ f'assets/example_image/{image}' for image in os.listdir("assets/example_image") ], inputs=[image_prompt], fn=preprocess_image, outputs=[image_prompt], run_on_click=True, examples_per_page=18, ) output_buf = gr.State() # Handlers demo.load(start_session) demo.unload(end_session) image_prompt.upload( preprocess_image, inputs=[image_prompt], outputs=[image_prompt], ) generate_btn.click( get_seed, inputs=[randomize_seed, seed], outputs=[seed], ).then( image_to_3d, inputs=[ image_prompt, seed, resolution, ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t, shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t, tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t, decimation_target, texture_size, ], outputs=[output_buf, video_output], ).then( lambda: gr.update(interactive=True), outputs=[extract_glb_btn], ) video_output.clear( lambda: gr.update(interactive=False), outputs=[extract_glb_btn], ) extract_glb_btn.click( extract_glb, inputs=[output_buf, decimation_target, texture_size], outputs=[model_output, download_glb], ).then( lambda: gr.update(interactive=True), outputs=[download_glb], ) model_output.clear( lambda: gr.update(interactive=False), outputs=[download_glb], ) # Launch the Gradio app if __name__ == "__main__": os.makedirs(TMP_DIR, exist_ok=True) pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B') pipeline.rembg_model = None pipeline.low_vram = True pipeline.cuda() # Load HDRI envmap for beautiful lighting # Load HDRI envmap for beautiful lighting envmap = { 'forest': EnvMap(torch.tensor( cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda' )), 'sunset': EnvMap(torch.tensor( cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda' )), 'courtyard': EnvMap(torch.tensor( cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda' )), } demo.launch()