opsiclear-admin commited on
Commit
c4c6fa7
·
verified ·
1 Parent(s): 35f099b

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +467 -366
app.py CHANGED
@@ -13,89 +13,28 @@ from datetime import datetime
13
  import shutil
14
  import cv2
15
  from typing import *
 
16
  import numpy as np
17
  from PIL import Image
18
  import base64
19
  import io
20
  import tempfile
21
-
22
- # Lazy imports - will be loaded when GPU is available
23
- torch = None
24
- SparseTensor = None
25
- Trellis2ImageTo3DPipeline = None
26
- EnvMap = None
27
- render_utils = None
28
- o_voxel = None
29
-
30
- # Global state - initialized on first GPU call
31
- pipeline = None
32
- envmap = None
33
- _initialized = False
34
-
35
-
36
- def _lazy_import():
37
- """Import GPU-dependent modules. Must be called from within a @spaces.GPU function."""
38
- global torch, SparseTensor, Trellis2ImageTo3DPipeline, EnvMap, render_utils, o_voxel
39
- if torch is None:
40
- import torch as _torch
41
- torch = _torch
42
- if SparseTensor is None:
43
- from trellis2.modules.sparse import SparseTensor as _SparseTensor
44
- SparseTensor = _SparseTensor
45
- if Trellis2ImageTo3DPipeline is None:
46
- from trellis2.pipelines import Trellis2ImageTo3DPipeline as _Trellis2ImageTo3DPipeline
47
- Trellis2ImageTo3DPipeline = _Trellis2ImageTo3DPipeline
48
- if EnvMap is None:
49
- from trellis2.renderers import EnvMap as _EnvMap
50
- EnvMap = _EnvMap
51
- if render_utils is None:
52
- from trellis2.utils import render_utils as _render_utils
53
- render_utils = _render_utils
54
- if o_voxel is None:
55
- import o_voxel as _o_voxel
56
- o_voxel = _o_voxel
57
- # Patch postprocess module with local fix for cumesh.fill_holes() bug
58
- import importlib.util
59
- _local_pp = os.path.join(os.path.dirname(os.path.abspath(__file__)),
60
- 'o-voxel', 'o_voxel', 'postprocess.py')
61
- if os.path.exists(_local_pp):
62
- _spec = importlib.util.spec_from_file_location('o_voxel.postprocess', _local_pp)
63
- _mod = importlib.util.module_from_spec(_spec)
64
- _spec.loader.exec_module(_mod)
65
- o_voxel.postprocess = _mod
66
- import sys
67
- sys.modules['o_voxel.postprocess'] = _mod
68
-
69
-
70
- def _initialize_pipeline():
71
- """Initialize the pipeline and environment maps. Must be called from within a @spaces.GPU function."""
72
- global pipeline, envmap, _initialized
73
- if _initialized:
74
- return
75
-
76
- _lazy_import()
77
-
78
- pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
79
- pipeline.rembg_model = None
80
- pipeline.low_vram = False
81
- pipeline.cuda()
82
-
83
- envmap = {
84
- 'forest': EnvMap(torch.tensor(
85
- cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
86
- dtype=torch.float32, device='cuda'
87
- )),
88
- 'sunset': EnvMap(torch.tensor(
89
- cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
90
- dtype=torch.float32, device='cuda'
91
- )),
92
- 'courtyard': EnvMap(torch.tensor(
93
- cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
94
- dtype=torch.float32, device='cuda'
95
- )),
96
- }
97
-
98
- _initialized = True
99
 
100
 
101
  MAX_SEED = np.iinfo(np.int32).max
@@ -114,88 +53,56 @@ DEFAULT_STEP = 3
114
 
115
 
