""" app.py – ReconViaGen v0.5 HuggingFace Space (ZeroGPU) ===================================================== Stage 1 (SS) – ReconViaGen VGGT-based sparse structure Stage 2 (Shape) – TRELLIS.2 shape_slat (DINOv3-conditioned) Stage 3 (Texture)– TRELLIS.2 tex_slat (DINOv3-conditioned, PBR) """ import sys, os # ── Path setup (must precede any trellis2 import) ──────────────────────────── _HERE = os.path.dirname(os.path.abspath(__file__)) _TRELLIS2 = os.path.join(_HERE, 'wheels', 'TRELLIS.2') if _TRELLIS2 not in sys.path: sys.path.insert(0, _TRELLIS2) # ── Environment variables (must be set BEFORE module imports) ──────────────── os.environ['SPCONV_ALGO'] = 'native' os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' os.environ['ATTN_BACKEND'] = 'flash_attn_3' os.environ['FLEX_GEMM_AUTOTUNE_CACHE_PATH'] = os.path.join(_HERE, 'autotune_cache.json') os.environ['FLEX_GEMM_AUTOTUNER_VERBOSE'] = '1' # ── Imports ─────────────────────────────────────────────────────────────────── import spaces import gradio as gr from datetime import datetime import shutil import cv2 import base64, io from typing import * import torch import numpy as np import imageio from PIL import Image import gc from trellis2.modules.sparse import SparseTensor from trellis2.utils import render_utils from trellis2.renderers import EnvMap import o_voxel from trellis.pipelines import TrellisVGGTTo3DPipeline from trellis2.pipelines import Trellis2ImageTo3DPipeline from trellis.pipelines.trellis_hybrid_pipeline import TrellisHybridPipeline # ── Constants ───────────────────────────────────────────────────────────────── MAX_SEED = np.iinfo(np.int32).max TMP_DIR = os.path.join(_HERE, 'tmp') os.makedirs(TMP_DIR, exist_ok=True) LOW_VRAM = False MODES = [ {"name": "Shaded", "render_key": "shaded"}, {"name": "Normal", "render_key": "normal"}, {"name": "Base color", "render_key": "base_color"}, {"name": "Metallic", "render_key": "metallic"}, {"name": "Roughness", "render_key": "roughness"}, ] STEPS = 8 DEFAULT_MODE = 0 DEFAULT_STEP = 3 # ── CSS / JS ────────────────────────────────────────────────────────────────── css = """ .badge-row { text-align: center !important; } .badge-row p { display: inline-flex !important; gap: 8px; justify-content: center; align-items: center; } .badge-row a { display: inline-block !important; } .badge-row img { display: inline-block !important; } .previewer-container { position: relative; width: 100%; height: 520px; display: flex; flex-direction: column; align-items: center; justify-content: center; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; } .previewer-container .display-row { width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; min-height: 360px; } .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; object-fit: contain; display: none; } .previewer-container .previewer-main-image.visible { display: block; } .previewer-container .mode-row { width: 100%; display: flex; gap: 10px; justify-content: center; margin-bottom: 10px; flex-wrap: wrap; } .previewer-container .mode-btn { padding: 4px 12px; border-radius: 14px; border: 2px solid #ddd; cursor: pointer; font-size: 13px; background: none; opacity: 0.55; transition: all 0.2s; } .previewer-container .mode-btn:hover { opacity: 0.9; } .previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); color: var(--color-accent); font-weight: 600; } .previewer-container .slider-row { width: 100%; display: flex; align-items: center; padding: 0 12px; } .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; background: transparent; } .previewer-container input[type=range]::-webkit-slider-runnable-track { height: 6px; background: #ddd; border-radius: 3px; } .previewer-container input[type=range]::-webkit-slider-thumb { height: 18px; width: 18px; border-radius: 50%; background: var(--color-accent); -webkit-appearance: none; margin-top: -6px; box-shadow: 0 2px 4px rgba(0,0,0,.2); } """ head = """ """ empty_html = """
""" # ── Helpers ─────────────────────────────────────────────────────────────────── def image_to_base64(image: Image.Image) -> str: buf = io.BytesIO() image.convert("RGB").save(buf, format="jpeg", quality=85) return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode() def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict: shape_slat, tex_slat, res = latents return { 'shape_slat_feats': shape_slat.feats.cpu().numpy(), 'tex_slat_feats': tex_slat.feats.cpu().numpy(), 'coords': shape_slat.coords.cpu().numpy(), 'res': res, } def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]: shape_slat = SparseTensor( feats=torch.from_numpy(state['shape_slat_feats']).cuda(), coords=torch.from_numpy(state['coords']).cuda(), ) tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda()) return shape_slat, tex_slat, state['res'] def get_seed(randomize_seed: bool, seed: int) -> int: return np.random.randint(0, MAX_SEED) if randomize_seed else seed # ── Session management ──────────────────────────────────────────────────────── def start_session(req: gr.Request): os.makedirs(os.path.join(TMP_DIR, str(req.session_hash)), exist_ok=True) def end_session(req: gr.Request): user_dir = os.path.join(TMP_DIR, str(req.session_hash)) if os.path.exists(user_dir): shutil.rmtree(user_dir) # ── Preprocessing ───────────────────────────────────────────────────────────── def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]: return [pipeline.preprocess_image(img[0]) for img in images] def preprocess_videos(video: str, num_per_second: float) -> List[Image.Image]: vid = imageio.get_reader(video, 'ffmpeg') fps = vid.get_meta_data()['fps'] frames = [] for i, frame in enumerate(vid): if i % max(int(fps/num_per_second), 1) == 0: img = Image.fromarray(frame) W, H = img.size img = img.resize((int(W / H * 512), 512)) frames.append(img) vid.close() return frames # ── 3D generation ───────────────────────────────────────────────────────────── @spaces.GPU(duration=120) def image_to_3d( image_gallery, multi_image_strategy: str, seed: int, pipeline_type: str, ss_source: str, # SS params ss_guidance_strength: float, ss_guidance_rescale: float, ss_sampling_steps: int, ss_rescale_t: float, # SLat params slat_guidance_strength: float, slat_guidance_rescale: float, slat_sampling_steps: int, slat_rescale_t: float, # Shape SLat params shape_slat_guidance_strength: float, shape_slat_guidance_rescale: float, shape_slat_sampling_steps: int, shape_slat_rescale_t: float, # Tex SLat params tex_slat_guidance_strength: float, tex_slat_guidance_rescale: float, tex_slat_sampling_steps: int, tex_slat_rescale_t: float, req: gr.Request, progress=gr.Progress(track_tqdm=True), ): # Collect images if not image_gallery: raise gr.Error("Please upload at least one image.") images = [] for item in image_gallery: img = item[0] if isinstance(item, (tuple, list)) else item if isinstance(img, str): img = Image.open(img) if img.mode != 'RGBA': img = img.convert('RGBA') images.append(img) ss_params = { "steps": ss_sampling_steps, "cfg_strength": ss_guidance_strength, "cfg_interval": [0.6, 1.0], "guidance_rescale": ss_guidance_rescale, "rescale_t": ss_rescale_t, } slat_params = { "steps": slat_sampling_steps, "cfg_strength": slat_guidance_strength, "cfg_interval": [0.6, 1.0], "guidance_rescale": slat_guidance_rescale, "rescale_t": slat_rescale_t, } shape_slat_params = { "steps": shape_slat_sampling_steps, "guidance_strength": shape_slat_guidance_strength, "guidance_rescale": shape_slat_guidance_rescale, "rescale_t": shape_slat_rescale_t, } tex_slat_params = { "steps": tex_slat_sampling_steps, "guidance_strength": tex_slat_guidance_strength, "guidance_rescale": tex_slat_guidance_rescale, "rescale_t": tex_slat_rescale_t, } if len(images) == 1: out_mesh_list, latents = pipeline.run( images, seed=seed, ss_sampler_params=ss_params, slat_sampler_params=slat_params, shape_slat_sampler_params=shape_slat_params, tex_slat_sampler_params=tex_slat_params, pipeline_type=pipeline_type, preprocess_image=True, return_latent=True, ss_source=ss_source, ) else: out_mesh_list, latents = pipeline.run_multi_image( images, strategy=multi_image_strategy, seed=seed, ss_sampler_params=ss_params, slat_sampler_params=slat_params, shape_slat_sampler_params=shape_slat_params, tex_slat_sampler_params=tex_slat_params, pipeline_type=pipeline_type, preprocess_image=True, return_latent=True, ss_source=ss_source, ) mesh = out_mesh_list[0] mesh.simplify(16777216) render_views = render_utils.render_snapshot( mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap ) state = pack_state(latents) torch.cuda.empty_cache() # ── Build previewer HTML ────────────────────────────────────────────────── images_html = "" for m_idx, mode in enumerate(MODES): for s_idx in range(STEPS): uid = f"view-m{m_idx}-s{s_idx}" is_vis = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP) vis_class = "visible" if is_vis else "" render_key = mode['render_key'] frames = render_views.get(render_key) if frames is None: continue img_b64 = image_to_base64(Image.fromarray(frames[s_idx])) images_html += f""" """ btns_html = "" for idx, mode in enumerate(MODES): active = "active" if idx == DEFAULT_MODE else "" btns_html += f""" """ full_html = f"""
{images_html}
{btns_html}
""" return state, full_html # ── GLB extraction ──────────────────────────────────────────────────────────── @spaces.GPU(duration=120) def extract_glb( state: dict, decimation_target: int, texture_size: int, req: gr.Request, progress=gr.Progress(track_tqdm=True), ) -> Tuple[str, str]: if state is None: raise gr.Error("Please generate a 3D model first.") user_dir = os.path.join(TMP_DIR, str(req.session_hash)) os.makedirs(user_dir, exist_ok=True) shape_slat, tex_slat, res = unpack_state(state) mesh = pipeline.trellis2_pipeline.decode_latent(shape_slat, tex_slat, res)[0] glb = o_voxel.postprocess.to_glb( vertices=mesh.vertices, faces=mesh.faces, attr_volume=mesh.attrs, coords=mesh.coords, attr_layout=pipeline.pbr_attr_layout, grid_size=res, aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], decimation_target=decimation_target, texture_size=texture_size, remesh=True, remesh_band=1, remesh_project=0, use_tqdm=True, ) now = datetime.now() timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}" glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb') glb.export(glb_path, extension_webp=True) torch.cuda.empty_cache() return glb_path, glb_path # ── Multi-view examples ─────────────────────────────────────────────────────── def prepare_multi_example() -> list: example_dir = os.path.join(_HERE, "assets", "example_multi_image") if not os.path.exists(example_dir): return [] cases = sorted(set(i.split('_')[0] for i in os.listdir(example_dir))) result = [] for case in cases: paths = [] for i in range(1, 9): p = os.path.join(example_dir, f'{case}_{i}.png') if os.path.exists(p): paths.append(p) if paths: result.append([paths]) return result # ── Gradio UI ───────────────────────────────────────────────────────────────── with gr.Blocks( title="ReconViaGen v0.5", delete_cache=(600, 600), ) as demo: gr.Markdown("

ReconViaGen-v0.5

") gr.Markdown( "[![GitHub stars](https://img.shields.io/github/stars/GAP-LAB-CUHK-SZ/ReconViaGen?label=GitHub%20%E2%98%85&logo=github&color=C8C)](https://github.com/GAP-LAB-CUHK-SZ/ReconViaGen/tree/v0.5) " "[![Project Page](https://www.obukhov.ai/img/badges/badge-website.svg)](https://jiahao620.github.io/reconviagen/) " "[![Paper](https://www.obukhov.ai/img/badges/badge-pdf.svg)](https://arxiv.org/abs/2510.23306)", elem_classes=["badge-row"], ) gr.Markdown(""" **Stage 1 - Sparse Structure**: ReconViaGen (VGGT multi-view aware) **Stage 2 - Shape SLat**: TRELLIS.2 (DINOv3-conditioned) **Stage 3 - Texture SLat**: TRELLIS.2 (DINOv3-conditioned, PBR output) > **Note:** For deployment and runtime efficiency on Hugging Face Spaces, the number of denoising steps has been reduced compared to the full pipeline. This may result in slightly lower visual quality. For best results, consider running locally with more steps. """) with gr.Row(): # ── Left panel ──────────────────────────────────────────────────────── with gr.Column(scale=1, min_width=380): input_video = gr.Video(label="Upload Video", interactive=True, height=220) image_prompt = gr.Gallery( label="Image Prompts (upload one or more views)", columns=3, rows=2, height=250, interactive=True, type="pil", file_types=["image"], ) with gr.Accordion("Pipeline Settings", open=False): multi_image_strategy = gr.Radio( choices=["average_right", "weighted_average", "sequential", "average", "adaptive_guidance_weight", "fixed_guidance_rescale"], value="adaptive_guidance_weight", label="Multi-image fusion strategy", ) pipeline_type = gr.Radio( choices=["512", "1024", "1024_cascade", "1536_cascade"], value="1024_cascade", label="Output Resolution", ) ss_source = gr.Radio( choices=["direct", "mesh", "mvtrellis2"], value="mesh", label="Stage 1 Coords Source", ) seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) randomize_seed = gr.Checkbox(label="Randomize Seed", value=False) decimation_target = gr.Slider(100000, 1000000, label="Decimation Target", value=500000, step=10000) texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024) num_per_second = gr.Slider(0.1, 30, label="Frames Per Second", value=1, step=0.1) generate_btn = gr.Button("Generate", variant="primary") with gr.Accordion("Advanced Settings", open=False): gr.Markdown("**Stage 1 - Sparse Structure (ReconViaGen)**") with gr.Row(): ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01) ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1) gr.Markdown("**Stage 2 - Structured Latent (ReconViaGen)**") with gr.Row(): slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01) slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1) slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) gr.Markdown("**Stage 3 - Shape SLat (TRELLIS.2)**") with gr.Row(): shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01) shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=8, step=1) shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) gr.Markdown("**Stage 4 - Texture SLat (TRELLIS.2)**") with gr.Row(): tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1) tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01) tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=8, step=1) tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1) # ── Right panel ─────────────────────────────────────────────────────── with gr.Column(scale=10): preview_output = gr.HTML(empty_html, label="3D Preview", show_label=True) extract_btn = gr.Button("Extract GLB") glb_output = gr.Model3D(label="Extracted GLB", height=480, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0)) download_btn = gr.DownloadButton(label="Download GLB", interactive=False) output_buf = gr.State() # Example row with gr.Row(): examples_multi = gr.Examples( examples=prepare_multi_example(), inputs=[image_prompt], examples_per_page=8, ) # ── Event handlers ──────────────────────────────────────────────────────── demo.load(start_session) demo.unload(end_session) input_video.upload(preprocess_videos, inputs=[input_video, num_per_second], outputs=[image_prompt]) input_video.clear(lambda: (None, None), outputs=[input_video, image_prompt]) generate_btn.click( get_seed, inputs=[randomize_seed, seed], outputs=[seed], ).then( image_to_3d, inputs=[ image_prompt, multi_image_strategy, seed, pipeline_type, ss_source, ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t, slat_guidance_strength, slat_guidance_rescale, slat_sampling_steps, slat_rescale_t, shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t, tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t, ], outputs=[output_buf, preview_output], ) extract_btn.click( extract_glb, inputs=[output_buf, decimation_target, texture_size], outputs=[glb_output, download_btn], ).then( lambda: gr.update(interactive=True), outputs=[download_btn] ) # ── Entry point ─────────────────────────────────────────────────────────────── if __name__ == "__main__": # Load ReconViaGen pipeline (SS stage) print("[1/2] Loading ReconViaGen pipeline (SS stage) ...") vggt_pipeline = TrellisVGGTTo3DPipeline.from_pretrained("Stable-X/trellis-vggt-v0-2") vggt_pipeline.cuda() vggt_pipeline.VGGT_model.cuda() vggt_pipeline.birefnet_model.cuda() del vggt_pipeline.models['slat_decoder_gs'] if LOW_VRAM: vggt_pipeline.VGGT_model.cpu() for model in vggt_pipeline.models.values(): model.cpu() gc.collect() torch.cuda.empty_cache() # Load TRELLIS.2 pipeline (shape/tex slat + decode) print("[2/2] Loading TRELLIS.2 pipeline (shape/tex slat) ...") trellis2_pipeline = Trellis2ImageTo3DPipeline.from_pretrained("microsoft/TRELLIS.2-4B") trellis2_pipeline.cuda() if LOW_VRAM: trellis2_pipeline.low_vram = True gc.collect() torch.cuda.empty_cache() # Combine into hybrid pipeline pipeline = TrellisHybridPipeline(vggt_pipeline, trellis2_pipeline, low_vram=LOW_VRAM) # Load HDRI environment maps for PBR rendering _HDRI_DIR = os.path.join(_HERE, 'assets', 'hdri') envmap = EnvMap(torch.tensor( cv2.cvtColor(cv2.imread(os.path.join(_HDRI_DIR, 'courtyard.exr'), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB), dtype=torch.float32, device='cuda' )) demo.launch(css=css, head=head)