Spaces:
Running on Zero
Running on Zero
Restore logo layout + fix GLB export with DownloadButton
Browse files
app.py
CHANGED
|
@@ -1,17 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from gradio_client import Client, handle_file
|
| 3 |
import spaces
|
| 4 |
from concurrent.futures import ThreadPoolExecutor
|
| 5 |
|
| 6 |
import os
|
| 7 |
-
import sys
|
| 8 |
-
|
| 9 |
-
_script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 10 |
-
|
| 11 |
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
|
| 12 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 13 |
os.environ["ATTN_BACKEND"] = "flash_attn_3"
|
| 14 |
-
os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(
|
| 15 |
os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
|
| 16 |
from datetime import datetime
|
| 17 |
import shutil
|
|
@@ -31,7 +32,8 @@ import o_voxel
|
|
| 31 |
|
| 32 |
# Patch postprocess module with local fix for cumesh.fill_holes() bug
|
| 33 |
import importlib.util
|
| 34 |
-
|
|
|
|
| 35 |
if os.path.exists(_local_postprocess):
|
| 36 |
_spec = importlib.util.spec_from_file_location('o_voxel.postprocess', _local_postprocess)
|
| 37 |
_mod = importlib.util.module_from_spec(_spec)
|
|
@@ -326,7 +328,8 @@ def start_session(req: gr.Request):
|
|
| 326 |
|
| 327 |
def end_session(req: gr.Request):
|
| 328 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 329 |
-
|
|
|
|
| 330 |
|
| 331 |
|
| 332 |
def remove_background(input: Image.Image) -> Image.Image:
|
|
@@ -372,9 +375,10 @@ def preprocess_image(input: Image.Image) -> Image.Image:
|
|
| 372 |
size = int(size * 1)
|
| 373 |
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
|
| 374 |
output = output.crop(bbox) # type: ignore
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
|
|
|
| 378 |
return output
|
| 379 |
|
| 380 |
|
|
@@ -431,34 +435,40 @@ def prepare_multi_example() -> List[str]:
|
|
| 431 |
|
| 432 |
def load_multi_example(image) -> List[Image.Image]:
|
| 433 |
"""Load all views for a multi-image case by matching the input image."""
|
| 434 |
-
|
|
|
|
| 435 |
|
| 436 |
-
# Convert
|
| 437 |
if isinstance(image, np.ndarray):
|
| 438 |
image = Image.fromarray(image)
|
| 439 |
|
| 440 |
-
#
|
| 441 |
-
|
| 442 |
|
| 443 |
# Find matching case by comparing with first images
|
| 444 |
-
|
| 445 |
-
for
|
| 446 |
-
|
|
|
|
|
|
|
| 447 |
if os.path.exists(first_img_path):
|
| 448 |
-
first_img = Image.open(first_img_path).convert('
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
|
|
|
|
|
|
| 452 |
images = []
|
| 453 |
for i in range(1, 7):
|
| 454 |
-
img_path = f'
|
| 455 |
if os.path.exists(img_path):
|
| 456 |
-
img = Image.open(img_path)
|
| 457 |
-
images.append(
|
| 458 |
-
|
|
|
|
| 459 |
|
| 460 |
-
# No match found, return the single image
|
| 461 |
-
return [
|
| 462 |
|
| 463 |
|
| 464 |
def split_image(image: Image.Image) -> List[Image.Image]:
|
|
@@ -476,7 +486,7 @@ def split_image(image: Image.Image) -> List[Image.Image]:
|
|
| 476 |
return [preprocess_image(image) for image in images]
|
| 477 |
|
| 478 |
|
| 479 |
-
@spaces.GPU(duration=
|
| 480 |
def image_to_3d(
|
| 481 |
seed: int,
|
| 482 |
resolution: str,
|
|
@@ -493,14 +503,29 @@ def image_to_3d(
|
|
| 493 |
tex_slat_sampling_steps: int,
|
| 494 |
tex_slat_rescale_t: float,
|
| 495 |
multiimages: List[Tuple[Image.Image, str]],
|
| 496 |
-
multiimage_algo: Literal["
|
| 497 |
-
tex_multiimage_algo: Literal["
|
| 498 |
req: gr.Request,
|
| 499 |
progress=gr.Progress(track_tqdm=True),
|
| 500 |
) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
# --- Sampling ---
|
| 502 |
outputs, latents = pipeline.run_multi_image(
|
| 503 |
-
|
| 504 |
seed=seed,
|
| 505 |
preprocess_image=False,
|
| 506 |
sparse_structure_sampler_params={
|
|
@@ -533,8 +558,16 @@ def image_to_3d(
|
|
| 533 |
mesh = outputs[0]
|
| 534 |
mesh.simplify(16777216) # nvdiffrast limit
|
| 535 |
images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 536 |
state = pack_state(latents)
|
| 537 |
-
del outputs, mesh, latents # Free memory
|
| 538 |
torch.cuda.empty_cache()
|
| 539 |
|
| 540 |
# --- HTML Construction ---
|
|
@@ -657,14 +690,25 @@ def extract_glb(
|
|
| 657 |
glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
|
| 658 |
glb.export(glb_path, extension_webp=True)
|
| 659 |
torch.cuda.empty_cache()
|
| 660 |
-
return glb_path, glb_path
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
with gr.Blocks(
|
| 664 |
-
gr.
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
""")
|
| 669 |
|
| 670 |
with gr.Row():
|
|
@@ -677,10 +721,6 @@ with gr.Blocks(delete_cache=(600, 600), theme=gr.themes.Soft(primary_hue="orange
|
|
| 677 |
decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
|
| 678 |
texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
|
| 679 |
|
| 680 |
-
with gr.Row():
|
| 681 |
-
generate_btn = gr.Button("Generate", variant="primary")
|
| 682 |
-
extract_btn = gr.Button("Extract GLB")
|
| 683 |
-
|
| 684 |
with gr.Accordion(label="Advanced Settings", open=False):
|
| 685 |
gr.Markdown("Stage 1: Sparse Structure Generation")
|
| 686 |
with gr.Row():
|
|
@@ -701,12 +741,16 @@ with gr.Blocks(delete_cache=(600, 600), theme=gr.themes.Soft(primary_hue="orange
|
|
| 701 |
tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
| 702 |
tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
|
| 703 |
multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Structure Algorithm", value="stochastic")
|
| 704 |
-
tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="
|
| 705 |
|
| 706 |
with gr.Column(scale=10):
|
| 707 |
preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
|
| 708 |
-
glb_output = gr.Model3D(label="Extracted GLB", height=400, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
|
| 709 |
-
download_btn = gr.DownloadButton(label="Download GLB")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 710 |
|
| 711 |
example_image = gr.Image(visible=False) # Hidden component for examples
|
| 712 |
examples_multi = gr.Examples(
|
|
@@ -715,7 +759,8 @@ with gr.Blocks(delete_cache=(600, 600), theme=gr.themes.Soft(primary_hue="orange
|
|
| 715 |
fn=load_multi_example,
|
| 716 |
outputs=[multiimage_prompt],
|
| 717 |
run_on_click=True,
|
| 718 |
-
|
|
|
|
| 719 |
)
|
| 720 |
|
| 721 |
output_buf = gr.State()
|
|
@@ -766,7 +811,7 @@ if __name__ == "__main__":
|
|
| 766 |
rmbg_client = Client("briaai/BRIA-RMBG-2.0")
|
| 767 |
pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
|
| 768 |
pipeline.rembg_model = None
|
| 769 |
-
pipeline.low_vram =
|
| 770 |
pipeline.cuda()
|
| 771 |
|
| 772 |
envmap = {
|
|
@@ -784,4 +829,4 @@ if __name__ == "__main__":
|
|
| 784 |
)),
|
| 785 |
}
|
| 786 |
|
| 787 |
-
demo.
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
warnings.filterwarnings("ignore", message=".*torch.distributed.reduce_op.*")
|
| 3 |
+
warnings.filterwarnings("ignore", message=".*torch.cuda.amp.autocast.*")
|
| 4 |
+
warnings.filterwarnings("ignore", message=".*Default grid_sample and affine_grid behavior.*")
|
| 5 |
+
|
| 6 |
import gradio as gr
|
| 7 |
from gradio_client import Client, handle_file
|
| 8 |
import spaces
|
| 9 |
from concurrent.futures import ThreadPoolExecutor
|
| 10 |
|
| 11 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
|
| 13 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 14 |
os.environ["ATTN_BACKEND"] = "flash_attn_3"
|
| 15 |
+
os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
|
| 16 |
os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
|
| 17 |
from datetime import datetime
|
| 18 |
import shutil
|
|
|
|
| 32 |
|
| 33 |
# Patch postprocess module with local fix for cumesh.fill_holes() bug
|
| 34 |
import importlib.util
|
| 35 |
+
import sys
|
| 36 |
+
_local_postprocess = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'o-voxel', 'o_voxel', 'postprocess.py')
|
| 37 |
if os.path.exists(_local_postprocess):
|
| 38 |
_spec = importlib.util.spec_from_file_location('o_voxel.postprocess', _local_postprocess)
|
| 39 |
_mod = importlib.util.module_from_spec(_spec)
|
|
|
|
| 328 |
|
| 329 |
def end_session(req: gr.Request):
|
| 330 |
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
|
| 331 |
+
if os.path.exists(user_dir):
|
| 332 |
+
shutil.rmtree(user_dir)
|
| 333 |
|
| 334 |
|
| 335 |
def remove_background(input: Image.Image) -> Image.Image:
|
|
|
|
| 375 |
size = int(size * 1)
|
| 376 |
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
|
| 377 |
output = output.crop(bbox) # type: ignore
|
| 378 |
+
output_np = np.array(output)
|
| 379 |
+
alpha = output_np[:, :, 3]
|
| 380 |
+
output_np[:, :, :3][alpha < 0.5 * 255] = [0, 0, 0]
|
| 381 |
+
output = Image.fromarray(output_np[:, :, :3])
|
| 382 |
return output
|
| 383 |
|
| 384 |
|
|
|
|
| 435 |
|
| 436 |
def load_multi_example(image) -> List[Image.Image]:
|
| 437 |
"""Load all views for a multi-image case by matching the input image."""
|
| 438 |
+
if image is None:
|
| 439 |
+
return []
|
| 440 |
|
| 441 |
+
# Convert to PIL Image if needed
|
| 442 |
if isinstance(image, np.ndarray):
|
| 443 |
image = Image.fromarray(image)
|
| 444 |
|
| 445 |
+
# Convert to RGB for consistent comparison
|
| 446 |
+
input_rgb = np.array(image.convert('RGB'))
|
| 447 |
|
| 448 |
# Find matching case by comparing with first images
|
| 449 |
+
example_dir = "assets/example_multi_image"
|
| 450 |
+
case_names = sorted(set([f.rsplit('_', 1)[0] for f in os.listdir(example_dir) if f.endswith('.png')]))
|
| 451 |
+
|
| 452 |
+
for case_name in case_names:
|
| 453 |
+
first_img_path = f'{example_dir}/{case_name}_1.png'
|
| 454 |
if os.path.exists(first_img_path):
|
| 455 |
+
first_img = Image.open(first_img_path).convert('RGB')
|
| 456 |
+
first_rgb = np.array(first_img)
|
| 457 |
+
|
| 458 |
+
# Compare images (check if same shape and content)
|
| 459 |
+
if input_rgb.shape == first_rgb.shape and np.array_equal(input_rgb, first_rgb):
|
| 460 |
+
# Found match, load all views (without preprocessing - will be done on Generate)
|
| 461 |
images = []
|
| 462 |
for i in range(1, 7):
|
| 463 |
+
img_path = f'{example_dir}/{case_name}_{i}.png'
|
| 464 |
if os.path.exists(img_path):
|
| 465 |
+
img = Image.open(img_path).convert('RGBA')
|
| 466 |
+
images.append(img)
|
| 467 |
+
if images:
|
| 468 |
+
return images
|
| 469 |
|
| 470 |
+
# No match found, return the single image
|
| 471 |
+
return [image.convert('RGBA') if image.mode != 'RGBA' else image]
|
| 472 |
|
| 473 |
|
| 474 |
def split_image(image: Image.Image) -> List[Image.Image]:
|
|
|
|
| 486 |
return [preprocess_image(image) for image in images]
|
| 487 |
|
| 488 |
|
| 489 |
+
@spaces.GPU(duration=120)
|
| 490 |
def image_to_3d(
|
| 491 |
seed: int,
|
| 492 |
resolution: str,
|
|
|
|
| 503 |
tex_slat_sampling_steps: int,
|
| 504 |
tex_slat_rescale_t: float,
|
| 505 |
multiimages: List[Tuple[Image.Image, str]],
|
| 506 |
+
multiimage_algo: Literal["multidiffusion", "stochastic"],
|
| 507 |
+
tex_multiimage_algo: Literal["multidiffusion", "stochastic"],
|
| 508 |
req: gr.Request,
|
| 509 |
progress=gr.Progress(track_tqdm=True),
|
| 510 |
) -> str:
|
| 511 |
+
if not multiimages:
|
| 512 |
+
raise gr.Error("Please upload images or select an example first.")
|
| 513 |
+
|
| 514 |
+
# Preprocess images (background removal, cropping, etc.)
|
| 515 |
+
images = [image[0] for image in multiimages]
|
| 516 |
+
processed_images = [preprocess_image(img) for img in images]
|
| 517 |
+
|
| 518 |
+
# Debug: save preprocessed images and log stats
|
| 519 |
+
for i, img in enumerate(processed_images):
|
| 520 |
+
arr = np.array(img)
|
| 521 |
+
print(f"[DEBUG] Preprocessed image {i}: mode={img.mode}, size={img.size}, "
|
| 522 |
+
f"dtype={arr.dtype}, min={arr.min()}, max={arr.max()}, mean={arr.mean():.1f}")
|
| 523 |
+
img.save(os.path.join(TMP_DIR, f'debug_preprocessed_{i}.png'))
|
| 524 |
+
print(f"[DEBUG] Pipeline params: mode={multiimage_algo}, tex_mode={tex_multiimage_algo}")
|
| 525 |
+
|
| 526 |
# --- Sampling ---
|
| 527 |
outputs, latents = pipeline.run_multi_image(
|
| 528 |
+
processed_images,
|
| 529 |
seed=seed,
|
| 530 |
preprocess_image=False,
|
| 531 |
sparse_structure_sampler_params={
|
|
|
|
| 558 |
mesh = outputs[0]
|
| 559 |
mesh.simplify(16777216) # nvdiffrast limit
|
| 560 |
images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
|
| 561 |
+
|
| 562 |
+
# Debug: save base_color render and log stats for all render modes
|
| 563 |
+
for key in images:
|
| 564 |
+
arr = images[key][0] # first view
|
| 565 |
+
print(f"[DEBUG] Render '{key}': shape={arr.shape}, min={arr.min()}, max={arr.max()}, mean={arr.mean():.1f}")
|
| 566 |
+
# Save base_color and shaded_forest for inspection
|
| 567 |
+
Image.fromarray(images['base_color'][0]).save(os.path.join(TMP_DIR, 'debug_base_color.png'))
|
| 568 |
+
Image.fromarray(images['shaded_forest'][0]).save(os.path.join(TMP_DIR, 'debug_shaded_forest.png'))
|
| 569 |
+
|
| 570 |
state = pack_state(latents)
|
|
|
|
| 571 |
torch.cuda.empty_cache()
|
| 572 |
|
| 573 |
# --- HTML Construction ---
|
|
|
|
| 690 |
glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
|
| 691 |
glb.export(glb_path, extension_webp=True)
|
| 692 |
torch.cuda.empty_cache()
|
| 693 |
+
return gr.update(value=glb_path, visible=True), gr.update(value=glb_path, visible=True)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate")) as demo:
|
| 697 |
+
gr.HTML("""
|
| 698 |
+
<div style="display: flex; align-items: center; gap: 20px; margin-bottom: 10px;">
|
| 699 |
+
<a href="https://www.opsiclear.com" target="_blank">
|
| 700 |
+
<img src="https://www.opsiclear.com/assets/logos/Logo_v2_compact_name.svg" alt="OpsiClear" style="height: 80px;">
|
| 701 |
+
</a>
|
| 702 |
+
<div>
|
| 703 |
+
<h2 style="margin: 0;">Multi-View to 3D with <a href="https://microsoft.github.io/TRELLIS.2" target="_blank">TRELLIS.2</a></h2>
|
| 704 |
+
<ul style="margin: 5px 0; padding-left: 20px;">
|
| 705 |
+
<li>Upload multiple images from different viewpoints to create a 3D asset with multi-image conditioning.</li>
|
| 706 |
+
<li>Click an example below to load a pre-made multi-view set, or upload your own images.</li>
|
| 707 |
+
<li>Click <b>Generate</b> to create the 3D model, then <b>Extract GLB</b> to export.</li>
|
| 708 |
+
<li style="color: #e67300;"><b>⚠️ Note:</b> Generation quality is highly sensitive to parameters. Adjust settings in Advanced Settings if results are unsatisfactory.</li>
|
| 709 |
+
</ul>
|
| 710 |
+
</div>
|
| 711 |
+
</div>
|
| 712 |
""")
|
| 713 |
|
| 714 |
with gr.Row():
|
|
|
|
| 721 |
decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
|
| 722 |
texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
|
| 723 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
with gr.Accordion(label="Advanced Settings", open=False):
|
| 725 |
gr.Markdown("Stage 1: Sparse Structure Generation")
|
| 726 |
with gr.Row():
|
|
|
|
| 741 |
tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
|
| 742 |
tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
|
| 743 |
multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Structure Algorithm", value="stochastic")
|
| 744 |
+
tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="stochastic")
|
| 745 |
|
| 746 |
with gr.Column(scale=10):
|
| 747 |
preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
|
| 748 |
+
glb_output = gr.Model3D(label="Extracted GLB", height=400, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0), visible=False)
|
| 749 |
+
download_btn = gr.DownloadButton(label="Download GLB", visible=False)
|
| 750 |
+
|
| 751 |
+
with gr.Row():
|
| 752 |
+
generate_btn = gr.Button("Generate", variant="primary")
|
| 753 |
+
extract_btn = gr.Button("Extract GLB")
|
| 754 |
|
| 755 |
example_image = gr.Image(visible=False) # Hidden component for examples
|
| 756 |
examples_multi = gr.Examples(
|
|
|
|
| 759 |
fn=load_multi_example,
|
| 760 |
outputs=[multiimage_prompt],
|
| 761 |
run_on_click=True,
|
| 762 |
+
cache_examples=False,
|
| 763 |
+
examples_per_page=50,
|
| 764 |
)
|
| 765 |
|
| 766 |
output_buf = gr.State()
|
|
|
|
| 811 |
rmbg_client = Client("briaai/BRIA-RMBG-2.0")
|
| 812 |
pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
|
| 813 |
pipeline.rembg_model = None
|
| 814 |
+
pipeline.low_vram = False
|
| 815 |
pipeline.cuda()
|
| 816 |
|
| 817 |
envmap = {
|
|
|
|
| 829 |
)),
|
| 830 |
}
|
| 831 |
|
| 832 |
+
demo.launch(css=css, head=head)
|