Spaces:
Running on Zero
Running on Zero
Fix multidiffusion CFG and add per-stage texture mode
Browse files
trellis2/pipelines/trellis2_image_to_3d.py
CHANGED
|
@@ -637,23 +637,36 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 637 |
|
| 638 |
elif mode == 'multidiffusion':
|
| 639 |
from .samplers import FlowEulerSampler
|
| 640 |
-
def _new_inference_model(self, model, x_t, t, cond,
|
| 641 |
-
|
| 642 |
-
|
| 643 |
-
|
| 644 |
-
#
|
| 645 |
-
# pred = sum(preds) / len(preds)
|
| 646 |
-
# neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs)
|
| 647 |
-
# return (1 + cfg_strength) * pred - cfg_strength * neg_pred
|
| 648 |
-
# else:
|
| 649 |
-
|
| 650 |
-
# Filter out guidance-related kwargs that the base sampler doesn't handle
|
| 651 |
-
filtered_kwargs = {k: v for k, v in kwargs.items()
|
| 652 |
-
if k not in ('neg_cond', 'guidance_strength', 'guidance_interval', 'guidance_rescale')}
|
| 653 |
preds = []
|
| 654 |
for i in range(len(cond)):
|
| 655 |
-
preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i + 1], **
|
| 656 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 657 |
return pred
|
| 658 |
|
| 659 |
else:
|
|
@@ -679,7 +692,8 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 679 |
return_latent: bool = False,
|
| 680 |
pipeline_type: Optional[str] = None,
|
| 681 |
max_num_tokens: int = 49152,
|
| 682 |
-
mode: Literal['stochastic', 'multidiffusion'] = '
|
|
|
|
| 683 |
) -> List[MeshWithVoxel]:
|
| 684 |
"""
|
| 685 |
Run the multi-image pipeline.
|
|
@@ -695,8 +709,10 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 695 |
return_latent (bool): Whether to return the latent codes.
|
| 696 |
pipeline_type (str): The type of the pipeline. Options: '512', '1024', '1024_cascade', '1536_cascade'.
|
| 697 |
max_num_tokens (int): The maximum number of tokens to use.
|
| 698 |
-
mode: The multi-image conditioning mode.
|
|
|
|
| 699 |
"""
|
|
|
|
| 700 |
# Check pipeline type
|
| 701 |
pipeline_type = pipeline_type or self.default_pipeline_type
|
| 702 |
if pipeline_type == '512':
|
|
@@ -740,7 +756,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 740 |
if pipeline_type == '512':
|
| 741 |
with (
|
| 742 |
self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode),
|
| 743 |
-
self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=
|
| 744 |
):
|
| 745 |
shape_slat = self.sample_shape_slat(
|
| 746 |
cond_512, self.models['shape_slat_flow_model_512'],
|
|
@@ -754,7 +770,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 754 |
elif pipeline_type == '1024':
|
| 755 |
with (
|
| 756 |
self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode),
|
| 757 |
-
self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=
|
| 758 |
):
|
| 759 |
shape_slat = self.sample_shape_slat(
|
| 760 |
cond_1024, self.models['shape_slat_flow_model_1024'],
|
|
@@ -768,7 +784,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 768 |
elif pipeline_type == '1024_cascade':
|
| 769 |
with (
|
| 770 |
self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode),
|
| 771 |
-
self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=
|
| 772 |
):
|
| 773 |
shape_slat, res = self.sample_shape_slat_cascade(
|
| 774 |
cond_512, cond_1024,
|
|
@@ -784,7 +800,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 784 |
elif pipeline_type == '1536_cascade':
|
| 785 |
with (
|
| 786 |
self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode),
|
| 787 |
-
self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=
|
| 788 |
):
|
| 789 |
shape_slat, res = self.sample_shape_slat_cascade(
|
| 790 |
cond_512, cond_1024,
|
|
|
|
| 637 |
|
| 638 |
elif mode == 'multidiffusion':
|
| 639 |
from .samplers import FlowEulerSampler
|
| 640 |
+
def _new_inference_model(self, model, x_t, t, cond,
|
| 641 |
+
neg_cond=None, guidance_strength=1,
|
| 642 |
+
guidance_interval=(0, 1), guidance_rescale=0.0,
|
| 643 |
+
**kwargs):
|
| 644 |
+
# Average per-image positive predictions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
preds = []
|
| 646 |
for i in range(len(cond)):
|
| 647 |
+
preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i + 1], **kwargs))
|
| 648 |
+
avg_pred = sum(preds) / len(preds)
|
| 649 |
+
|
| 650 |
+
# Apply guidance interval
|
| 651 |
+
if not (guidance_interval[0] <= t <= guidance_interval[1]):
|
| 652 |
+
guidance_strength = 1
|
| 653 |
+
|
| 654 |
+
# Apply CFG
|
| 655 |
+
if guidance_strength == 1 or neg_cond is None:
|
| 656 |
+
return avg_pred
|
| 657 |
+
neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs)
|
| 658 |
+
pred = guidance_strength * avg_pred + (1 - guidance_strength) * neg_pred
|
| 659 |
+
|
| 660 |
+
# Apply guidance rescale
|
| 661 |
+
if guidance_rescale > 0:
|
| 662 |
+
x_0_pos = self._pred_to_xstart(x_t, t, avg_pred)
|
| 663 |
+
x_0_cfg = self._pred_to_xstart(x_t, t, pred)
|
| 664 |
+
std_pos = x_0_pos.std(dim=list(range(1, x_0_pos.ndim)), keepdim=True)
|
| 665 |
+
std_cfg = x_0_cfg.std(dim=list(range(1, x_0_cfg.ndim)), keepdim=True)
|
| 666 |
+
x_0_rescaled = x_0_cfg * (std_pos / std_cfg)
|
| 667 |
+
x_0 = guidance_rescale * x_0_rescaled + (1 - guidance_rescale) * x_0_cfg
|
| 668 |
+
pred = self._xstart_to_pred(x_t, t, x_0)
|
| 669 |
+
|
| 670 |
return pred
|
| 671 |
|
| 672 |
else:
|
|
|
|
| 692 |
return_latent: bool = False,
|
| 693 |
pipeline_type: Optional[str] = None,
|
| 694 |
max_num_tokens: int = 49152,
|
| 695 |
+
mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
|
| 696 |
+
tex_mode: Optional[Literal['stochastic', 'multidiffusion']] = None,
|
| 697 |
) -> List[MeshWithVoxel]:
|
| 698 |
"""
|
| 699 |
Run the multi-image pipeline.
|
|
|
|
| 709 |
return_latent (bool): Whether to return the latent codes.
|
| 710 |
pipeline_type (str): The type of the pipeline. Options: '512', '1024', '1024_cascade', '1536_cascade'.
|
| 711 |
max_num_tokens (int): The maximum number of tokens to use.
|
| 712 |
+
mode: The multi-image conditioning mode for structure and shape.
|
| 713 |
+
tex_mode: The multi-image conditioning mode for texture. If None, uses mode.
|
| 714 |
"""
|
| 715 |
+
tex_mode = tex_mode or mode
|
| 716 |
# Check pipeline type
|
| 717 |
pipeline_type = pipeline_type or self.default_pipeline_type
|
| 718 |
if pipeline_type == '512':
|
|
|
|
| 756 |
if pipeline_type == '512':
|
| 757 |
with (
|
| 758 |
self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode),
|
| 759 |
+
self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=tex_mode),
|
| 760 |
):
|
| 761 |
shape_slat = self.sample_shape_slat(
|
| 762 |
cond_512, self.models['shape_slat_flow_model_512'],
|
|
|
|
| 770 |
elif pipeline_type == '1024':
|
| 771 |
with (
|
| 772 |
self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode),
|
| 773 |
+
self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=tex_mode),
|
| 774 |
):
|
| 775 |
shape_slat = self.sample_shape_slat(
|
| 776 |
cond_1024, self.models['shape_slat_flow_model_1024'],
|
|
|
|
| 784 |
elif pipeline_type == '1024_cascade':
|
| 785 |
with (
|
| 786 |
self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode),
|
| 787 |
+
self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=tex_mode),
|
| 788 |
):
|
| 789 |
shape_slat, res = self.sample_shape_slat_cascade(
|
| 790 |
cond_512, cond_1024,
|
|
|
|
| 800 |
elif pipeline_type == '1536_cascade':
|
| 801 |
with (
|
| 802 |
self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode),
|
| 803 |
+
self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=tex_mode),
|
| 804 |
):
|
| 805 |
shape_slat, res = self.sample_shape_slat_cascade(
|
| 806 |
cond_512, cond_1024,
|