opsiclear-admin commited on
Commit
ac144c2
·
verified ·
1 Parent(s): be34f5c

Unified image input: single Gallery for 1+ images, no blending

Browse files
Files changed (1) hide show
  1. app.py +121 -371
app.py CHANGED
@@ -104,205 +104,42 @@ DEFAULT_STEP = 3
104
 
105
  css = """
106
  /* Overwrite Gradio Default Style */
107
- .stepper-wrapper {
108
- padding: 0;
109
- }
110
-
111
- .stepper-container {
112
- padding: 0;
113
- align-items: center;
114
- }
115
-
116
- .step-button {
117
- flex-direction: row;
118
- }
119
-
120
- .step-connector {
121
- transform: none;
122
- }
123
-
124
- .step-number {
125
- width: 16px;
126
- height: 16px;
127
- }
128
-
129
- .step-label {
130
- position: relative;
131
- bottom: 0;
132
- }
133
-
134
- .wrap.center.full {
135
- inset: 0;
136
- height: 100%;
137
- }
138
-
139
- .wrap.center.full.translucent {
140
- background: var(--block-background-fill);
141
- }
142
-
143
- .meta-text-center {
144
- display: block !important;
145
- position: absolute !important;
146
- top: unset !important;
147
- bottom: 0 !important;
148
- right: 0 !important;
149
- transform: unset !important;
150
- }
151
 
152
  /* Previewer */
153
- .previewer-container {
154
- position: relative;
155
- font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
156
- width: 100%;
157
- height: 722px;
158
- margin: 0 auto;
159
- padding: 20px;
160
- display: flex;
161
- flex-direction: column;
162
- align-items: center;
163
- justify-content: center;
164
- }
165
-
166
- .previewer-container .tips-icon {
167
- position: absolute;
168
- right: 10px;
169
- top: 10px;
170
- z-index: 10;
171
- border-radius: 10px;
172
- color: #fff;
173
- background-color: var(--color-accent);
174
- padding: 3px 6px;
175
- user-select: none;
176
- }
177
-
178
- .previewer-container .tips-text {
179
- position: absolute;
180
- right: 10px;
181
- top: 50px;
182
- color: #fff;
183
- background-color: var(--color-accent);
184
- border-radius: 10px;
185
- padding: 6px;
186
- text-align: left;
187
- max-width: 300px;
188
- z-index: 10;
189
- transition: all 0.3s;
190
- opacity: 0%;
191
- user-select: none;
192
- }
193
-
194
- .previewer-container .tips-text p {
195
- font-size: 14px;
196
- line-height: 1.2;
197
- }
198
-
199
- .tips-icon:hover + .tips-text {
200
- display: block;
201
- opacity: 100%;
202
- }
203
-
204
- /* Row 1: Display Modes */
205
- .previewer-container .mode-row {
206
- width: 100%;
207
- display: flex;
208
- gap: 8px;
209
- justify-content: center;
210
- margin-bottom: 20px;
211
- flex-wrap: wrap;
212
- }
213
- .previewer-container .mode-btn {
214
- width: 24px;
215
- height: 24px;
216
- border-radius: 50%;
217
- cursor: pointer;
218
- opacity: 0.5;
219
- transition: all 0.2s;
220
- border: 2px solid #ddd;
221
- object-fit: cover;
222
- }
223
  .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
224
- .previewer-container .mode-btn.active {
225
- opacity: 1;
226
- border-color: var(--color-accent);
227
- transform: scale(1.1);
228
- }
229
-
230
- /* Row 2: Display Image */
231
- .previewer-container .display-row {
232
- margin-bottom: 20px;
233
- min-height: 400px;
234
- width: 100%;
235
- flex-grow: 1;
236
- display: flex;
237
- justify-content: center;
238
- align-items: center;
239
- }
240
- .previewer-container .previewer-main-image {
241
- max-width: 100%;
242
- max-height: 100%;
243
- flex-grow: 1;
244
- object-fit: contain;
245
- display: none;
246
- }
247
- .previewer-container .previewer-main-image.visible {
248
- display: block;
249
- }
250
-
251
- /* Row 3: Custom HTML Slider */
252
- .previewer-container .slider-row {
253
- width: 100%;
254
- display: flex;
255
- flex-direction: column;
256
- align-items: center;
257
- gap: 10px;
258
- padding: 0 10px;
259
- }
260
-
261
- .previewer-container input[type=range] {
262
- -webkit-appearance: none;
263
- width: 100%;
264
- max-width: 400px;
265
- background: transparent;
266
- }
267
- .previewer-container input[type=range]::-webkit-slider-runnable-track {
268
- width: 100%;
269
- height: 8px;
270
- cursor: pointer;
271
- background: #ddd;
272
- border-radius: 5px;
273
- }
274
- .previewer-container input[type=range]::-webkit-slider-thumb {
275
- height: 20px;
276
- width: 20px;
277
- border-radius: 50%;
278
- background: var(--color-accent);
279
- cursor: pointer;
280
- -webkit-appearance: none;
281
- margin-top: -6px;
282
- box-shadow: 0 2px 5px rgba(0,0,0,0.2);
283
- transition: transform 0.1s;
284
- }
285
- .previewer-container input[type=range]::-webkit-slider-thumb:hover {
286
- transform: scale(1.2);
287
- }
288
-
289
- /* Overwrite Previewer Block Style */
290
- .gradio-container .padded:has(.previewer-container) {
291
- padding: 0 !important;
292
- }
293
-
294
- .gradio-container:has(.previewer-container) [data-testid="block-label"] {
295
- position: absolute;
296
- top: 0;
297
- left: 0;
298
- }
299
  """
