choephix commited on
Commit
0ce1f9f
·
1 Parent(s): 2aa7f7d

Update Python formatting

Browse files
Files changed (2) hide show
  1. app.py +210 -101
  2. trellis2/pipelines/trellis2_image_to_3d.py +237 -137
app.py CHANGED
@@ -3,11 +3,14 @@ from gradio_client import Client, handle_file
3
  import spaces
4
 
5
  import os
6
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
 
7
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
8
  os.environ["ATTN_BACKEND"] = "flash_attn_3"
9
- os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
10
- os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
 
 
11
  from datetime import datetime
12
  import shutil
13
  import cv2
@@ -26,14 +29,30 @@ import o_voxel
26
 
27
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
30
  MODES = [
31
  {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
32
  {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
33
- {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
34
- {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
35
- {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
36
- {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  ]
38
  STEPS = 8
39
  DEFAULT_MODE = 3
@@ -307,16 +326,16 @@ def image_to_base64(image):
307
  def start_session(req: gr.Request):
308
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
309
  os.makedirs(user_dir, exist_ok=True)
310
-
311
-
312
  def end_session(req: gr.Request):
313
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
314
  shutil.rmtree(user_dir)
315
-
316
 
317
  def remove_background(input: Image.Image) -> Image.Image:
318
- with tempfile.NamedTemporaryFile(suffix='.png') as f:
319
- input = input.convert('RGB')
320
  input.save(f.name)
321
  output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
322
  output = Image.open(output)
@@ -329,14 +348,17 @@ def preprocess_image(input: Image.Image) -> Image.Image:
329
  """
330
  # if has alpha channel, use it directly; otherwise, remove background
331
  has_alpha = False
332
- if input.mode == 'RGBA':
333
  alpha = np.array(input)[:, :, 3]
334
  if not np.all(alpha == 255):
335
  has_alpha = True
336
  max_size = max(input.size)
337
  scale = min(1, 1024 / max_size)
338
  if scale < 1:
339
- input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
 
 
 
340
  if has_alpha:
341
  output = input
342
  else:
@@ -344,11 +366,21 @@ def preprocess_image(input: Image.Image) -> Image.Image:
344
  output_np = np.array(output)
345
  alpha = output_np[:, :, 3]
346
  bbox = np.argwhere(alpha > 0.8 * 255)
347
- bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
 
 
 
 
 
348
  center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
349
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
350
  size = int(size * 1)
351
- bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
 
 
 
 
 
352
  output = output.crop(bbox) # type: ignore
353
  output = np.array(output).astype(np.float32) / 255
354
  output = output[:, :, :3] * output[:, :, 3:4]
@@ -359,20 +391,20 @@ def preprocess_image(input: Image.Image) -> Image.Image:
359
  def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
360
  shape_slat, tex_slat, res = latents
361
  return {
362
- 'shape_slat_feats': shape_slat.feats.cpu().numpy(),
363
- 'tex_slat_feats': tex_slat.feats.cpu().numpy(),
364
- 'coords': shape_slat.coords.cpu().numpy(),
365
- 'res': res,
366
  }
367
-
368
-
369
  def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
370
  shape_slat = SparseTensor(
371
- feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
372
- coords=torch.from_numpy(state['coords']).cuda(),
373
  )
374
- tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
375
- return shape_slat, tex_slat, state['res']
376
 
377
 
378
  def get_seed(randomize_seed: bool, seed: int) -> int:
@@ -433,11 +465,13 @@ def image_to_3d(
433
  return_latent=True,
434
  )
435
  mesh = outputs[0]
436
- mesh.simplify(16777216) # nvdiffrast limit
437
- images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
 
 
438
  state = pack_state(latents)
439
  torch.cuda.empty_cache()
440
-
441
  # --- HTML Construction ---
442
  # The Stack of 48 Images
443
  images_html = ""
@@ -445,14 +479,16 @@ def image_to_3d(
445
  for s_idx in range(STEPS):
446
  # ID Naming Convention: view-m{mode}-s{step}
447
  unique_id = f"view-m{m_idx}-s{s_idx}"
448
-
449
  # Logic: Only Mode 0, Step 0 is visible initially
450
- is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
451
  vis_class = "visible" if is_visible else ""
452
-
453
  # Image Source
454
- img_base64 = image_to_base64(Image.fromarray(images[mode['render_key']][s_idx]))
455
-
 
 
456
  # Render the Tag
457
  images_html += f"""
458
  <img id="{unique_id}"
@@ -460,19 +496,19 @@ def image_to_3d(
460
  src="{img_base64}"
461
  loading="eager">
462
  """
463
-
464
  # Button Row HTML
465
  btns_html = ""
466
- for idx, mode in enumerate(MODES):
467
  active_class = "active" if idx == DEFAULT_MODE else ""
468
  # Note: onclick calls the JS function defined in Head
469
  btns_html += f"""
470
- <img src="{mode['icon_base64']}"
471
  class="mode-btn {active_class}"
472
  onclick="selectMode({idx})"
473
- title="{mode['name']}">
474
  """
475
-
476
  # Assemble the full component
477
  full_html = f"""
478
  <div class="previewer-container">
@@ -500,7 +536,7 @@ def image_to_3d(
500
  </div>
501
  </div>
502
  """
503
-
504
  return state, full_html
505
 
506
 
@@ -545,7 +581,7 @@ def extract_glb(
545
  now = datetime.now()
546
  timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
547
  os.makedirs(user_dir, exist_ok=True)
548
- glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
549
  glb.export(glb_path, extension_webp=True)
550
  torch.cuda.empty_cache()
551
  return glb_path, glb_path
@@ -557,53 +593,102 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
557
  * Upload an image (preferably with an alpha-masked foreground object) and click Generate to create a 3D asset.
558
  * Click Extract GLB to export and download the generated GLB file if you're satisfied with the result. Otherwise, try another time.
559
  """)
560
-
561
  with gr.Row():
562
  with gr.Column(scale=1, min_width=360):
563
- image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400)
564
-
565
- resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="512")
 
 
 
 
 
 
 
 
566
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
567
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
568
- decimation_target = gr.Slider(20000, 500000, label="Decimation Target", value=20000, step=10000)
569
- texture_size = gr.Slider(1024, 4096, label="Texture Size", value=1024, step=1024)
570
-
 
 
 
 
571
  generate_btn = gr.Button("Generate")
572
-
573
- with gr.Accordion(label="Advanced Settings", open=False):
574
  gr.Markdown("Stage 1: Sparse Structure Generation")
575
  with gr.Row():
576
- ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
577
- ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01)
578
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=8, step=1)
579
- ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
 
 
 
 
 
 
 
 
580
  gr.Markdown("Stage 2: Shape Generation")
581
  with gr.Row():
582
- shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
583
- shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01)
584
- shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=8, step=1)
585
- shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
 
 
 
 
 
 
 
 
586
  gr.Markdown("Stage 3: Material Generation")
587
  with gr.Row():
588
- tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
589
- tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
590
- tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=8, step=1)
591
- tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
 
 
 
 
 
 
 
 
592
 
593
  with gr.Column(scale=10):
594
  with gr.Walkthrough(selected=0) as walkthrough:
595
  with gr.Step("Preview", id=0):
596
- preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
 
 
 
 
 
597
  extract_btn = gr.Button("Extract GLB")
598
  with gr.Step("Extract", id=1):
