opsiclear-admin commited on
Commit
c9ea1f8
·
verified ·
1 Parent(s): 42fc1a9

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +778 -712
app.py CHANGED
@@ -1,712 +1,778 @@
1
- import gradio as gr
2
- from gradio_client import Client, handle_file
3
- import spaces
4
- from concurrent.futures import ThreadPoolExecutor
5
-
6
- import os
7
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
8
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
9
- os.environ["ATTN_BACKEND"] = "flash_attn_3"
10
- os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
11
- os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
12
- from datetime import datetime
13
- import shutil
14
- import cv2
15
- from typing import *
16
- import numpy as np
17
- from PIL import Image
18
- import base64
19
- import io
20
- import tempfile
21
-
22
- # Lazy imports - will be loaded when GPU is available
23
- torch = None
24
- SparseTensor = None
25
- Trellis2ImageTo3DPipeline = None
26
- EnvMap = None
27
- render_utils = None
28
- o_voxel = None
29
-
30
- # Global state - initialized on first GPU call
31
- pipeline = None
32
- envmap = None
33
- _initialized = False
34
-
35
-
36
- def _lazy_import():
37
- """Import GPU-dependent modules. Must be called from within a @spaces.GPU function."""
38
- global torch, SparseTensor, Trellis2ImageTo3DPipeline, EnvMap, render_utils, o_voxel
39
- if torch is None:
40
- import torch as _torch
41
- torch = _torch
42
- if SparseTensor is None:
43
- from trellis2.modules.sparse import SparseTensor as _SparseTensor
44
- SparseTensor = _SparseTensor
45
- if Trellis2ImageTo3DPipeline is None:
46
- from trellis2.pipelines import Trellis2ImageTo3DPipeline as _Trellis2ImageTo3DPipeline
47
- Trellis2ImageTo3DPipeline = _Trellis2ImageTo3DPipeline
48
- if EnvMap is None:
49
- from trellis2.renderers import EnvMap as _EnvMap
50
- EnvMap = _EnvMap
51
- if render_utils is None:
52
- from trellis2.utils import render_utils as _render_utils
53
- render_utils = _render_utils
54
- if o_voxel is None:
55
- import o_voxel as _o_voxel
56
- o_voxel = _o_voxel
57
- # Patch postprocess module with local fix for cumesh.fill_holes() bug
58
- import importlib.util
59
- _local_pp = os.path.join(os.path.dirname(os.path.abspath(__file__)),
60
- 'o-voxel', 'o_voxel', 'postprocess.py')
61
- if os.path.exists(_local_pp):
62
- _spec = importlib.util.spec_from_file_location('o_voxel.postprocess', _local_pp)
63
- _mod = importlib.util.module_from_spec(_spec)
64
- _spec.loader.exec_module(_mod)
65
- o_voxel.postprocess = _mod
66
- import sys
67
- sys.modules['o_voxel.postprocess'] = _mod
68
-
69
-
70
- def _initialize_pipeline():
71
- """Initialize the pipeline and environment maps. Must be called from within a @spaces.GPU function."""
72
- global pipeline, envmap, _initialized
73
- if _initialized:
74
- return
75
-
76
- _lazy_import()
77
-
78
- pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
79
- pipeline.rembg_model = None
80
- pipeline.low_vram = False
81
- pipeline.cuda()
82
-
83
- envmap = {
84
- 'forest': EnvMap(torch.tensor(
85
- cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
86
- dtype=torch.float32, device='cuda'
87
- )),
88
- 'sunset': EnvMap(torch.tensor(
89
- cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
90
- dtype=torch.float32, device='cuda'
91
- )),
92
- 'courtyard': EnvMap(torch.tensor(
93
- cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
94
- dtype=torch.float32, device='cuda'
95
- )),
96
- }
97
-
98
- _initialized = True
99
-
100
-
101
- MAX_SEED = np.iinfo(np.int32).max
102
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
103
- MODES = [
104
- {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
105
- {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
106
- {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
107
- {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
108
- {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
109
- {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
110
- ]
111
- STEPS = 8
112
- DEFAULT_MODE = 3
113
- DEFAULT_STEP = 3
114
-
115
-
116
- css = """
117
- /* ColmapView Dark Theme */
118
- :root {
119
- --bg-void: #0a0a0a;
120
- --bg-primary: #0f0f0f;
121
- --bg-secondary: #161616;
122
- --bg-tertiary: #1e1e1e;
123
- --bg-input: #1a1a1a;
124
- --bg-elevated: #242424;
125
- --bg-hover: #262626;
126
- --text-primary: #e8e8e8;
127
- --text-secondary: #8a8a8a;
128
- --text-muted: #5a5a5a;
129
- --border-subtle: #222222;
130
- --border-color: #2a2a2a;
131
- --border-light: #3a3a3a;
132
- --accent: #b8b8b8;
133
- --accent-hover: #d0d0d0;
134
- --accent-dim: rgba(184, 184, 184, 0.12);
135
- --shadow-sm: 0 2px 4px rgba(0, 0, 0, 0.4), 0 1px 2px rgba(0, 0, 0, 0.3);
136
- --shadow-md: 0 4px 12px rgba(0, 0, 0, 0.5), 0 2px 4px rgba(0, 0, 0, 0.3);
137
- --radius: 0.25rem;
138
- --radius-md: 0.375rem;
139
- --radius-lg: 0.5rem;
140
- --font-sans: 'Roboto', -apple-system, BlinkMacSystemFont, sans-serif;
141
- --font-mono: 'JetBrains Mono', 'Fira Code', 'Consolas', monospace;
142
- }
143
-
144
- /* Global Overrides */
145
- .gradio-container { background: var(--bg-primary) !important; color: var(--text-primary) !important; font-family: var(--font-sans) !important; }
146
- .dark { background: var(--bg-primary) !important; }
147
- body { background: var(--bg-void) !important; color-scheme: dark; }
148
-
149
- /* Panels & Blocks */
150
- .block { background: var(--bg-secondary) !important; border: 1px solid var(--border-color) !important; border-radius: var(--radius-lg) !important; }
151
- .panel { background: var(--bg-secondary) !important; }
152
-
153
- /* Inputs */
154
- input, textarea, select { background: var(--bg-input) !important; border: 1px solid var(--border-color) !important; color: var(--text-primary) !important; border-radius: var(--radius) !important; }
155
- input:focus, textarea:focus, select:focus { border-color: var(--accent) !important; outline: none !important; }
156
-
157
- /* Buttons */
158
- .primary { background: var(--accent) !important; color: var(--bg-void) !important; border: none !important; }
159
- .primary:hover { background: var(--accent-hover) !important; }
160
- .secondary { background: var(--bg-hover) !important; color: var(--text-primary) !important; border: 1px solid var(--border-color) !important; }
161
- .secondary:hover { background: var(--bg-tertiary) !important; }
162
- button { transition: all 0.15s ease-out !important; }
163
-
164
- /* Labels & Text */
165
- label, .label-wrap { color: var(--text-secondary) !important; font-size: 0.875rem !important; }
166
- .prose { color: var(--text-primary) !important; }
167
- .prose h2 { color: var(--text-primary) !important; border: none !important; }
168
-
169
- /* Sliders */
170
- input[type="range"] { accent-color: var(--accent) !important; }
171
-
172
- /* Gallery */
173
- .gallery { background: var(--bg-tertiary) !important; border: 1px solid var(--border-color) !important; }
174
-
175
- /* Accordion */
176
- .accordion { background: var(--bg-secondary) !important; border: 1px solid var(--border-color) !important; }
177
-
178
- /* Scrollbar */
179
- ::-webkit-scrollbar { width: 12px; }
180
- ::-webkit-scrollbar-track { background: var(--bg-secondary); }
181
- ::-webkit-scrollbar-thumb { background: var(--border-light); border-radius: 9999px; }
182
- ::-webkit-scrollbar-thumb:hover { background: var(--text-muted); }
183
-
184
- /* Gradio Overrides */
185
- .stepper-wrapper { padding: 0; }
186
- .stepper-container { padding: 0; align-items: center; }
187
- .step-button { flex-direction: row; }
188
- .step-connector { transform: none; }
189
- .step-number { width: 16px; height: 16px; background: var(--bg-tertiary) !important; border: 1px solid var(--border-color) !important; }
190
- .step-label { position: relative; bottom: 0; color: var(--text-secondary) !important; }
191
- .wrap.center.full { inset: 0; height: 100%; }
192
- .wrap.center.full.translucent { background: var(--bg-secondary); }
193
- .meta-text-center { display: block !important; position: absolute !important; top: unset !important; bottom: 0 !important; right: 0 !important; transform: unset !important; }
194
-
195
- /* Previewer */
196
- .previewer-container {
197
- position: relative;
198
- font-family: var(--font-sans);
199
- width: 100%;
200
- height: 722px;
201
- margin: 0 auto;
202
- padding: 20px;
203
- display: flex;
204
- flex-direction: column;
205
- align-items: center;
206
- justify-content: center;
207
- background: var(--bg-void);
208
- border-radius: var(--radius-lg);
209
- }
210
- .previewer-container .tips-icon {
211
- position: absolute; right: 10px; top: 10px; z-index: 10;
212
- border-radius: var(--radius-md); color: var(--bg-void);
213
- background-color: var(--accent); padding: 4px 8px;
214
- user-select: none; font-size: 0.75rem; font-weight: 500;
215
- }
216
- .previewer-container .tips-text {
217
- position: absolute; right: 10px; top: 45px;
218
- color: var(--text-primary); background-color: var(--bg-elevated);
219
- border: 1px solid var(--border-light); border-radius: var(--radius-md);
220
- padding: 8px 12px; text-align: left; max-width: 280px; z-index: 10;
221
- transition: opacity 0.15s ease-out; opacity: 0; user-select: none;
222
- box-shadow: var(--shadow-md);
223
- }
224
- .previewer-container .tips-text p { font-size: 0.75rem; line-height: 1.4; color: var(--text-secondary); margin: 4px 0; }
225
- .tips-icon:hover + .tips-text { opacity: 1; }
226
- .previewer-container .mode-row { width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 16px; flex-wrap: wrap; }
227
- .previewer-container .mode-btn {
228
- width: 28px; height: 28px; border-radius: 50%; cursor: pointer;
229
- opacity: 0.5; transition: all 0.15s ease-out;
230
- border: 2px solid var(--border-light); object-fit: cover;
231
- }
232
- .previewer-container .mode-btn:hover { opacity: 0.85; transform: scale(1.1); border-color: var(--text-muted); }
233
- .previewer-container .mode-btn.active { opacity: 1; border-color: var(--accent); transform: scale(1.1); }
234
- .previewer-container .display-row { margin-bottom: 16px; min-height: 400px; width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; }
235
- .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; }
236
- .previewer-container .previewer-main-image.visible { display: block; }
237
- .previewer-container .slider-row { width: 100%; display: flex; flex-direction: column; align-items: center; gap: 8px; padding: 0 16px; }
238
- .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; height: 20px; }
239
- .previewer-container input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 4px; cursor: pointer; background: var(--bg-tertiary); border-radius: 9999px; }
240
- .previewer-container input[type=range]::-webkit-slider-thumb {
241
- height: 12px; width: 12px; border-radius: 50%; background: var(--accent);
242
- cursor: pointer; -webkit-appearance: none; margin-top: -4px;
243
- box-shadow: var(--shadow-sm); transition: transform 0.1s ease-out;
244
- }
245
- .previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); background: var(--accent-hover); }
246
- .gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
247
- .gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; }
248
- """
249
-
250
-
251
- head = """
252
- <script>
253
- function refreshView(mode, step) {
254
- const allImgs = document.querySelectorAll('.previewer-main-image');
255
- for (let i = 0; i < allImgs.length; i++) {
256
- const img = allImgs[i];
257
- if (img.classList.contains('visible')) {
258
- const id = img.id;
259
- const [_, m, s] = id.split('-');
260
- if (mode === -1) mode = parseInt(m.slice(1));
261
- if (step === -1) step = parseInt(s.slice(1));
262
- break;
263
- }
264
- }
265
- allImgs.forEach(img => img.classList.remove('visible'));
266
- const targetId = 'view-m' + mode + '-s' + step;
267
- const targetImg = document.getElementById(targetId);
268
- if (targetImg) { targetImg.classList.add('visible'); }
269
- const allBtns = document.querySelectorAll('.mode-btn');
270
- allBtns.forEach((btn, idx) => {
271
- if (idx === mode) btn.classList.add('active');
272
- else btn.classList.remove('active');
273
- });
274
- }
275
- function selectMode(mode) { refreshView(mode, -1); }
276
- function onSliderChange(val) { refreshView(-1, parseInt(val)); }
277
- </script>
278
- """
279
-
280
-
281
- empty_html = """
282
- <div class="previewer-container">
283
- <svg style="opacity: .5; height: var(--size-5); color: var(--body-text-color);"
284
- xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
285
- </div>
286
- """
287
-
288
-
289
- def image_to_base64(image):
290
- buffered = io.BytesIO()
291
- image = image.convert("RGB")
292
- image.save(buffered, format="jpeg", quality=85)
293
- img_str = base64.b64encode(buffered.getvalue()).decode()
294
- return f"data:image/jpeg;base64,{img_str}"
295
-
296
-
297
- def start_session(req: gr.Request):
298
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
299
- os.makedirs(user_dir, exist_ok=True)
300
-
301
-
302
- def end_session(req: gr.Request):
303
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
304
- if os.path.exists(user_dir):
305
- shutil.rmtree(user_dir)
306
-
307
-
308
- def remove_background(input: Image.Image) -> Image.Image:
309
- with tempfile.NamedTemporaryFile(suffix='.png') as f:
310
- input = input.convert('RGB')
311
- input.save(f.name)
312
- output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
313
- output = Image.open(output)
314
- return output
315
-
316
-
317
- def preprocess_image(input: Image.Image) -> Image.Image:
318
- """Preprocess a single input image."""
319
- has_alpha = False
320
- if input.mode == 'RGBA':
321
- alpha = np.array(input)[:, :, 3]
322
- if not np.all(alpha == 255):
323
- has_alpha = True
324
- max_size = max(input.size)
325
- scale = min(1, 1024 / max_size)
326
- if scale < 1:
327
- input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
328
- if has_alpha:
329
- output = input
330
- else:
331
- output = remove_background(input)
332
- output_np = np.array(output)
333
- alpha = output_np[:, :, 3]
334
- bbox = np.argwhere(alpha > 0.8 * 255)
335
- bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
336
- center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
337
- size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
338
- size = int(size * 1)
339
- bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
340
- output = output.crop(bbox)
341
- output = np.array(output).astype(np.float32) / 255
342
- output = output[:, :, :3] * output[:, :, 3:4]
343
- output = Image.fromarray((output * 255).astype(np.uint8))
344
- return output
345
-
346
-
347
- def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
348
- """Preprocess a list of input images. Uses parallel processing."""
349
- if not images:
350
- return []
351
- imgs = [img[0] if isinstance(img, tuple) else img for img in images]
352
- with ThreadPoolExecutor(max_workers=min(4, len(imgs))) as executor:
353
- processed_images = list(executor.map(preprocess_image, imgs))
354
- return processed_images
355
-
356
-
357
- def pack_state(latents):
358
- shape_slat, tex_slat, res = latents
359
- return {
360
- 'shape_slat_feats': shape_slat.feats.cpu().numpy(),
361
- 'tex_slat_feats': tex_slat.feats.cpu().numpy(),
362
- 'coords': shape_slat.coords.cpu().numpy(),
363
- 'res': res,
364
- }
365
-
366
-
367
- def unpack_state(state: dict):
368
- _lazy_import()
369
- shape_slat = SparseTensor(
370
- feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
371
- coords=torch.from_numpy(state['coords']).cuda(),
372
- )
373
- tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
374
- return shape_slat, tex_slat, state['res']
375
-
376
-
377
- def get_seed(randomize_seed: bool, seed: int) -> int:
378
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
379
-
380
-
381
- def prepare_examples() -> List[List[str]]:
382
- """Prepare examples as lists of image paths (not concatenated)."""
383
- example_dir = "assets/example_multi_image"
384
- if not os.path.exists(example_dir):
385
- return []
386
- files = os.listdir(example_dir)
387
- cases = list(set([f.split('_')[0] for f in files if '_' in f and f.endswith('.png')]))
388
- examples = []
389
- for case in sorted(cases):
390
- case_images = []
391
- for i in range(1, 10): # Support up to 9 images per example
392
- img_path = f'{example_dir}/{case}_{i}.png'
393
- if os.path.exists(img_path):
394
- case_images.append(img_path)
395
- if case_images:
396
- examples.append(case_images)
397
- return examples
398
-
399
-
400
- def load_example(example_paths: List[str]) -> List[Image.Image]:
401
- """Load example images from paths."""
402
- images = []
403
- for path in example_paths:
404
- img = Image.open(path)
405
- images.append(img)
406
- return images
407
-
408
-
409
- @spaces.GPU(duration=120)
410
- def image_to_3d(
411
- images: List[Tuple[Image.Image, str]],
412
- seed: int,
413
- resolution: str,
414
- ss_guidance_strength: float,
415
- ss_guidance_rescale: float,
416
- ss_sampling_steps: int,
417
- ss_rescale_t: float,
418
- shape_slat_guidance_strength: float,
419
- shape_slat_guidance_rescale: float,
420
- shape_slat_sampling_steps: int,
421
- shape_slat_rescale_t: float,
422
- tex_slat_guidance_strength: float,
423
- tex_slat_guidance_rescale: float,
424
- tex_slat_sampling_steps: int,
425
- tex_slat_rescale_t: float,
426
- multiimage_algo: Literal["multidiffusion", "stochastic"],
427
- tex_multiimage_algo: Literal["multidiffusion", "stochastic"],
428
- req: gr.Request,
429
- progress=gr.Progress(track_tqdm=True),
430
- ) -> str:
431
- # Initialize pipeline on first call
432
- _initialize_pipeline()
433
-
434
- # Extract images from gallery format
435
- if not images:
436
- raise gr.Error("Please upload at least one image")
437
-
438
- imgs = [img[0] if isinstance(img, tuple) else img for img in images]
439
-
440
- # --- Sampling ---
441
- if len(imgs) == 1:
442
- # Single image mode
443
- outputs, latents = pipeline.run(
444
- imgs[0],
445
- seed=seed,
446
- preprocess_image=False,
447
- sparse_structure_sampler_params={
448
- "steps": ss_sampling_steps,
449
- "guidance_strength": ss_guidance_strength,
450
- "guidance_rescale": ss_guidance_rescale,
451
- "rescale_t": ss_rescale_t,
452
- },
453
- shape_slat_sampler_params={
454
- "steps": shape_slat_sampling_steps,
455
- "guidance_strength": shape_slat_guidance_strength,
456
- "guidance_rescale": shape_slat_guidance_rescale,
457
- "rescale_t": shape_slat_rescale_t,
458
- },
459
- tex_slat_sampler_params={
460
- "steps": tex_slat_sampling_steps,
461
- "guidance_strength": tex_slat_guidance_strength,
462
- "guidance_rescale": tex_slat_guidance_rescale,
463
- "rescale_t": tex_slat_rescale_t,
464
- },
465
- pipeline_type={
466
- "512": "512",
467
- "1024": "1024_cascade",
468
- "1536": "1536_cascade",
469
- }[resolution],
470
- return_latent=True,
471
- )
472
- else:
473
- # Multi-image mode
474
- outputs, latents = pipeline.run_multi_image(
475
- imgs,
476
- seed=seed,
477
- preprocess_image=False,
478
- sparse_structure_sampler_params={
479
- "steps": ss_sampling_steps,
480
- "guidance_strength": ss_guidance_strength,
481
- "guidance_rescale": ss_guidance_rescale,
482
- "rescale_t": ss_rescale_t,
483
- },
484
- shape_slat_sampler_params={
485
- "steps": shape_slat_sampling_steps,
486
- "guidance_strength": shape_slat_guidance_strength,
487
- "guidance_rescale": shape_slat_guidance_rescale,
488
- "rescale_t": shape_slat_rescale_t,
489
- },
490
- tex_slat_sampler_params={
491
- "steps": tex_slat_sampling_steps,
492
- "guidance_strength": tex_slat_guidance_strength,
493
- "guidance_rescale": tex_slat_guidance_rescale,
494
- "rescale_t": tex_slat_rescale_t,
495
- },
496
- pipeline_type={
497
- "512": "512",
498
- "1024": "1024_cascade",
499
- "1536": "1536_cascade",
500
- }[resolution],
501
- return_latent=True,
502
- mode=multiimage_algo,
503
- tex_mode=tex_multiimage_algo,
504
- )
505
-
506
- mesh = outputs[0]
507
- mesh.simplify(16777216)
508
- render_images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
509
- state = pack_state(latents)
510
- torch.cuda.empty_cache()
511
-
512
- # --- HTML Construction ---
513
- def encode_preview_image(args):
514
- m_idx, s_idx, render_key = args
515
- img_base64 = image_to_base64(Image.fromarray(render_images[render_key][s_idx]))
516
- return (m_idx, s_idx, img_base64)
517
-
518
- encode_tasks = [(m_idx, s_idx, mode['render_key']) for m_idx, mode in enumerate(MODES) for s_idx in range(STEPS)]
519
-
520
- with ThreadPoolExecutor(max_workers=8) as executor:
521
- encoded_results = list(executor.map(encode_preview_image, encode_tasks))
522
-
523
- encoded_map = {(m, s): b64 for m, s, b64 in encoded_results}
524
- images_html = ""
525
- for m_idx, mode in enumerate(MODES):
526
- for s_idx in range(STEPS):
527
- unique_id = f"view-m{m_idx}-s{s_idx}"
528
- is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
529
- vis_class = "visible" if is_visible else ""
530
- img_base64 = encoded_map[(m_idx, s_idx)]
531
- images_html += f'<img id="{unique_id}" class="previewer-main-image {vis_class}" src="{img_base64}" loading="eager">'
532
-
533
- btns_html = ""
534
- for idx, mode in enumerate(MODES):
535
- active_class = "active" if idx == DEFAULT_MODE else ""
536
- btns_html += f'<img src="{mode["icon_base64"]}" class="mode-btn {active_class}" onclick="selectMode({idx})" title="{mode["name"]}">'
537
-
538
- full_html = f"""
539
- <div class="previewer-container">
540
- <div class="tips-wrapper">
541
- <div class="tips-icon">Tips</div>
542
- <div class="tips-text">
543
- <p>Render Mode - Click buttons to switch render modes.</p>
544
- <p>View Angle - Drag slider to change view.</p>
545
- </div>
546
- </div>
547
- <div class="display-row">{images_html}</div>
548
- <div class="mode-row" id="btn-group">{btns_html}</div>
549
- <div class="slider-row">
550
- <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
551
- </div>
552
- </div>
553
- """
554
- return state, full_html
555
-
556
-
557
- @spaces.GPU(duration=120)
558
- def extract_glb(
559
- state: dict,
560
- decimation_target: int,
561
- texture_size: int,
562
- req: gr.Request,
563
- progress=gr.Progress(track_tqdm=True),
564
- ) -> Tuple[str, str]:
565
- _initialize_pipeline()
566
-
567
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
568
- shape_slat, tex_slat, res = unpack_state(state)
569
- mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
570
- mesh.simplify(16777216)
571
- glb = o_voxel.postprocess.to_glb(
572
- vertices=mesh.vertices,
573
- faces=mesh.faces,
574
- attr_volume=mesh.attrs,
575
- coords=mesh.coords,
576
- attr_layout=pipeline.pbr_attr_layout,
577
- grid_size=res,
578
- aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
579
- decimation_target=decimation_target,
580
- texture_size=texture_size,
581
- remesh=True,
582
- remesh_band=1,
583
- remesh_project=0,
584
- use_tqdm=True,
585
- )
586
- now = datetime.now()
587
- timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
588
- os.makedirs(user_dir, exist_ok=True)
589
- glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
590
- glb.export(glb_path, extension_webp=True)
591
- torch.cuda.empty_cache()
592
- return glb_path, glb_path
593
-
594
-
595
- with gr.Blocks(delete_cache=(600, 600)) as demo:
596
- gr.Markdown("""
597
- ## Multi-View Image to 3D with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
598
- Upload one or more images of an object and click Generate to create a 3D asset.
599
- Multiple views from different angles will produce better results.
600
- """)
601
-
602
- with gr.Row():
603
- with gr.Column(scale=1, min_width=360):
604
- image_prompt = gr.Gallery(
605
- label="Input Images (upload 1 or more views)",
606
- format="png",
607
- type="pil",
608
- height=400,
609
- columns=3,
610
- object_fit="contain"
611
- )
612
- gr.Markdown("*Upload multiple views of the same object for better 3D reconstruction.*")
613
-
614
- resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
615
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
616
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
617
- decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
618
- texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
619
-
620
- generate_btn = gr.Button("Generate", variant="primary")
621
-
622
- with gr.Accordion(label="Advanced Settings", open=False):
623
- gr.Markdown("Stage 1: Sparse Structure Generation")
624
- with gr.Row():
625
- ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
626
- ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01)
627
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
628
- ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
629
- gr.Markdown("Stage 2: Shape Generation")
630
- with gr.Row():
631
- shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
632
- shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01)
633
- shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
634
- shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
635
- gr.Markdown("Stage 3: Material Generation")
636
- with gr.Row():
637
- tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
638
- tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
639
- tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
640
- tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
641
- multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Structure Algorithm", value="stochastic")
642
- tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="multidiffusion")
643
-
644
- with gr.Column(scale=10):
645
- with gr.Walkthrough(selected=0) as walkthrough:
646
- with gr.Step("Preview", id=0):
647
- preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
648
- extract_btn = gr.Button("Extract GLB")
649
- with gr.Step("Extract", id=1):
650
- glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
651
- download_btn = gr.DownloadButton(label="Download GLB")
652
- gr.Markdown("*GLB extraction may take 30+ seconds.*")
653
-
654
- with gr.Column(scale=1, min_width=200):
655
- gr.Markdown("### Examples")
656
- # Create example buttons that load images into gallery
657
- example_data = prepare_examples()
658
- for i, example_paths in enumerate(example_data[:12]): # Limit to 12 examples
659
- case_name = os.path.basename(example_paths[0]).split('_')[0]
660
- btn = gr.Button(f"{case_name} ({len(example_paths)} views)", size="sm")
661
- btn.click(
662
- fn=lambda paths=example_paths: load_example(paths),
663
- outputs=[image_prompt]
664
- )
665
-
666
- output_buf = gr.State()
667
-
668
- # Handlers
669
- demo.load(start_session)
670
- demo.unload(end_session)
671
-
672
- image_prompt.upload(
673
- preprocess_images,
674
- inputs=[image_prompt],
675
- outputs=[image_prompt],
676
- )
677
-
678
- generate_btn.click(
679
- get_seed, inputs=[randomize_seed, seed], outputs=[seed],
680
- ).then(
681
- lambda: gr.Walkthrough(selected=0), outputs=walkthrough
682
- ).then(
683
- image_to_3d,
684
- inputs=[
685
- image_prompt, seed, resolution,
686
- ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
687
- shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
688
- tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
689
- multiimage_algo, tex_multiimage_algo
690
- ],
691
- outputs=[output_buf, preview_output],
692
- )
693
-
694
- extract_btn.click(
695
- lambda: gr.Walkthrough(selected=1), outputs=walkthrough
696
- ).then(
697
- extract_glb,
698
- inputs=[output_buf, decimation_target, texture_size],
699
- outputs=[glb_output, download_btn],
700
- )
701
-
702
-
703
- if __name__ == "__main__":
704
- os.makedirs(TMP_DIR, exist_ok=True)
705
-
706
- for i in range(len(MODES)):
707
- icon = Image.open(MODES[i]['icon'])
708
- MODES[i]['icon_base64'] = image_to_base64(icon)
709
-
710
- rmbg_client = Client("briaai/BRIA-RMBG-2.0")
711
-
712
- demo.launch(css=css, head=head)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_client import Client, handle_file
3
+ import spaces
4
+ from concurrent.futures import ThreadPoolExecutor
5
+
6
+ import os
7
+ import sys
8
+
9
+ # Prioritize local o-voxel submodule (with cumesh.fill_holes() fix) over prebuilt wheel
10
+ _script_dir = os.path.dirname(os.path.abspath(__file__))
11
+ sys.path.insert(0, os.path.join(_script_dir, 'o-voxel'))
12
+
13
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
14
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
15
+ os.environ["ATTN_BACKEND"] = "flash_attn_3"
16
+ os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(_script_dir, 'autotune_cache.json')
17
+ os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
18
+ from datetime import datetime
19
+ import shutil
20
+ import cv2
21
+ from typing import *
22
+ import torch
23
+ import numpy as np
24
+ from PIL import Image
25
+ import base64
26
+ import io
27
+ import tempfile
28
+ from trellis2.modules.sparse import SparseTensor
29
+ from trellis2.pipelines import Trellis2ImageTo3DPipeline
30
+ from trellis2.renderers import EnvMap
31
+ from trellis2.utils import render_utils
32
+ import o_voxel
33
+
34
+
35
+ MAX_SEED = np.iinfo(np.int32).max
36
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
37
+ MODES = [
38
+ {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
39
+ {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
40
+ {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
41
+ {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
42
+ {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
43
+ {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
44
+ ]
45
+ STEPS = 8
46
+ DEFAULT_MODE = 3
47
+ DEFAULT_STEP = 3
48
+
49
+
50
+ css = """
51
+ /* Overwrite Gradio Default Style */
52
+ .stepper-wrapper {
53
+ padding: 0;
54
+ }
55
+
56
+ .stepper-container {
57
+ padding: 0;
58
+ align-items: center;
59
+ }
60
+
61
+ .step-button {
62
+ flex-direction: row;
63
+ }
64
+
65
+ .step-connector {
66
+ transform: none;
67
+ }
68
+
69
+ .step-number {
70
+ width: 16px;
71
+ height: 16px;
72
+ }
73
+
74
+ .step-label {
75
+ position: relative;
76
+ bottom: 0;
77
+ }
78
+
79
+ .wrap.center.full {
80
+ inset: 0;
81
+ height: 100%;
82
+ }
83
+
84
+ .wrap.center.full.translucent {
85
+ background: var(--block-background-fill);
86
+ }
87
+
88
+ .meta-text-center {
89
+ display: block !important;
90
+ position: absolute !important;
91
+ top: unset !important;
92
+ bottom: 0 !important;
93
+ right: 0 !important;
94
+ transform: unset !important;
95
+ }
96
+
97
+ /* Previewer */
98
+ .previewer-container {
99
+ position: relative;
100
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
101
+ width: 100%;
102
+ height: 722px;
103
+ margin: 0 auto;
104
+ padding: 20px;
105
+ display: flex;
106
+ flex-direction: column;
107
+ align-items: center;
108
+ justify-content: center;
109
+ }
110
+
111
+ .previewer-container .tips-icon {
112
+ position: absolute;
113
+ right: 10px;
114
+ top: 10px;
115
+ z-index: 10;
116
+ border-radius: 10px;
117
+ color: #fff;
118
+ background-color: var(--color-accent);
119
+ padding: 3px 6px;
120
+ user-select: none;
121
+ }
122
+
123
+ .previewer-container .tips-text {
124
+ position: absolute;
125
+ right: 10px;
126
+ top: 50px;
127
+ color: #fff;
128
+ background-color: var(--color-accent);
129
+ border-radius: 10px;
130
+ padding: 6px;
131
+ text-align: left;
132
+ max-width: 300px;
133
+ z-index: 10;
134
+ transition: all 0.3s;
135
+ opacity: 0%;
136
+ user-select: none;
137
+ }
138
+
139
+ .previewer-container .tips-text p {
140
+ font-size: 14px;
141
+ line-height: 1.2;
142
+ }
143
+
144
+ .tips-icon:hover + .tips-text {
145
+ display: block;
146
+ opacity: 100%;
147
+ }
148
+
149
+ /* Row 1: Display Modes */
150
+ .previewer-container .mode-row {
151
+ width: 100%;
152
+ display: flex;
153
+ gap: 8px;
154
+ justify-content: center;
155
+ margin-bottom: 20px;
156
+ flex-wrap: wrap;
157
+ }
158
+ .previewer-container .mode-btn {
159
+ width: 24px;
160
+ height: 24px;
161
+ border-radius: 50%;
162
+ cursor: pointer;
163
+ opacity: 0.5;
164
+ transition: all 0.2s;
165
+ border: 2px solid var(--neutral-600, #555);
166
+ object-fit: cover;
167
+ }
168
+ .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
169
+ .previewer-container .mode-btn.active {
170
+ opacity: 1;
171
+ border-color: var(--color-accent);
172
+ transform: scale(1.1);
173
+ }
174
+
175
+ /* Row 2: Display Image */
176
+ .previewer-container .display-row {
177
+ margin-bottom: 20px;
178
+ min-height: 400px;
179
+ width: 100%;
180
+ flex-grow: 1;
181
+ display: flex;
182
+ justify-content: center;
183
+ align-items: center;
184
+ }
185
+ .previewer-container .previewer-main-image {
186
+ max-width: 100%;
187
+ max-height: 100%;
188
+ flex-grow: 1;
189
+ object-fit: contain;
190
+ display: none;
191
+ }
192
+ .previewer-container .previewer-main-image.visible {
193
+ display: block;
194
+ }
195
+
196
+ /* Row 3: Custom HTML Slider */
197
+ .previewer-container .slider-row {
198
+ width: 100%;
199
+ display: flex;
200
+ flex-direction: column;
201
+ align-items: center;
202
+ gap: 10px;
203
+ padding: 0 10px;
204
+ }
205
+
206
+ .previewer-container input[type=range] {
207
+ -webkit-appearance: none;
208
+ width: 100%;
209
+ max-width: 400px;
210
+ background: transparent;
211
+ }
212
+ .previewer-container input[type=range]::-webkit-slider-runnable-track {
213
+ width: 100%;
214
+ height: 8px;
215
+ cursor: pointer;
216
+ background: var(--neutral-700, #404040);
217
+ border-radius: 5px;
218
+ }
219
+ .previewer-container input[type=range]::-webkit-slider-thumb {
220
+ height: 20px;
221
+ width: 20px;
222
+ border-radius: 50%;
223
+ background: var(--color-accent);
224
+ cursor: pointer;
225
+ -webkit-appearance: none;
226
+ margin-top: -6px;
227
+ box-shadow: 0 2px 5px rgba(0,0,0,0.2);
228
+ transition: transform 0.1s;
229
+ }
230
+ .previewer-container input[type=range]::-webkit-slider-thumb:hover {
231
+ transform: scale(1.2);
232
+ }
233
+
234
+ /* Overwrite Previewer Block Style */
235
+ .gradio-container .padded:has(.previewer-container) {
236
+ padding: 0 !important;
237
+ }
238
+
239
+ .gradio-container:has(.previewer-container) [data-testid="block-label"] {
240
+ position: absolute;
241
+ top: 0;
242
+ left: 0;
243
+ }
244
+ """
245
+
246
+
247
+ head = """
248
+ <script>
249
+ function refreshView(mode, step) {
250
+ // 1. Find current mode and step
251
+ const allImgs = document.querySelectorAll('.previewer-main-image');
252
+ for (let i = 0; i < allImgs.length; i++) {
253
+ const img = allImgs[i];
254
+ if (img.classList.contains('visible')) {
255
+ const id = img.id;
256
+ const [_, m, s] = id.split('-');
257
+ if (mode === -1) mode = parseInt(m.slice(1));
258
+ if (step === -1) step = parseInt(s.slice(1));
259
+ break;
260
+ }
261
+ }
262
+
263
+ // 2. Hide ALL images
264
+ // We select all elements with class 'previewer-main-image'
265
+ allImgs.forEach(img => img.classList.remove('visible'));
266
+
267
+ // 3. Construct the specific ID for the current state
268
+ // Format: view-m{mode}-s{step}
269
+ const targetId = 'view-m' + mode + '-s' + step;
270
+ const targetImg = document.getElementById(targetId);
271
+
272
+ // 4. Show ONLY the target
273
+ if (targetImg) {
274
+ targetImg.classList.add('visible');
275
+ }
276
+
277
+ // 5. Update Button Highlights
278
+ const allBtns = document.querySelectorAll('.mode-btn');
279
+ allBtns.forEach((btn, idx) => {
280
+ if (idx === mode) btn.classList.add('active');
281
+ else btn.classList.remove('active');
282
+ });
283
+ }
284
+
285
+ // --- Action: Switch Mode ---
286
+ function selectMode(mode) {
287
+ refreshView(mode, -1);
288
+ }
289
+
290
+ // --- Action: Slider Change ---
291
+ function onSliderChange(val) {
292
+ refreshView(-1, parseInt(val));
293
+ }
294
+ </script>
295
+ """
296
+
297
+
298
+ empty_html = f"""
299
+ <div class="previewer-container">
300
+ <svg style=" opacity: .5; height: var(--size-5); color: var(--body-text-color);"
301
+ xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
302
+ </div>
303
+ """
304
+
305
+
306
+ def image_to_base64(image):
307
+ buffered = io.BytesIO()
308
+ image = image.convert("RGB")
309
+ image.save(buffered, format="jpeg", quality=85)
310
+ img_str = base64.b64encode(buffered.getvalue()).decode()
311
+ return f"data:image/jpeg;base64,{img_str}"
312
+
313
+
314
+ def start_session(req: gr.Request):
315
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
316
+ os.makedirs(user_dir, exist_ok=True)
317
+
318
+
319
+ def end_session(req: gr.Request):
320
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
321
+ shutil.rmtree(user_dir)
322
+
323
+
324
+ def remove_background(input: Image.Image) -> Image.Image:
325
+ with tempfile.NamedTemporaryFile(suffix='.png') as f:
326
+ input = input.convert('RGB')
327
+ input.save(f.name)
328
+ output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
329
+ output = Image.open(output)
330
+ return output
331
+
332
+
333
+ def preprocess_image(input: Image.Image) -> Image.Image:
334
+ """
335
+ Preprocess the input image.
336
+ """
337
+ # if has alpha channel, use it directly; otherwise, remove background
338
+ has_alpha = False
339
+ if input.mode == 'RGBA':
340
+ alpha = np.array(input)[:, :, 3]
341
+ if not np.all(alpha == 255):
342
+ has_alpha = True
343
+ max_size = max(input.size)
344
+ scale = min(1, 1024 / max_size)
345
+ if scale < 1:
346
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
347
+ if has_alpha:
348
+ output = input
349
+ else:
350
+ output = remove_background(input)
351
+ output_np = np.array(output)
352
+ alpha = output_np[:, :, 3]
353
+ bbox = np.argwhere(alpha > 0.8 * 255)
354
+ if bbox.size == 0:
355
+ # No visible pixels, center the image in a square
356
+ size = max(output.size)
357
+ square = Image.new('RGB', (size, size), (0, 0, 0))
358
+ output_rgb = output.convert('RGB') if output.mode == 'RGBA' else output
359
+ square.paste(output_rgb, ((size - output.width) // 2, (size - output.height) // 2))
360
+ return square
361
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
362
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
363
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
364
+ size = int(size * 1)
365
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
366
+ output = output.crop(bbox) # type: ignore
367
+ output = np.array(output).astype(np.float32) / 255
368
+ output = output[:, :, :3] * output[:, :, 3:4]
369
+ output = Image.fromarray((output * 255).astype(np.uint8))
370
+ return output
371
+
372
+
373
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
374
+ """
375
+ Preprocess a list of input images for multi-image conditioning.
376
+ Uses parallel processing for faster background removal.
377
+ """
378
+ images = [image[0] for image in images]
379
+ with ThreadPoolExecutor(max_workers=min(4, len(images))) as executor:
380
+ processed_images = list(executor.map(preprocess_image, images))
381
+ return processed_images
382
+
383
+
384
+ def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
385
+ shape_slat, tex_slat, res = latents
386
+ return {
387
+ 'shape_slat_feats': shape_slat.feats.cpu().numpy(),
388
+ 'tex_slat_feats': tex_slat.feats.cpu().numpy(),
389
+ 'coords': shape_slat.coords.cpu().numpy(),
390
+ 'res': res,
391
+ }
392
+
393
+
394
+ def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
395
+ shape_slat = SparseTensor(
396
+ feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
397
+ coords=torch.from_numpy(state['coords']).cuda(),
398
+ )
399
+ tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
400
+ return shape_slat, tex_slat, state['res']
401
+
402
+
403
+ def get_seed(randomize_seed: bool, seed: int) -> int:
404
+ """
405
+ Get the random seed.
406
+ """
407
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
408
+
409
+
410
+ def prepare_multi_example() -> List[str]:
411
+ """
412
+ Prepare multi-image examples. Returns list of image paths.
413
+ Shows only the first view as representative thumbnail.
414
+ """
415
+ multi_case = sorted(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
416
+ examples = []
417
+ for case in multi_case:
418
+ first_img = f'assets/example_multi_image/{case}_1.png'
419
+ if os.path.exists(first_img):
420
+ examples.append(first_img)
421
+ return examples
422
+
423
+
424
+ def load_multi_example(image) -> List[Image.Image]:
425
+ """Load all views for a multi-image case by matching the input image."""
426
+ import hashlib
427
+
428
+ # Convert numpy array to PIL Image if needed
429
+ if isinstance(image, np.ndarray):
430
+ image = Image.fromarray(image)
431
+
432
+ # Get hash of input image for matching
433
+ input_hash = hashlib.md5(np.array(image.convert('RGBA')).tobytes()).hexdigest()
434
+
435
+ # Find matching case by comparing with first images
436
+ multi_case = sorted(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
437
+ for case_name in multi_case:
438
+ first_img_path = f'assets/example_multi_image/{case_name}_1.png'
439
+ if os.path.exists(first_img_path):
440
+ first_img = Image.open(first_img_path).convert('RGBA')
441
+ first_hash = hashlib.md5(np.array(first_img).tobytes()).hexdigest()
442
+ if first_hash == input_hash:
443
+ # Found match, load all views
444
+ images = []
445
+ for i in range(1, 7):
446
+ img_path = f'assets/example_multi_image/{case_name}_{i}.png'
447
+ if os.path.exists(img_path):
448
+ img = Image.open(img_path)
449
+ images.append(preprocess_image(img))
450
+ return images
451
+
452
+ # No match found, return the single image preprocessed
453
+ return [preprocess_image(image)]
454
+
455
+
456
+ def split_image(image: Image.Image) -> List[Image.Image]:
457
+ """
458
+ Split a concatenated image into multiple views.
459
+ """
460
+ image = np.array(image)
461
+ alpha = image[..., 3]
462
+ alpha = np.any(alpha > 0, axis=0)
463
+ start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
464
+ end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
465
+ images = []
466
+ for s, e in zip(start_pos, end_pos):
467
+ images.append(Image.fromarray(image[:, s:e+1]))
468
+ return [preprocess_image(image) for image in images]
469
+
470
+
471
+ @spaces.GPU(duration=90)
472
+ def image_to_3d(
473
+ seed: int,
474
+ resolution: str,
475
+ ss_guidance_strength: float,
476
+ ss_guidance_rescale: float,
477
+ ss_sampling_steps: int,
478
+ ss_rescale_t: float,
479
+ shape_slat_guidance_strength: float,
480
+ shape_slat_guidance_rescale: float,
481
+ shape_slat_sampling_steps: int,
482
+ shape_slat_rescale_t: float,
483
+ tex_slat_guidance_strength: float,
484
+ tex_slat_guidance_rescale: float,
485
+ tex_slat_sampling_steps: int,
486
+ tex_slat_rescale_t: float,
487
+ multiimages: List[Tuple[Image.Image, str]],
488
+ multiimage_algo: Literal["stochastic", "multidiffusion"],
489
+ tex_multiimage_algo: Literal["stochastic", "multidiffusion"],
490
+ req: gr.Request,
491
+ progress=gr.Progress(track_tqdm=True),
492
+ ) -> str:
493
+ # --- Sampling ---
494
+ outputs, latents = pipeline.run_multi_image(
495
+ [image[0] for image in multiimages],
496
+ seed=seed,
497
+ preprocess_image=False,
498
+ sparse_structure_sampler_params={
499
+ "steps": ss_sampling_steps,
500
+ "guidance_strength": ss_guidance_strength,
501
+ "guidance_rescale": ss_guidance_rescale,
502
+ "rescale_t": ss_rescale_t,
503
+ },
504
+ shape_slat_sampler_params={
505
+ "steps": shape_slat_sampling_steps,
506
+ "guidance_strength": shape_slat_guidance_strength,
507
+ "guidance_rescale": shape_slat_guidance_rescale,
508
+ "rescale_t": shape_slat_rescale_t,
509
+ },
510
+ tex_slat_sampler_params={
511
+ "steps": tex_slat_sampling_steps,
512
+ "guidance_strength": tex_slat_guidance_strength,
513
+ "guidance_rescale": tex_slat_guidance_rescale,
514
+ "rescale_t": tex_slat_rescale_t,
515
+ },
516
+ pipeline_type={
517
+ "512": "512",
518
+ "1024": "1024_cascade",
519
+ "1536": "1536_cascade",
520
+ }[resolution],
521
+ return_latent=True,
522
+ mode=multiimage_algo,
523
+ tex_mode=tex_multiimage_algo,
524
+ )
525
+ mesh = outputs[0]
526
+ mesh.simplify(16777216) # nvdiffrast limit
527
+ images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
528
+ state = pack_state(latents)
529
+ del outputs, mesh, latents # Free memory
530
+ torch.cuda.empty_cache()
531
+
532
+ # --- HTML Construction ---
533
+ # The Stack of 48 Images - encode in parallel for speed
534
+ def encode_preview_image(args):
535
+ m_idx, s_idx, render_key = args
536
+ img_base64 = image_to_base64(Image.fromarray(images[render_key][s_idx]))
537
+ return (m_idx, s_idx, img_base64)
538
+
539
+ encode_tasks = [
540
+ (m_idx, s_idx, mode['render_key'])
541
+ for m_idx, mode in enumerate(MODES)
542
+ for s_idx in range(STEPS)
543
+ ]
544
+
545
+ with ThreadPoolExecutor(max_workers=8) as executor:
546
+ encoded_results = list(executor.map(encode_preview_image, encode_tasks))
547
+
548
+ # Build HTML from encoded results
549
+ encoded_map = {(m, s): b64 for m, s, b64 in encoded_results}
550
+ images_html = ""
551
+ for m_idx, mode in enumerate(MODES):
552
+ for s_idx in range(STEPS):
553
+ unique_id = f"view-m{m_idx}-s{s_idx}"
554
+ is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
555
+ vis_class = "visible" if is_visible else ""
556
+ img_base64 = encoded_map[(m_idx, s_idx)]
557
+
558
+ images_html += f"""
559
+ <img id="{unique_id}"
560
+ class="previewer-main-image {vis_class}"
561
+ src="{img_base64}"
562
+ loading="eager">
563
+ """
564
+
565
+ # Button Row HTML
566
+ btns_html = ""
567
+ for idx, mode in enumerate(MODES):
568
+ active_class = "active" if idx == DEFAULT_MODE else ""
569
+ # Note: onclick calls the JS function defined in Head
570
+ btns_html += f"""
571
+ <img src="{mode['icon_base64']}"
572
+ class="mode-btn {active_class}"
573
+ onclick="selectMode({idx})"
574
+ title="{mode['name']}">
575
+ """
576
+
577
+ # Assemble the full component
578
+ full_html = f"""
579
+ <div class="previewer-container">
580
+ <div class="tips-wrapper">
581
+ <div class="tips-icon">💡Tips</div>
582
+ <div class="tips-text">
583
+ <p>● <b>Render Mode</b> - Click on the circular buttons to switch between different render modes.</p>
584
+ <p>● <b>View Angle</b> - Drag the slider to change the view angle.</p>
585
+ </div>
586
+ </div>
587
+
588
+ <!-- Row 1: Viewport containing 48 static <img> tags -->
589
+ <div class="display-row">
590
+ {images_html}
591
+ </div>
592
+
593
+ <!-- Row 2 -->
594
+ <div class="mode-row" id="btn-group">
595
+ {btns_html}
596
+ </div>
597
+
598
+ <!-- Row 3: Slider -->
599
+ <div class="slider-row">
600
+ <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
601
+ </div>
602
+ </div>
603
+ """
604
+
605
+ return state, full_html
606
+
607
+
608
+ @spaces.GPU(duration=60)
609
+ def extract_glb(
610
+ state: dict,
611
+ decimation_target: int,
612
+ texture_size: int,
613
+ req: gr.Request,
614
+ progress=gr.Progress(track_tqdm=True),
615
+ ) -> str:
616
+ """
617
+ Extract a GLB file from the 3D model.
618
+
619
+ Args:
620
+ state (dict): The state of the generated 3D model.
621
+ decimation_target (int): The target face count for decimation.
622
+ texture_size (int): The texture resolution.
623
+
624
+ Returns:
625
+ str: The path to the extracted GLB file.
626
+ """
627
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
628
+ shape_slat, tex_slat, res = unpack_state(state)
629
+ mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
630
+ mesh.simplify(16777216) # nvdiffrast limit
631
+ glb = o_voxel.postprocess.to_glb(
632
+ vertices=mesh.vertices,
633
+ faces=mesh.faces,
634
+ attr_volume=mesh.attrs,
635
+ coords=mesh.coords,
636
+ attr_layout=pipeline.pbr_attr_layout,
637
+ grid_size=res,
638
+ aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
639
+ decimation_target=decimation_target,
640
+ texture_size=texture_size,
641
+ remesh=True,
642
+ remesh_band=1,
643
+ remesh_project=0,
644
+ use_tqdm=True,
645
+ )
646
+ now = datetime.now()
647
+ timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
648
+ os.makedirs(user_dir, exist_ok=True)
649
+ glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
650
+ glb.export(glb_path, extension_webp=True)
651
+ torch.cuda.empty_cache()
652
+ return glb_path
653
+
654
+
655
+ with gr.Blocks(delete_cache=(600, 600), theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate")) as demo:
656
+ gr.Markdown("""
657
+ ## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
658
+ * Upload an image and click Generate to create a 3D asset. If the image has alpha channel, it will be used as the mask. Otherwise, background is automatically removed.
659
+ * Click Extract GLB to export the GLB file if you're satisfied with the preview.
660
+ """)
661
+
662
+ with gr.Row():
663
+ with gr.Column(scale=1, min_width=360):
664
+ multiimage_prompt = gr.Gallery(label="Multi-View Images", format="png", type="pil", height=400, columns=3)
665
+
666
+ resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
667
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
668
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
669
+ decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
670
+ texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
671
+
672
+ with gr.Row():
673
+ generate_btn = gr.Button("Generate", variant="primary")
674
+ extract_btn = gr.Button("Extract GLB")
675
+
676
+ with gr.Accordion(label="Advanced Settings", open=False):
677
+ gr.Markdown("Stage 1: Sparse Structure Generation")
678
+ with gr.Row():
679
+ ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
680
+ ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01)
681
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
682
+ ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
683
+ gr.Markdown("Stage 2: Shape Generation")
684
+ with gr.Row():
685
+ shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
686
+ shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01)
687
+ shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
688
+ shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
689
+ gr.Markdown("Stage 3: Material Generation")
690
+ with gr.Row():
691
+ tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
692
+ tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
693
+ tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
694
+ tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
695
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Structure Algorithm", value="stochastic")
696
+ tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="multidiffusion")
697
+
698
+ with gr.Column(scale=10):
699
+ preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
700
+ glb_output = gr.Model3D(label="Extracted GLB", height=400, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0), visible=False)
701
+
702
+ example_image = gr.Image(visible=False) # Hidden component for examples
703
+ examples_multi = gr.Examples(
704
+ examples=prepare_multi_example(),
705
+ inputs=[example_image],
706
+ fn=load_multi_example,
707
+ outputs=[multiimage_prompt],
708
+ run_on_click=True,
709
+ examples_per_page=24,
710
+ )
711
+
712
+ output_buf = gr.State()
713
+
714
+
715
+ # Handlers
716
+ demo.load(start_session)
717
+ demo.unload(end_session)
718
+ multiimage_prompt.upload(
719
+ preprocess_images,
720
+ inputs=[multiimage_prompt],
721
+ outputs=[multiimage_prompt],
722
+ )
723
+
724
+ generate_btn.click(
725
+ get_seed,
726
+ inputs=[randomize_seed, seed],
727
+ outputs=[seed],
728
+ ).then(
729
+ image_to_3d,
730
+ inputs=[
731
+ seed, resolution,
732
+ ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
733
+ shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
734
+ tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
735
+ multiimage_prompt, multiimage_algo, tex_multiimage_algo
736
+ ],
737
+ outputs=[output_buf, preview_output],
738
+ )
739
+
740
+ extract_btn.click(
741
+ extract_glb,
742
+ inputs=[output_buf, decimation_target, texture_size],
743
+ outputs=[glb_output],
744
+ )
745
+
746
+
747
+ # Launch the Gradio app
748
+ if __name__ == "__main__":
749
+ os.makedirs(TMP_DIR, exist_ok=True)
750
+
751
+ # Construct ui components
752
+ btn_img_base64_strs = {}
753
+ for i in range(len(MODES)):
754
+ icon = Image.open(MODES[i]['icon'])
755
+ MODES[i]['icon_base64'] = image_to_base64(icon)
756
+
757
+ rmbg_client = Client("briaai/BRIA-RMBG-2.0")
758
+ pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
759
+ pipeline.rembg_model = None
760
+ pipeline.low_vram = True # Enable low VRAM mode for better memory efficiency
761
+ pipeline.cuda()
762
+
763
+ envmap = {
764
+ 'forest': EnvMap(torch.tensor(
765
+ cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
766
+ dtype=torch.float32, device='cuda'
767
+ )),
768
+ 'sunset': EnvMap(torch.tensor(
769
+ cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
770
+ dtype=torch.float32, device='cuda'
771
+ )),
772
+ 'courtyard': EnvMap(torch.tensor(
773
+ cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
774
+ dtype=torch.float32, device='cuda'
775
+ )),
776
+ }
777
+
778
+ demo.queue(max_size=10, default_concurrency_limit=1).launch(css=css, head=head)