116
  css = """
117
- /* ColmapView Dark Theme */
118
- :root {
119
- --bg-void: #0a0a0a;
120
- --bg-primary: #0f0f0f;
121
- --bg-secondary: #161616;
122
- --bg-tertiary: #1e1e1e;
123
- --bg-input: #1a1a1a;
124
- --bg-elevated: #242424;
125
- --bg-hover: #262626;
126
- --text-primary: #e8e8e8;
127
- --text-secondary: #8a8a8a;
128
- --text-muted: #5a5a5a;
129
- --border-subtle: #222222;
130
- --border-color: #2a2a2a;
131
- --border-light: #3a3a3a;
132
- --accent: #b8b8b8;
133
- --accent-hover: #d0d0d0;
134
- --accent-dim: rgba(184, 184, 184, 0.12);
135
- --shadow-sm: 0 2px 4px rgba(0, 0, 0, 0.4), 0 1px 2px rgba(0, 0, 0, 0.3);
136
- --shadow-md: 0 4px 12px rgba(0, 0, 0, 0.5), 0 2px 4px rgba(0, 0, 0, 0.3);
137
- --radius: 0.25rem;
138
- --radius-md: 0.375rem;
139
- --radius-lg: 0.5rem;
140
- --font-sans: 'Roboto', -apple-system, BlinkMacSystemFont, sans-serif;
141
- --font-mono: 'JetBrains Mono', 'Fira Code', 'Consolas', monospace;
 
 
 
 
 
 
 
 
 
 
142
  }
143
 
144
- /* Global Overrides */
145
- .gradio-container { background: var(--bg-primary) !important; color: var(--text-primary) !important; font-family: var(--font-sans) !important; }
146
- .dark { background: var(--bg-primary) !important; }
147
- body { background: var(--bg-void) !important; color-scheme: dark; }
148
-
149
- /* Panels & Blocks */
150
- .block { background: var(--bg-secondary) !important; border: 1px solid var(--border-color) !important; border-radius: var(--radius-lg) !important; }
151
- .panel { background: var(--bg-secondary) !important; }
152
-
153
- /* Inputs */
154
- input, textarea, select { background: var(--bg-input) !important; border: 1px solid var(--border-color) !important; color: var(--text-primary) !important; border-radius: var(--radius) !important; }
155
- input:focus, textarea:focus, select:focus { border-color: var(--accent) !important; outline: none !important; }
156
-
157
- /* Buttons */
158
- .primary { background: var(--accent) !important; color: var(--bg-void) !important; border: none !important; }
159
- .primary:hover { background: var(--accent-hover) !important; }
160
- .secondary { background: var(--bg-hover) !important; color: var(--text-primary) !important; border: 1px solid var(--border-color) !important; }
161
- .secondary:hover { background: var(--bg-tertiary) !important; }
162
- button { transition: all 0.15s ease-out !important; }
163
-
164
- /* Labels & Text */
165
- label, .label-wrap { color: var(--text-secondary) !important; font-size: 0.875rem !important; }
166
- .prose { color: var(--text-primary) !important; }
167
- .prose h2 { color: var(--text-primary) !important; border: none !important; }
168
-
169
- /* Sliders */
170
- input[type="range"] { accent-color: var(--accent) !important; }
171
-
172
- /* Gallery */
173
- .gallery { background: var(--bg-tertiary) !important; border: 1px solid var(--border-color) !important; }
174
-
175
- /* Accordion */
176
- .accordion { background: var(--bg-secondary) !important; border: 1px solid var(--border-color) !important; }
177
-
178
- /* Scrollbar */
179
- ::-webkit-scrollbar { width: 12px; }
180
- ::-webkit-scrollbar-track { background: var(--bg-secondary); }
181
- ::-webkit-scrollbar-thumb { background: var(--border-light); border-radius: 9999px; }
182
- ::-webkit-scrollbar-thumb:hover { background: var(--text-muted); }
183
-
184
- /* Gradio Overrides */
185
- .stepper-wrapper { padding: 0; }
186
- .stepper-container { padding: 0; align-items: center; }
187
- .step-button { flex-direction: row; }
188
- .step-connector { transform: none; }
189
- .step-number { width: 16px; height: 16px; background: var(--bg-tertiary) !important; border: 1px solid var(--border-color) !important; }
190
- .step-label { position: relative; bottom: 0; color: var(--text-secondary) !important; }
191
- .wrap.center.full { inset: 0; height: 100%; }
192
- .wrap.center.full.translucent { background: var(--bg-secondary); }
193
- .meta-text-center { display: block !important; position: absolute !important; top: unset !important; bottom: 0 !important; right: 0 !important; transform: unset !important; }
194
 
195
  /* Previewer */
196
  .previewer-container {
197
  position: relative;
198
- font-family: var(--font-sans);
199
  width: 100%;
200
  height: 722px;
201
  margin: 0 auto;
@@ -204,53 +111,148 @@ input[type="range"] { accent-color: var(--accent) !important; }
204
  flex-direction: column;
205
  align-items: center;
206
  justify-content: center;
207
- background: var(--bg-void);
208
- border-radius: var(--radius-lg);
209
  }
 
210
  .previewer-container .tips-icon {
211
- position: absolute; right: 10px; top: 10px; z-index: 10;
212
- border-radius: var(--radius-md); color: var(--bg-void);
213
- background-color: var(--accent); padding: 4px 8px;
214
- user-select: none; font-size: 0.75rem; font-weight: 500;
 
 
 
 
 
215
  }
 
216
  .previewer-container .tips-text {
217
- position: absolute; right: 10px; top: 45px;
218
- color: var(--text-primary); background-color: var(--bg-elevated);
219
- border: 1px solid var(--border-light); border-radius: var(--radius-md);
220
- padding: 8px 12px; text-align: left; max-width: 280px; z-index: 10;
221
- transition: opacity 0.15s ease-out; opacity: 0; user-select: none;
222
- box-shadow: var(--shadow-md);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  }
224
- .previewer-container .tips-text p { font-size: 0.75rem; line-height: 1.4; color: var(--text-secondary); margin: 4px 0; }
225
- .tips-icon:hover + .tips-text { opacity: 1; }
226
- .previewer-container .mode-row { width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 16px; flex-wrap: wrap; }
227
  .previewer-container .mode-btn {
228
- width: 28px; height: 28px; border-radius: 50%; cursor: pointer;
229
- opacity: 0.5; transition: all 0.15s ease-out;
230
- border: 2px solid var(--border-light); object-fit: cover;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  }
232
- .previewer-container .mode-btn:hover { opacity: 0.85; transform: scale(1.1); border-color: var(--text-muted); }
233
- .previewer-container .mode-btn.active { opacity: 1; border-color: var(--accent); transform: scale(1.1); }
234
- .previewer-container .display-row { margin-bottom: 16px; min-height: 400px; width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; }
235
- .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; }
236
- .previewer-container .previewer-main-image.visible { display: block; }
237
- .previewer-container .slider-row { width: 100%; display: flex; flex-direction: column; align-items: center; gap: 8px; padding: 0 16px; }
238
- .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; height: 20px; }
239
- .previewer-container input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 4px; cursor: pointer; background: var(--bg-tertiary); border-radius: 9999px; }
240
  .previewer-container input[type=range]::-webkit-slider-thumb {
241
- height: 12px; width: 12px; border-radius: 50%; background: var(--accent);
242
- cursor: pointer; -webkit-appearance: none; margin-top: -4px;
243
- box-shadow: var(--shadow-sm); transition: transform 0.1s ease-out;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  }
245
- .previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); background: var(--accent-hover); }
246
- .gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
247
- .gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; }
248
  """
249
 
250
 
251
  head = """
252
  <script>
253
  function refreshView(mode, step) {
 
254
  const allImgs = document.querySelectorAll('.previewer-main-image');
255
  for (let i = 0; i < allImgs.length; i++) {
256
  const img = allImgs[i];
@@ -262,26 +264,46 @@ head = """
262
  break;
263
  }
264
  }
 
 
 
265
  allImgs.forEach(img => img.classList.remove('visible'));
 
 
 
266
  const targetId = 'view-m' + mode + '-s' + step;
267
  const targetImg = document.getElementById(targetId);
268
- if (targetImg) { targetImg.classList.add('visible'); }
 
 
 
 
 
 
269
  const allBtns = document.querySelectorAll('.mode-btn');
270
  allBtns.forEach((btn, idx) => {
271
  if (idx === mode) btn.classList.add('active');
272
  else btn.classList.remove('active');
273
  });
274
  }
275
- function selectMode(mode) { refreshView(mode, -1); }
276
- function onSliderChange(val) { refreshView(-1, parseInt(val)); }
 
 
 
 
 
 
 
 
277
  </script>
278
  """
279
 
280
 
281
- empty_html = """
282
  <div class="previewer-container">
283
- <svg style="opacity: .5; height: var(--size-5); color: var(--body-text-color);"
284
- xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
285
  </div>