599
- glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
 
 
 
 
 
 
600
  download_btn = gr.DownloadButton(label="Download GLB")
601
- gr.Markdown("*We are actively working on improving the speed of GLB extraction. Currently, it may take half a minute or more and face count is limited.*")
602
-
 
 
603
  with gr.Column(scale=1, min_width=172):
604
  examples = gr.Examples(
605
  examples=[
606
- f'assets/example_image/{image}'
607
  for image in os.listdir("assets/example_image")
608
  ],
609
  inputs=[image_prompt],
@@ -612,14 +697,13 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
612
  run_on_click=True,
613
  examples_per_page=18,
614
  )
615
-
616
  output_buf = gr.State()
617
-
618
 
619
  # Handlers
620
  demo.load(start_session)
621
  demo.unload(end_session)
622
-
623
  image_prompt.upload(
624
  preprocess_image,
625
  inputs=[image_prompt],
@@ -630,27 +714,34 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
630
  get_seed,
631
  inputs=[randomize_seed, seed],
632
  outputs=[seed],
633
- ).then(
634
- lambda: gr.Walkthrough(selected=0), outputs=walkthrough
635
- ).then(
636
  image_to_3d,
637
  inputs=[
638
- image_prompt, seed, resolution,
639
- ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
640
- shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
641
- tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
 
 
 
 
 
 
 
 
 
 
 
642
  ],
643
  outputs=[output_buf, preview_output],
644
  )
645
-
646
- extract_btn.click(
647
- lambda: gr.Walkthrough(selected=1), outputs=walkthrough
648
- ).then(
649
  extract_glb,
650
  inputs=[output_buf, decimation_target, texture_size],
651
  outputs=[glb_output, download_btn],
652
  )
653
-
654
 
655
  # Launch the Gradio app
656
  if __name__ == "__main__":
@@ -659,29 +750,47 @@ if __name__ == "__main__":
659
  # Construct ui components
660
  btn_img_base64_strs = {}
661
  for i in range(len(MODES)):
662
- icon = Image.open(MODES[i]['icon'])
663
- MODES[i]['icon_base64'] = image_to_base64(icon)
664
 
665
  rmbg_client = Client("briaai/BRIA-RMBG-2.0")
666
- pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
667
  pipeline.rembg_model = None
668
  pipeline.low_vram = False
669
  pipeline.cuda()
670
-
671
  envmap = {
672
- 'forest': EnvMap(torch.tensor(
673
- cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
674
- dtype=torch.float32, device='cuda'
675
- )),
676
- 'sunset': EnvMap(torch.tensor(
677
- cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
678
- dtype=torch.float32, device='cuda'
679
- )),
680
- 'courtyard': EnvMap(torch.tensor(
681
- cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
682
- dtype=torch.float32, device='cuda'
683
- )),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
684
  }
685
-
686
- #demo.launch(css=css, head=head)
687
  demo.launch(server_name="0.0.0.0", server_port=7860, css=css, head=head)
 
3
  import spaces
4
 
5
  import os
6
+
7
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
8
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
9
  os.environ["ATTN_BACKEND"] = "flash_attn_3"
10
+ os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(
11
+ os.path.dirname(os.path.abspath(__file__)), "autotune_cache.json"
12
+ )
13
+ os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = "1"
14
  from datetime import datetime
15
  import shutil
16
  import cv2
 
29
 
30
 
31
  MAX_SEED = np.iinfo(np.int32).max