300
 
301
 
302
  head = """
303
  <script>
304
  function refreshView(mode, step) {
305
- // 1. Find current mode and step
306
  const allImgs = document.querySelectorAll('.previewer-main-image');
307
  for (let i = 0; i < allImgs.length; i++) {
308
  const img = allImgs[i];
@@ -314,46 +151,26 @@ head = """
314
  break;
315
  }
316
  }
317
-
318
- // 2. Hide ALL images
319
- // We select all elements with class 'previewer-main-image'
320
  allImgs.forEach(img => img.classList.remove('visible'));
321
-
322
- // 3. Construct the specific ID for the current state
323
- // Format: view-m{mode}-s{step}
324
  const targetId = 'view-m' + mode + '-s' + step;
325
  const targetImg = document.getElementById(targetId);
326
-
327
- // 4. Show ONLY the target
328
- if (targetImg) {
329
- targetImg.classList.add('visible');
330
- }
331
-
332
- // 5. Update Button Highlights
333
  const allBtns = document.querySelectorAll('.mode-btn');
334
  allBtns.forEach((btn, idx) => {
335
  if (idx === mode) btn.classList.add('active');
336
  else btn.classList.remove('active');
337
  });
338
  }
339
-
340
- // --- Action: Switch Mode ---
341
- function selectMode(mode) {
342
- refreshView(mode, -1);
343
- }
344
-
345
- // --- Action: Slider Change ---
346
- function onSliderChange(val) {
347
- refreshView(-1, parseInt(val));
348
- }
349
  </script>
350
  """
351
 
352
 
353
- empty_html = f"""
354
  <div class="previewer-container">
355
- <svg style=" opacity: .5; height: var(--size-5); color: var(--body-text-color);"
356
- xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
357
  </div>