286
  """
287
 
@@ -315,7 +337,10 @@ def remove_background(input: Image.Image) -> Image.Image:
315
 
316
 
317
  def preprocess_image(input: Image.Image) -> Image.Image:
318
- """Preprocess a single input image."""
 
 
 
319
  has_alpha = False
320
  if input.mode == 'RGBA':
321
  alpha = np.array(input)[:, :, 3]
@@ -332,29 +357,42 @@ def preprocess_image(input: Image.Image) -> Image.Image:
332
  output_np = np.array(output)
333
  alpha = output_np[:, :, 3]
334
  bbox = np.argwhere(alpha > 0.8 * 255)
 
 
 
 
 
 
 
335
  bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
336
  center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
337
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
338
  size = int(size * 1)
339
  bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
340
- output = output.crop(bbox)
341
- output = np.array(output).astype(np.float32) / 255
342
- output = output[:, :, :3] * output[:, :, 3:4]
343
- output = Image.fromarray((output * 255).astype(np.uint8))
 
 
 
 
 
344
  return output
345
 
346
 
347
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
348
- """Preprocess a list of input images. Uses parallel processing."""
349
- if not images:
350
- return []
351
- imgs = [img[0] if isinstance(img, tuple) else img for img in images]
352
- with ThreadPoolExecutor(max_workers=min(4, len(imgs))) as executor:
353
- processed_images = list(executor.map(preprocess_image, imgs))
 
354
  return processed_images
355
 
356
 
357
- def pack_state(latents):
358
  shape_slat, tex_slat, res = latents
359
  return {
360
  'shape_slat_feats': shape_slat.feats.cpu().numpy(),
@@ -364,8 +402,7 @@ def pack_state(latents):
364
  }
365
 
366
 
367
- def unpack_state(state: dict):
368
- _lazy_import()
369
  shape_slat = SparseTensor(
370
  feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
371
  coords=torch.from_numpy(state['coords']).cuda(),
@@ -375,40 +412,81 @@ def unpack_state(state: dict):
375
 
376
 
377
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
 
 
378
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
379
 
380
 
381
- def prepare_examples() -> List[List[str]]:
382
- """Prepare examples as lists of image paths (not concatenated)."""
383
- example_dir = "assets/example_multi_image"
384
- if not os.path.exists(example_dir):
385
- return []
386
- files = os.listdir(example_dir)
387
- cases = list(set([f.split('_')[0] for f in files if '_' in f and f.endswith('.png')]))
388
  examples = []
389
- for case in sorted(cases):
390
- case_images = []
391
- for i in range(1, 10): # Support up to 9 images per example
392
- img_path = f'{example_dir}/{case}_{i}.png'
393
- if os.path.exists(img_path):
394
- case_images.append(img_path)
395
- if case_images:
396
- examples.append(case_images)
397
  return examples
398
 
399
 
400
- def load_example(example_paths: List[str]) -> List[Image.Image]:
401
- """Load example images from paths."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  images = []
403
- for path in example_paths:
404
- img = Image.open(path)
405
- images.append(img)
406
- return images
407
 
408
 
409
  @spaces.GPU(duration=120)