32
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
33
  MODES = [
34
  {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
35
  {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
36
+ {
37
+ "name": "Base color",
38
+ "icon": "assets/app/basecolor.png",
39
+ "render_key": "base_color",
40
+ },
41
+ {
42
+ "name": "HDRI forest",
43
+ "icon": "assets/app/hdri_forest.png",
44
+ "render_key": "shaded_forest",
45
+ },
46
+ {
47
+ "name": "HDRI sunset",
48
+ "icon": "assets/app/hdri_sunset.png",
49
+ "render_key": "shaded_sunset",
50
+ },
51
+ {
52
+ "name": "HDRI courtyard",
53
+ "icon": "assets/app/hdri_courtyard.png",
54
+ "render_key": "shaded_courtyard",
55
+ },
56
  ]
57
  STEPS = 8
58
  DEFAULT_MODE = 3
 
326
  def start_session(req: gr.Request):
327
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
328
  os.makedirs(user_dir, exist_ok=True)
329
+
330
+
331
  def end_session(req: gr.Request):
332
  user_dir = os.path.join(TMP_DIR, str(req.session_hash))
333
  shutil.rmtree(user_dir)
334
+
335
 
336
  def remove_background(input: Image.Image) -> Image.Image:
337
+ with tempfile.NamedTemporaryFile(suffix=".png") as f:
338
+ input = input.convert("RGB")
339
  input.save(f.name)
340
  output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
341
  output = Image.open(output)
 
348
  """
349
  # if has alpha channel, use it directly; otherwise, remove background
350
  has_alpha = False
351
+ if input.mode == "RGBA":
352
  alpha = np.array(input)[:, :, 3]
353
  if not np.all(alpha == 255):
354
  has_alpha = True
355
  max_size = max(input.size)
356
  scale = min(1, 1024 / max_size)
357
  if scale < 1:
358
+ input = input.resize(
359
+ (int(input.width * scale), int(input.height * scale)),
360
+ Image.Resampling.LANCZOS,
361
+ )
362
  if has_alpha:
363
  output = input
364
  else:
 
366
  output_np = np.array(output)
367
  alpha = output_np[:, :, 3]
368
  bbox = np.argwhere(alpha > 0.8 * 255)
369
+ bbox = (
370
+ np.min(bbox[:, 1]),
371
+ np.min(bbox[:, 0]),
372
+ np.max(bbox[:, 1]),
373
+ np.max(bbox[:, 0]),
374
+ )
375
  center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
376
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
377
  size = int(size * 1)
378
+ bbox = (
379
+ center[0] - size // 2,
380
+ center[1] - size // 2,
381
+ center[0] + size // 2,
382
+ center[1] + size // 2,
383
+ )
384
  output = output.crop(bbox) # type: ignore
385
  output = np.array(output).astype(np.float32) / 255
386
  output = output[:, :, :3] * output[:, :, 3:4]
 
391
  def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
392
  shape_slat, tex_slat, res = latents
393
  return {
394
+ "shape_slat_feats": shape_slat.feats.cpu().numpy(),
395
+ "tex_slat_feats": tex_slat.feats.cpu().numpy(),
396
+ "coords": shape_slat.coords.cpu().numpy(),
397
+ "res": res,
398
  }
399
+
400
+
401
  def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
402
  shape_slat = SparseTensor(
403
+ feats=torch.from_numpy(state["shape_slat_feats"]).cuda(),
404
+ coords=torch.from_numpy(state["coords"]).cuda(),
405
  )
406
+ tex_slat = shape_slat.replace(torch.from_numpy(state["tex_slat_feats"]).cuda())
407
+ return shape_slat, tex_slat, state["res"]
408
 
409
 
410
  def get_seed(randomize_seed: bool, seed: int) -> int:
 
465
  return_latent=True,
466
  )
467
  mesh = outputs[0]
468
+ mesh.simplify(16777216) # nvdiffrast limit
469
+ images = render_utils.render_snapshot(
470
+ mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap
471
+ )
472
  state = pack_state(latents)
473
  torch.cuda.empty_cache()
474
+
475
  # --- HTML Construction ---
476
  # The Stack of 48 Images
477
  images_html = ""
 
479
  for s_idx in range(STEPS):
480
  # ID Naming Convention: view-m{mode}-s{step}
481
  unique_id = f"view-m{m_idx}-s{s_idx}"
482
+
483
  # Logic: Only Mode 0, Step 0 is visible initially
484
+ is_visible = m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP
485
  vis_class = "visible" if is_visible else ""
486
+
487
  # Image Source
488
+ img_base64 = image_to_base64(
489
+ Image.fromarray(images[mode["render_key"]][s_idx])
490
+ )
491
+
492
  # Render the Tag
493
  images_html += f"""
494
  <img id="{unique_id}"
 
496
  src="{img_base64}"
497
  loading="eager">
498
  """
499
+
500
  # Button Row HTML
501
  btns_html = ""
502
+ for idx, mode in enumerate(MODES):
503
  active_class = "active" if idx == DEFAULT_MODE else ""
504
  # Note: onclick calls the JS function defined in Head
505
  btns_html += f"""
506
+ <img src="{mode["icon_base64"]}"
507
  class="mode-btn {active_class}"
508
  onclick="selectMode({idx})"
509
+ title="{mode["name"]}">
510
  """
511
+
512
  # Assemble the full component
513
  full_html = f"""
514
  <div class="previewer-container">
 
536
  </div>
537
  </div>
538
  """
539
+
540
  return state, full_html
541
 
542
 
 
581
  now = datetime.now()
582
  timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
583
  os.makedirs(user_dir, exist_ok=True)
584
+ glb_path = os.path.join(user_dir, f"sample_{timestamp}.glb")
585
  glb.export(glb_path, extension_webp=True)
586
  torch.cuda.empty_cache()
587
  return glb_path, glb_path
 
593
  * Upload an image (preferably with an alpha-masked foreground object) and click Generate to create a 3D asset.
594
  * Click Extract GLB to export and download the generated GLB file if you're satisfied with the result. Otherwise, try another time.
595
  """)
596
+
597
  with gr.Row():
598
  with gr.Column(scale=1, min_width=360):
599
+ image_prompt = gr.Image(
600
+ label="Image Prompt",
601
+ format="png",
602
+ image_mode="RGBA",
603
+ type="pil",
604
+ height=400,
605
+ )
606
+
607
+ resolution = gr.Radio(
608
+ ["512", "1024", "1536"], label="Resolution", value="512"
609
+ )
610
  seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
611
  randomize_seed = gr.Checkbox(label="Randomize Seed", value=False)
612
+ decimation_target = gr.Slider(
613
+ 20000, 500000, label="Decimation Target", value=20000, step=10000
614
+ )
615
+ texture_size = gr.Slider(
616
+ 1024, 4096, label="Texture Size", value=1024, step=1024
617
+ )
618
+
619
  generate_btn = gr.Button("Generate")
620
+
621
+ with gr.Accordion(label="Advanced Settings", open=False):
622
  gr.Markdown("Stage 1: Sparse Structure Generation")
623
  with gr.Row():
624
+ ss_guidance_strength = gr.Slider(
625
+ 1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1
626
+ )
627
+ ss_guidance_rescale = gr.Slider(
628
+ 0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01
629
+ )
630
+ ss_sampling_steps = gr.Slider(
631
+ 1, 50, label="Sampling Steps", value=8, step=1
632
+ )
633
+ ss_rescale_t = gr.Slider(
634
+ 1.0, 6.0, label="Rescale T", value=5.0, step=0.1
635
+ )
636
  gr.Markdown("Stage 2: Shape Generation")
637
  with gr.Row():
638
+ shape_slat_guidance_strength = gr.Slider(
639
+ 1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1
640
+ )
641
+ shape_slat_guidance_rescale = gr.Slider(
642
+ 0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01
643
+ )
644
+ shape_slat_sampling_steps = gr.Slider(
645
+ 1, 50, label="Sampling Steps", value=8, step=1
646
+ )
647
+ shape_slat_rescale_t = gr.Slider(
648
+ 1.0, 6.0, label="Rescale T", value=3.0, step=0.1
649
+ )
650
  gr.Markdown("Stage 3: Material Generation")
651
  with gr.Row():
652
+ tex_slat_guidance_strength = gr.Slider(
653
+ 1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1
654
+ )
655
+ tex_slat_guidance_rescale = gr.Slider(
656
+ 0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01
657
+ )
658
+ tex_slat_sampling_steps = gr.Slider(
659
+ 1, 50, label="Sampling Steps", value=8, step=1
660
+ )
661
+ tex_slat_rescale_t = gr.Slider(
662
+ 1.0, 6.0, label="Rescale T", value=3.0, step=0.1
663
+ )
664
 
665
  with gr.Column(scale=10):
666
  with gr.Walkthrough(selected=0) as walkthrough:
667
  with gr.Step("Preview", id=0):
668
+ preview_output = gr.HTML(
669
+ empty_html,
670
+ label="3D Asset Preview",
671
+ show_label=True,
672
+ container=True,
673
+ )
674
  extract_btn = gr.Button("Extract GLB")
675
  with gr.Step("Extract", id=1):
676
+ glb_output = gr.Model3D(
677
+ label="Extracted GLB",
678
+ height=724,
679
+ show_label=True,
680
+ display_mode="solid",
681
+ clear_color=(0.25, 0.25, 0.25, 1.0),
682
+ )
683
  download_btn = gr.DownloadButton(label="Download GLB")
684
+ gr.Markdown(
685
+ "*We are actively working on improving the speed of GLB extraction. Currently, it may take half a minute or more and face count is limited.*"
686
+ )
687
+
688
  with gr.Column(scale=1, min_width=172):
689
  examples = gr.Examples(
690
  examples=[
691
+ f"assets/example_image/{image}"
692
  for image in os.listdir("assets/example_image")
693
  ],
694
  inputs=[image_prompt],
 
697
  run_on_click=True,
698
  examples_per_page=18,
699
  )
700
+
701
  output_buf = gr.State()
 
702
 
703
  # Handlers
704
  demo.load(start_session)
705
  demo.unload(end_session)
706
+
707
  image_prompt.upload(
708
  preprocess_image,
709
  inputs=[image_prompt],
 
714
  get_seed,
715
  inputs=[randomize_seed, seed],
716
  outputs=[seed],
717
+ ).then(lambda: gr.Walkthrough(selected=0), outputs=walkthrough).then(
 
 
718
  image_to_3d,
719
  inputs=[
720
+ image_prompt,
721
+ seed,
722
+ resolution,
723
+ ss_guidance_strength,
724
+ ss_guidance_rescale,
725
+ ss_sampling_steps,
726
+ ss_rescale_t,
727
+ shape_slat_guidance_strength,
728
+ shape_slat_guidance_rescale,
729
+ shape_slat_sampling_steps,
730
+ shape_slat_rescale_t,
731
+ tex_slat_guidance_strength,
732
+ tex_slat_guidance_rescale,
733
+ tex_slat_sampling_steps,
734
+ tex_slat_rescale_t,
735
  ],
736
  outputs=[output_buf, preview_output],
737
  )
738
+
739
+ extract_btn.click(lambda: gr.Walkthrough(selected=1), outputs=walkthrough).then(
 
 
740
  extract_glb,
741
  inputs=[output_buf, decimation_target, texture_size],
742
  outputs=[glb_output, download_btn],
743
  )
744
+
745
 
746
  # Launch the Gradio app
747
  if __name__ == "__main__":
 
750
  # Construct ui components
751
  btn_img_base64_strs = {}
752
  for i in range(len(MODES)):
753
+ icon = Image.open(MODES[i]["icon"])
754
+ MODES[i]["icon_base64"] = image_to_base64(icon)
755
 
756
  rmbg_client = Client("briaai/BRIA-RMBG-2.0")
757
+ pipeline = Trellis2ImageTo3DPipeline.from_pretrained("microsoft/TRELLIS.2-4B")
758
  pipeline.rembg_model = None
759
  pipeline.low_vram = False
760
  pipeline.cuda()
761
+
762
  envmap = {
763
+ "forest": EnvMap(
764
+ torch.tensor(
765
+ cv2.cvtColor(
766
+ cv2.imread("assets/hdri/forest.exr", cv2.IMREAD_UNCHANGED),
767
+ cv2.COLOR_BGR2RGB,
768
+ ),
769
+ dtype=torch.float32,
770
+ device="cuda",
771
+ )
772
+ ),
773
+ "sunset": EnvMap(
774
+ torch.tensor(
775
+ cv2.cvtColor(
776
+ cv2.imread("assets/hdri/sunset.exr", cv2.IMREAD_UNCHANGED),
777
+ cv2.COLOR_BGR2RGB,
778
+ ),
779
+ dtype=torch.float32,
780
+ device="cuda",
781
+ )
782
+ ),
783
+ "courtyard": EnvMap(
784
+ torch.tensor(
785
+ cv2.cvtColor(
786
+ cv2.imread("assets/hdri/courtyard.exr", cv2.IMREAD_UNCHANGED),
787
+ cv2.COLOR_BGR2RGB,
788
+ ),
789
+ dtype=torch.float32,
790
+ device="cuda",
791
+ )
792
+ ),
793
  }
794
+
795
+ # demo.launch(css=css, head=head)
796
  demo.launch(server_name="0.0.0.0", server_port=7860, css=css, head=head)
trellis2/pipelines/trellis2_image_to_3d.py CHANGED
@@ -28,6 +28,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
28
  rembg_model (Callable): The model for removing background.
29
  low_vram (bool): Whether to use low-VRAM mode.
30
  """
 
31
  def __init__(
32
  self,
33
  models: dict[str, nn.Module] = None,
@@ -42,7 +43,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
42
  image_cond_model: Callable = None,
43
  rembg_model: Callable = None,
44
  low_vram: bool = True,
45
- default_pipeline_type: str = '1024_cascade',
46
  ):
47
  if models is None:
48
  return
@@ -60,12 +61,12 @@ class Trellis2ImageTo3DPipeline(Pipeline):
60
  self.low_vram = low_vram
61
  self.default_pipeline_type = default_pipeline_type
62
  self.pbr_attr_layout = {
63
- 'base_color': slice(0, 3),
64
- 'metallic': slice(3, 4),
65
- 'roughness': slice(4, 5),
66
- 'alpha': slice(5, 6),
67
  }
68
- self._device = 'cpu'
69
 
70
  @staticmethod
71
  def from_pretrained(path: str) -> "Trellis2ImageTo3DPipeline":
@@ -75,35 +76,51 @@ class Trellis2ImageTo3DPipeline(Pipeline):
75
  Args:
76
  path (str): The path to the model. Can be either local path or a Hugging Face repository.
77
  """
78
- pipeline = super(Trellis2ImageTo3DPipeline, Trellis2ImageTo3DPipeline).from_pretrained(path)
 
 
79
  new_pipeline = Trellis2ImageTo3DPipeline()
80
  new_pipeline.__dict__ = pipeline.__dict__
81
  args = pipeline._pretrained_args
82
 
83
- new_pipeline.sparse_structure_sampler = getattr(samplers, args['sparse_structure_sampler']['name'])(**args['sparse_structure_sampler']['args'])
84
- new_pipeline.sparse_structure_sampler_params = args['sparse_structure_sampler']['params']
 
 
 
 
 
 
 
 
 
85
 
86
- new_pipeline.shape_slat_sampler = getattr(samplers, args['shape_slat_sampler']['name'])(**args['shape_slat_sampler']['args'])
87
- new_pipeline.shape_slat_sampler_params = args['shape_slat_sampler']['params']
 
 
88
 
89
- new_pipeline.tex_slat_sampler = getattr(samplers, args['tex_slat_sampler']['name'])(**args['tex_slat_sampler']['args'])
90
- new_pipeline.tex_slat_sampler_params = args['tex_slat_sampler']['params']
91
 
92
- new_pipeline.shape_slat_normalization = args['shape_slat_normalization']
93
- new_pipeline.tex_slat_normalization = args['tex_slat_normalization']
 
 
 
 
94
 
95
- new_pipeline.image_cond_model = getattr(image_feature_extractor, args['image_cond_model']['name'])(**args['image_cond_model']['args'])
96
- new_pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
97
-
98
- new_pipeline.low_vram = args.get('low_vram', True)
99
- new_pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
100
  new_pipeline.pbr_attr_layout = {
101
- 'base_color': slice(0, 3),
102
- 'metallic': slice(3, 4),
103
- 'roughness': slice(4, 5),
104
- 'alpha': slice(5, 6),
105
  }
106
- new_pipeline._device = 'cpu'
107
 
108
  return new_pipeline
109
 
@@ -121,18 +138,21 @@ class Trellis2ImageTo3DPipeline(Pipeline):
121
  """
122
  # if has alpha channel, use it directly; otherwise, remove background
123
  has_alpha = False
124
- if input.mode == 'RGBA':
125
  alpha = np.array(input)[:, :, 3]
126
  if not np.all(alpha == 255):
127
  has_alpha = True
128
  max_size = max(input.size)
129
  scale = min(1, 1024 / max_size)
130
  if scale < 1:
131
- input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
 
 
 
132
  if has_alpha:
133
  output = input
134
  else:
135
- input = input.convert('RGB')
136
  if self.low_vram:
137
  self.rembg_model.to(self.device)
138
  output = self.rembg_model(input)
@@ -141,18 +161,33 @@ class Trellis2ImageTo3DPipeline(Pipeline):
141
  output_np = np.array(output)
142
  alpha = output_np[:, :, 3]
143
  bbox = np.argwhere(alpha > 0.8 * 255)
144
- bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
 
 
 
 
 
145
  center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
146
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
147
  size = int(size * 1)
148
- bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
 
 
 
 
 
149
  output = output.crop(bbox) # type: ignore
150
  output = np.array(output).astype(np.float32) / 255
151
  output = output[:, :, :3] * output[:, :, 3:4]
152
  output = Image.fromarray((output * 255).astype(np.uint8))
153
  return output
154
-
155
- def get_cond(self, image: Union[torch.Tensor, list[Image.Image]], resolution: int, include_neg_cond: bool = True) -> dict:
 
 
 
 
 
156
  """
157
  Get the conditioning information for the model.
158
 
@@ -169,11 +204,11 @@ class Trellis2ImageTo3DPipeline(Pipeline):
169
  if self.low_vram:
170
  self.image_cond_model.cpu()
171
  if not include_neg_cond:
172
- return {'cond': cond}
173
  neg_cond = torch.zeros_like(cond)
174
  return {
175
- 'cond': cond,
176
- 'neg_cond': neg_cond,
177
  }
178
 
179
  def sample_sparse_structure(
@@ -185,7 +220,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
185
  ) -> torch.Tensor:
186
  """
187
  Sample sparse structures with the given conditioning.
188
-
189
  Args:
190
  cond (dict): The conditioning information.
191
  resolution (int): The resolution of the sparse structure.
@@ -193,7 +228,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
193
  sampler_params (dict): Additional parameters for the sampler.
194
  """
195
  # Sample sparse structure latent
196
- flow_model = self.models['sparse_structure_flow_model']
197
  reso = flow_model.resolution
198
  in_channels = flow_model.in_channels
199
  noise = torch.randn(num_samples, in_channels, reso, reso, reso).to(self.device)
@@ -210,17 +245,19 @@ class Trellis2ImageTo3DPipeline(Pipeline):
210
  ).samples