358
  """
359
 
@@ -387,10 +204,7 @@ def remove_background(input: Image.Image) -> Image.Image:
387
 
388
 
389
  def preprocess_image(input: Image.Image) -> Image.Image:
390
- """
391
- Preprocess the input image.
392
- """
393
- # if has alpha channel, use it directly; otherwise, remove background
394
  has_alpha = False
395
  if input.mode == 'RGBA':
396
  alpha = np.array(input)[:, :, 3]
@@ -412,7 +226,7 @@ def preprocess_image(input: Image.Image) -> Image.Image:
412
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
413
  size = int(size * 1)
414
  bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
415
- output = output.crop(bbox) # type: ignore
416
  output = np.array(output).astype(np.float32) / 255
417
  output = output[:, :, :3] * output[:, :, 3:4]
418
  output = Image.fromarray((output * 255).astype(np.uint8))
@@ -420,13 +234,12 @@ def preprocess_image(input: Image.Image) -> Image.Image:
420
 
421
 
422
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
423
- """
424
- Preprocess a list of input images for multi-image conditioning.
425
- Uses parallel processing for faster background removal.
426
- """
427
- images = [image[0] for image in images]
428
- with ThreadPoolExecutor(max_workers=min(4, len(images))) as executor:
429
- processed_images = list(executor.map(preprocess_image, images))
430
  return processed_images
431
 
432
 
@@ -451,47 +264,40 @@ def unpack_state(state: dict):
451
 
452
 
453
  def get_seed(randomize_seed: bool, seed: int) -> int:
454
- """
455
- Get the random seed.
456
- """
457
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
458
 
459
 
460
- def prepare_multi_example() -> List[Image.Image]:
461
- """
462
- Prepare multi-image examples for the gallery.
463
- """
464
- multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
  images = []
466
- for case in multi_case:
467
- _images = []
468
- for i in range(1, 4):
469
- img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
470
- W, H = img.size
471
- img = img.resize((int(W / H * 512), 512))
472
- _images.append(np.array(img))
473
- images.append(Image.fromarray(np.concatenate(_images, axis=1)))
474
  return images
475
 
476
 
477
- def split_image(image: Image.Image) -> List[Image.Image]:
478
- """
479
- Split a concatenated image into multiple views.
480
- """
481
- image = np.array(image)
482
- alpha = image[..., 3]
483
- alpha = np.any(alpha > 0, axis=0)
484
- start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
485
- end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
486
- images = []
487
- for s, e in zip(start_pos, end_pos):
488
- images.append(Image.fromarray(image[:, s:e+1]))
489
- return [preprocess_image(image) for image in images]
490
-
491
-
492
  @spaces.GPU(duration=120)
493
  def image_to_3d(
494
- image: Image.Image,
495
  seed: int,
496
  resolution: str,
497
  ss_guidance_strength: float,
@@ -506,19 +312,24 @@ def image_to_3d(
506
  tex_slat_guidance_rescale: float,
507
  tex_slat_sampling_steps: int,
508
  tex_slat_rescale_t: float,
 
509
  req: gr.Request,
510
  progress=gr.Progress(track_tqdm=True),
511
- multiimages: List[Tuple[Image.Image, str]] = None,
512
- is_multiimage: bool = False,
513
- multiimage_algo: Literal["multidiffusion", "stochastic"] = "stochastic",
514
  ) -> str:
515
  # Initialize pipeline on first call
516
  _initialize_pipeline()
517
 
 
 
 
 
 
 
518
  # --- Sampling ---
519
- if not is_multiimage:
 
520
  outputs, latents = pipeline.run(
521
- image,
522
  seed=seed,
523
  preprocess_image=False,
524
  sparse_structure_sampler_params={
@@ -547,8 +358,9 @@ def image_to_3d(
547
  return_latent=True,
548
  )
549
  else:
 
550
  outputs, latents = pipeline.run_multi_image(
551
- [image[0] for image in multiimages],
552
  seed=seed,
553
  preprocess_image=False,
554
  sparse_structure_sampler_params={
@@ -577,29 +389,24 @@ def image_to_3d(
577
  return_latent=True,
578
  mode=multiimage_algo,
579
  )
 
580
  mesh = outputs[0]
581
- mesh.simplify(16777216) # nvdiffrast limit
582
- images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
583
  state = pack_state(latents)
584
  torch.cuda.empty_cache()
585
 
586
  # --- HTML Construction ---
587
- # The Stack of 48 Images - encode in parallel for speed
588
  def encode_preview_image(args):
589
  m_idx, s_idx, render_key = args
590
- img_base64 = image_to_base64(Image.fromarray(images[render_key][s_idx]))
591
  return (m_idx, s_idx, img_base64)
592
 
593
- encode_tasks = [
594
- (m_idx, s_idx, mode['render_key'])
595
- for m_idx, mode in enumerate(MODES)
596
- for s_idx in range(STEPS)
597
- ]
598
 
599
  with ThreadPoolExecutor(max_workers=8) as executor:
600
  encoded_results = list(executor.map(encode_preview_image, encode_tasks))
601
 
602
- # Build HTML from encoded results
603
  encoded_map = {(m, s): b64 for m, s, b64 in encoded_results}
604
  images_html = ""
605
  for m_idx, mode in enumerate(MODES):
@@ -608,54 +415,29 @@ def image_to_3d(
608
  is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
609
  vis_class = "visible" if is_visible else ""
610
  img_base64 = encoded_map[(m_idx, s_idx)]
 
611
 
612
- images_html += f"""
613
- <img id="{unique_id}"
614
- class="previewer-main-image {vis_class}"
615
- src="{img_base64}"
616
- loading="eager">
617
- """
618
-
619
- # Button Row HTML
620
  btns_html = ""
621
  for idx, mode in enumerate(MODES):
622
  active_class = "active" if idx == DEFAULT_MODE else ""
623
- # Note: onclick calls the JS function defined in Head
624
- btns_html += f"""
625
- <img src="{mode['icon_base64']}"
626
- class="mode-btn {active_class}"
627
- onclick="selectMode({idx})"
628
- title="{mode['name']}">
629
- """
630
-
631
- # Assemble the full component
632
  full_html = f"""
