opsiclear-admin commited on
Commit
42fc1a9
·
verified ·
1 Parent(s): 01dc5d4

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +369 -443
app.py CHANGED
@@ -4,40 +4,98 @@ import spaces
4
  from concurrent.futures import ThreadPoolExecutor
5
 
6
  import os
7
- import sys
8
-
9
- _script_dir = os.path.dirname(os.path.abspath(__file__))
10
-
11
  os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
12
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
13
  os.environ["ATTN_BACKEND"] = "flash_attn_3"
14
- os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(_script_dir, 'autotune_cache.json')
15
  os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
16
  from datetime import datetime
17
  import shutil
18
  import cv2
19
  from typing import *
20
- import torch
21
  import numpy as np
22
  from PIL import Image
23
  import base64
24
  import io
25
  import tempfile
26
- from trellis2.modules.sparse import SparseTensor
27
- from trellis2.pipelines import Trellis2ImageTo3DPipeline
28
- from trellis2.renderers import EnvMap
29
- from trellis2.utils import render_utils
30
- import o_voxel
31
-
32
- # Patch postprocess module with local fix for cumesh.fill_holes() bug
33
- import importlib.util
34
- _local_postprocess = os.path.join(_script_dir, 'o-voxel', 'o_voxel', 'postprocess.py')
35
- if os.path.exists(_local_postprocess):
36
- _spec = importlib.util.spec_from_file_location('o_voxel.postprocess', _local_postprocess)
37
- _mod = importlib.util.module_from_spec(_spec)
38
- _spec.loader.exec_module(_mod)
39
- o_voxel.postprocess = _mod
40
- sys.modules['o_voxel.postprocess'] = _mod
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
 
43
  MAX_SEED = np.iinfo(np.int32).max
@@ -56,56 +114,88 @@ DEFAULT_STEP = 3
56
 
57
 
58
  css = """
59
- /* Overwrite Gradio Default Style */
60
- .stepper-wrapper {
61
- padding: 0;
62
- }
63
-
64
- .stepper-container {
65
- padding: 0;
66
- align-items: center;
67
- }
68
-
69
- .step-button {
70
- flex-direction: row;
71
- }
72
-
73
- .step-connector {
74
- transform: none;
75
- }
76
-
77
- .step-number {
78
- width: 16px;
79
- height: 16px;
80
- }
81
-
82
- .step-label {
83
- position: relative;
84
- bottom: 0;
85
- }
86
-
87
- .wrap.center.full {
88
- inset: 0;
89
- height: 100%;
90
  }
91
 
92
- .wrap.center.full.translucent {
93
- background: var(--block-background-fill);
94
- }
95
-
96
- .meta-text-center {
97
- display: block !important;
98
- position: absolute !important;
99
- top: unset !important;
100
- bottom: 0 !important;
101
- right: 0 !important;
102
- transform: unset !important;
103
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  /* Previewer */
106
  .previewer-container {
107
  position: relative;
108
- font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
109
  width: 100%;
110
  height: 722px;
111
  margin: 0 auto;
@@ -114,148 +204,53 @@ css = """
114
  flex-direction: column;
115
  align-items: center;
116
  justify-content: center;
 
 
117
  }
118
-
119
  .previewer-container .tips-icon {
120
- position: absolute;
121
- right: 10px;
122
- top: 10px;
123
- z-index: 10;
124
- border-radius: 10px;
125
- color: #fff;
126
- background-color: var(--color-accent);
127
- padding: 3px 6px;
128
- user-select: none;
129
  }
130
-
131
  .previewer-container .tips-text {
132
- position: absolute;
133
- right: 10px;
134
- top: 50px;
135
- color: #fff;
136
- background-color: var(--color-accent);
137
- border-radius: 10px;
138
- padding: 6px;
139
- text-align: left;
140
- max-width: 300px;
141
- z-index: 10;
142
- transition: all 0.3s;
143
- opacity: 0%;
144
- user-select: none;
145
- }
146
-
147
- .previewer-container .tips-text p {
148
- font-size: 14px;
149
- line-height: 1.2;
150
- }
151
-
152
- .tips-icon:hover + .tips-text {
153
- display: block;
154
- opacity: 100%;
155
- }
156
-
157
- /* Row 1: Display Modes */
158
- .previewer-container .mode-row {
159
- width: 100%;
160
- display: flex;
161
- gap: 8px;
162
- justify-content: center;
163
- margin-bottom: 20px;
164
- flex-wrap: wrap;
165
  }
 
 
 
