opsiclear-admin commited on
Commit
a0ea9e7
·
verified ·
1 Parent(s): 1de8ccc

Fix multidiffusion CFG and add per-stage texture mode

Browse files
trellis2/pipelines/trellis2_image_to_3d.py CHANGED
@@ -637,23 +637,36 @@ class Trellis2ImageTo3DPipeline(Pipeline):
637
 
638
  elif mode == 'multidiffusion':
639
  from .samplers import FlowEulerSampler
640
- def _new_inference_model(self, model, x_t, t, cond, **kwargs):
641
- # if cfg_interval[0] <= t <= cfg_interval[1]:
642
- # preds = []
643
- # for i in range(len(cond)):
644
- # preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i + 1], **kwargs))
645
- # pred = sum(preds) / len(preds)
646
- # neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs)
647
- # return (1 + cfg_strength) * pred - cfg_strength * neg_pred
648
- # else:
649
-
650
- # Filter out guidance-related kwargs that the base sampler doesn't handle
651
- filtered_kwargs = {k: v for k, v in kwargs.items()
652
- if k not in ('neg_cond', 'guidance_strength', 'guidance_interval', 'guidance_rescale')}
653
  preds = []
654
  for i in range(len(cond)):
655
- preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i + 1], **filtered_kwargs))
656
- pred = sum(preds) / len(preds)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
657
  return pred
658
 
659
  else:
@@ -679,7 +692,8 @@ class Trellis2ImageTo3DPipeline(Pipeline):
679
  return_latent: bool = False,
680
  pipeline_type: Optional[str] = None,
681
  max_num_tokens: int = 49152,
682
- mode: Literal['stochastic', 'multidiffusion'] = 'multidiffusion',
 
683
  ) -> List[MeshWithVoxel]:
684
  """
685
  Run the multi-image pipeline.
@@ -695,8 +709,10 @@ class Trellis2ImageTo3DPipeline(Pipeline):
695
  return_latent (bool): Whether to return the latent codes.
696
  pipeline_type (str): The type of the pipeline. Options: '512', '1024', '1024_cascade', '1536_cascade'.
697
  max_num_tokens (int): The maximum number of tokens to use.
698
- mode: The multi-image conditioning mode.
 
699
  """
 
700
  # Check pipeline type
701
  pipeline_type = pipeline_type or self.default_pipeline_type
702
  if pipeline_type == '512':
@@ -740,7 +756,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
740
  if pipeline_type == '512':
741
  with (
742
  self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode),
743
- self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=mode),
744
  ):
745
  shape_slat = self.sample_shape_slat(
746
  cond_512, self.models['shape_slat_flow_model_512'],
@@ -754,7 +770,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
754
  elif pipeline_type == '1024':
755
  with (
756
  self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode),
757
- self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=mode),
758
  ):
759
  shape_slat = self.sample_shape_slat(
760
  cond_1024, self.models['shape_slat_flow_model_1024'],
@@ -768,7 +784,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
768
  elif pipeline_type == '1024_cascade':
769
  with (
770
  self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode),
771
- self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=mode),
772
  ):
773
  shape_slat, res = self.sample_shape_slat_cascade(
774
  cond_512, cond_1024,
@@ -784,7 +800,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
784
  elif pipeline_type == '1536_cascade':
785
  with (
786
  self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode),
787
- self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=mode),
788
  ):
789
  shape_slat, res = self.sample_shape_slat_cascade(
790
  cond_512, cond_1024,
 
637
 
638
  elif mode == 'multidiffusion':
639
  from .samplers import FlowEulerSampler
640
+ def _new_inference_model(self, model, x_t, t, cond,
641
+ neg_cond=None, guidance_strength=1,
642
+ guidance_interval=(0, 1), guidance_rescale=0.0,
643
+ **kwargs):
644
+ # Average per-image positive predictions
 
 
 
 
 
 
 
 
645
  preds = []
646
  for i in range(len(cond)):
647
+ preds.append(FlowEulerSampler._inference_model(self, model, x_t, t, cond[i:i + 1], **kwargs))
648
+ avg_pred = sum(preds) / len(preds)
649
+
650
+ # Apply guidance interval
651
+ if not (guidance_interval[0] <= t <= guidance_interval[1]):
652
+ guidance_strength = 1
653
+
654
+ # Apply CFG
655
+ if guidance_strength == 1 or neg_cond is None:
656
+ return avg_pred
657
+ neg_pred = FlowEulerSampler._inference_model(self, model, x_t, t, neg_cond, **kwargs)
658
+ pred = guidance_strength * avg_pred + (1 - guidance_strength) * neg_pred
659
+
660
+ # Apply guidance rescale
661
+ if guidance_rescale > 0:
662
+ x_0_pos = self._pred_to_xstart(x_t, t, avg_pred)
663
+ x_0_cfg = self._pred_to_xstart(x_t, t, pred)
664
+ std_pos = x_0_pos.std(dim=list(range(1, x_0_pos.ndim)), keepdim=True)
665
+ std_cfg = x_0_cfg.std(dim=list(range(1, x_0_cfg.ndim)), keepdim=True)
666
+ x_0_rescaled = x_0_cfg * (std_pos / std_cfg)
667
+ x_0 = guidance_rescale * x_0_rescaled + (1 - guidance_rescale) * x_0_cfg
668
+ pred = self._xstart_to_pred(x_t, t, x_0)
669
+
670
  return pred
671
 
672
  else:
 
692
  return_latent: bool = False,
693
  pipeline_type: Optional[str] = None,
694
  max_num_tokens: int = 49152,
695
+ mode: Literal['stochastic', 'multidiffusion'] = 'stochastic',
696
+ tex_mode: Optional[Literal['stochastic', 'multidiffusion']] = None,
697
  ) -> List[MeshWithVoxel]:
698
  """
699
  Run the multi-image pipeline.
 
709
  return_latent (bool): Whether to return the latent codes.
710
  pipeline_type (str): The type of the pipeline. Options: '512', '1024', '1024_cascade', '1536_cascade'.
711
  max_num_tokens (int): The maximum number of tokens to use.
712
+ mode: The multi-image conditioning mode for structure and shape.
713
+ tex_mode: The multi-image conditioning mode for texture. If None, uses mode.
714
  """
715
+ tex_mode = tex_mode or mode
716
  # Check pipeline type
717
  pipeline_type = pipeline_type or self.default_pipeline_type
718
  if pipeline_type == '512':
 
756
  if pipeline_type == '512':
757
  with (
758
  self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode),
759
+ self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=tex_mode),
760
  ):
761
  shape_slat = self.sample_shape_slat(
762
  cond_512, self.models['shape_slat_flow_model_512'],
 
770
  elif pipeline_type == '1024':
771
  with (
772
  self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode),
773
+ self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=tex_mode),
774
  ):
775
  shape_slat = self.sample_shape_slat(
776
  cond_1024, self.models['shape_slat_flow_model_1024'],
 
784
  elif pipeline_type == '1024_cascade':
785
  with (
786
  self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode),
787
+ self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=tex_mode),
788
  ):
789
  shape_slat, res = self.sample_shape_slat_cascade(
790
  cond_512, cond_1024,
 
800
  elif pipeline_type == '1536_cascade':
801
  with (
802
  self.inject_sampler_multi_image('shape_slat_sampler', len(images), shape_slat_steps, mode=mode),
803
+ self.inject_sampler_multi_image('tex_slat_sampler', len(images), tex_slat_steps, mode=tex_mode),
804
  ):
805
  shape_slat, res = self.sample_shape_slat_cascade(
806
  cond_512, cond_1024,