Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
163fb7c
1
Parent(s):
65e9322
update
Browse files- app.py +2 -1
- requirements.txt +1 -1
- trellis2/pipelines/trellis2_image_to_3d.py +9 -9
app.py
CHANGED
|
@@ -4,6 +4,7 @@ import spaces
|
|
| 4 |
import os
|
| 5 |
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
|
| 6 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
|
|
|
| 7 |
os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
|
| 8 |
os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
|
| 9 |
from datetime import datetime
|
|
@@ -330,7 +331,7 @@ with gr.Blocks(delete_cache=(600, 600)) as demo:
|
|
| 330 |
|
| 331 |
# Launch the Gradio app
|
| 332 |
if __name__ == "__main__":
|
| 333 |
-
pipeline = Trellis2ImageTo3DPipeline.from_pretrained('
|
| 334 |
pipeline.low_vram = False
|
| 335 |
pipeline.cuda()
|
| 336 |
|
|
|
|
| 4 |
import os
|
| 5 |
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
|
| 6 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 7 |
+
os.environ["ATTN_BACKEND"] = "flash_attn_3"
|
| 8 |
os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
|
| 9 |
os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
|
| 10 |
from datetime import datetime
|
|
|
|
| 331 |
|
| 332 |
# Launch the Gradio app
|
| 333 |
if __name__ == "__main__":
|
| 334 |
+
pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
|
| 335 |
pipeline.low_vram = False
|
| 336 |
pipeline.cuda()
|
| 337 |
|
requirements.txt
CHANGED
|
@@ -16,7 +16,7 @@ kornia==0.8.2
|
|
| 16 |
timm==1.0.22
|
| 17 |
git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
|
| 18 |
https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
| 19 |
-
|
| 20 |
https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/cumesh-0.0.1-cp310-cp310-linux_x86_64.whl
|
| 21 |
https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/flex_gemm-0.0.1-cp310-cp310-linux_x86_64.whl
|
| 22 |
https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/o_voxel-0.0.1-cp310-cp310-linux_x86_64.whl
|
|
|
|
| 16 |
timm==1.0.22
|
| 17 |
git+https://github.com/EasternJournalist/utils3d.git@9a4eb15e4021b67b12c460c7057d642626897ec8
|
| 18 |
https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
| 19 |
+
https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
|
| 20 |
https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/cumesh-0.0.1-cp310-cp310-linux_x86_64.whl
|
| 21 |
https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/flex_gemm-0.0.1-cp310-cp310-linux_x86_64.whl
|
| 22 |
https://github.com/JeffreyXiang/Storages/releases/download/Space_Wheels_251210/o_voxel-0.0.1-cp310-cp310-linux_x86_64.whl
|
trellis2/pipelines/trellis2_image_to_3d.py
CHANGED
|
@@ -43,7 +43,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 43 |
image_cond_model: Callable = None,
|
| 44 |
rembg_model: Callable = None,
|
| 45 |
low_vram: bool = True,
|
| 46 |
-
default_pipeline_type: str = '
|
| 47 |
):
|
| 48 |
if models is None:
|
| 49 |
return
|
|
@@ -97,7 +97,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 97 |
new_pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
|
| 98 |
|
| 99 |
new_pipeline.low_vram = args.get('low_vram', True)
|
| 100 |
-
new_pipeline.default_pipeline_type = args.get('default_pipeline_type', '
|
| 101 |
new_pipeline.pbr_attr_layout = {
|
| 102 |
'base_color': slice(0, 3),
|
| 103 |
'metallic': slice(3, 4),
|
|
@@ -114,7 +114,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 114 |
super().to(device)
|
| 115 |
self.image_cond_model.to(device)
|
| 116 |
self.rembg_model.to(device)
|
| 117 |
-
|
| 118 |
@spaces.GPU()
|
| 119 |
def remove_background(self, input: Image.Image) -> Image.Image:
|
| 120 |
input = input.convert('RGB')
|
|
@@ -509,7 +509,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 509 |
tex_slat_sampler_params (dict): Additional parameters for the texture SLat sampler.
|
| 510 |
preprocess_image (bool): Whether to preprocess the image.
|
| 511 |
return_latent (bool): Whether to return the latent codes.
|
| 512 |
-
pipeline_type (str): The type of the pipeline. Options: '512', '1024', '
|
| 513 |
max_num_tokens (int): The maximum number of tokens to use.
|
| 514 |
"""
|
| 515 |
# Check pipeline type
|
|
@@ -520,11 +520,11 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 520 |
elif pipeline_type == '1024':
|
| 521 |
assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
|
| 522 |
assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
|
| 523 |
-
elif pipeline_type == '
|
| 524 |
assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
|
| 525 |
assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
|
| 526 |
assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
|
| 527 |
-
elif pipeline_type == '
|
| 528 |
assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
|
| 529 |
assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
|
| 530 |
assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
|
|
@@ -536,7 +536,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 536 |
torch.manual_seed(seed)
|
| 537 |
cond_512 = self.get_cond([image], 512)
|
| 538 |
cond_1024 = self.get_cond([image], 1024) if pipeline_type != '512' else None
|
| 539 |
-
ss_res = {'512': 32, '1024': 64, '
|
| 540 |
coords = self.sample_sparse_structure(
|
| 541 |
cond_512, ss_res,
|
| 542 |
num_samples, sparse_structure_sampler_params
|
|
@@ -561,7 +561,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 561 |
shape_slat, tex_slat_sampler_params
|
| 562 |
)
|
| 563 |
res = 1024
|
| 564 |
-
elif pipeline_type == '
|
| 565 |
shape_slat, res = self.sample_shape_slat_cascade(
|
| 566 |
cond_512, cond_1024,
|
| 567 |
self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'],
|
|
@@ -573,7 +573,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 573 |
cond_1024, self.models['tex_slat_flow_model_1024'],
|
| 574 |
shape_slat, tex_slat_sampler_params
|
| 575 |
)
|
| 576 |
-
elif pipeline_type == '
|
| 577 |
shape_slat, res = self.sample_shape_slat_cascade(
|
| 578 |
cond_512, cond_1024,
|
| 579 |
self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'],
|
|
|
|
| 43 |
image_cond_model: Callable = None,
|
| 44 |
rembg_model: Callable = None,
|
| 45 |
low_vram: bool = True,
|
| 46 |
+
default_pipeline_type: str = '1024_cascade',
|
| 47 |
):
|
| 48 |
if models is None:
|
| 49 |
return
|
|
|
|
| 97 |
new_pipeline.rembg_model = getattr(rembg, args['rembg_model']['name'])(**args['rembg_model']['args'])
|
| 98 |
|
| 99 |
new_pipeline.low_vram = args.get('low_vram', True)
|
| 100 |
+
new_pipeline.default_pipeline_type = args.get('default_pipeline_type', '1024_cascade')
|
| 101 |
new_pipeline.pbr_attr_layout = {
|
| 102 |
'base_color': slice(0, 3),
|
| 103 |
'metallic': slice(3, 4),
|
|
|
|
| 114 |
super().to(device)
|
| 115 |
self.image_cond_model.to(device)
|
| 116 |
self.rembg_model.to(device)
|
| 117 |
+
|
| 118 |
@spaces.GPU()
|
| 119 |
def remove_background(self, input: Image.Image) -> Image.Image:
|
| 120 |
input = input.convert('RGB')
|
|
|
|
| 509 |
tex_slat_sampler_params (dict): Additional parameters for the texture SLat sampler.
|
| 510 |
preprocess_image (bool): Whether to preprocess the image.
|
| 511 |
return_latent (bool): Whether to return the latent codes.
|
| 512 |
+
pipeline_type (str): The type of the pipeline. Options: '512', '1024', '1024_cascade', '1536_cascade'.
|
| 513 |
max_num_tokens (int): The maximum number of tokens to use.
|
| 514 |
"""
|
| 515 |
# Check pipeline type
|
|
|
|
| 520 |
elif pipeline_type == '1024':
|
| 521 |
assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
|
| 522 |
assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
|
| 523 |
+
elif pipeline_type == '1024_cascade':
|
| 524 |
assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
|
| 525 |
assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
|
| 526 |
assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
|
| 527 |
+
elif pipeline_type == '1536_cascade':
|
| 528 |
assert 'shape_slat_flow_model_512' in self.models, "No 512 resolution shape SLat flow model found."
|
| 529 |
assert 'shape_slat_flow_model_1024' in self.models, "No 1024 resolution shape SLat flow model found."
|
| 530 |
assert 'tex_slat_flow_model_1024' in self.models, "No 1024 resolution texture SLat flow model found."
|
|
|
|
| 536 |
torch.manual_seed(seed)
|
| 537 |
cond_512 = self.get_cond([image], 512)
|
| 538 |
cond_1024 = self.get_cond([image], 1024) if pipeline_type != '512' else None
|
| 539 |
+
ss_res = {'512': 32, '1024': 64, '1024_cascade': 32, '1536_cascade': 32}[pipeline_type]
|
| 540 |
coords = self.sample_sparse_structure(
|
| 541 |
cond_512, ss_res,
|
| 542 |
num_samples, sparse_structure_sampler_params
|
|
|
|
| 561 |
shape_slat, tex_slat_sampler_params
|
| 562 |
)
|
| 563 |
res = 1024
|
| 564 |
+
elif pipeline_type == '1024_cascade':
|
| 565 |
shape_slat, res = self.sample_shape_slat_cascade(
|
| 566 |
cond_512, cond_1024,
|
| 567 |
self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'],
|
|
|
|
| 573 |
cond_1024, self.models['tex_slat_flow_model_1024'],
|
| 574 |
shape_slat, tex_slat_sampler_params
|
| 575 |
)
|
| 576 |
+
elif pipeline_type == '1536_cascade':
|
| 577 |
shape_slat, res = self.sample_shape_slat_cascade(
|
| 578 |
cond_512, cond_1024,
|
| 579 |
self.models['shape_slat_flow_model_512'], self.models['shape_slat_flow_model_1024'],
|