211
  if self.low_vram:
212
  flow_model.cpu()
213
-
214
  # Decode sparse structure latent
215
- decoder = self.models['sparse_structure_decoder']
216
  if self.low_vram:
217
  decoder.to(self.device)
218
- decoded = decoder(z_s)>0
219
  if self.low_vram:
220
  decoder.cpu()
221
  if resolution != decoded.shape[2]:
222
  ratio = decoded.shape[2] // resolution
223
- decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
 
 
224
  coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int()
225
 
226
  return coords
@@ -234,7 +271,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
234
  ) -> SparseTensor:
235
  """
236
  Sample structured latent with the given conditioning.
237
-
238
  Args:
239
  cond (dict): The conditioning information.
240
  coords (torch.Tensor): The coordinates of the sparse structure.
@@ -259,12 +296,12 @@ class Trellis2ImageTo3DPipeline(Pipeline):
259
  if self.low_vram:
260
  flow_model.cpu()
261
 
262
- std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device)
263
- mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device)
264
  slat = slat * std + mean
265
-
266
  return slat
267
-
268
  def sample_shape_slat_cascade(
269
  self,
270
  lr_cond: dict,
@@ -279,7 +316,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
279
  ) -> SparseTensor:
280
  """
281
  Sample structured latent with the given conditioning.