633
  <div class="previewer-container">
634
  <div class="tips-wrapper">
635
- <div class="tips-icon">💡Tips</div>
636
  <div class="tips-text">
637
- <p>● <b>Render Mode</b> - Click on the circular buttons to switch between different render modes.</p>
638
- <p>● <b>View Angle</b> - Drag the slider to change the view angle.</p>
639
  </div>
640
  </div>
641
-
642
- <!-- Row 1: Viewport containing 48 static <img> tags -->
643
- <div class="display-row">
644
- {images_html}
645
- </div>
646
-
647
- <!-- Row 2 -->
648
- <div class="mode-row" id="btn-group">
649
- {btns_html}
650
- </div>
651
-
652
- <!-- Row 3: Slider -->
653
  <div class="slider-row">
654
  <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
655
  </div>
656
  </div>
657
  """
658
-
659
  return state, full_html
660
 
661
 
@@ -667,24 +449,12 @@ def extract_glb(
667
  req: gr.Request,
668
  progress=gr.Progress(track_tqdm=True),
669
  ) -> Tuple[str, str]:
670
- """
671
- Extract a GLB file from the 3D model.
672
-
673
- Args:
674
- state (dict): The state of the generated 3D model.
675
- decimation_target (int): The target face count for decimation.
676
- texture_size (int): The texture resolution.
677
-
678
- Returns:
679
- str: The path to the extracted GLB file.
680
- """
681
- # Initialize pipeline on first call
682
  _initialize_pipeline()
683
 
684
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
685
  shape_slat, tex_slat, res = unpack_state(state)
686
  mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
687
- mesh.simplify(16777216) # nvdiffrast limit
688
  glb = o_voxel.postprocess.to_glb(
689
  vertices=mesh.vertices,
690
  faces=mesh.faces,
@@ -711,22 +481,22 @@ def extract_glb(
711
 
712
  with gr.Blocks(delete_cache=(600, 600)) as demo:
713
  gr.Markdown("""
