Extend3D / app.py
Seungwoo-Yoon
zero gpu duration
70bf12b
import spaces
from extend3d import Extend3D
from trellis.utils import render_utils, postprocessing_utils
import imageio
import random
import uuid
from pathlib import Path
import numpy as np
import torch
import gradio as gr
MODEL_ID = "microsoft/TRELLIS-image-large"
DEFAULT_OUTPUT_DIR = "./output"
# ---------------------------------------------------------------------------
# Pipeline loading
# ---------------------------------------------------------------------------
PIPELINE: Extend3D = Extend3D.from_pretrained(MODEL_ID).cuda()
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
@spaces.GPU(duration=300)
def run_extend3d(
image_pil,
seed: int,
randomize_seed: bool,
width: int,
length: int,
div: int,
ss_optim: bool,
ss_iterations: int,
ss_steps: int,
ss_rescale_t: float,
ss_t_noise: float,
ss_t_start: float,
ss_cfg_strength: float,
ss_alpha: float,
ss_batch_size: int,
slat_optim: bool,
slat_steps: int,
slat_rescale_t: float,
slat_cfg_strength: float,
slat_batch_size: int,
progress=gr.Progress(),
):
if randomize_seed:
seed = random.randint(0, 2147483647)
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
pipe = PIPELINE
output = pipe.run(
image_pil,
width, length, div,
ss_optim, ss_iterations, ss_steps,
ss_rescale_t, ss_t_noise, ss_t_start,
ss_cfg_strength, ss_alpha, ss_batch_size,
slat_optim, slat_steps, slat_rescale_t,
slat_cfg_strength, slat_batch_size,
progress_callback=lambda frac, desc: progress(frac, desc=desc),
)
gaussian = output["gaussian"][0]
mesh = output["mesh"][0]
out_dir = Path(DEFAULT_OUTPUT_DIR)
out_dir.mkdir(parents=True, exist_ok=True)
run_id = uuid.uuid4().hex
# Render preview video
progress(0, desc="Rendering video...")
color_frames = render_utils.render_video(gaussian, r=1.6, resolution=1024)["color"]
progress(0.5, desc="Rendering video...")
normal_frames = render_utils.render_video(mesh, r=1.6, resolution=1024)["normal"]
progress(1.0, desc="Rendering video...")
video_frames = [
np.concatenate([c, n], axis=1)
for c, n in zip(color_frames, normal_frames)
]
video_path = str(out_dir / f"preview_{run_id}.mp4")
imageio.mimsave(video_path, video_frames, fps=30)
# Export GLB mesh
progress(0, desc="Exporting GLB...")
glb = postprocessing_utils.to_glb(gaussian, mesh, simplify=0.98, texture_size=1024)
glb.visual.material.metallicFactor = 0.0
glb_path = str(out_dir / f"preview_{run_id}.glb")
glb.export(glb_path)
progress(1.0, desc="Done!")
return video_path, glb_path, seed
# ---------------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------------
css = """
#examples_gallery .gallery-item {
width: 160px !important;
height: 160px !important;
min-width: 160px !important;
}
#examples_gallery img {
width: 100% !important;
height: 100% !important;
object-fit: cover;
}
#examples_gallery .gallery {
width: 100% !important;
height: 100% !important;
object-fit: cover;
max-width: none !important;
justify-content: center;
}
"""
with gr.Blocks(title="Extend3D Demo", css=css) as demo:
gr.Markdown("# Extend3D: Town-scale 3D Generation")
gr.Markdown("[Project Page](https://seungwoo-yoon.github.io/extend3d-page/) | [Code](https://github.com/Seungwoo-Yoon/Extend3D) | [Paper](#)")
with gr.Row():
# Left column: inputs and settings
with gr.Column(scale=4, min_width=420):
gr.Markdown("### Input")
image_in = gr.Image(label="Input Image", type="pil")
run_btn = gr.Button("Run", variant="primary")
with gr.Accordion("Settings", open=True):
seed = gr.Slider(0, 2147483647, value=42, step=1, label="seed")
randomize_seed = gr.Checkbox(value=True, label="randomize_seed")
with gr.Row():
width = gr.Slider(1, 8, value=2, step=1, label="width")
length = gr.Slider(1, 8, value=2, step=1, label="length")
div = gr.Slider(1, 8, value=4, step=1, label="div")
with gr.Accordion("Sparse Structure Settings", open=False):
ss_optim = gr.Checkbox(value=True, label="optimize")
with gr.Row():
ss_iterations = gr.Slider(1, 10, value=3, step=1, label="iterations")
ss_steps = gr.Slider(1, 100, value=25, step=1, label="steps")
with gr.Row():
ss_rescale_t = gr.Slider(1, 10, value=3.0, step=0.1, label="rescale_t")
ss_cfg_strength = gr.Slider(1, 10, value=7.5, step=0.1, label="cfg_strength")
with gr.Row():
ss_t_noise = gr.Slider(0, 1, value=0.6, step=0.1, label="t_noise")
ss_t_start = gr.Slider(0, 1, value=0.8, step=0.1, label="t_start")
ss_alpha = gr.Slider(1, 10, value=5.0, step=0.1, label="alpha")
ss_batch_size = gr.Slider(1, 16, value=1, step=1, label="batch_size")
with gr.Accordion("SLAT Settings", open=False):
slat_optim = gr.Checkbox(value=True, label="optimize")
with gr.Row():
slat_steps = gr.Slider(1, 100, value=25, step=1, label="steps")
with gr.Row():
slat_rescale_t = gr.Slider(1, 10, value=3.0, step=0.1, label="rescale_t")
slat_cfg_strength = gr.Slider(1, 10, value=3.0, step=0.1, label="cfg_strength")
slat_batch_size = gr.Slider(1, 16, value=1, step=1, label="batch_size")
# Right column: outputs
with gr.Column(scale=5, min_width=420):
gr.Markdown("### Output")
preview_video = gr.Video(label="3D Preview (Video)", value=None, autoplay=True, loop=True)
preview_glb = gr.Model3D(label="3D Preview (GLB)", value=None)
gr.Examples(
examples=[
"assets/examples/0.png",
"assets/examples/1.png",
"assets/examples/2.png",
"assets/examples/3.png",
"assets/examples/4.png",
"assets/examples/5.webp",
],
inputs=[image_in],
label="Examples",
examples_per_page=6,
elem_id="examples_gallery",
)
run_btn.click(
fn=run_extend3d,
inputs=[
image_in,
seed, randomize_seed,
width, length, div,
ss_optim, ss_iterations, ss_steps, ss_rescale_t, ss_t_noise, ss_t_start,
ss_cfg_strength, ss_alpha, ss_batch_size,
slat_optim, slat_steps, slat_rescale_t, slat_cfg_strength, slat_batch_size,
],
outputs=[preview_video, preview_glb, seed],
)
if __name__ == "__main__":
demo.launch()