282
-
283
  Args:
284
  cond (dict): The conditioning information.
285
  coords (torch.Tensor): The coordinates of the sparse structure.
@@ -287,7 +324,9 @@ class Trellis2ImageTo3DPipeline(Pipeline):
287
  """
288
  # LR
289
  noise = SparseTensor(
290
- feats=torch.randn(coords.shape[0], flow_model_lr.in_channels).to(self.device),
 
 
291
  coords=coords,
292
  )
293
  sampler_params = {**self.shape_slat_sampler_params, **sampler_params}
@@ -303,32 +342,39 @@ class Trellis2ImageTo3DPipeline(Pipeline):
303
  ).samples
304
  if self.low_vram:
305
  flow_model_lr.cpu()
306
- std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device)
307
- mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device)
308
  slat = slat * std + mean
309
-
310
  # Upsample
311
  if self.low_vram:
312
- self.models['shape_slat_decoder'].to(self.device)
313
- self.models['shape_slat_decoder'].low_vram = True
314
- hr_coords = self.models['shape_slat_decoder'].upsample(slat, upsample_times=4)
315
  if self.low_vram:
316
- self.models['shape_slat_decoder'].cpu()
317
- self.models['shape_slat_decoder'].low_vram = False
318
  hr_resolution = resolution
319
  while True:
320
- quant_coords = torch.cat([
321
- hr_coords[:, :1],
322
- ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
323
- ], dim=1)
 
 
 
 
 
324
  coords = quant_coords.unique(dim=0)
325
  num_tokens = coords.shape[0]
326
  if num_tokens < max_num_tokens or hr_resolution == 1024:
327
  if hr_resolution != resolution:
328
- print(f"Due to the limited number of tokens, the resolution is reduced to {hr_resolution}.")
 
 
329
  break
330
  hr_resolution -= 128
331
-
332
  # Sample structured latent
333
  noise = SparseTensor(
334
  feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
@@ -348,10 +394,10 @@ class Trellis2ImageTo3DPipeline(Pipeline):
348
  if self.low_vram:
349
  flow_model.cpu()
350
 
351
- std = torch.tensor(self.shape_slat_normalization['std'])[None].to(slat.device)
352
- mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(slat.device)
353
  slat = slat * std + mean
354
-
355
  return slat, hr_resolution
356
 
357
  def decode_shape_slat(
@@ -370,16 +416,16 @@ class Trellis2ImageTo3DPipeline(Pipeline):
370
  List[Mesh]: The decoded meshes.
371
  List[SparseTensor]: The decoded substructures.
372
  """
373
- self.models['shape_slat_decoder'].set_resolution(resolution)
374
  if self.low_vram:
375
- self.models['shape_slat_decoder'].to(self.device)
376
- self.models['shape_slat_decoder'].low_vram = True
377
- ret = self.models['shape_slat_decoder'](slat, return_subs=True)
378
  if self.low_vram:
379
- self.models['shape_slat_decoder'].cpu()
380
- self.models['shape_slat_decoder'].low_vram = False
381
  return ret
382
-
383
  def sample_tex_slat(
384
  self,
385
  cond: dict,
@@ -389,19 +435,31 @@ class Trellis2ImageTo3DPipeline(Pipeline):
389
  ) -> SparseTensor:
390
  """
391
  Sample structured latent with the given conditioning.
392
-
393
  Args:
394
  cond (dict): The conditioning information.
395
  shape_slat (SparseTensor): The structured latent for shape
396
  sampler_params (dict): Additional parameters for the sampler.
397
  """
398
  # Sample structured latent
399
- std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
400
- mean = torch.tensor(self.shape_slat_normalization['mean'])[None].to(shape_slat.device)
 
 
 
 
401
  shape_slat = (shape_slat - mean) / std
402
 
403
- in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels
404
- noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device))
 
 
 
 
 
 
 
 
405
  sampler_params = {**self.tex_slat_sampler_params, **sampler_params}
406
  if self.low_vram:
407
  flow_model.to(self.device)
@@ -417,10 +475,10 @@ class Trellis2ImageTo3DPipeline(Pipeline):
417
  if self.low_vram:
418
  flow_model.cpu()
419
 
420
- std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device)
421
- mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device)
422
  slat = slat * std + mean
423
-
424
  return slat
425
 
