opsiclear-admin's picture
Restore premultiplied alpha preprocessing (matches training data)
458c383 verified
import warnings
warnings.filterwarnings("ignore", message=".*torch.distributed.reduce_op.*")
warnings.filterwarnings("ignore", message=".*torch.cuda.amp.autocast.*")
warnings.filterwarnings("ignore", message=".*Default grid_sample and affine_grid behavior.*")
import gradio as gr
from gradio_client import Client, handle_file
import spaces
from concurrent.futures import ThreadPoolExecutor
import os
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["ATTN_BACKEND"] = "flash_attn_3"
os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
from datetime import datetime
import shutil
import cv2
from typing import *
import torch
import numpy as np
from PIL import Image
import base64
import io
import tempfile
from trellis2.modules.sparse import SparseTensor
from trellis2.pipelines import Trellis2ImageTo3DPipeline
from trellis2.renderers import EnvMap
from trellis2.utils import render_utils
import o_voxel
# Patch postprocess module with local fix for cumesh.fill_holes() bug
import importlib.util
import sys
_local_postprocess = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'o-voxel', 'o_voxel', 'postprocess.py')
if os.path.exists(_local_postprocess):
_spec = importlib.util.spec_from_file_location('o_voxel.postprocess', _local_postprocess)
_mod = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(_mod)
o_voxel.postprocess = _mod
sys.modules['o_voxel.postprocess'] = _mod
MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
MODES = [
{"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
{"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
{"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
{"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
{"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
{"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
]
STEPS = 8
DEFAULT_MODE = 3
DEFAULT_STEP = 3
css = """
/* Overwrite Gradio Default Style */
.stepper-wrapper {
padding: 0;
}
.stepper-container {
padding: 0;
align-items: center;
}
.step-button {
flex-direction: row;
}
.step-connector {
transform: none;
}
.step-number {
width: 16px;
height: 16px;
}
.step-label {
position: relative;
bottom: 0;
}
.wrap.center.full {
inset: 0;
height: 100%;
}
.wrap.center.full.translucent {
background: var(--block-background-fill);
}
.meta-text-center {
display: block !important;
position: absolute !important;
top: unset !important;
bottom: 0 !important;
right: 0 !important;
transform: unset !important;
}
/* Previewer */
.previewer-container {
position: relative;
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
width: 100%;
height: 722px;
margin: 0 auto;
padding: 20px;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
}
.previewer-container .tips-icon {
position: absolute;
right: 10px;
top: 10px;
z-index: 10;
border-radius: 10px;
color: #fff;
background-color: var(--color-accent);
padding: 3px 6px;
user-select: none;
}
.previewer-container .tips-text {
position: absolute;
right: 10px;
top: 50px;
color: #fff;
background-color: var(--color-accent);
border-radius: 10px;
padding: 6px;
text-align: left;
max-width: 300px;
z-index: 10;
transition: all 0.3s;
opacity: 0%;
user-select: none;
}
.previewer-container .tips-text p {
font-size: 14px;
line-height: 1.2;
}
.tips-icon:hover + .tips-text {
display: block;
opacity: 100%;
}
/* Row 1: Display Modes */
.previewer-container .mode-row {
width: 100%;
display: flex;
gap: 8px;
justify-content: center;
margin-bottom: 20px;
flex-wrap: wrap;
}
.previewer-container .mode-btn {
width: 24px;
height: 24px;
border-radius: 50%;
cursor: pointer;
opacity: 0.5;
transition: all 0.2s;
border: 2px solid var(--neutral-600, #555);
object-fit: cover;
}
.previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
.previewer-container .mode-btn.active {
opacity: 1;
border-color: var(--color-accent);
transform: scale(1.1);
}
/* Row 2: Display Image */
.previewer-container .display-row {
margin-bottom: 20px;
min-height: 400px;
width: 100%;
flex-grow: 1;
display: flex;
justify-content: center;
align-items: center;
}
.previewer-container .previewer-main-image {
max-width: 100%;
max-height: 100%;
flex-grow: 1;
object-fit: contain;
display: none;
}
.previewer-container .previewer-main-image.visible {
display: block;
}
/* Row 3: Custom HTML Slider */
.previewer-container .slider-row {
width: 100%;
display: flex;
flex-direction: column;
align-items: center;
gap: 10px;
padding: 0 10px;
}
.previewer-container input[type=range] {
-webkit-appearance: none;
width: 100%;
max-width: 400px;
background: transparent;
}
.previewer-container input[type=range]::-webkit-slider-runnable-track {
width: 100%;
height: 8px;
cursor: pointer;
background: var(--neutral-700, #404040);
border-radius: 5px;
}
.previewer-container input[type=range]::-webkit-slider-thumb {
height: 20px;
width: 20px;
border-radius: 50%;
background: var(--color-accent);
cursor: pointer;
-webkit-appearance: none;
margin-top: -6px;
box-shadow: 0 2px 5px rgba(0,0,0,0.2);
transition: transform 0.1s;
}
.previewer-container input[type=range]::-webkit-slider-thumb:hover {
transform: scale(1.2);
}
/* Overwrite Previewer Block Style */
.gradio-container .padded:has(.previewer-container) {
padding: 0 !important;
}
.gradio-container:has(.previewer-container) [data-testid="block-label"] {
position: absolute;
top: 0;
left: 0;
}
"""
head = """
<script>
function refreshView(mode, step) {
// 1. Find current mode and step
const allImgs = document.querySelectorAll('.previewer-main-image');
for (let i = 0; i < allImgs.length; i++) {
const img = allImgs[i];
if (img.classList.contains('visible')) {
const id = img.id;
const [_, m, s] = id.split('-');
if (mode === -1) mode = parseInt(m.slice(1));
if (step === -1) step = parseInt(s.slice(1));
break;
}
}
// 2. Hide ALL images
// We select all elements with class 'previewer-main-image'
allImgs.forEach(img => img.classList.remove('visible'));
// 3. Construct the specific ID for the current state
// Format: view-m{mode}-s{step}
const targetId = 'view-m' + mode + '-s' + step;
const targetImg = document.getElementById(targetId);
// 4. Show ONLY the target
if (targetImg) {
targetImg.classList.add('visible');
}
// 5. Update Button Highlights
const allBtns = document.querySelectorAll('.mode-btn');
allBtns.forEach((btn, idx) => {
if (idx === mode) btn.classList.add('active');
else btn.classList.remove('active');
});
}
// --- Action: Switch Mode ---
function selectMode(mode) {
refreshView(mode, -1);
}
// --- Action: Slider Change ---
function onSliderChange(val) {
refreshView(-1, parseInt(val));
}
</script>
"""
empty_html = f"""
<div class="previewer-container">
<svg style=" opacity: .5; height: var(--size-5); color: var(--body-text-color);"
xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
</div>
"""
def image_to_base64(image):
buffered = io.BytesIO()
image = image.convert("RGB")
image.save(buffered, format="jpeg", quality=85)
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/jpeg;base64,{img_str}"
def start_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
os.makedirs(user_dir, exist_ok=True)
def end_session(req: gr.Request):
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
if os.path.exists(user_dir):
shutil.rmtree(user_dir)
def remove_background(input: Image.Image) -> Image.Image:
with tempfile.NamedTemporaryFile(suffix='.png') as f:
input = input.convert('RGB')
input.save(f.name)
output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
output = Image.open(output)
return output
def preprocess_image(input: Image.Image) -> Image.Image:
"""
Preprocess the input image.
"""
# if has alpha channel, use it directly; otherwise, remove background
has_alpha = False
if input.mode == 'RGBA':
alpha = np.array(input)[:, :, 3]
if not np.all(alpha == 255):
has_alpha = True
max_size = max(input.size)
scale = min(1, 1024 / max_size)
if scale < 1:
input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
if has_alpha:
output = input
else:
output = remove_background(input)
output_np = np.array(output)
alpha = output_np[:, :, 3]
bbox = np.argwhere(alpha > 0.8 * 255)
if bbox.size == 0:
# No visible pixels, center the image in a square
size = max(output.size)
square = Image.new('RGB', (size, size), (0, 0, 0))
output_rgb = output.convert('RGB') if output.mode == 'RGBA' else output
square.paste(output_rgb, ((size - output.width) // 2, (size - output.height) // 2))
return square
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
size = int(size * 1)
bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
output = output.crop(bbox) # type: ignore
output = np.array(output).astype(np.float32) / 255
output = output[:, :, :3] * output[:, :, 3:4]
output = Image.fromarray((output * 255).astype(np.uint8))
return output
def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
"""
Preprocess a list of input images for multi-image conditioning.
Uses parallel processing for faster background removal.
"""
images = [image[0] for image in images]
with ThreadPoolExecutor(max_workers=min(4, len(images))) as executor:
processed_images = list(executor.map(preprocess_image, images))
return processed_images
def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
shape_slat, tex_slat, res = latents
return {
'shape_slat_feats': shape_slat.feats.cpu().numpy(),
'tex_slat_feats': tex_slat.feats.cpu().numpy(),
'coords': shape_slat.coords.cpu().numpy(),
'res': res,
}
def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
shape_slat = SparseTensor(
feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
coords=torch.from_numpy(state['coords']).cuda(),
)
tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
return shape_slat, tex_slat, state['res']
def get_seed(randomize_seed: bool, seed: int) -> int:
"""
Get the random seed.
"""
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
def prepare_multi_example() -> List[str]:
"""
Prepare multi-image examples. Returns list of image paths.
Shows only the first view as representative thumbnail.
"""
multi_case = sorted(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
examples = []
for case in multi_case:
first_img = f'assets/example_multi_image/{case}_1.png'
if os.path.exists(first_img):
examples.append(first_img)
return examples
def load_multi_example(image) -> List[Image.Image]:
"""Load all views for a multi-image case by matching the input image."""
if image is None:
return []
# Convert to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Convert to RGB for consistent comparison
input_rgb = np.array(image.convert('RGB'))
# Find matching case by comparing with first images
example_dir = "assets/example_multi_image"
case_names = sorted(set([f.rsplit('_', 1)[0] for f in os.listdir(example_dir) if f.endswith('.png')]))
for case_name in case_names:
first_img_path = f'{example_dir}/{case_name}_1.png'
if os.path.exists(first_img_path):
first_img = Image.open(first_img_path).convert('RGB')
first_rgb = np.array(first_img)
# Compare images (check if same shape and content)
if input_rgb.shape == first_rgb.shape and np.array_equal(input_rgb, first_rgb):
# Found match, load all views (without preprocessing - will be done on Generate)
images = []
for i in range(1, 7):
img_path = f'{example_dir}/{case_name}_{i}.png'
if os.path.exists(img_path):
img = Image.open(img_path).convert('RGBA')
images.append(img)
if images:
return images
# No match found, return the single image
return [image.convert('RGBA') if image.mode != 'RGBA' else image]
def split_image(image: Image.Image) -> List[Image.Image]:
"""
Split a concatenated image into multiple views.
"""
image = np.array(image)
alpha = image[..., 3]
alpha = np.any(alpha > 0, axis=0)
start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
images = []
for s, e in zip(start_pos, end_pos):
images.append(Image.fromarray(image[:, s:e+1]))
return [preprocess_image(image) for image in images]
@spaces.GPU(duration=120)
def image_to_3d(
seed: int,
resolution: str,
ss_guidance_strength: float,
ss_guidance_rescale: float,
ss_sampling_steps: int,
ss_rescale_t: float,
shape_slat_guidance_strength: float,
shape_slat_guidance_rescale: float,
shape_slat_sampling_steps: int,
shape_slat_rescale_t: float,
tex_slat_guidance_strength: float,
tex_slat_guidance_rescale: float,
tex_slat_sampling_steps: int,
tex_slat_rescale_t: float,
multiimages: List[Tuple[Image.Image, str]],
multiimage_algo: Literal["multidiffusion", "stochastic"],
tex_multiimage_algo: Literal["multidiffusion", "stochastic"],
req: gr.Request,
progress=gr.Progress(track_tqdm=True),
) -> str:
if not multiimages:
raise gr.Error("Please upload images or select an example first.")
# Preprocess images (background removal, cropping, etc.)
images = [image[0] for image in multiimages]
processed_images = [preprocess_image(img) for img in images]
# Debug: save preprocessed images and log stats
for i, img in enumerate(processed_images):
arr = np.array(img)
print(f"[DEBUG] Preprocessed image {i}: mode={img.mode}, size={img.size}, "
f"dtype={arr.dtype}, min={arr.min()}, max={arr.max()}, mean={arr.mean():.1f}")
img.save(os.path.join(TMP_DIR, f'debug_preprocessed_{i}.png'))
print(f"[DEBUG] Pipeline params: mode={multiimage_algo}, tex_mode={tex_multiimage_algo}")
# --- Sampling ---
outputs, latents = pipeline.run_multi_image(
processed_images,
seed=seed,
preprocess_image=False,
sparse_structure_sampler_params={
"steps": ss_sampling_steps,
"guidance_strength": ss_guidance_strength,
"guidance_rescale": ss_guidance_rescale,
"rescale_t": ss_rescale_t,
},
shape_slat_sampler_params={
"steps": shape_slat_sampling_steps,
"guidance_strength": shape_slat_guidance_strength,
"guidance_rescale": shape_slat_guidance_rescale,
"rescale_t": shape_slat_rescale_t,
},
tex_slat_sampler_params={
"steps": tex_slat_sampling_steps,
"guidance_strength": tex_slat_guidance_strength,
"guidance_rescale": tex_slat_guidance_rescale,
"rescale_t": tex_slat_rescale_t,
},
pipeline_type={
"512": "512",
"1024": "1024_cascade",
"1536": "1536_cascade",
}[resolution],
return_latent=True,
mode=multiimage_algo,
tex_mode=tex_multiimage_algo,
)
mesh = outputs[0]
mesh.simplify(16777216) # nvdiffrast limit
images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
# Debug: save base_color render and log stats for all render modes
for key in images:
arr = images[key][0] # first view
print(f"[DEBUG] Render '{key}': shape={arr.shape}, min={arr.min()}, max={arr.max()}, mean={arr.mean():.1f}")
# Save base_color and shaded_forest for inspection
Image.fromarray(images['base_color'][0]).save(os.path.join(TMP_DIR, 'debug_base_color.png'))
Image.fromarray(images['shaded_forest'][0]).save(os.path.join(TMP_DIR, 'debug_shaded_forest.png'))
state = pack_state(latents)
torch.cuda.empty_cache()
# --- HTML Construction ---
# The Stack of 48 Images - encode in parallel for speed
def encode_preview_image(args):
m_idx, s_idx, render_key = args
img_base64 = image_to_base64(Image.fromarray(images[render_key][s_idx]))
return (m_idx, s_idx, img_base64)
encode_tasks = [
(m_idx, s_idx, mode['render_key'])
for m_idx, mode in enumerate(MODES)
for s_idx in range(STEPS)
]
with ThreadPoolExecutor(max_workers=8) as executor:
encoded_results = list(executor.map(encode_preview_image, encode_tasks))
# Build HTML from encoded results
encoded_map = {(m, s): b64 for m, s, b64 in encoded_results}
images_html = ""
for m_idx, mode in enumerate(MODES):
for s_idx in range(STEPS):
unique_id = f"view-m{m_idx}-s{s_idx}"
is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
vis_class = "visible" if is_visible else ""
img_base64 = encoded_map[(m_idx, s_idx)]
images_html += f"""
<img id="{unique_id}"
class="previewer-main-image {vis_class}"
src="{img_base64}"
loading="eager">
"""
# Button Row HTML
btns_html = ""
for idx, mode in enumerate(MODES):
active_class = "active" if idx == DEFAULT_MODE else ""
# Note: onclick calls the JS function defined in Head
btns_html += f"""
<img src="{mode['icon_base64']}"
class="mode-btn {active_class}"
onclick="selectMode({idx})"
title="{mode['name']}">
"""
# Assemble the full component
full_html = f"""
<div class="previewer-container">
<div class="tips-wrapper">
<div class="tips-icon">💡Tips</div>
<div class="tips-text">
<p>● <b>Render Mode</b> - Click on the circular buttons to switch between different render modes.</p>
<p>● <b>View Angle</b> - Drag the slider to change the view angle.</p>
</div>
</div>
<!-- Row 1: Viewport containing 48 static <img> tags -->
<div class="display-row">
{images_html}
</div>
<!-- Row 2 -->
<div class="mode-row" id="btn-group">
{btns_html}
</div>
<!-- Row 3: Slider -->
<div class="slider-row">
<input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
</div>
</div>
"""
return state, full_html
@spaces.GPU(duration=120)
def extract_glb(
state: dict,
decimation_target: int,
texture_size: int,
req: gr.Request,
progress=gr.Progress(track_tqdm=True),
) -> Tuple[str, str]:
"""
Extract a GLB file from the 3D model.
Args:
state (dict): The state of the generated 3D model.
decimation_target (int): The target face count for decimation.
texture_size (int): The texture resolution.
Returns:
Tuple[str, str]: The path to the extracted GLB file (for Model3D and DownloadButton).
"""
user_dir = os.path.join(TMP_DIR, str(req.session_hash))
shape_slat, tex_slat, res = unpack_state(state)
mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
mesh.simplify(16777216) # nvdiffrast limit
glb = o_voxel.postprocess.to_glb(
vertices=mesh.vertices,
faces=mesh.faces,
attr_volume=mesh.attrs,
coords=mesh.coords,
attr_layout=pipeline.pbr_attr_layout,
grid_size=res,
aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
decimation_target=decimation_target,
texture_size=texture_size,
remesh=True,
remesh_band=1,
remesh_project=0,
use_tqdm=True,
)
now = datetime.now()
timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
os.makedirs(user_dir, exist_ok=True)
glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
glb.export(glb_path, extension_webp=True)
torch.cuda.empty_cache()
return glb_path, glb_path
with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate")) as demo:
gr.HTML("""
<div style="display: flex; align-items: center; gap: 20px; margin-bottom: 10px;">
<a href="https://www.opsiclear.com" target="_blank">
<img src="https://www.opsiclear.com/assets/logos/Logo_v2_compact_name.svg" alt="OpsiClear" style="height: 80px;">
</a>
<div>
<h2 style="margin: 0;">Multi-View to 3D with <a href="https://microsoft.github.io/TRELLIS.2" target="_blank">TRELLIS.2</a></h2>
<ul style="margin: 5px 0; padding-left: 20px;">
<li>Upload multiple images from different viewpoints to create a 3D asset with multi-image conditioning.</li>
<li>Click an example below to load a pre-made multi-view set, or upload your own images.</li>
<li>Click <b>Generate</b> to create the 3D model, then <b>Extract GLB</b> to export.</li>
<li style="color: #e67300;"><b>⚠️ Note:</b> Generation quality is highly sensitive to parameters. Adjust settings in Advanced Settings if results are unsatisfactory.</li>
</ul>
</div>
</div>
""")
with gr.Row():
with gr.Column(scale=1, min_width=360):
multiimage_prompt = gr.Gallery(label="Multi-View Images", format="png", type="pil", height=400, columns=3)
resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
with gr.Accordion(label="Advanced Settings", open=False):
gr.Markdown("Stage 1: Sparse Structure Generation")
with gr.Row():
ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01)
ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
gr.Markdown("Stage 2: Shape Generation")
with gr.Row():
shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01)
shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
gr.Markdown("Stage 3: Material Generation")
with gr.Row():
tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Structure Algorithm", value="stochastic")
tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="stochastic")
with gr.Column(scale=10):
with gr.Row():
generate_btn = gr.Button("Generate", variant="primary")
extract_btn = gr.Button("Extract GLB")
preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
glb_output = gr.Model3D(label="Extracted GLB", height=400, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
download_btn = gr.DownloadButton(label="Download GLB")
example_image = gr.Image(visible=False) # Hidden component for examples
examples_multi = gr.Examples(
examples=prepare_multi_example(),
inputs=[example_image],
fn=load_multi_example,
outputs=[multiimage_prompt],
run_on_click=True,
cache_examples=False,
examples_per_page=50,
)
output_buf = gr.State()
# Handlers
demo.load(start_session)
demo.unload(end_session)
multiimage_prompt.upload(
preprocess_images,
inputs=[multiimage_prompt],
outputs=[multiimage_prompt],
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
image_to_3d,
inputs=[
seed, resolution,
ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
multiimage_prompt, multiimage_algo, tex_multiimage_algo
],
outputs=[output_buf, preview_output],
)
extract_btn.click(
extract_glb,
inputs=[output_buf, decimation_target, texture_size],
outputs=[glb_output, download_btn],
)
# Launch the Gradio app
if __name__ == "__main__":
os.makedirs(TMP_DIR, exist_ok=True)
# Construct ui components
btn_img_base64_strs = {}
for i in range(len(MODES)):
icon = Image.open(MODES[i]['icon'])
MODES[i]['icon_base64'] = image_to_base64(icon)
rmbg_client = Client("briaai/BRIA-RMBG-2.0")
pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
pipeline.rembg_model = None
pipeline.low_vram = False
pipeline.cuda()
envmap = {
'forest': EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
)),
'sunset': EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
)),
'courtyard': EnvMap(torch.tensor(
cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
dtype=torch.float32, device='cuda'
)),
}
demo.launch(css=css, head=head)