714
- ## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
715
- * Upload an image (preferably with an alpha-masked foreground object) and click Generate to create a 3D asset.
716
- * Click Extract GLB to export and download the generated GLB file if you're satisfied with the result. Otherwise, try another time.
717
  """)
718
 
719
  with gr.Row():
720
  with gr.Column(scale=1, min_width=360):
721
- with gr.Tabs() as input_tabs:
722
- with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
723
- image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400)
724
- with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
725
- multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=400, columns=3)
726
- gr.Markdown("""
727
- Input different views of the object in separate images.
728
- *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
729
- """)
730
 
731
  resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
732
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
@@ -734,7 +504,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
734
  decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
735
  texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
736
 
737
- generate_btn = gr.Button("Generate")
738
 
739
  with gr.Accordion(label="Advanced Settings", open=False):
740
  gr.Markdown("Stage 1: Sparse Structure Generation")
@@ -765,51 +535,34 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
765
  with gr.Step("Extract", id=1):
766
  glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
767
  download_btn = gr.DownloadButton(label="Download GLB")
768
- gr.Markdown("*We are actively working on improving the speed of GLB extraction. Currently, it may take half a minute or more and face count is limited.*")
769
-
770
- with gr.Column(scale=1, min_width=172) as multiimage_example:
771
- examples_multi = gr.Examples(
772
- examples=prepare_multi_example(),
773
- label="Multi Image Examples",
774
- inputs=[image_prompt],
775
- fn=split_image,
776
- outputs=[multiimage_prompt],
777
- run_on_click=True,
778
- examples_per_page=8,
779
- )
 
780
 
781
- is_multiimage = gr.State(False)
782
  output_buf = gr.State()
783
 
784
-
785
  # Handlers
786
  demo.load(start_session)
787
  demo.unload(end_session)
788
 
789
- single_image_input_tab.select(
790
- lambda: False,
791
- outputs=[is_multiimage]
792
- )
793
- multiimage_input_tab.select(
794
- lambda: True,
795
- outputs=[is_multiimage]
796
- )
797
-
798
  image_prompt.upload(
799
- preprocess_image,
800
  inputs=[image_prompt],
801
  outputs=[image_prompt],
802
  )
803
- multiimage_prompt.upload(
804
- preprocess_images,
805
- inputs=[multiimage_prompt],
806
- outputs=[multiimage_prompt],
807
- )
808
 
809
  generate_btn.click(
810
- get_seed,
811
- inputs=[randomize_seed, seed],
812
- outputs=[seed],
813
  ).then(
814
  lambda: gr.Walkthrough(selected=0), outputs=walkthrough
815
  ).then(
@@ -819,7 +572,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
819
  ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
820
  shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
821
  tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
822
- multiimage_prompt, is_multiimage, multiimage_algo
823
  ],
824
  outputs=[output_buf, preview_output],
825
  )
@@ -833,12 +586,9 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
833
  )
834
 
835
 
836
- # Launch the Gradio app
837
  if __name__ == "__main__":
838
  os.makedirs(TMP_DIR, exist_ok=True)
839
 
840
- # Construct ui components (CPU-only, no GPU needed)
841
- btn_img_base64_strs = {}
842
  for i in range(len(MODES)):
843
  icon = Image.open(MODES[i]['icon'])
844
  MODES[i]['icon_base64'] = image_to_base64(icon)
 
104
 
105
  css = """
106
  /* Overwrite Gradio Default Style */
107
+ .stepper-wrapper { padding: 0; }
108
+ .stepper-container { padding: 0; align-items: center; }
109
+ .step-button { flex-direction: row; }
110
+ .step-connector { transform: none; }
111
+ .step-number { width: 16px; height: 16px; }
112
+ .step-label { position: relative; bottom: 0; }
113
+ .wrap.center.full { inset: 0; height: 100%; }
114
+ .wrap.center.full.translucent { background: var(--block-background-fill); }
115
+ .meta-text-center { display: block !important; position: absolute !important; top: unset !important; bottom: 0 !important; right: 0 !important; transform: unset !important; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
 
117
  /* Previewer */
118
+ .previewer-container { position: relative; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; width: 100%; height: 722px; margin: 0 auto; padding: 20px; display: flex; flex-direction: column; align-items: center; justify-content: center; }
119
+ .previewer-container .tips-icon { position: absolute; right: 10px; top: 10px; z-index: 10; border-radius: 10px; color: #fff; background-color: var(--color-accent); padding: 3px 6px; user-select: none; }
120
+ .previewer-container .tips-text { position: absolute; right: 10px; top: 50px; color: #fff; background-color: var(--color-accent); border-radius: 10px; padding: 6px; text-align: left; max-width: 300px; z-index: 10; transition: all 0.3s; opacity: 0%; user-select: none; }
121
+ .previewer-container .tips-text p { font-size: 14px; line-height: 1.2; }
122
+ .tips-icon:hover + .tips-text { display: block; opacity: 100%; }
123
+ .previewer-container .mode-row { width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 20px; flex-wrap: wrap; }
124
+ .previewer-container .mode-btn { width: 24px; height: 24px; border-radius: 50%; cursor: pointer; opacity: 0.5; transition: all 0.2s; border: 2px solid #ddd; object-fit: cover; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
126
+ .previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); }
127
+ .previewer-container .display-row { margin-bottom: 20px; min-height: 400px; width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; }
128
+ .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; }
129
+ .previewer-container .previewer-main-image.visible { display: block; }
130
+ .previewer-container .slider-row { width: 100%; display: flex; flex-direction: column; align-items: center; gap: 10px; padding: 0 10px; }
131
+ .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; }
132
+ .previewer-container input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 8px; cursor: pointer; background: #ddd; border-radius: 5px; }
133
+ .previewer-container input[type=range]::-webkit-slider-thumb { height: 20px; width: 20px; border-radius: 50%; background: var(--color-accent); cursor: pointer; -webkit-appearance: none; margin-top: -6px; box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: transform 0.1s; }
134
+ .previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); }
135
+ .gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
136
+ .gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  """
138
 
139
 
140
  head = """