166
  .previewer-container .mode-btn {
167
- width: 24px;
168
- height: 24px;
169
- border-radius: 50%;
170
- cursor: pointer;
171
- opacity: 0.5;
172
- transition: all 0.2s;
173
- border: 2px solid var(--neutral-600, #555);
174
- object-fit: cover;
175
- }
176
- .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
177
- .previewer-container .mode-btn.active {
178
- opacity: 1;
179
- border-color: var(--color-accent);
180
- transform: scale(1.1);
181
- }
182
-
183
- /* Row 2: Display Image */
184
- .previewer-container .display-row {
185
- margin-bottom: 20px;
186
- min-height: 400px;
187
- width: 100%;
188
- flex-grow: 1;
189
- display: flex;
190
- justify-content: center;
191
- align-items: center;
192
- }
193
- .previewer-container .previewer-main-image {
194
- max-width: 100%;
195
- max-height: 100%;
196
- flex-grow: 1;
197
- object-fit: contain;
198
- display: none;
199
- }
200
- .previewer-container .previewer-main-image.visible {
201
- display: block;
202
- }
203
-
204
- /* Row 3: Custom HTML Slider */
205
- .previewer-container .slider-row {
206
- width: 100%;
207
- display: flex;
208
- flex-direction: column;
209
- align-items: center;
210
- gap: 10px;
211
- padding: 0 10px;
212
- }
213
-
214
- .previewer-container input[type=range] {
215
- -webkit-appearance: none;
216
- width: 100%;
217
- max-width: 400px;
218
- background: transparent;
219
- }
220
- .previewer-container input[type=range]::-webkit-slider-runnable-track {
221
- width: 100%;
222
- height: 8px;
223
- cursor: pointer;
224
- background: var(--neutral-700, #404040);
225
- border-radius: 5px;
226
  }
 
 
 
 
 
 
 
 
227
  .previewer-container input[type=range]::-webkit-slider-thumb {
228
- height: 20px;
229
- width: 20px;
230
- border-radius: 50%;
231
- background: var(--color-accent);
232
- cursor: pointer;
233
- -webkit-appearance: none;
234
- margin-top: -6px;
235
- box-shadow: 0 2px 5px rgba(0,0,0,0.2);
236
- transition: transform 0.1s;
237
- }
238
- .previewer-container input[type=range]::-webkit-slider-thumb:hover {
239
- transform: scale(1.2);
240
- }
241
-
242
- /* Overwrite Previewer Block Style */
243
- .gradio-container .padded:has(.previewer-container) {
244
- padding: 0 !important;
245
- }
246
-
247
- .gradio-container:has(.previewer-container) [data-testid="block-label"] {
248
- position: absolute;
249
- top: 0;
250
- left: 0;
251
  }
 
 
 
252
  """
253
 
254
 
255
  head = """
256
  <script>
257
  function refreshView(mode, step) {
258
- // 1. Find current mode and step
259
  const allImgs = document.querySelectorAll('.previewer-main-image');
260
  for (let i = 0; i < allImgs.length; i++) {
261
  const img = allImgs[i];
@@ -267,46 +262,26 @@ head = """
267
  break;
268
  }
269
  }
270
-
271
- // 2. Hide ALL images
272
- // We select all elements with class 'previewer-main-image'
273
  allImgs.forEach(img => img.classList.remove('visible'));
274
-
275
- // 3. Construct the specific ID for the current state
276
- // Format: view-m{mode}-s{step}
277
  const targetId = 'view-m' + mode + '-s' + step;
278
  const targetImg = document.getElementById(targetId);
279
-
280
- // 4. Show ONLY the target
281
- if (targetImg) {
282
- targetImg.classList.add('visible');
283
- }
284
-
285
- // 5. Update Button Highlights
286
  const allBtns = document.querySelectorAll('.mode-btn');
287
  allBtns.forEach((btn, idx) => {
288
  if (idx === mode) btn.classList.add('active');
289
  else btn.classList.remove('active');
290
  });
291
  }
292
-
293
- // --- Action: Switch Mode ---
294
- function selectMode(mode) {
295
- refreshView(mode, -1);
296
- }
297
-
298
- // --- Action: Slider Change ---
299
- function onSliderChange(val) {
300
- refreshView(-1, parseInt(val));
301
- }
302
  </script>
303
  """
304
 
305
 
306
- empty_html = f"""
307
  <div class="previewer-container">
308
- <svg style=" opacity: .5; height: var(--size-5); color: var(--body-text-color);"
309
- xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
310
  </div>
311
  """
312
 
@@ -326,7 +301,8 @@ def start_session(req: gr.Request):
326
 
327
  def end_session(req: gr.Request):
328
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
329
- shutil.rmtree(user_dir)
 
330
 
331
 
332
  def remove_background(input: Image.Image) -> Image.Image:
@@ -339,10 +315,7 @@ def remove_background(input: Image.Image) -> Image.Image:
339
 
340
 
341
  def preprocess_image(input: Image.Image) -> Image.Image:
342
- """
343
- Preprocess the input image.
344
- """
345
- # if has alpha channel, use it directly; otherwise, remove background
346
  has_alpha = False
347
  if input.mode == 'RGBA':
348
  alpha = np.array(input)[:, :, 3]
@@ -359,19 +332,12 @@ def preprocess_image(input: Image.Image) -> Image.Image:
359
  output_np = np.array(output)
360
  alpha = output_np[:, :, 3]
361
  bbox = np.argwhere(alpha > 0.8 * 255)
362
- if bbox.size == 0:
363
- # No visible pixels, center the image in a square
364
- size = max(output.size)
365
- square = Image.new('RGB', (size, size), (0, 0, 0))
366
- output_rgb = output.convert('RGB') if output.mode == 'RGBA' else output
367
- square.paste(output_rgb, ((size - output.width) // 2, (size - output.height) // 2))
368
- return square
369
  bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
370
  center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
371
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
372
  size = int(size * 1)
373
  bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
374
- output = output.crop(bbox) # type: ignore
375
  output = np.array(output).astype(np.float32) / 255
376
  output = output[:, :, :3] * output[:, :, 3:4]
377
  output = Image.fromarray((output * 255).astype(np.uint8))
@@ -379,17 +345,16 @@ def preprocess_image(input: Image.Image) -> Image.Image:
379
 
380
 
381
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
382
- """
383
- Preprocess a list of input images for multi-image conditioning.
384
- Uses parallel processing for faster background removal.
385
- """
386
- images = [image[0] for image in images]
387
- with ThreadPoolExecutor(max_workers=min(4, len(images))) as executor:
388
- processed_images = list(executor.map(preprocess_image, images))
389
  return processed_images
390
 
391
 
392
- def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
393
  shape_slat, tex_slat, res = latents
394
  return {
395
  'shape_slat_feats': shape_slat.feats.cpu().numpy(),
@@ -399,7 +364,8 @@ def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
399
  }
400
 
401
 
402
- def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
 
403
  shape_slat = SparseTensor(
404
  feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
405
  coords=torch.from_numpy(state['coords']).cuda(),
@@ -409,75 +375,40 @@ def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
409
 
410
 
411
  def get_seed(randomize_seed: bool, seed: int) -> int:
412
- """
413
- Get the random seed.
414
- """
415
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
416
 
417
 
418
- def prepare_multi_example() -> List[str]:
419
- """
420
- Prepare multi-image examples. Returns list of image paths.
421
- Shows only the first view as representative thumbnail.
422
- """
423
- multi_case = sorted(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
 
424
  examples = []
425
- for case in multi_case:
426
- first_img = f'assets/example_multi_image/{case}_1.png'
427
- if os.path.exists(first_img):
428
- examples.append(first_img)
 
 
 
 
429
  return examples
430
 
431
 
432
- def load_multi_example(image) -> List[Image.Image]:
433
- """Load all views for a multi-image case by matching the input image."""
434
- import hashlib
435
-
436
- # Convert numpy array to PIL Image if needed
437
- if isinstance(image, np.ndarray):
438
- image = Image.fromarray(image)
439
-
440
- # Get hash of input image for matching
441
- input_hash = hashlib.md5(np.array(image.convert('RGBA')).tobytes()).hexdigest()
442
-
443
- # Find matching case by comparing with first images
444
- multi_case = sorted(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
445
- for case_name in multi_case:
446
- first_img_path = f'assets/example_multi_image/{case_name}_1.png'
447
- if os.path.exists(first_img_path):
448
- first_img = Image.open(first_img_path).convert('RGBA')
449
- first_hash = hashlib.md5(np.array(first_img).tobytes()).hexdigest()
450
- if first_hash == input_hash:
451
- # Found match, load all views
452
- images = []
453
- for i in range(1, 7):
454
- img_path = f'assets/example_multi_image/{case_name}_{i}.png'
455
- if os.path.exists(img_path):
456
- img = Image.open(img_path)
457
- images.append(preprocess_image(img))
458
- return images
459
-
460
- # No match found, return the single image preprocessed
461
- return [preprocess_image(image)]
462
-
463
-
464
- def split_image(image: Image.Image) -> List[Image.Image]:
465
- """
466
- Split a concatenated image into multiple views.
467
- """
468
- image = np.array(image)
469
- alpha = image[..., 3]
470
- alpha = np.any(alpha > 0, axis=0)
471
- start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
472
- end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
473
  images = []
474
- for s, e in zip(start_pos, end_pos):
475
- images.append(Image.fromarray(image[:, s:e+1]))
476
- return [preprocess_image(image) for image in images]
 
477
 
478
 
479
- @spaces.GPU(duration=90)
480
  def image_to_3d(
 
481
  seed: int,
482
  resolution: str,
483
  ss_guidance_strength: float,
@@ -492,68 +423,103 @@ def image_to_3d(
492
  tex_slat_guidance_rescale: float,
493
  tex_slat_sampling_steps: int,
494
  tex_slat_rescale_t: float,
495
- multiimages: List[Tuple[Image.Image, str]],
496
- multiimage_algo: Literal["stochastic", "multidiffusion"],
497
- tex_multiimage_algo: Literal["stochastic", "multidiffusion"],
498
  req: gr.Request,
499
  progress=gr.Progress(track_tqdm=True),
500
  ) -> str:
 
 
 
 
 
 
 
 
 
501
  # --- Sampling ---
502
- outputs, latents = pipeline.run_multi_image(
503
- [image[0] for image in multiimages],
504
- seed=seed,
505
- preprocess_image=False,
506
- sparse_structure_sampler_params={
507
- "steps": ss_sampling_steps,
508
- "guidance_strength": ss_guidance_strength,
509
- "guidance_rescale": ss_guidance_rescale,
510
- "rescale_t": ss_rescale_t,
511
- },
512
- shape_slat_sampler_params={
513
- "steps": shape_slat_sampling_steps,
514
- "guidance_strength": shape_slat_guidance_strength,
515
- "guidance_rescale": shape_slat_guidance_rescale,
516
- "rescale_t": shape_slat_rescale_t,
517
- },
518
- tex_slat_sampler_params={
519
- "steps": tex_slat_sampling_steps,
520
- "guidance_strength": tex_slat_guidance_strength,
521
- "guidance_rescale": tex_slat_guidance_rescale,
522
- "rescale_t": tex_slat_rescale_t,
523
- },
524
- pipeline_type={
525
- "512": "512",
526
- "1024": "1024_cascade",
527
- "1536": "1536_cascade",
528
- }[resolution],
529
- return_latent=True,
530
- mode=multiimage_algo,
531
- tex_mode=tex_multiimage_algo,
532
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  mesh = outputs[0]
534
- mesh.simplify(16777216) # nvdiffrast limit
535
- images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
536
  state = pack_state(latents)
537
- del outputs, mesh, latents # Free memory
538
  torch.cuda.empty_cache()
539
 
540
  # --- HTML Construction ---
541
- # The Stack of 48 Images - encode in parallel for speed
542
  def encode_preview_image(args):
543
  m_idx, s_idx, render_key = args
544
- img_base64 = image_to_base64(Image.fromarray(images[render_key][s_idx]))
545
  return (m_idx, s_idx, img_base64)
546
 
547
- encode_tasks = [
548
- (m_idx, s_idx, mode['render_key'])
549
- for m_idx, mode in enumerate(MODES)
550
- for s_idx in range(STEPS)
551
- ]
552
 
553
  with ThreadPoolExecutor(max_workers=8) as executor:
554
  encoded_results = list(executor.map(encode_preview_image, encode_tasks))
555
 
556
- # Build HTML from encoded results
557
  encoded_map = {(m, s): b64 for m, s, b64 in encoded_results}
558
  images_html = ""
559
  for m_idx, mode in enumerate(MODES):
@@ -562,80 +528,46 @@ def image_to_3d(
562
  is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
563
  vis_class = "visible" if is_visible else ""
564
  img_base64 = encoded_map[(m_idx, s_idx)]
 
565
 
566
- images_html += f"""
567
- <img id="{unique_id}"
568
- class="previewer-main-image {vis_class}"
569
- src="{img_base64}"
570
- loading="eager">
571
- """
572
-
573
- # Button Row HTML
574
  btns_html = ""
575
  for idx, mode in enumerate(MODES):
576
  active_class = "active" if idx == DEFAULT_MODE else ""
577
- # Note: onclick calls the JS function defined in Head
578
- btns_html += f"""
579
- <img src="{mode['icon_base64']}"
580
- class="mode-btn {active_class}"
581
- onclick="selectMode({idx})"
582
- title="{mode['name']}">
583
- """
584
-
585
- # Assemble the full component
586
  full_html = f"""
587
  <div class="previewer-container">
588
  <div class="tips-wrapper">
589
- <div class="tips-icon">💡Tips</div>
590
  <div class="tips-text">
591
- <p>● <b>Render Mode</b> - Click on the circular buttons to switch between different render modes.</p>
592
- <p>● <b>View Angle</b> - Drag the slider to change the view angle.</p>
593
  </div>
594
  </div>
595
-
596
- <!-- Row 1: Viewport containing 48 static <img> tags -->
597
- <div class="display-row">
598
- {images_html}
599
- </div>
600
-
601
- <!-- Row 2 -->
602
- <div class="mode-row" id="btn-group">
603
- {btns_html}
604
- </div>
605
-
606
- <!-- Row 3: Slider -->
607
  <div class="slider-row">
608
  <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
609
  </div>
610
  </div>
611
  """
612
-
613
  return state, full_html
614
 
615
 
616
- @spaces.GPU(duration=60)
617
  def extract_glb(
618
  state: dict,
619
  decimation_target: int,
620
  texture_size: int,
621
  req: gr.Request,
622
  progress=gr.Progress(track_tqdm=True),
623
- ) -> str:
624
- """
625
- Extract a GLB file from the 3D model.
626
 
627
- Args:
628
- state (dict): The state of the generated 3D model.
629
- decimation_target (int): The target face count for decimation.
630
- texture_size (int): The texture resolution.
631
-
632
- Returns:
633
- str: The path to the extracted GLB file.
634
- """
635
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
636
  shape_slat, tex_slat, res = unpack_state(state)
637
  mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
638
- mesh.simplify(16777216) # nvdiffrast limit
639
  glb = o_voxel.postprocess.to_glb(
640
  vertices=mesh.vertices,
641
  faces=mesh.faces,
@@ -657,19 +589,27 @@ def extract_glb(
657
  glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
658
  glb.export(glb_path, extension_webp=True)
659
  torch.cuda.empty_cache()
660
- return glb_path
661
 
662
 
663
- with gr.Blocks(delete_cache=(600, 600), theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate")) as demo:
664
  gr.Markdown("""
665
- ## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
666
- * Upload an image and click Generate to create a 3D asset. If the image has alpha channel, it will be used as the mask. Otherwise, background is automatically removed.
667
- * Click Extract GLB to export the GLB file if you're satisfied with the preview.
668
  """)
669
 
670
  with gr.Row():
671
  with gr.Column(scale=1, min_width=360):
672
- multiimage_prompt = gr.Gallery(label="Multi-View Images", format="png", type="pil", height=400, columns=3)
 
 
 
 
 
 
 
 
673
 
674
  resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
675
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
@@ -677,9 +617,7 @@ with gr.Blocks(delete_cache=(600, 600), theme=gr.themes.Soft(primary_hue="orange
677
  decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
678
  texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
679
 
680
- with gr.Row():
681
- generate_btn = gr.Button("Generate", variant="primary")
682
- extract_btn = gr.Button("Extract GLB")
683
 
684
  with gr.Accordion(label="Advanced Settings", open=False):
685
  gr.Markdown("Stage 1: Sparse Structure Generation")
@@ -704,83 +642,71 @@ with gr.Blocks(delete_cache=(600, 600), theme=gr.themes.Soft(primary_hue="orange
704
  tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="multidiffusion")
705
 
706
  with gr.Column(scale=10):
707
- preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
708
- glb_output = gr.Model3D(label="Extracted GLB", height=400, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0), visible=False)
709
-
710
- example_image = gr.Image(visible=False) # Hidden component for examples
711
- examples_multi = gr.Examples(
712
- examples=prepare_multi_example(),
713
- inputs=[example_image],
714
- fn=load_multi_example,
715
- outputs=[multiimage_prompt],
716
- run_on_click=True,
717
- examples_per_page=24,
718
- )
 
 
 
 
 
 
 
 
719
 
720
  output_buf = gr.State()
721
 
722
-
723
  # Handlers
724
  demo.load(start_session)
725
  demo.unload(end_session)
726
- multiimage_prompt.upload(
 
727
  preprocess_images,
728
- inputs=[multiimage_prompt],
729
- outputs=[multiimage_prompt],
730
  )
731
 
732
  generate_btn.click(
733
- get_seed,
734
- inputs=[randomize_seed, seed],
735
- outputs=[seed],
736
  ).then(
737
  image_to_3d,
738
  inputs=[
739
- seed, resolution,
740
  ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
741
  shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
742
  tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
743
- multiimage_prompt, multiimage_algo, tex_multiimage_algo
744
  ],
745
  outputs=[output_buf, preview_output],
746
  )
747
 
748
  extract_btn.click(
 
 
749
  extract_glb,
750
  inputs=[output_buf, decimation_target, texture_size],
751
- outputs=[glb_output],
752
  )
753
 
754
 
755
- # Launch the Gradio app
756
  if __name__ == "__main__":
757
  os.makedirs(TMP_DIR, exist_ok=True)
758
 
759
- # Construct ui components
760
- btn_img_base64_strs = {}
761
  for i in range(len(MODES)):
762
  icon = Image.open(MODES[i]['icon'])
763
  MODES[i]['icon_base64'] = image_to_base64(icon)
764
 
765
  rmbg_client = Client("briaai/BRIA-RMBG-2.0")
766
- pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
767
- pipeline.rembg_model = None
768
- pipeline.low_vram = True # Enable low VRAM mode for better memory efficiency
769
- pipeline.cuda()
770
-
771
- envmap = {
772
- 'forest': EnvMap(torch.tensor(
773
- cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
774
- dtype=torch.float32, device='cuda'
775
- )),
776
- 'sunset': EnvMap(torch.tensor(
777
- cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
778
- dtype=torch.float32, device='cuda'
779
- )),
780
- 'courtyard': EnvMap(torch.tensor(
781
- cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
782
- dtype=torch.float32, device='cuda'
783
- )),
784
- }
785
 
786
- demo.queue(max_size=10, default_concurrency_limit=1).launch(css=css, head=head)
 
4
  from concurrent.futures import ThreadPoolExecutor
5
 
6
  import os
 
 
 
 
7
  os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
8
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
9
  os.environ["ATTN_BACKEND"] = "flash_attn_3"
10
+ os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
11
  os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
12
  from datetime import datetime
13
  import shutil
14
  import cv2
15
  from typing import *
 
16
  import numpy as np
17
  from PIL import Image
18
  import base64
19
  import io
20
  import tempfile
21
+
22
+ # Lazy imports - will be loaded when GPU is available
23
+ torch = None
24
+ SparseTensor = None
25
+ Trellis2ImageTo3DPipeline = None
26
+ EnvMap = None
27
+ render_utils = None
28
+ o_voxel = None
29
+
30
+ # Global state - initialized on first GPU call
31
+ pipeline = None
32
+ envmap = None
33
+ _initialized = False
34
+
35
+
36
+ def _lazy_import():
37
+ """Import GPU-dependent modules. Must be called from within a @spaces.GPU function."""
38
+ global torch, SparseTensor, Trellis2ImageTo3DPipeline, EnvMap, render_utils, o_voxel
39
+ if torch is None:
40
+ import torch as _torch
41
+ torch = _torch
42
+ if SparseTensor is None:
43
+ from trellis2.modules.sparse import SparseTensor as _SparseTensor
44
+ SparseTensor = _SparseTensor
45
+ if Trellis2ImageTo3DPipeline is None:
46
+ from trellis2.pipelines import Trellis2ImageTo3DPipeline as _Trellis2ImageTo3DPipeline
47
+ Trellis2ImageTo3DPipeline = _Trellis2ImageTo3DPipeline
48
+ if EnvMap is None:
49
+ from trellis2.renderers import EnvMap as _EnvMap
50
+ EnvMap = _EnvMap
51
+ if render_utils is None:
52
+ from trellis2.utils import render_utils as _render_utils
53
+ render_utils = _render_utils
54
+ if o_voxel is None:
55
+ import o_voxel as _o_voxel
56
+ o_voxel = _o_voxel
57
+ # Patch postprocess module with local fix for cumesh.fill_holes() bug
58
+ import importlib.util
59
+ _local_pp = os.path.join(os.path.dirname(os.path.abspath(__file__)),
60
+ 'o-voxel', 'o_voxel', 'postprocess.py')
61
+ if os.path.exists(_local_pp):
62
+ _spec = importlib.util.spec_from_file_location('o_voxel.postprocess', _local_pp)
63
+ _mod = importlib.util.module_from_spec(_spec)
64
+ _spec.loader.exec_module(_mod)
65
+ o_voxel.postprocess = _mod
66
+ import sys
67
+ sys.modules['o_voxel.postprocess'] = _mod
68
+
69
+
70
+ def _initialize_pipeline():
71
+ """Initialize the pipeline and environment maps. Must be called from within a @spaces.GPU function."""
72
+ global pipeline, envmap, _initialized
73
+ if _initialized:
74
+ return
75
+
76
+ _lazy_import()
77
+
78
+ pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
79
+ pipeline.rembg_model = None
80
+ pipeline.low_vram = False
81
+ pipeline.cuda()
82
+
83
+ envmap = {
84
+ 'forest': EnvMap(torch.tensor(
85
+ cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
86
+ dtype=torch.float32, device='cuda'
87
+ )),
88
+ 'sunset': EnvMap(torch.tensor(
89
+ cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
90
+ dtype=torch.float32, device='cuda'
91
+ )),
92
+ 'courtyard': EnvMap(torch.tensor(
93
+ cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
94
+ dtype=torch.float32, device='cuda'
95
+ )),
96
+ }
97
+
98
+ _initialized = True
99
 
100
 
101
  MAX_SEED = np.iinfo(np.int32).max
 
114
 
115
 
116
  css = """
117
+ /* ColmapView Dark Theme */
118
+ :root {
119
+ --bg-void: #0a0a0a;
120
+ --bg-primary: #0f0f0f;
121
+ --bg-secondary: #161616;
122
+ --bg-tertiary: #1e1e1e;
123
+ --bg-input: #1a1a1a;
124
+ --bg-elevated: #242424;
125
+ --bg-hover: #262626;
126
+ --text-primary: #e8e8e8;
127
+ --text-secondary: #8a8a8a;
128
+ --text-muted: #5a5a5a;
129
+ --border-subtle: #222222;
130
+ --border-color: #2a2a2a;
131
+ --border-light: #3a3a3a;
132
+ --accent: #b8b8b8;
133
+ --accent-hover: #d0d0d0;
134
+ --accent-dim: rgba(184, 184, 184, 0.12);
135
+ --shadow-sm: 0 2px 4px rgba(0, 0, 0, 0.4), 0 1px 2px rgba(0, 0, 0, 0.3);
136
+ --shadow-md: 0 4px 12px rgba(0, 0, 0, 0.5), 0 2px 4px rgba(0, 0, 0, 0.3);
137
+ --radius: 0.25rem;
138
+ --radius-md: 0.375rem;
139
+ --radius-lg: 0.5rem;
140
+ --font-sans: 'Roboto', -apple-system, BlinkMacSystemFont, sans-serif;
141
+ --font-mono: 'JetBrains Mono', 'Fira Code', 'Consolas', monospace;
 
 
 
 
 
 
142
  }
143
 
144
+ /* Global Overrides */
145
+ .gradio-container { background: var(--bg-primary) !important; color: var(--text-primary) !important; font-family: var(--font-sans) !important; }
146
+ .dark { background: var(--bg-primary) !important; }
147
+ body { background: var(--bg-void) !important; color-scheme: dark; }
148
+
149
+ /* Panels & Blocks */
150
+ .block { background: var(--bg-secondary) !important; border: 1px solid var(--border-color) !important; border-radius: var(--radius-lg) !important; }
151
+ .panel { background: var(--bg-secondary) !important; }
152
+
153
+ /* Inputs */
154
+ input, textarea, select { background: var(--bg-input) !important; border: 1px solid var(--border-color) !important; color: var(--text-primary) !important; border-radius: var(--radius) !important; }
155
+ input:focus, textarea:focus, select:focus { border-color: var(--accent) !important; outline: none !important; }
156
+
157
+ /* Buttons */
158
+ .primary { background: var(--accent) !important; color: var(--bg-void) !important; border: none !important; }
159
+ .primary:hover { background: var(--accent-hover) !important; }
160
+ .secondary { background: var(--bg-hover) !important; color: var(--text-primary) !important; border: 1px solid var(--border-color) !important; }
161
+ .secondary:hover { background: var(--bg-tertiary) !important; }
162
+ button { transition: all 0.15s ease-out !important; }
163
+
164
+ /* Labels & Text */
165
+ label, .label-wrap { color: var(--text-secondary) !important; font-size: 0.875rem !important; }
166
+ .prose { color: var(--text-primary) !important; }
167
+ .prose h2 { color: var(--text-primary) !important; border: none !important; }
168
+
169
+ /* Sliders */
170
+ input[type="range"] { accent-color: var(--accent) !important; }
171
+
172
+ /* Gallery */
173
+ .gallery { background: var(--bg-tertiary) !important; border: 1px solid var(--border-color) !important; }
174
+
175
+ /* Accordion */
176
+ .accordion { background: var(--bg-secondary) !important; border: 1px solid var(--border-color) !important; }
177
+
178
+ /* Scrollbar */
179
+ ::-webkit-scrollbar { width: 12px; }
180
+ ::-webkit-scrollbar-track { background: var(--bg-secondary); }
181
+ ::-webkit-scrollbar-thumb { background: var(--border-light); border-radius: 9999px; }
182
+ ::-webkit-scrollbar-thumb:hover { background: var(--text-muted); }
183
+
184
+ /* Gradio Overrides */
185
+ .stepper-wrapper { padding: 0; }
186
+ .stepper-container { padding: 0; align-items: center; }
187
+ .step-button { flex-direction: row; }
188
+ .step-connector { transform: none; }
189
+ .step-number { width: 16px; height: 16px; background: var(--bg-tertiary) !important; border: 1px solid var(--border-color) !important; }
190
+ .step-label { position: relative; bottom: 0; color: var(--text-secondary) !important; }
191
+ .wrap.center.full { inset: 0; height: 100%; }
192
+ .wrap.center.full.translucent { background: var(--bg-secondary); }
193
+ .meta-text-center { display: block !important; position: absolute !important; top: unset !important; bottom: 0 !important; right: 0 !important; transform: unset !important; }
194
 
195
  /* Previewer */
196
  .previewer-container {
197
  position: relative;
198
+ font-family: var(--font-sans);
199
  width: 100%;
200
  height: 722px;
201
  margin: 0 auto;
 
204
  flex-direction: column;
205
  align-items: center;
206
  justify-content: center;
207
+ background: var(--bg-void);
208
+ border-radius: var(--radius-lg);
209
  }
 
210
  .previewer-container .tips-icon {
211
+ position: absolute; right: 10px; top: 10px; z-index: 10;
212
+ border-radius: var(--radius-md); color: var(--bg-void);
213
+ background-color: var(--accent); padding: 4px 8px;
214
+ user-select: none; font-size: 0.75rem; font-weight: 500;
 
 
 
 
 
215
  }
 
216
  .previewer-container .tips-text {
217
+ position: absolute; right: 10px; top: 45px;
218
+ color: var(--text-primary); background-color: var(--bg-elevated);
219
+ border: 1px solid var(--border-light); border-radius: var(--radius-md);
220
+ padding: 8px 12px; text-align: left; max-width: 280px; z-index: 10;
221
+ transition: opacity 0.15s ease-out; opacity: 0; user-select: none;
222
+ box-shadow: var(--shadow-md);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  }
224
+ .previewer-container .tips-text p { font-size: 0.75rem; line-height: 1.4; color: var(--text-secondary); margin: 4px 0; }
225
+ .tips-icon:hover + .tips-text { opacity: 1; }
226
+ .previewer-container .mode-row { width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 16px; flex-wrap: wrap; }
227
  .previewer-container .mode-btn {
228
+ width: 28px; height: 28px; border-radius: 50%; cursor: pointer;
229
+ opacity: 0.5; transition: all 0.15s ease-out;
230
+ border: 2px solid var(--border-light); object-fit: cover;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  }
232
+ .previewer-container .mode-btn:hover { opacity: 0.85; transform: scale(1.1); border-color: var(--text-muted); }
233
+ .previewer-container .mode-btn.active { opacity: 1; border-color: var(--accent); transform: scale(1.1); }
234
+ .previewer-container .display-row { margin-bottom: 16px; min-height: 400px; width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; }
235
+ .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; }
236
+ .previewer-container .previewer-main-image.visible { display: block; }
237
+ .previewer-container .slider-row { width: 100%; display: flex; flex-direction: column; align-items: center; gap: 8px; padding: 0 16px; }
238
+ .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; height: 20px; }
239
+ .previewer-container input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 4px; cursor: pointer; background: var(--bg-tertiary); border-radius: 9999px; }
240
  .previewer-container input[type=range]::-webkit-slider-thumb {
241
+ height: 12px; width: 12px; border-radius: 50%; background: var(--accent);
242
+ cursor: pointer; -webkit-appearance: none; margin-top: -4px;
243
+ box-shadow: var(--shadow-sm); transition: transform 0.1s ease-out;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  }
245
+ .previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); background: var(--accent-hover); }
246
+ .gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
247
+ .gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; }
248
  """
249
 
250
 
251
  head = """
252
  <script>
253
  function refreshView(mode, step) {
 
254
  const allImgs = document.querySelectorAll('.previewer-main-image');
255
  for (let i = 0; i < allImgs.length; i++) {
256
  const img = allImgs[i];
 
262
  break;
263
  }
264
  }
 
 
 
265
  allImgs.forEach(img => img.classList.remove('visible'));
 
 
 
266
  const targetId = 'view-m' + mode + '-s' + step;
267
  const targetImg = document.getElementById(targetId);
268
+ if (targetImg) { targetImg.classList.add('visible'); }
 
 
 
 
 
 
269
  const allBtns = document.querySelectorAll('.mode-btn');
270
  allBtns.forEach((btn, idx) => {
271
  if (idx === mode) btn.classList.add('active');
272
  else btn.classList.remove('active');
273
  });
274
  }
275
+ function selectMode(mode) { refreshView(mode, -1); }
276
+ function onSliderChange(val) { refreshView(-1, parseInt(val)); }
 
 
 
 
 
 
 
 
277
  </script>
278
  """
279
 
280
 
281
+ empty_html = """
282
  <div class="previewer-container">
283
+ <svg style="opacity: .5; height: var(--size-5); color: var(--body-text-color);"
284
+ xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
285
  </div>
286
  """
287
 
 
301
 
302
  def end_session(req: gr.Request):
303
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
304
+ if os.path.exists(user_dir):
305
+ shutil.rmtree(user_dir)
306
 
307
 
308
  def remove_background(input: Image.Image) -> Image.Image:
 
315
 
316
 
317
  def preprocess_image(input: Image.Image) -> Image.Image:
318
+ """Preprocess a single input image."""
 
 
 
319
  has_alpha = False
320
  if input.mode == 'RGBA':
321
  alpha = np.array(input)[:, :, 3]
 
332
  output_np = np.array(output)
333
  alpha = output_np[:, :, 3]
334
  bbox = np.argwhere(alpha > 0.8 * 255)
 
 
 
 
 
 
 
335
  bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
336
  center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
337
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
338
  size = int(size * 1)
339
  bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
340
+ output = output.crop(bbox)
341
  output = np.array(output).astype(np.float32) / 255
342
  output = output[:, :, :3] * output[:, :, 3:4]
343
  output = Image.fromarray((output * 255).astype(np.uint8))
 
345
 
346
 
347
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
348
+ """Preprocess a list of input images. Uses parallel processing."""
349
+ if not images:
350
+ return []
351
+ imgs = [img[0] if isinstance(img, tuple) else img for img in images]
352
+ with ThreadPoolExecutor(max_workers=min(4, len(imgs))) as executor:
353
+ processed_images = list(executor.map(preprocess_image, imgs))
 
354
  return processed_images
355
 
356
 
357
+ def pack_state(latents):
358
  shape_slat, tex_slat, res = latents
359
  return {
360
  'shape_slat_feats': shape_slat.feats.cpu().numpy(),
 
364
  }
365
 
366
 
367
+ def unpack_state(state: dict):
368
+ _lazy_import()
369
  shape_slat = SparseTensor(
370
  feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
371
  coords=torch.from_numpy(state['coords']).cuda(),
 
375
 
376
 
377
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
 
 
378
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
379
 
380
 
381
+ def prepare_examples() -> List[List[str]]:
382
+ """Prepare examples as lists of image paths (not concatenated)."""
383
+ example_dir = "assets/example_multi_image"
384
+ if not os.path.exists(example_dir):
385
+ return []
386
+ files = os.listdir(example_dir)
387
+ cases = list(set([f.split('_')[0] for f in files if '_' in f and f.endswith('.png')]))
388
  examples = []
389
+ for case in sorted(cases):
390
+ case_images = []
391
+ for i in range(1, 10): # Support up to 9 images per example
392
+ img_path = f'{example_dir}/{case}_{i}.png'
393
+ if os.path.exists(img_path):
394
+ case_images.append(img_path)
395
+ if case_images:
396
+ examples.append(case_images)
397
  return examples
398
 
399
 
400
+ def load_example(example_paths: List[str]) -> List[Image.Image]:
401
+ """Load example images from paths."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
  images = []
403
+ for path in example_paths:
404
+ img = Image.open(path)
405
+ images.append(img)
406
+ return images
407
 
408
 
409
+ @spaces.GPU(duration=120)
410
  def image_to_3d(
411
+ images: List[Tuple[Image.Image, str]],
412
  seed: int,
413
  resolution: str,
414
  ss_guidance_strength: float,
 
423
  tex_slat_guidance_rescale: float,
424
  tex_slat_sampling_steps: int,
425
  tex_slat_rescale_t: float,
426
+ multiimage_algo: Literal["multidiffusion", "stochastic"],
427
+ tex_multiimage_algo: Literal["multidiffusion", "stochastic"],
 
428
  req: gr.Request,
429
  progress=gr.Progress(track_tqdm=True),
430
  ) -> str:
431
+ # Initialize pipeline on first call
432
+ _initialize_pipeline()
433
+
434
+ # Extract images from gallery format
435
+ if not images:
436
+ raise gr.Error("Please upload at least one image")
437
+
438
+ imgs = [img[0] if isinstance(img, tuple) else img for img in images]
439
+
440
  # --- Sampling ---
441
+ if len(imgs) == 1:
442
+ # Single image mode
443
+ outputs, latents = pipeline.run(
444
+ imgs[0],
445
+ seed=seed,
446
+ preprocess_image=False,
447
+ sparse_structure_sampler_params={
448
+ "steps": ss_sampling_steps,
449
+ "guidance_strength": ss_guidance_strength,
450
+ "guidance_rescale": ss_guidance_rescale,
451
+ "rescale_t": ss_rescale_t,
452
+ },
453
+ shape_slat_sampler_params={
454
+ "steps": shape_slat_sampling_steps,
455
+ "guidance_strength": shape_slat_guidance_strength,
456
+ "guidance_rescale": shape_slat_guidance_rescale,
457
+ "rescale_t": shape_slat_rescale_t,
458
+ },
459
+ tex_slat_sampler_params={
460
+ "steps": tex_slat_sampling_steps,
461
+ "guidance_strength": tex_slat_guidance_strength,
462
+ "guidance_rescale": tex_slat_guidance_rescale,
463
+ "rescale_t": tex_slat_rescale_t,
464
+ },
465
+ pipeline_type={
466
+ "512": "512",
467
+ "1024": "1024_cascade",
468
+ "1536": "1536_cascade",
469
+ }[resolution],
470
+ return_latent=True,
471
+ )
472
+ else:
473
+ # Multi-image mode
474
+ outputs, latents = pipeline.run_multi_image(
475
+ imgs,
476
+ seed=seed,
477
+ preprocess_image=False,
478
+ sparse_structure_sampler_params={
479
+ "steps": ss_sampling_steps,
480
+ "guidance_strength": ss_guidance_strength,
481
+ "guidance_rescale": ss_guidance_rescale,
482
+ "rescale_t": ss_rescale_t,
483
+ },
484
+ shape_slat_sampler_params={
485
+ "steps": shape_slat_sampling_steps,
486
+ "guidance_strength": shape_slat_guidance_strength,
487
+ "guidance_rescale": shape_slat_guidance_rescale,
488
+ "rescale_t": shape_slat_rescale_t,
489
+ },
490
+ tex_slat_sampler_params={
491
+ "steps": tex_slat_sampling_steps,
492
+ "guidance_strength": tex_slat_guidance_strength,
493
+ "guidance_rescale": tex_slat_guidance_rescale,
494
+ "rescale_t": tex_slat_rescale_t,
495
+ },
496
+ pipeline_type={
497
+ "512": "512",
498
+ "1024": "1024_cascade",
499
+ "1536": "1536_cascade",
500
+ }[resolution],
501
+ return_latent=True,
502
+ mode=multiimage_algo,
503
+ tex_mode=tex_multiimage_algo,
504
+ )
505
+
506
  mesh = outputs[0]
507
+ mesh.simplify(16777216)
508
+ render_images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
509
  state = pack_state(latents)
 
510
  torch.cuda.empty_cache()
511
 
512
  # --- HTML Construction ---
 
513
  def encode_preview_image(args):
514
  m_idx, s_idx, render_key = args
515
+ img_base64 = image_to_base64(Image.fromarray(render_images[render_key][s_idx]))
516
  return (m_idx, s_idx, img_base64)
517
 
518
+ encode_tasks = [(m_idx, s_idx, mode['render_key']) for m_idx, mode in enumerate(MODES) for s_idx in range(STEPS)]
 
 
 
 
519
 
520
  with ThreadPoolExecutor(max_workers=8) as executor:
521
  encoded_results = list(executor.map(encode_preview_image, encode_tasks))
522
 
 
523
  encoded_map = {(m, s): b64 for m, s, b64 in encoded_results}
524
  images_html = ""
525
  for m_idx, mode in enumerate(MODES):
 
528
  is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
529
  vis_class = "visible" if is_visible else ""
530
  img_base64 = encoded_map[(m_idx, s_idx)]
531
+ images_html += f'<img id="{unique_id}" class="previewer-main-image {vis_class}" src="{img_base64}" loading="eager">'
532
 
 
 
 
 
 
 
 
 
533
  btns_html = ""
534
  for idx, mode in enumerate(MODES):
535
  active_class = "active" if idx == DEFAULT_MODE else ""
536
+ btns_html += f'<img src="{mode["icon_base64"]}" class="mode-btn {active_class}" onclick="selectMode({idx})" title="{mode["name"]}">'
537
+
 
 
 
 
 
 
 
538
  full_html = f"""
539
  <div class="previewer-container">
540
  <div class="tips-wrapper">
541
+ <div class="tips-icon">Tips</div>
542
  <div class="tips-text">
543
+ <p>Render Mode - Click buttons to switch render modes.</p>
544
+ <p>View Angle - Drag slider to change view.</p>
545
  </div>
546
  </div>
547
+ <div class="display-row">{images_html}</div>
548
+ <div class="mode-row" id="btn-group">{btns_html}</div>
 
 
 
 
 
 
 
 
 
 
549
  <div class="slider-row">
550
  <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
551
  </div>
552
  </div>
553
  """
 
554
  return state, full_html
555
 
556
 
557
+ @spaces.GPU(duration=120)
558
  def extract_glb(
559
  state: dict,
560
  decimation_target: int,
561
  texture_size: int,
562
  req: gr.Request,
563
  progress=gr.Progress(track_tqdm=True),
564
+ ) -> Tuple[str, str]:
565
+ _initialize_pipeline()
 
566
 
 
 
 
 
 
 
 
 
567
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
568
  shape_slat, tex_slat, res = unpack_state(state)
569
  mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
570
+ mesh.simplify(16777216)
571
  glb = o_voxel.postprocess.to_glb(
572
  vertices=mesh.vertices,
573
  faces=mesh.faces,
 
589
  glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
590
  glb.export(glb_path, extension_webp=True)
591
  torch.cuda.empty_cache()
592
+ return glb_path, glb_path
593
 
594
 
595
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
596
  gr.Markdown("""
597
+ ## Multi-View Image to 3D with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
598
+ Upload one or more images of an object and click Generate to create a 3D asset.
599
+ Multiple views from different angles will produce better results.
600
  """)
601
 
602
  with gr.Row():
603
  with gr.Column(scale=1, min_width=360):
604
+ image_prompt = gr.Gallery(
605
+ label="Input Images (upload 1 or more views)",
606
+ format="png",
607
+ type="pil",
608
+ height=400,
609
+ columns=3,
610
+ object_fit="contain"
611
+ )
612
+ gr.Markdown("*Upload multiple views of the same object for better 3D reconstruction.*")
613
 
614
  resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
615
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
 
617
  decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
618
  texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
619
 
620
+ generate_btn = gr.Button("Generate", variant="primary")
 
 
621
 
622
  with gr.Accordion(label="Advanced Settings", open=False):
623
  gr.Markdown("Stage 1: Sparse Structure Generation")
 
642
  tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="multidiffusion")
643
 
644
  with gr.Column(scale=10):
645
+ with gr.Walkthrough(selected=0) as walkthrough:
646
+ with gr.Step("Preview", id=0):
647
+ preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
648
+ extract_btn = gr.Button("Extract GLB")
649
+ with gr.Step("Extract", id=1):
650
+ glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
651
+ download_btn = gr.DownloadButton(label="Download GLB")
652
+ gr.Markdown("*GLB extraction may take 30+ seconds.*")
653
+
654
+ with gr.Column(scale=1, min_width=200):
655
+ gr.Markdown("### Examples")
656
+ # Create example buttons that load images into gallery
657
+ example_data = prepare_examples()
658
+ for i, example_paths in enumerate(example_data[:12]): # Limit to 12 examples
659
+ case_name = os.path.basename(example_paths[0]).split('_')[0]
660
+ btn = gr.Button(f"{case_name} ({len(example_paths)} views)", size="sm")
661
+ btn.click(
662
+ fn=lambda paths=example_paths: load_example(paths),
663
+ outputs=[image_prompt]
664
+ )
665
 
666
  output_buf = gr.State()
667
 
 
668
  # Handlers
669
  demo.load(start_session)
670
  demo.unload(end_session)
671
+
672
+ image_prompt.upload(
673
  preprocess_images,
674
+ inputs=[image_prompt],
675
+ outputs=[image_prompt],
676
  )
677
 
678
  generate_btn.click(
679
+ get_seed, inputs=[randomize_seed, seed], outputs=[seed],
680
+ ).then(
681
+ lambda: gr.Walkthrough(selected=0), outputs=walkthrough
682
  ).then(
683
  image_to_3d,
684
  inputs=[
685
+ image_prompt, seed, resolution,
686
  ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
687
  shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
688
  tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
689
+ multiimage_algo, tex_multiimage_algo
690
  ],
691
  outputs=[output_buf, preview_output],
692
  )
693
 
694
  extract_btn.click(
695
+ lambda: gr.Walkthrough(selected=1), outputs=walkthrough
696
+ ).then(
697
  extract_glb,
698
  inputs=[output_buf, decimation_target, texture_size],
699
+ outputs=[glb_output, download_btn],
700
  )
701
 
702
 
 
703
  if __name__ == "__main__":
704
  os.makedirs(TMP_DIR, exist_ok=True)
705
 
 
 
706
  for i in range(len(MODES)):
707
  icon = Image.open(MODES[i]['icon'])
708
  MODES[i]['icon_base64'] = image_to_base64(icon)
709
 
710
  rmbg_client = Client("briaai/BRIA-RMBG-2.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
 
712
+ demo.launch(css=css, head=head)