426
  def decode_tex_slat(
@@ -439,12 +497,12 @@ class Trellis2ImageTo3DPipeline(Pipeline):
439
  List[SparseTensor]: The decoded texture voxels
440
  """
441
  if self.low_vram:
442
- self.models['tex_slat_decoder'].to(self.device)
443
- ret = self.models['tex_slat_decoder'](slat, guide_subs=subs) * 0.5 + 0.5
444
  if self.low_vram:
445
- self.models['tex_slat_decoder'].cpu()
446
  return ret
447
-
448
  @torch.no_grad()
449
  def decode_latent(
450
  self,
@@ -467,17 +525,18 @@ class Trellis2ImageTo3DPipeline(Pipeline):
467
  m.fill_holes()
468
  out_mesh.append(
469
  MeshWithVoxel(
470
- m.vertices, m.faces,
471
- origin = [-0.5, -0.5, -0.5],
472
- voxel_size = 1 / resolution,
473
- coords = v.coords[:, 1:],
474
- attrs = v.feats,
475
- voxel_shape = torch.Size([*v.shape, *v.spatial_shape]),
476
- layout=self.pbr_attr_layout
 
477
  )
478
  )
479
  return out_mesh
480
-
481
  @torch.no_grad()
482
  def run(
483
  self,
@@ -509,76 +568,117 @@ class Trellis2ImageTo3DPipeline(Pipeline):
509
  """
510
  # Check pipeline type
511
  pipeline_type = pipeline_type or self.default_pipeline_type
512
- if pipeline_type == '512':
513
- assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
514
- assert 'tex_slat_flow_model_512' in self.models, "No 512 resolution texture SLat flow model found."
515
- elif pipeline_type == '1024':
516
- assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
517
- assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
518
- elif pipeline_type == '1024_cascade':
519
- assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
520
- assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
521
- assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
522
- elif pipeline_type == '1536_cascade':
523
- assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
524
- assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
525
- assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  else:
527
  raise ValueError(f"Invalid pipeline type: {pipeline_type}")
528
-
529
  if preprocess_image:
530
  image = self.preprocess_image(image)
531
  torch.manual_seed(seed)
532
  cond_512 = self.get_cond([image], 512)
533
- cond_1024 = self.get_cond([image], 1024) if pipeline_type != '512' else None
534
- ss_res = {'512': 32, '1024': 64, '1024_cascade': 32, '1536_cascade': 32}[pipeline_type]
 
 
535
  coords = self.sample_sparse_structure(
536
- cond_512, ss_res,
537
- num_samples, sparse_structure_sampler_params
538
  )
539
- if pipeline_type == '512':
540
  shape_slat = self.sample_shape_slat(
541
- cond_512, self.models['shape_slat_flow_model_512'],
542
- coords, shape_slat_sampler_params
 
 
543
  )
544
  tex_slat = self.sample_tex_slat(
545
- cond_512, self.models['tex_slat_flow_model_512'],
546
- shape_slat, tex_slat_sampler_params
 
 
547
  )
548
  res = 512
549
- elif pipeline_type == '1024':
550
  shape_slat = self.sample_shape_slat(
551
- cond_1024, self.models['shape_slat_flow_model_1024'],
552
- coords, shape_slat_sampler_params
 
 
553
  )
554
  tex_slat = self.sample_tex_slat(
555
- cond_1024, self.models['tex_slat_flow_model_1024'],
556
- shape_slat, tex_slat_sampler_params
 
 
557
  )
558
  res = 1024
559
- elif pipeline_type == '1024_cascade':
560
  shape_slat, res = self.sample_shape_slat_cascade(
561
- cond_512, cond_1024,
562
- self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'],
563
- 512, 1024,
564
- coords, shape_slat_sampler_params,
565
- max_num_tokens
 
 
 
 
566
  )
567
  tex_slat = self.sample_tex_slat(
568
- cond_1024, self.models['tex_slat_flow_model_1024'],
569
- shape_slat, tex_slat_sampler_params
 
 
570
  )
571
- elif pipeline_type == '1536_cascade':
572
  shape_slat, res = self.sample_shape_slat_cascade(
573
- cond_512, cond_1024,
574
- self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'],
575
- 512, 1536,
576
- coords, shape_slat_sampler_params,
577
- max_num_tokens
 
 
 
 
578
  )
579
  tex_slat = self.sample_tex_slat(
580
- cond_1024, self.models['tex_slat_flow_model_1024'],
581
- shape_slat, tex_slat_sampler_params
 
 
582
  )
583
  torch.cuda.empty_cache()
584
  out_mesh = self.decode_latent(shape_slat, tex_slat, res)
 
28
  rembg_model (Callable): The model for removing background.
29
  low_vram (bool): Whether to use low-VRAM mode.
30
  """
31
+
32
  def __init__(
33
  self,
34
  models: dict[str, nn.Module] = None,
 
43
  image_cond_model: Callable = None,
44
  rembg_model: Callable = None,
45
  low_vram: bool = True,
46
+ default_pipeline_type: str = "1024_cascade",
47
  ):
48
  if models is None:
49
  return
 
61
  self.low_vram = low_vram
62
  self.default_pipeline_type = default_pipeline_type
63
  self.pbr_attr_layout = {
64
+ "base_color": slice(0, 3),
65
+ "metallic": slice(3, 4),
66
+ "roughness": slice(4, 5),
67
+ "alpha": slice(5, 6),
68
  }
69
+ self._device = "cpu"
70
 
71
  @staticmethod
72
  def from_pretrained(path: str) -> "Trellis2ImageTo3DPipeline":
 
76
  Args:
77
  path (str): The path to the model. Can be either local path or a Hugging Face repository.
78
  """
79
+ pipeline = super(
80
+ Trellis2ImageTo3DPipeline, Trellis2ImageTo3DPipeline
81
+ ).from_pretrained(path)
82
  new_pipeline = Trellis2ImageTo3DPipeline()
83
  new_pipeline.__dict__ = pipeline.__dict__
84
  args = pipeline._pretrained_args
85
 
86
+ new_pipeline.sparse_structure_sampler = getattr(
87
+ samplers, args["sparse_structure_sampler"]["name"]
88
+ )(**args["sparse_structure_sampler"]["args"])
89
+ new_pipeline.sparse_structure_sampler_params = args["sparse_structure_sampler"][
90
+ "params"
91
+ ]
92
+
93
+ new_pipeline.shape_slat_sampler = getattr(
94
+ samplers, args["shape_slat_sampler"]["name"]
95
+ )(**args["shape_slat_sampler"]["args"])
96
+ new_pipeline.shape_slat_sampler_params = args["shape_slat_sampler"]["params"]
97
 
98
+ new_pipeline.tex_slat_sampler = getattr(
99
+ samplers, args["tex_slat_sampler"]["name"]
100
+ )(**args["tex_slat_sampler"]["args"])
101
+ new_pipeline.tex_slat_sampler_params = args["tex_slat_sampler"]["params"]
102
 
103
+ new_pipeline.shape_slat_normalization = args["shape_slat_normalization"]
104
+ new_pipeline.tex_slat_normalization = args["tex_slat_normalization"]
105
 
106
+ new_pipeline.image_cond_model = getattr(
107
+ image_feature_extractor, args["image_cond_model"]["name"]
108
+ )(**args["image_cond_model"]["args"])
109
+ new_pipeline.rembg_model = getattr(rembg, args["rembg_model"]["name"])(
110
+ **args["rembg_model"]["args"]
111
+ )
112
 
113
+ new_pipeline.low_vram = args.get("low_vram", True)
114
+ new_pipeline.default_pipeline_type = args.get(
115
+ "default_pipeline_type", "1024_cascade"
116
+ )
 
117
  new_pipeline.pbr_attr_layout = {
118
+ "base_color": slice(0, 3),
119
+ "metallic": slice(3, 4),
120
+ "roughness": slice(4, 5),
121
+ "alpha": slice(5, 6),
122
  }
123
+ new_pipeline._device = "cpu"
124
 
125
  return new_pipeline
126
 
 
138
  """
139
  # if has alpha channel, use it directly; otherwise, remove background
140
  has_alpha = False
141
+ if input.mode == "RGBA":
142
  alpha = np.array(input)[:, :, 3]
143
  if not np.all(alpha == 255):
144
  has_alpha = True
145
  max_size = max(input.size)
146
  scale = min(1, 1024 / max_size)
147
  if scale < 1:
148
+ input = input.resize(
149
+ (int(input.width * scale), int(input.height * scale)),
150
+ Image.Resampling.LANCZOS,
151
+ )
152
  if has_alpha:
153
  output = input
154
  else:
155
+ input = input.convert("RGB")
156
  if self.low_vram:
157
  self.rembg_model.to(self.device)
158
  output = self.rembg_model(input)
 
161
  output_np = np.array(output)
162
  alpha = output_np[:, :, 3]
163
  bbox = np.argwhere(alpha > 0.8 * 255)
164
+ bbox = (
165
+ np.min(bbox[:, 1]),
166
+ np.min(bbox[:, 0]),
167
+ np.max(bbox[:, 1]),
168
+ np.max(bbox[:, 0]),
169
+ )
170
  center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
171
  size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
172
  size = int(size * 1)
173
+ bbox = (
174
+ center[0] - size // 2,
175
+ center[1] - size // 2,
176
+ center[0] + size // 2,
177
+ center[1] + size // 2,
178
+ )
179
  output = output.crop(bbox) # type: ignore
180
  output = np.array(output).astype(np.float32) / 255
181
  output = output[:, :, :3] * output[:, :, 3:4]
182
  output = Image.fromarray((output * 255).astype(np.uint8))
183
  return output
184
+
185
+ def get_cond(
186
+ self,
187
+ image: Union[torch.Tensor, list[Image.Image]],
188
+ resolution: int,
189
+ include_neg_cond: bool = True,
190
+ ) -> dict:
191
  """
192
  Get the conditioning information for the model.
193
 
 
204
  if self.low_vram:
205
  self.image_cond_model.cpu()
206
  if not include_neg_cond:
207
+ return {"cond": cond}
208
  neg_cond = torch.zeros_like(cond)
209
  return {
210
+ "cond": cond,
211
+ "neg_cond": neg_cond,
212
  }
213
 
214
  def sample_sparse_structure(
 
220
  ) -> torch.Tensor:
221
  """
222
  Sample sparse structures with the given conditioning.
223
+
224
  Args:
225
  cond (dict): The conditioning information.
226
  resolution (int): The resolution of the sparse structure.
 
228
  sampler_params (dict): Additional parameters for the sampler.
229
  """
230
  # Sample sparse structure latent
231
+ flow_model = self.models["sparse_structure_flow_model"]
232
  reso = flow_model.resolution
233
  in_channels = flow_model.in_channels
234
  noise = torch.randn(num_samples, in_channels, reso, reso, reso).to(self.device)
 
245
  ).samples
