JeffreyXiang commited on
Commit
163fb7c
·
1 Parent(s): 65e9322
app.py CHANGED
@@ -4,6 +4,7 @@ import spaces
4
  import os
5
  os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
6
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
 
7
  os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
8
  os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
9
  from datetime import datetime
@@ -330,7 +331,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
330
 
331
  # Launch the Gradio app
332
  if __name__ == "__main__":
333
- pipeline = Trellis2ImageTo3DPipeline.from_pretrained('JeffreyXiang/TRELLIS.2-4B')
334
  pipeline.low_vram = False
335
  pipeline.cuda()
336
 
 
4
  import os
5
  os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
6
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
7
+ os.environ["ATTN_BACKEND"] = "flash_attn_3"
8
  os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
9
  os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
10
  from datetime import datetime
 
331
 
332
  # Launch the Gradio app
333
  if __name__ == "__main__":
334
+ pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
335
  pipeline.low_vram = False
336
  pipeline.cuda()
337
 
requirements.txt CHANGED
@@ -16,7 +16,7 @@ kornia==0.8.2
16
  timm==1.0.22
17
  git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
18
  https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
19
-
20
  https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/cumesh-0.0.1-cp310-cp310-linux_x86_64.whl
21
  https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/flex_gemm-0.0.1-cp310-cp310-linux_x86_64.whl
22
  https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/o_voxel-0.0.1-cp310-cp310-linux_x86_64.whl
 
16
  timm==1.0.22
17
  git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
18
  https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
19
+ https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
20
  https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/cumesh-0.0.1-cp310-cp310-linux_x86_64.whl
21
  https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/flex_gemm-0.0.1-cp310-cp310-linux_x86_64.whl
22
  https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/o_voxel-0.0.1-cp310-cp310-linux_x86_64.whl
trellis2/pipelines/trellis2_image_to_3d.py CHANGED
@@ -43,7 +43,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
43
  image_cond_model: Callable = None,
44
  rembg_model: Callable = None,
45
  low_vram: bool = True,
46
- default_pipeline_type: str = '512->1024',
47
  ):
48
  if models is None:
49
  return
@@ -97,7 +97,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
97
  new_pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
98
 
99
  new_pipeline.low_vram = args.get('low_vram', True)
100
- new_pipeline.default_pipeline_type = args.get('default_pipeline_type', '512->1024')
101
  new_pipeline.pbr_attr_layout = {
102
  'base_color': slice(0, 3),
103
  'metallic': slice(3, 4),
@@ -114,7 +114,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
114
  super().to(device)
115
  self.image_cond_model.to(device)
116
  self.rembg_model.to(device)
117
-
118
  @spaces.GPU()
119
  def remove_background(self, input: Image.Image) -> Image.Image:
120
  input = input.convert('RGB')
@@ -509,7 +509,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
509
  tex_slat_sampler_params (dict): Additional parameters for the texture SLat sampler.
510
  preprocess_image (bool): Whether to preprocess the image.
511
  return_latent (bool): Whether to return the latent codes.
512
- pipeline_type (str): The type of the pipeline. Options: '512', '1024', '512->1024', '512->1536'.
513
  max_num_tokens (int): The maximum number of tokens to use.
514
  """
515
  # Check pipeline type
@@ -520,11 +520,11 @@ class Trellis2ImageTo3DPipeline(Pipeline):
520
  elif pipeline_type == '1024':
521
  assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
522
  assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
523
- elif pipeline_type == '512->1024':
524
  assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
525
  assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
526
  assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
527
- elif pipeline_type == '512->1536':
528
  assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
529
  assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
530
  assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
@@ -536,7 +536,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
536
  torch.manual_seed(seed)
537
  cond_512 = self.get_cond([image], 512)
538
  cond_1024 = self.get_cond([image], 1024) if pipeline_type != '512' else None
539
- ss_res = {'512': 32, '1024': 64, '512->1024': 32, '512->1536': 32}[pipeline_type]
540
  coords = self.sample_sparse_structure(
541
  cond_512, ss_res,
542
  num_samples, sparse_structure_sampler_params
@@ -561,7 +561,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
561
  shape_slat, tex_slat_sampler_params
562
  )
563
  res = 1024
564
- elif pipeline_type == '512->1024':
565
  shape_slat, res = self.sample_shape_slat_cascade(
566
  cond_512, cond_1024,
567
  self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'],
@@ -573,7 +573,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
573
  cond_1024, self.models['tex_slat_flow_model_1024'],
574
  shape_slat, tex_slat_sampler_params
575
  )
576
- elif pipeline_type == '512->1536':
577
  shape_slat, res = self.sample_shape_slat_cascade(
578
  cond_512, cond_1024,
579
  self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'],
 
43
  image_cond_model: Callable = None,
44
  rembg_model: Callable = None,
45
  low_vram: bool = True,
46
+ default_pipeline_type: str = '1024_cascade',
47
  ):
48
  if models is None:
49
  return
 
97
  new_pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
98
 
99
  new_pipeline.low_vram = args.get('low_vram', True)
100
+ new_pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
101
  new_pipeline.pbr_attr_layout = {
102
  'base_color': slice(0, 3),
103
  'metallic': slice(3, 4),
 
114
  super().to(device)
115
  self.image_cond_model.to(device)
116
  self.rembg_model.to(device)
117
+
118
  @spaces.GPU()
119
  def remove_background(self, input: Image.Image) -> Image.Image:
120
  input = input.convert('RGB')
 
509
  tex_slat_sampler_params (dict): Additional parameters for the texture SLat sampler.
510
  preprocess_image (bool): Whether to preprocess the image.
511
  return_latent (bool): Whether to return the latent codes.
512
+ pipeline_type (str): The type of the pipeline. Options: '512', '1024', '1024_cascade', '1536_cascade'.
513
  max_num_tokens (int): The maximum number of tokens to use.
514
  """
515
  # Check pipeline type
 
520
  elif pipeline_type == '1024':
521
  assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
522
  assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
523
+ elif pipeline_type == '1024_cascade':
524
  assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
525
  assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
526
  assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
527
+ elif pipeline_type == '1536_cascade':
528
  assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
529
  assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
530
  assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
 
536
  torch.manual_seed(seed)
537
  cond_512 = self.get_cond([image], 512)
538
  cond_1024 = self.get_cond([image], 1024) if pipeline_type != '512' else None
539
+ ss_res = {'512': 32, '1024': 64, '1024_cascade': 32, '1536_cascade': 32}[pipeline_type]
540
  coords = self.sample_sparse_structure(
541
  cond_512, ss_res,
542
  num_samples, sparse_structure_sampler_params
 
561
  shape_slat, tex_slat_sampler_params
562
  )
563
  res = 1024
564
+ elif pipeline_type == '1024_cascade':
565
  shape_slat, res = self.sample_shape_slat_cascade(
566
  cond_512, cond_1024,
567
  self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'],
 
573
  cond_1024, self.models['tex_slat_flow_model_1024'],
574
  shape_slat, tex_slat_sampler_params
575
  )
576
+ elif pipeline_type == '1536_cascade':
577
  shape_slat, res = self.sample_shape_slat_cascade(
578
  cond_512, cond_1024,
579
  self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'],