410
  def image_to_3d(
411
- images: List[Tuple[Image.Image, str]],
412
  seed: int,
413
  resolution: str,
414
  ss_guidance_strength: float,
@@ -423,103 +501,74 @@ def image_to_3d(
423
  tex_slat_guidance_rescale: float,
424
  tex_slat_sampling_steps: int,
425
  tex_slat_rescale_t: float,
 
426
  multiimage_algo: Literal["multidiffusion", "stochastic"],
427
  tex_multiimage_algo: Literal["multidiffusion", "stochastic"],
428
  req: gr.Request,
429
  progress=gr.Progress(track_tqdm=True),
430
  ) -> str:
431
- # Initialize pipeline on first call
432
- _initialize_pipeline()
433
 
434
- # Extract images from gallery format
435
- if not images:
436
- raise gr.Error("Please upload at least one image")
437
-
438
- imgs = [img[0] if isinstance(img, tuple) else img for img in images]
439
 
440
  # --- Sampling ---
441
- if len(imgs) == 1:
442
- # Single image mode
443
- outputs, latents = pipeline.run(
444
- imgs[0],
445
- seed=seed,
446
- preprocess_image=False,
447
- sparse_structure_sampler_params={
448
- "steps": ss_sampling_steps,
449
- "guidance_strength": ss_guidance_strength,
450
- "guidance_rescale": ss_guidance_rescale,
451
- "rescale_t": ss_rescale_t,
452
- },
453
- shape_slat_sampler_params={
454
- "steps": shape_slat_sampling_steps,
455
- "guidance_strength": shape_slat_guidance_strength,
456
- "guidance_rescale": shape_slat_guidance_rescale,
457
- "rescale_t": shape_slat_rescale_t,
458
- },
459
- tex_slat_sampler_params={
460
- "steps": tex_slat_sampling_steps,
461
- "guidance_strength": tex_slat_guidance_strength,
462
- "guidance_rescale": tex_slat_guidance_rescale,
463
- "rescale_t": tex_slat_rescale_t,
464
- },
465
- pipeline_type={
466
- "512": "512",
467
- "1024": "1024_cascade",
468
- "1536": "1536_cascade",
469
- }[resolution],
470
- return_latent=True,
471
- )
472
- else:
473
- # Multi-image mode
474
- outputs, latents = pipeline.run_multi_image(
475
- imgs,
476
- seed=seed,
477
- preprocess_image=False,
478
- sparse_structure_sampler_params={
479
- "steps": ss_sampling_steps,
480
- "guidance_strength": ss_guidance_strength,
481
- "guidance_rescale": ss_guidance_rescale,
482
- "rescale_t": ss_rescale_t,
483
- },
484
- shape_slat_sampler_params={
485
- "steps": shape_slat_sampling_steps,
486
- "guidance_strength": shape_slat_guidance_strength,
487
- "guidance_rescale": shape_slat_guidance_rescale,
488
- "rescale_t": shape_slat_rescale_t,
489
- },
490
- tex_slat_sampler_params={
491
- "steps": tex_slat_sampling_steps,
492
- "guidance_strength": tex_slat_guidance_strength,
493
- "guidance_rescale": tex_slat_guidance_rescale,
494
- "rescale_t": tex_slat_rescale_t,
495
- },
496
- pipeline_type={
497
- "512": "512",
498
- "1024": "1024_cascade",
499
- "1536": "1536_cascade",
500
- }[resolution],
501
- return_latent=True,
502
- mode=multiimage_algo,
503
- tex_mode=tex_multiimage_algo,
504
- )
505
-
506
  mesh = outputs[0]
507
- mesh.simplify(16777216)
508
- render_images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
509
  state = pack_state(latents)
510
  torch.cuda.empty_cache()
511
 
512
  # --- HTML Construction ---
 
513
  def encode_preview_image(args):
514
  m_idx, s_idx, render_key = args
515
- img_base64 = image_to_base64(Image.fromarray(render_images[render_key][s_idx]))
516
  return (m_idx, s_idx, img_base64)
517
 
518
- encode_tasks = [(m_idx, s_idx, mode['render_key']) for m_idx, mode in enumerate(MODES) for s_idx in range(STEPS)]
 
 
 
 
519
 
520
  with ThreadPoolExecutor(max_workers=8) as executor:
521
  encoded_results = list(executor.map(encode_preview_image, encode_tasks))
522
 
 
523
  encoded_map = {(m, s): b64 for m, s, b64 in encoded_results}
524
  images_html = ""
525
  for m_idx, mode in enumerate(MODES):
@@ -528,29 +577,54 @@ def image_to_3d(
528
  is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
529
  vis_class = "visible" if is_visible else ""
530
  img_base64 = encoded_map[(m_idx, s_idx)]
531
- images_html += f'<img id="{unique_id}" class="previewer-main-image {vis_class}" src="{img_base64}" loading="eager">'
532
 
 
 
 
 
 
 
 
 
533
  btns_html = ""
534
  for idx, mode in enumerate(MODES):
535
  active_class = "active" if idx == DEFAULT_MODE else ""
536
- btns_html += f'<img src="{mode["icon_base64"]}" class="mode-btn {active_class}" onclick="selectMode({idx})" title="{mode["name"]}">'
537
-
 
 
 
 
 
 
 
538
  full_html = f"""
539
  <div class="previewer-container">
540
  <div class="tips-wrapper">
541
- <div class="tips-icon">Tips</div>
542
  <div class="tips-text">
543
- <p>Render Mode - Click buttons to switch render modes.</p>
544
- <p>View Angle - Drag slider to change view.</p>
545
  </div>
546
  </div>
547
- <div class="display-row">{images_html}</div>
548
- <div class="mode-row" id="btn-group">{btns_html}</div>
 
 
 
 
 
 
 
 
 
 
549
  <div class="slider-row">
550
  <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
551
  </div>
552
  </div>
553
  """
 
554
  return state, full_html
555
 
556
 
@@ -561,13 +635,22 @@ def extract_glb(
561
  texture_size: int,
562
  req: gr.Request,
563
  progress=gr.Progress(track_tqdm=True),
564
- ) -> Tuple[str, str]:
565
- _initialize_pipeline()
 
566
 
 
 
 
 
 
 
 
 
567
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
568
  shape_slat, tex_slat, res = unpack_state(state)
569
  mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
570
- mesh.simplify(16777216)
571
  glb = o_voxel.postprocess.to_glb(
572
  vertices=mesh.vertices,
573
  faces=mesh.faces,
@@ -589,27 +672,30 @@ def extract_glb(
589
  glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
590
  glb.export(glb_path, extension_webp=True)
591
  torch.cuda.empty_cache()
592
- return glb_path, glb_path
593
-
594
-
595
- with gr.Blocks(delete_cache=(600, 600)) as demo:
596
- gr.Markdown("""
597
- ## Multi-View Image to 3D with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
598
- Upload one or more images of an object and click Generate to create a 3D asset.
599
- Multiple views from different angles will produce better results.
 
 
 
 
 
 
 
 
 
 
 
600
  """)
601
 
602
  with gr.Row():
603
  with gr.Column(scale=1, min_width=360):
604
- image_prompt = gr.Gallery(
605
- label="Input Images (upload 1 or more views)",
606
- format="png",
607
- type="pil",
608
- height=400,
609
- columns=3,
610
- object_fit="contain"
611
- )
612
- gr.Markdown("*Upload multiple views of the same object for better 3D reconstruction.*")
613
 
614
  resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
615
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
@@ -617,8 +703,6 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
617
  decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
618
  texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
619
 
620
- generate_btn = gr.Button("Generate", variant="primary")
621
-
622
  with gr.Accordion(label="Advanced Settings", open=False):
623
  gr.Markdown("Stage 1: Sparse Structure Generation")
624
  with gr.Row():
@@ -642,71 +726,88 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
642
  tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="multidiffusion")
643
 
644
  with gr.Column(scale=10):
645
- with gr.Walkthrough(selected=0) as walkthrough:
646
- with gr.Step("Preview", id=0):
647
- preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
648
- extract_btn = gr.Button("Extract GLB")
649
- with gr.Step("Extract", id=1):
650
- glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
651
- download_btn = gr.DownloadButton(label="Download GLB")
652
- gr.Markdown("*GLB extraction may take 30+ seconds.*")
653
-
654
- with gr.Column(scale=1, min_width=200):
655
- gr.Markdown("### Examples")
656
- # Create example buttons that load images into gallery
657
- example_data = prepare_examples()
658
- for i, example_paths in enumerate(example_data[:12]): # Limit to 12 examples
659
- case_name = os.path.basename(example_paths[0]).split('_')[0]
660
- btn = gr.Button(f"{case_name} ({len(example_paths)} views)", size="sm")
661
- btn.click(
662
- fn=lambda paths=example_paths: load_example(paths),
663
- outputs=[image_prompt]
664
- )
665
 
666
  output_buf = gr.State()
667
 
 
668
  # Handlers
669
  demo.load(start_session)
670
  demo.unload(end_session)
671
-
672
- image_prompt.upload(
673
  preprocess_images,
674
- inputs=[image_prompt],
675
- outputs=[image_prompt],
676
  )
677
 
678
  generate_btn.click(
679
- get_seed, inputs=[randomize_seed, seed], outputs=[seed],
680
- ).then(
681
- lambda: gr.Walkthrough(selected=0), outputs=walkthrough
682
  ).then(
683
  image_to_3d,
684
  inputs=[
685
- image_prompt, seed, resolution,
686
  ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
687
  shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
688
  tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
689
- multiimage_algo, tex_multiimage_algo
690
  ],
691
  outputs=[output_buf, preview_output],
692
  )
693
 
694
  extract_btn.click(
695
- lambda: gr.Walkthrough(selected=1), outputs=walkthrough
696
- ).then(
697
  extract_glb,
698
  inputs=[output_buf, decimation_target, texture_size],
699
- outputs=[glb_output, download_btn],
700
  )
701
 
702
 
 
703
  if __name__ == "__main__":
704
  os.makedirs(TMP_DIR, exist_ok=True)
705
 
 
 
706
  for i in range(len(MODES)):
707
  icon = Image.open(MODES[i]['icon'])
708
  MODES[i]['icon_base64'] = image_to_base64(icon)
709
 
710
  rmbg_client = Client("briaai/BRIA-RMBG-2.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
 
712
  demo.launch(css=css, head=head)
 
13
  import shutil
14
  import cv2
15
  from typing import *
16
+ import torch
17
  import numpy as np
18
  from PIL import Image
19
  import base64
20
  import io
21
  import tempfile
22
+ from trellis2.modules.sparse import SparseTensor
23
+ from trellis2.pipelines import Trellis2ImageTo3DPipeline
24
+ from trellis2.renderers import EnvMap
25
+ from trellis2.utils import render_utils
26
+ import o_voxel
27
+
28
+ # Patch postprocess module with local fix for cumesh.fill_holes() bug
29
+ import importlib.util
30
+ _local_postprocess = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'o-voxel', 'o_voxel', 'postprocess.py')
31
+ if os.path.exists(_local_postprocess):
32
+ import sys
33
+ _spec = importlib.util.spec_from_file_location('o_voxel.postprocess', _local_postprocess)
34
+ _mod = importlib.util.module_from_spec(_spec)
35
+ _spec.loader.exec_module(_mod)
36
+ o_voxel.postprocess = _mod
37
+ sys.modules['o_voxel.postprocess'] = _mod
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  MAX_SEED = np.iinfo(np.int32).max
 
53
 
54
 
55
  css = """
56
+ /* Overwrite Gradio Default Style */
57
+ .stepper-wrapper {
58
+ padding: 0;
59
+ }
60
+
61
+ .stepper-container {
62
+ padding: 0;
63
+ align-items: center;
64
+ }
65
+
66
+ .step-button {
67
+ flex-direction: row;
68
+ }
69
+
70
+ .step-connector {
71
+ transform: none;
72
+ }
73
+
74
+ .step-number {
75
+ width: 16px;
76
+ height: 16px;
77
+ }
78
+
79
+ .step-label {
80
+ position: relative;
81
+ bottom: 0;
82
+ }
83
+
84
+ .wrap.center.full {
85
+ inset: 0;
86
+ height: 100%;
87
+ }
88
+
89
+ .wrap.center.full.translucent {
90
+ background: var(--block-background-fill);
91
  }
92
 
93
+ .meta-text-center {
94
+ display: block !important;
95
+ position: absolute !important;
96
+ top: unset !important;
97
+ bottom: 0 !important;
98
+ right: 0 !important;
99
+ transform: unset !important;
100
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  /* Previewer */
103
  .previewer-container {
104
  position: relative;
105
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
106
  width: 100%;
107
  height: 722px;
108
  margin: 0 auto;
 
111
  flex-direction: column;
112
  align-items: center;
113
  justify-content: center;
 
 
114
  }
115
+
116
  .previewer-container .tips-icon {
117
+ position: absolute;
118
+ right: 10px;
119
+ top: 10px;
120
+ z-index: 10;
121
+ border-radius: 10px;
122
+ color: #fff;
123
+ background-color: var(--color-accent);
124
+ padding: 3px 6px;
125
+ user-select: none;
126
  }
127
+
128
  .previewer-container .tips-text {
129
+ position: absolute;
130
+ right: 10px;
131
+ top: 50px;
132
+ color: #fff;
133
+ background-color: var(--color-accent);
134
+ border-radius: 10px;
135
+ padding: 6px;
136
+ text-align: left;
137
+ max-width: 300px;
138
+ z-index: 10;
139
+ transition: all 0.3s;
140
+ opacity: 0%;
141
+ user-select: none;
142
+ }
143
+
144
+ .previewer-container .tips-text p {
145
+ font-size: 14px;
146
+ line-height: 1.2;
147
+ }
148
+
149
+ .tips-icon:hover + .tips-text {
150
+ display: block;
151
+ opacity: 100%;
152
+ }
153
+
154
+ /* Row 1: Display Modes */
155
+ .previewer-container .mode-row {
156
+ width: 100%;
157
+ display: flex;
158
+ gap: 8px;
159
+ justify-content: center;
160
+ margin-bottom: 20px;
161
+ flex-wrap: wrap;
162
  }
 
 
 
163
  .previewer-container .mode-btn {
164
+ width: 24px;
165
+ height: 24px;
166
+ border-radius: 50%;
167
+ cursor: pointer;
168
+ opacity: 0.5;
169
+ transition: all 0.2s;
170
+ border: 2px solid var(--neutral-600, #555);
171
+ object-fit: cover;
172
+ }
173
+ .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
174
+ .previewer-container .mode-btn.active {
175
+ opacity: 1;
176
+ border-color: var(--color-accent);
177
+ transform: scale(1.1);
178
+ }
179
+
180
+ /* Row 2: Display Image */
181
+ .previewer-container .display-row {
182
+ margin-bottom: 20px;
183
+ min-height: 400px;
184
+ width: 100%;
185
+ flex-grow: 1;
186
+ display: flex;
187
+ justify-content: center;
188
+ align-items: center;
189
+ }
190
+ .previewer-container .previewer-main-image {
191
+ max-width: 100%;
192
+ max-height: 100%;
193
+ flex-grow: 1;
194
+ object-fit: contain;
195
+ display: none;
196
+ }
197
+ .previewer-container .previewer-main-image.visible {
198
+ display: block;
199
+ }
200
+
201
+ /* Row 3: Custom HTML Slider */
202
+ .previewer-container .slider-row {
203
+ width: 100%;
204
+ display: flex;
205
+ flex-direction: column;
206
+ align-items: center;
207
+ gap: 10px;
208
+ padding: 0 10px;
209
+ }
210
+
211
+ .previewer-container input[type=range] {
212
+ -webkit-appearance: none;
213
+ width: 100%;
214
+ max-width: 400px;
215
+ background: transparent;
216
+ }
217
+ .previewer-container input[type=range]::-webkit-slider-runnable-track {
218
+ width: 100%;
219
+ height: 8px;
220
+ cursor: pointer;
221
+ background: var(--neutral-700, #404040);
222
+ border-radius: 5px;
223
  }
 
 
 
 
 
 
 
 
224
  .previewer-container input[type=range]::-webkit-slider-thumb {
225
+ height: 20px;
226
+ width: 20px;
227
+ border-radius: 50%;
228
+ background: var(--color-accent);
229
+ cursor: pointer;
230
+ -webkit-appearance: none;
231
+ margin-top: -6px;
232
+ box-shadow: 0 2px 5px rgba(0,0,0,0.2);
233
+ transition: transform 0.1s;
234
+ }
235
+ .previewer-container input[type=range]::-webkit-slider-thumb:hover {
236
+ transform: scale(1.2);
237
+ }
238
+
239
+ /* Overwrite Previewer Block Style */
240
+ .gradio-container .padded:has(.previewer-container) {
241
+ padding: 0 !important;
242
+ }
243
+
244
+ .gradio-container:has(.previewer-container) [data-testid="block-label"] {
245
+ position: absolute;
246
+ top: 0;
247
+ left: 0;
248
  }
 
 
 
249
  """
250
 
251
 
252
  head = """
253
  <script>
254
  function refreshView(mode, step) {
255
+ // 1. Find current mode and step
256
  const allImgs = document.querySelectorAll('.previewer-main-image');
257
  for (let i = 0; i < allImgs.length; i++) {
258
  const img = allImgs[i];
 
264
  break;
265
  }
266
  }
267
+
268
+ // 2. Hide ALL images
269
+ // We select all elements with class 'previewer-main-image'
270
  allImgs.forEach(img => img.classList.remove('visible'));
271
+
272
+ // 3. Construct the specific ID for the current state
273
+ // Format: view-m{mode}-s{step}
274
  const targetId = 'view-m' + mode + '-s' + step;
275
  const targetImg = document.getElementById(targetId);
276
+
277
+ // 4. Show ONLY the target
278
+ if (targetImg) {
279
+ targetImg.classList.add('visible');
280
+ }
281
+
282
+ // 5. Update Button Highlights
283
  const allBtns = document.querySelectorAll('.mode-btn');
284
  allBtns.forEach((btn, idx) => {
285
  if (idx === mode) btn.classList.add('active');
286
  else btn.classList.remove('active');
287
  });
288
  }
289
+
290
+ // --- Action: Switch Mode ---
291
+ function selectMode(mode) {
292
+ refreshView(mode, -1);
293
+ }
294
+
295
+ // --- Action: Slider Change ---
296
+ function onSliderChange(val) {
297
+ refreshView(-1, parseInt(val));
298
+ }
299
  </script>
300
  """
301
 
302
 
303
+ empty_html = f"""
304
  <div class="previewer-container">
305
+ <svg style=" opacity: .5; height: var(--size-5); color: var(--body-text-color);"
306
+ xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
307
  </div>
308
  """
309
 
 
337
 
338
 
339
  def preprocess_image(input: Image.Image) -> Image.Image:
340
+ """
341
+ Preprocess the input image.
342
+ """
343
+ # if has alpha channel, use it directly; otherwise, remove background
344
  has_alpha = False
345
  if input.mode == 'RGBA':
346
  alpha = np.array(input)[:, :, 3]
 
357
  output_np = np.array(output)
358
  alpha = output_np[:, :, 3]
359
  bbox = np.argwhere(alpha > 0.8 * 255)
360
+ if bbox.size == 0:
361
+ # No visible pixels, center the image in a square
362
+ size = max(output.size)
363
+ square = Image.new('RGB', (size, size), (0, 0, 0))
364
+ output_rgb = output.convert('RGB') if output.mode == 'RGBA' else output
365
+ square.paste(output_rgb, ((size - output.width) // 2, (size - output.height) // 2))
366
+ return square
367
  bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
368
  center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
369
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
370
  size = int(size * 1)
371
  bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
372
+ output = output.crop(bbox) # type: ignore
373
+ output_np = np.array(output).astype(np.float32)
374
+ rgb = output_np[:, :, :3]
375
+ alpha = output_np[:, :, 3:4] / 255.0
376
+ # Use threshold to avoid darkening foreground pixels with slightly transparent alpha
377
+ # Pixels with alpha > 0.5 keep their full RGB, pixels below are blacked out
378
+ mask = (alpha > 0.5).astype(np.float32)
379
+ rgb = rgb * mask
380
+ output = Image.fromarray(rgb.astype(np.uint8))
381
  return output
382
 
383
 
384
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
385
+ """
386
+ Preprocess a list of input images for multi-image conditioning.
387
+ Uses parallel processing for faster background removal.
388
+ """
389
+ images = [image[0] for image in images]
390
+ with ThreadPoolExecutor(max_workers=min(4, len(images))) as executor:
391
+ processed_images = list(executor.map(preprocess_image, images))
392
  return processed_images
393
 
394
 
395
+ def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
396
  shape_slat, tex_slat, res = latents
397
  return {
398
  'shape_slat_feats': shape_slat.feats.cpu().numpy(),
 
402
  }
403
 
404
 
405
+ def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
 
406
  shape_slat = SparseTensor(
407
  feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
408
  coords=torch.from_numpy(state['coords']).cuda(),
 
412
 
413
 
414
  def get_seed(randomize_seed: bool, seed: int) -> int:
415
+ """
416
+ Get the random seed.
417
+ """
418
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
419
 
420
 
421
+ def prepare_multi_example() -> List[str]:
422
+ """
423
+ Prepare multi-image examples. Returns list of image paths.
424
+ Shows only the first view as representative thumbnail.
425
+ """
426
+ multi_case = sorted(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
 
427
  examples = []
428
+ for case in multi_case:
429
+ first_img = f'assets/example_multi_image/{case}_1.png'
430
+ if os.path.exists(first_img):
431
+ examples.append(first_img)
 
 
 
 
432
  return examples
433
 
434
 
435
+ def load_multi_example(image) -> List[Image.Image]:
436
+ """Load all views for a multi-image case by matching the input image."""
437
+ if image is None:
438
+ return []
439
+
440
+ # Convert to PIL Image if needed
441
+ if isinstance(image, np.ndarray):
442
+ image = Image.fromarray(image)
443
+
444
+ # Convert to RGB for consistent comparison
445
+ input_rgb = np.array(image.convert('RGB'))
446
+
447
+ # Find matching case by comparing with first images
448
+ example_dir = "assets/example_multi_image"
449
+ case_names = sorted(set([f.rsplit('_', 1)[0] for f in os.listdir(example_dir) if f.endswith('.png')]))
450
+
451
+ for case_name in case_names:
452
+ first_img_path = f'{example_dir}/{case_name}_1.png'
453
+ if os.path.exists(first_img_path):
454
+ first_img = Image.open(first_img_path).convert('RGB')
455
+ first_rgb = np.array(first_img)
456
+
457
+ # Compare images (check if same shape and content)
458
+ if input_rgb.shape == first_rgb.shape and np.array_equal(input_rgb, first_rgb):
459
+ # Found match, load all views (without preprocessing - will be done on Generate)
460
+ images = []
461
+ for i in range(1, 7):
462
+ img_path = f'{example_dir}/{case_name}_{i}.png'
463
+ if os.path.exists(img_path):
464
+ img = Image.open(img_path).convert('RGBA')
465
+ images.append(img)
466
+ if images:
467
+ return images
468
+
469
+ # No match found, return the single image
470
+ return [image.convert('RGBA') if image.mode != 'RGBA' else image]
471
+
472
+
473
+ def split_image(image: Image.Image) -> List[Image.Image]:
474
+ """
475
+ Split a concatenated image into multiple views.
476
+ """
477
+ image = np.array(image)
478
+ alpha = image[..., 3]
479
+ alpha = np.any(alpha > 0, axis=0)
480
+ start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
481
+ end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
482
  images = []
483
+ for s, e in zip(start_pos, end_pos):
484
+ images.append(Image.fromarray(image[:, s:e+1]))
485
+ return [preprocess_image(image) for image in images]
 
486
 
487
 
488
  @spaces.GPU(duration=120)
489
  def image_to_3d(
 
490
  seed: int,
491
  resolution: str,
492
  ss_guidance_strength: float,
 
501
  tex_slat_guidance_rescale: float,
502
  tex_slat_sampling_steps: int,
503
  tex_slat_rescale_t: float,
504
+ multiimages: List[Tuple[Image.Image, str]],
505
  multiimage_algo: Literal["multidiffusion", "stochastic"],
506
  tex_multiimage_algo: Literal["multidiffusion", "stochastic"],
507
  req: gr.Request,
508
  progress=gr.Progress(track_tqdm=True),
509
  ) -> str:
510
+ if not multiimages:
511
+ raise gr.Error("Please upload images or select an example first.")
512
 
513
+ # Preprocess images (background removal, cropping, etc.)
514
+ images = [image[0] for image in multiimages]
515
+ processed_images = [preprocess_image(img) for img in images]
 
 
516
 
517
  # --- Sampling ---
518
+ outputs, latents = pipeline.run_multi_image(
519
+ processed_images,
520
+ seed=seed,
521
+ preprocess_image=False,
522
+ sparse_structure_sampler_params={
523
+ "steps": ss_sampling_steps,
524
+ "guidance_strength": ss_guidance_strength,
525
+ "guidance_rescale": ss_guidance_rescale,
526
+ "rescale_t": ss_rescale_t,
527
+ },
528
+ shape_slat_sampler_params={
529
+ "steps": shape_slat_sampling_steps,
530
+ "guidance_strength": shape_slat_guidance_strength,
531
+ "guidance_rescale": shape_slat_guidance_rescale,
532
+ "rescale_t": shape_slat_rescale_t,
533
+ },
534
+ tex_slat_sampler_params={
535
+ "steps": tex_slat_sampling_steps,
536
+ "guidance_strength": tex_slat_guidance_strength,
537
+ "guidance_rescale": tex_slat_guidance_rescale,
538
+ "rescale_t": tex_slat_rescale_t,
539
+ },
540
+ pipeline_type={
541
+ "512": "512",
542
+ "1024": "1024_cascade",
543
+ "1536": "1536_cascade",
544
+ }[resolution],
545
+ return_latent=True,
546
+ mode=multiimage_algo,
547
+ tex_mode=tex_multiimage_algo,
548
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
549
  mesh = outputs[0]
550
+ mesh.simplify(16777216) # nvdiffrast limit
551
+ images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
552
  state = pack_state(latents)
553
  torch.cuda.empty_cache()
554
 
555
  # --- HTML Construction ---
556
+ # The Stack of 48 Images - encode in parallel for speed
557
  def encode_preview_image(args):
558
  m_idx, s_idx, render_key = args
559
+ img_base64 = image_to_base64(Image.fromarray(images[render_key][s_idx]))
560
  return (m_idx, s_idx, img_base64)
561
 
562
+ encode_tasks = [
563
+ (m_idx, s_idx, mode['render_key'])
564
+ for m_idx, mode in enumerate(MODES)
565
+ for s_idx in range(STEPS)
566
+ ]
567
 
568
  with ThreadPoolExecutor(max_workers=8) as executor:
569
  encoded_results = list(executor.map(encode_preview_image, encode_tasks))
570
 
571
+ # Build HTML from encoded results
572
  encoded_map = {(m, s): b64 for m, s, b64 in encoded_results}
573
  images_html = ""
574
  for m_idx, mode in enumerate(MODES):
 
577
  is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
578
  vis_class = "visible" if is_visible else ""
579
  img_base64 = encoded_map[(m_idx, s_idx)]
 
580
 
581
+ images_html += f"""
582
+ <img id="{unique_id}"
583
+ class="previewer-main-image {vis_class}"
584
+ src="{img_base64}"
585
+ loading="eager">
586
+ """
587
+
588
+ # Button Row HTML
589
  btns_html = ""
590
  for idx, mode in enumerate(MODES):
591
  active_class = "active" if idx == DEFAULT_MODE else ""
592
+ # Note: onclick calls the JS function defined in Head
593
+ btns_html += f"""
594
+ <img src="{mode['icon_base64']}"
595
+ class="mode-btn {active_class}"
596
+ onclick="selectMode({idx})"
597
+ title="{mode['name']}">
598
+ """
599
+
600
+ # Assemble the full component
601
  full_html = f"""
602
  <div class="previewer-container">
603
  <div class="tips-wrapper">
604
+ <div class="tips-icon">💡Tips</div>
605
  <div class="tips-text">
606
+ <p>● <b>Render Mode</b> - Click on the circular buttons to switch between different render modes.</p>
607
+ <p>● <b>View Angle</b> - Drag the slider to change the view angle.</p>
608
  </div>
609
  </div>
610
+
611
+ <!-- Row 1: Viewport containing 48 static <img> tags -->
612
+ <div class="display-row">
613
+ {images_html}
614
+ </div>
615
+
616
+ <!-- Row 2 -->
617
+ <div class="mode-row" id="btn-group">
618
+ {btns_html}
619
+ </div>
620
+
621
+ <!-- Row 3: Slider -->
622
  <div class="slider-row">
623
  <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
624
  </div>
625
  </div>
626
  """
627
+
628
  return state, full_html
629
 
630
 
 
635
  texture_size: int,
636
  req: gr.Request,
637
  progress=gr.Progress(track_tqdm=True),
638
+ ) -> str:
639
+ """
640
+ Extract a GLB file from the 3D model.
641
 
642
+ Args:
643
+ state (dict): The state of the generated 3D model.
644
+ decimation_target (int): The target face count for decimation.
645
+ texture_size (int): The texture resolution.
646
+
647
+ Returns:
648
+ str: The path to the extracted GLB file.
649
+ """
650
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
651
  shape_slat, tex_slat, res = unpack_state(state)
652
  mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
653
+ mesh.simplify(16777216) # nvdiffrast limit
654
  glb = o_voxel.postprocess.to_glb(
655
  vertices=mesh.vertices,
656
  faces=mesh.faces,
 
672
  glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
673
  glb.export(glb_path, extension_webp=True)
674
  torch.cuda.empty_cache()
675
+ return glb_path
676
+
677
+
678
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate")) as demo:
679
+ gr.HTML("""
680
+ <div style="display: flex; align-items: center; gap: 20px; margin-bottom: 10px;">
681
+ <a href="https://www.opsiclear.com" target="_blank">
682
+ <img src="https://www.opsiclear.com/assets/logos/Logo_v2_compact_name.svg" alt="OpsiClear" style="height: 80px;">
683
+ </a>
684
+ <div>
685
+ <h2 style="margin: 0;">Multi-View to 3D with <a href="https://microsoft.github.io/TRELLIS.2" target="_blank">TRELLIS.2</a></h2>
686
+ <ul style="margin: 5px 0; padding-left: 20px;">
687
+ <li>Upload multiple images from different viewpoints to create a 3D asset with multi-image conditioning.</li>
688
+ <li>Click an example below to load a pre-made multi-view set, or upload your own images.</li>
689
+ <li>Click <b>Generate</b> to create the 3D model, then <b>Extract GLB</b> to export.</li>
690
+ <li style="color: #e67300;"><b>⚠️ Note:</b> Generation quality is highly sensitive to parameters. Adjust settings in Advanced Settings if results are unsatisfactory.</li>
691
+ </ul>
692
+ </div>
693
+ </div>
694
  """)
695
 
696
  with gr.Row():
697
  with gr.Column(scale=1, min_width=360):
698
+ multiimage_prompt = gr.Gallery(label="Multi-View Images", format="png", type="pil", height=400, columns=3)
 
 
 
 
 
 
 
 
699
 
700
  resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
701
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
 
703
  decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
704
  texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
705
 
 
 
706
  with gr.Accordion(label="Advanced Settings", open=False):
707
  gr.Markdown("Stage 1: Sparse Structure Generation")
708
  with gr.Row():
 
726
  tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="multidiffusion")
727
 
728
  with gr.Column(scale=10):
729
+ preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
730
+ glb_output = gr.Model3D(label="Extracted GLB", height=400, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0), visible=False)
731
+
732
+ with gr.Row():
733
+ generate_btn = gr.Button("Generate", variant="primary")
734
+ extract_btn = gr.Button("Extract GLB")
735
+
736
+ example_image = gr.Image(visible=False) # Hidden component for examples
737
+ examples_multi = gr.Examples(
738
+ examples=prepare_multi_example(),
739
+ inputs=[example_image],
740
+ fn=load_multi_example,
741
+ outputs=[multiimage_prompt],
742
+ run_on_click=True,
743
+ cache_examples=False,
744
+ examples_per_page=50,
745
+ )
 
 
 
746
 
747
  output_buf = gr.State()
748
 
749
+
750
  # Handlers
751
  demo.load(start_session)
752
  demo.unload(end_session)
753
+ multiimage_prompt.upload(
 
754
  preprocess_images,
755
+ inputs=[multiimage_prompt],
756
+ outputs=[multiimage_prompt],
757
  )
758
 
759
  generate_btn.click(
760
+ get_seed,
761
+ inputs=[randomize_seed, seed],
762
+ outputs=[seed],
763
  ).then(
764
  image_to_3d,
765
  inputs=[
766
+ seed, resolution,
767
  ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
768
  shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
769
  tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
770
+ multiimage_prompt, multiimage_algo, tex_multiimage_algo
771
  ],
772
  outputs=[output_buf, preview_output],
773
  )
774
 
775
  extract_btn.click(
 
 
776
  extract_glb,
777
  inputs=[output_buf, decimation_target, texture_size],
778
+ outputs=[glb_output],
779
  )
780
 
781
 
782
+ # Launch the Gradio app
783
  if __name__ == "__main__":
784
  os.makedirs(TMP_DIR, exist_ok=True)
785
 
786
+ # Construct ui components
787
+ btn_img_base64_strs = {}
788
  for i in range(len(MODES)):
789
  icon = Image.open(MODES[i]['icon'])
790
  MODES[i]['icon_base64'] = image_to_base64(icon)
791
 
792
  rmbg_client = Client("briaai/BRIA-RMBG-2.0")
793
+ pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
794
+ pipeline.rembg_model = None
795
+ pipeline.low_vram = False
796
+ pipeline.cuda()
797
+
798
+ envmap = {
799
+ 'forest': EnvMap(torch.tensor(
800
+ cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
801
+ dtype=torch.float32, device='cuda'
802
+ )),
803
+ 'sunset': EnvMap(torch.tensor(
804
+ cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
805
+ dtype=torch.float32, device='cuda'
806
+ )),
807
+ 'courtyard': EnvMap(torch.tensor(
808
+ cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
809
+ dtype=torch.float32, device='cuda'
810
+ )),
811
+ }
812
 
813
  demo.launch(css=css, head=head)