141
  <script>
142
  function refreshView(mode, step) {
 
143
  const allImgs = document.querySelectorAll('.previewer-main-image');
144
  for (let i = 0; i < allImgs.length; i++) {
145
  const img = allImgs[i];
 
151
  break;
152
  }
153
  }
 
 
 
154
  allImgs.forEach(img => img.classList.remove('visible'));
 
 
 
155
  const targetId = 'view-m' + mode + '-s' + step;
156
  const targetImg = document.getElementById(targetId);
157
+ if (targetImg) { targetImg.classList.add('visible'); }
 
 
 
 
 
 
158
  const allBtns = document.querySelectorAll('.mode-btn');
159
  allBtns.forEach((btn, idx) => {
160
  if (idx === mode) btn.classList.add('active');
161
  else btn.classList.remove('active');
162
  });
163
  }
164
+ function selectMode(mode) { refreshView(mode, -1); }
165
+ function onSliderChange(val) { refreshView(-1, parseInt(val)); }
 
 
 
 
 
 
 
 
166
  </script>
167
  """
168
 
169
 
170
+ empty_html = """
171
  <div class="previewer-container">
172
+ <svg style="opacity: .5; height: var(--size-5); color: var(--body-text-color);"
173
+ xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
174
  </div>
175
  """
176
 
 
204
 
205
 
206
  def preprocess_image(input: Image.Image) -> Image.Image:
207
+ """Preprocess a single input image."""
 
 
 
208
  has_alpha = False
209
  if input.mode == 'RGBA':
210
  alpha = np.array(input)[:, :, 3]
 
226
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
227
  size = int(size * 1)
228
  bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
229
+ output = output.crop(bbox)
230
  output = np.array(output).astype(np.float32) / 255
231
  output = output[:, :, :3] * output[:, :, 3:4]
232
  output = Image.fromarray((output * 255).astype(np.uint8))
 
234
 
235
 
236
  def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
237
+ """Preprocess a list of input images. Uses parallel processing."""
238
+ if not images:
239
+ return []
240
+ imgs = [img[0] if isinstance(img, tuple) else img for img in images]
241
+ with ThreadPoolExecutor(max_workers=min(4, len(imgs))) as executor:
242
+ processed_images = list(executor.map(preprocess_image, imgs))
 
243
  return processed_images
244
 
245
 
 
264
 
265
 
266
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
 
 
267
  return np.random.randint(0, MAX_SEED) if randomize_seed else seed
268
 
269
 
270
+ def prepare_examples() -> List[List[str]]:
271
+ """Prepare examples as lists of image paths (not concatenated)."""
272
+ example_dir = "assets/example_multi_image"
273
+ if not os.path.exists(example_dir):
274
+ return []
275
+ files = os.listdir(example_dir)
276
+ cases = list(set([f.split('_')[0] for f in files if '_' in f and f.endswith('.png')]))
277
+ examples = []
278
+ for case in sorted(cases):
279
+ case_images = []
280
+ for i in range(1, 10): # Support up to 9 images per example
281
+ img_path = f'{example_dir}/{case}_{i}.png'
282
+ if os.path.exists(img_path):
283
+ case_images.append(img_path)
284
+ if case_images:
285
+ examples.append(case_images)
286
+ return examples
287
+
288
+
289
+ def load_example(example_paths: List[str]) -> List[Image.Image]:
290
+ """Load example images from paths."""
291
  images = []
292
+ for path in example_paths:
293
+ img = Image.open(path)
294
+ images.append(img)
 
 
 
 
 
295
  return images
296
 
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  @spaces.GPU(duration=120)
299
  def image_to_3d(
300
+ images: List[Tuple[Image.Image, str]],
301
  seed: int,
302
  resolution: str,
303
  ss_guidance_strength: float,
 
312
  tex_slat_guidance_rescale: float,
313
  tex_slat_sampling_steps: int,
314
  tex_slat_rescale_t: float,
315
+ multiimage_algo: Literal["multidiffusion", "stochastic"],
316
  req: gr.Request,
317
  progress=gr.Progress(track_tqdm=True),
 
 
 
318
  ) -> str:
319
  # Initialize pipeline on first call
320
  _initialize_pipeline()
321
 
322
+ # Extract images from gallery format
323
+ if not images:
324
+ raise gr.Error("Please upload at least one image")
325
+
326
+ imgs = [img[0] if isinstance(img, tuple) else img for img in images]
327
+
328
  # --- Sampling ---
329
+ if len(imgs) == 1:
330
+ # Single image mode
331
  outputs, latents = pipeline.run(
332
+ imgs[0],
333
  seed=seed,
334
  preprocess_image=False,
335
  sparse_structure_sampler_params={
 
358
  return_latent=True,
359
  )
360
  else:
361
+ # Multi-image mode
362
  outputs, latents = pipeline.run_multi_image(
363
+ imgs,
364
  seed=seed,
365
  preprocess_image=False,
366
  sparse_structure_sampler_params={
 
389
  return_latent=True,
390
  mode=multiimage_algo,
391
  )
392
+
393
  mesh = outputs[0]
394
+ mesh.simplify(16777216)
395
+ render_images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
396
  state = pack_state(latents)
397
  torch.cuda.empty_cache()
398
 
399
  # --- HTML Construction ---
 
400
  def encode_preview_image(args):
401
  m_idx, s_idx, render_key = args
402
+ img_base64 = image_to_base64(Image.fromarray(render_images[render_key][s_idx]))
403
  return (m_idx, s_idx, img_base64)
404
 
405
+ encode_tasks = [(m_idx, s_idx, mode['render_key']) for m_idx, mode in enumerate(MODES) for s_idx in range(STEPS)]
 
 
 
 
406
 
407
  with ThreadPoolExecutor(max_workers=8) as executor:
408
  encoded_results = list(executor.map(encode_preview_image, encode_tasks))
409
 
 
410
  encoded_map = {(m, s): b64 for m, s, b64 in encoded_results}
411
  images_html = ""
412
  for m_idx, mode in enumerate(MODES):
 
415
  is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
416
  vis_class = "visible" if is_visible else ""
417
  img_base64 = encoded_map[(m_idx, s_idx)]
418
+ images_html += f'<img id="{unique_id}" class="previewer-main-image {vis_class}" src="{img_base64}" loading="eager">'
419
 
 
 
 
 
 
 
 
 
420
  btns_html = ""
421
  for idx, mode in enumerate(MODES):
422
  active_class = "active" if idx == DEFAULT_MODE else ""
423
+ btns_html += f'<img src="{mode["icon_base64"]}" class="mode-btn {active_class}" onclick="selectMode({idx})" title="{mode["name"]}">'
424
+
 
 
 
 
 
 
 
425
  full_html = f"""