246
  if self.low_vram:
247
  flow_model.cpu()
248
+
249
  # Decode sparse structure latent
250
+ decoder = self.models["sparse_structure_decoder"]
251
  if self.low_vram:
252
  decoder.to(self.device)
253
+ decoded = decoder(z_s) > 0
254
  if self.low_vram:
255
  decoder.cpu()
256
  if resolution != decoded.shape[2]:
257
  ratio = decoded.shape[2] // resolution
258
+ decoded = (
259
+ torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
260
+ )
261
  coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int()
262
 
263
  return coords
 
271
  ) -> SparseTensor:
272
  """
273
  Sample structured latent with the given conditioning.
274
+
275
  Args:
276
  cond (dict): The conditioning information.
277
  coords (torch.Tensor): The coordinates of the sparse structure.
 
296
  if self.low_vram:
297
  flow_model.cpu()
298
 
299
+ std = torch.tensor(self.shape_slat_normalization["std"])[None].to(slat.device)
300
+ mean = torch.tensor(self.shape_slat_normalization["mean"])[None].to(slat.device)
301
  slat = slat * std + mean
302
+
303
  return slat
304
+
305
  def sample_shape_slat_cascade(
306
  self,
307
  lr_cond: dict,
 
316
  ) -> SparseTensor:
317
  """
318
  Sample structured latent with the given conditioning.
319
+
320
  Args:
321
  cond (dict): The conditioning information.
322
  coords (torch.Tensor): The coordinates of the sparse structure.
 
324
  """
325
  # LR
326
  noise = SparseTensor(
327
+ feats=torch.randn(coords.shape[0], flow_model_lr.in_channels).to(
328
+ self.device
329
+ ),
330
  coords=coords,
331
  )
332
  sampler_params = {**self.shape_slat_sampler_params, **sampler_params}
 
342
  ).samples
343
  if self.low_vram:
344
  flow_model_lr.cpu()
345
+ std = torch.tensor(self.shape_slat_normalization["std"])[None].to(slat.device)
346
+ mean = torch.tensor(self.shape_slat_normalization["mean"])[None].to(slat.device)
347
  slat = slat * std + mean
348
+
349
  # Upsample
350
  if self.low_vram:
351
+ self.models["shape_slat_decoder"].to(self.device)
352
+ self.models["shape_slat_decoder"].low_vram = True
353
+ hr_coords = self.models["shape_slat_decoder"].upsample(slat, upsample_times=4)
354
  if self.low_vram:
355
+ self.models["shape_slat_decoder"].cpu()
356
+ self.models["shape_slat_decoder"].low_vram = False
357
  hr_resolution = resolution
358
  while True:
359
+ quant_coords = torch.cat(
360
+ [
361
+ hr_coords[:, :1],
362
+ (
363
+ (hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)
364
+ ).int(),
365
+ ],
366
+ dim=1,
367
+ )
368
  coords = quant_coords.unique(dim=0)
369
  num_tokens = coords.shape[0]
370
  if num_tokens < max_num_tokens or hr_resolution == 1024:
371
  if hr_resolution != resolution:
372
+ print(
373
+ f"Due to the limited number of tokens, the resolution is reduced to {hr_resolution}."
374
+ )
375
  break
376
  hr_resolution -= 128
377
+
378
  # Sample structured latent
379
  noise = SparseTensor(
380
  feats=torch.randn(coords.shape[0], flow_model.in_channels).to(self.device),
 
394
  if self.low_vram:
395
  flow_model.cpu()
396
 
397
+ std = torch.tensor(self.shape_slat_normalization["std"])[None].to(slat.device)
398
+ mean = torch.tensor(self.shape_slat_normalization["mean"])[None].to(slat.device)
399
  slat = slat * std + mean
400
+
401
  return slat, hr_resolution
402
 
403
  def decode_shape_slat(
 
416
  List[Mesh]: The decoded meshes.
417
  List[SparseTensor]: The decoded substructures.
418
  """
419
+ self.models["shape_slat_decoder"].set_resolution(resolution)
420
  if self.low_vram:
421
+ self.models["shape_slat_decoder"].to(self.device)
422
+ self.models["shape_slat_decoder"].low_vram = True
423
+ ret = self.models["shape_slat_decoder"](slat, return_subs=True)
424
  if self.low_vram:
425
+ self.models["shape_slat_decoder"].cpu()
426
+ self.models["shape_slat_decoder"].low_vram = False
427
  return ret
428
+
429
  def sample_tex_slat(
430
  self,
431
  cond: dict,
 
435
  ) -> SparseTensor:
436
  """
437
  Sample structured latent with the given conditioning.
438
+
439
  Args:
440
  cond (dict): The conditioning information.
441
  shape_slat (SparseTensor): The structured latent for shape
442
  sampler_params (dict): Additional parameters for the sampler.
443
  """
444
  # Sample structured latent
445
+ std = torch.tensor(self.shape_slat_normalization["std"])[None].to(
446
+ shape_slat.device
447
+ )
448
+ mean = torch.tensor(self.shape_slat_normalization["mean"])[None].to(
449
+ shape_slat.device
450
+ )
451
  shape_slat = (shape_slat - mean) / std
452
 
453
+ in_channels = (
454
+ flow_model.in_channels
455
+ if isinstance(flow_model, nn.Module)
456
+ else flow_model[0].in_channels
457
+ )
458
+ noise = shape_slat.replace(
459
+ feats=torch.randn(
460
+ shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]
461
+ ).to(self.device)
462
+ )
463
  sampler_params = {**self.tex_slat_sampler_params, **sampler_params}
464
  if self.low_vram:
465
  flow_model.to(self.device)
 
475
  if self.low_vram:
476
  flow_model.cpu()
477
 
478
+ std = torch.tensor(self.tex_slat_normalization["std"])[None].to(slat.device)
479
+ mean = torch.tensor(self.tex_slat_normalization["mean"])[None].to(slat.device)
480
  slat = slat * std + mean
481
+
482
  return slat
483
 
484
  def decode_tex_slat(
 
497
  List[SparseTensor]: The decoded texture voxels
498
  """
499
  if self.low_vram:
500
+ self.models["tex_slat_decoder"].to(self.device)
501
+ ret = self.models["tex_slat_decoder"](slat, guide_subs=subs) * 0.5 + 0.5
502
  if self.low_vram:
503
+ self.models["tex_slat_decoder"].cpu()
504
  return ret
505
+
506
  @torch.no_grad()
507
  def decode_latent(
508
  self,
 
525
  m.fill_holes()
526
  out_mesh.append(
527
  MeshWithVoxel(
528
+ m.vertices,
529
+ m.faces,
530
+ origin=[-0.5, -0.5, -0.5],
531
+ voxel_size=1 / resolution,
532
+ coords=v.coords[:, 1:],
533
+ attrs=v.feats,
534
+ voxel_shape=torch.Size([*v.shape, *v.spatial_shape]),
535
+ layout=self.pbr_attr_layout,
536
  )
537
  )
538
  return out_mesh
539
+
540
  @torch.no_grad()
541
  def run(
542
  self,
 
568
  """
569
  # Check pipeline type
570
  pipeline_type = pipeline_type or self.default_pipeline_type
571
+ if pipeline_type == "512":
572
+ assert "shape_slat_flow_model_512" in self.models, (
573
+ "No 512 resolution shape SLat flow model found."
574
+ )
575
+ assert "tex_slat_flow_model_512" in self.models, (
576
+ "No 512 resolution texture SLat flow model found."
577
+ )
578
+ elif pipeline_type == "1024":
579
+ assert "shape_slat_flow_model_1024" in self.models, (
580
+ "No 1024 resolution shape SLat flow model found."
581
+ )
582
+ assert "tex_slat_flow_model_1024" in self.models, (
583
+ "No 1024 resolution texture SLat flow model found."
584
+ )
585
+ elif pipeline_type == "1024_cascade":
586
+ assert "shape_slat_flow_model_512" in self.models, (
587
+ "No 512 resolution shape SLat flow model found."
588
+ )
589
+ assert "shape_slat_flow_model_1024" in self.models, (
590
+ "No 1024 resolution shape SLat flow model found."
591
+ )
592
+ assert "tex_slat_flow_model_1024" in self.models, (
593
+ "No 1024 resolution texture SLat flow model found."
594
+ )
595
+ elif pipeline_type == "1536_cascade":
596
+ assert "shape_slat_flow_model_512" in self.models, (
597
+ "No 512 resolution shape SLat flow model found."
598
+ )
599
+ assert "shape_slat_flow_model_1024" in self.models, (
600
+ "No 1024 resolution shape SLat flow model found."
601
+ )
602
+ assert "tex_slat_flow_model_1024" in self.models, (
603
+ "No 1024 resolution texture SLat flow model found."
604
+ )
605
  else:
606
  raise ValueError(f"Invalid pipeline type: {pipeline_type}")
607
+
608
  if preprocess_image:
609
  image = self.preprocess_image(image)
610
  torch.manual_seed(seed)
611
  cond_512 = self.get_cond([image], 512)
612
+ cond_1024 = self.get_cond([image], 1024) if pipeline_type != "512" else None
613
+ ss_res = {"512": 32, "1024": 64, "1024_cascade": 32, "1536_cascade": 32}[
614
+ pipeline_type
615
+ ]
616
  coords = self.sample_sparse_structure(
617
+ cond_512, ss_res, num_samples, sparse_structure_sampler_params
 
618
  )
619
+ if pipeline_type == "512":
620
  shape_slat = self.sample_shape_slat(
621
+ cond_512,
622
+ self.models["shape_slat_flow_model_512"],
623
+ coords,
624
+ shape_slat_sampler_params,
625
  )
626
  tex_slat = self.sample_tex_slat(
627
+ cond_512,
628
+ self.models["tex_slat_flow_model_512"],
629
+ shape_slat,
630
+ tex_slat_sampler_params,
631
  )
632
  res = 512
633
+ elif pipeline_type == "1024":
634
  shape_slat = self.sample_shape_slat(
635
+ cond_1024,
636
+ self.models["shape_slat_flow_model_1024"],
637
+ coords,
638
+ shape_slat_sampler_params,
639
  )
640
  tex_slat = self.sample_tex_slat(
641
+ cond_1024,
642
+ self.models["tex_slat_flow_model_1024"],
643
+ shape_slat,
644
+ tex_slat_sampler_params,
645
  )
646
  res = 1024
647
+ elif pipeline_type == "1024_cascade":
648
  shape_slat, res = self.sample_shape_slat_cascade(
649
+ cond_512,
650
+ cond_1024,
651
+ self.models["shape_slat_flow_model_512"],
652
+ self.models["shape_slat_flow_model_1024"],
653
+ 512,
654
+ 1024,
655
+ coords,
656
+ shape_slat_sampler_params,
657
+ max_num_tokens,
658
  )
659
  tex_slat = self.sample_tex_slat(
660
+ cond_1024,
661
+ self.models["tex_slat_flow_model_1024"],
662
+ shape_slat,
663
+ tex_slat_sampler_params,
664
  )
665
+ elif pipeline_type == "1536_cascade":
666
  shape_slat, res = self.sample_shape_slat_cascade(
667
+ cond_512,
668
+ cond_1024,
669
+ self.models["shape_slat_flow_model_512"],
670
+ self.models["shape_slat_flow_model_1024"],
671
+ 512,
672
+ 1536,
673
+ coords,
674
+ shape_slat_sampler_params,
675
+ max_num_tokens,
676
  )
677
  tex_slat = self.sample_tex_slat(
678
+ cond_1024,
679
+ self.models["tex_slat_flow_model_1024"],
680
+ shape_slat,
681
+ tex_slat_sampler_params,
682
  )
683
  torch.cuda.empty_cache()
684
  out_mesh = self.decode_latent(shape_slat, tex_slat, res)