426
  <div class="previewer-container">
427
  <div class="tips-wrapper">
428
+ <div class="tips-icon">Tips</div>
429
  <div class="tips-text">
430
+ <p>Render Mode - Click buttons to switch render modes.</p>
431
+ <p>View Angle - Drag slider to change view.</p>
432
  </div>
433
  </div>
434
+ <div class="display-row">{images_html}</div>
435
+ <div class="mode-row" id="btn-group">{btns_html}</div>
 
 
 
 
 
 
 
 
 
 
436
  <div class="slider-row">
437
  <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
438
  </div>
439
  </div>
440
  """
 
441
  return state, full_html
442
 
443
 
 
449
  req: gr.Request,
450
  progress=gr.Progress(track_tqdm=True),
451
  ) -> Tuple[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
452
  _initialize_pipeline()
453
 
454
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
455
  shape_slat, tex_slat, res = unpack_state(state)
456
  mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
457
+ mesh.simplify(16777216)
458
  glb = o_voxel.postprocess.to_glb(
459
  vertices=mesh.vertices,
460
  faces=mesh.faces,
 
481
 
482
  with gr.Blocks(delete_cache=(600, 600)) as demo:
483
  gr.Markdown("""
484
+ ## Multi-View Image to 3D with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
485
+ Upload one or more images of an object and click Generate to create a 3D asset.
486
+ Multiple views from different angles will produce better results.
487
  """)
488
 
489
  with gr.Row():
490
  with gr.Column(scale=1, min_width=360):
491
+ image_prompt = gr.Gallery(
492
+ label="Input Images (upload 1 or more views)",
493
+ format="png",
494
+ type="pil",
495
+ height=400,
496
+ columns=3,
497
+ object_fit="contain"
498
+ )
499
+ gr.Markdown("*Upload multiple views of the same object for better 3D reconstruction.*")
500
 
501
  resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
502
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
 
504
  decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
505
  texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
506
 
507
+ generate_btn = gr.Button("Generate", variant="primary")
508
 
509
  with gr.Accordion(label="Advanced Settings", open=False):
510
  gr.Markdown("Stage 1: Sparse Structure Generation")
 
535
  with gr.Step("Extract", id=1):
536
  glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
537
  download_btn = gr.DownloadButton(label="Download GLB")
538
+ gr.Markdown("*GLB extraction may take 30+ seconds.*")
539
+
540
+ with gr.Column(scale=1, min_width=200):
541
+ gr.Markdown("### Examples")
542
+ # Create example buttons that load images into gallery
543
+ example_data = prepare_examples()
544
+ for i, example_paths in enumerate(example_data[:12]): # Limit to 12 examples
545
+ case_name = os.path.basename(example_paths[0]).split('_')[0]
546
+ btn = gr.Button(f"{case_name} ({len(example_paths)} views)", size="sm")
547
+ btn.click(
548
+ fn=lambda paths=example_paths: load_example(paths),
549
+ outputs=[image_prompt]
550
+ )
551
 
 
552
  output_buf = gr.State()
553
 
 
554
  # Handlers
555
  demo.load(start_session)
556
  demo.unload(end_session)
557
 
 
 
 
 
 
 
 
 
 
558
  image_prompt.upload(
559
+ preprocess_images,
560
  inputs=[image_prompt],
561
  outputs=[image_prompt],
562
  )
 
 
 
 
 
563
 
564
  generate_btn.click(
565
+ get_seed, inputs=[randomize_seed, seed], outputs=[seed],
 
 
566
  ).then(
567
  lambda: gr.Walkthrough(selected=0), outputs=walkthrough
568
  ).then(
 
572
  ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
573
  shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
574
  tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
575
+ multiimage_algo
576
  ],
577
  outputs=[output_buf, preview_output],
578
  )
 
586
  )
587
 
588
 
 
589
  if __name__ == "__main__":
590
  os.makedirs(TMP_DIR, exist_ok=True)
591
 
 
 
592
  for i in range(len(MODES)):
593
  icon = Image.open(MODES[i]['icon'])
594
  MODES[i]['icon_base64'] = image_to_base64(icon)