opsiclear-admin commited on
Commit
9c5ed9a
·
verified ·
1 Parent(s): 04f45c8

Fix black texture bug: detect degenerate sampler trajectories and retry

Browse files
trellis2/pipelines/trellis2_image_to_3d.py CHANGED
@@ -392,20 +392,62 @@ class Trellis2ImageTo3DPipeline(Pipeline):
392
  self.models['shape_slat_decoder'].low_vram = False
393
  return ret
394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  def sample_tex_slat(
396
  self,
397
  cond: dict,
398
  flow_model,
399
  shape_slat: SparseTensor,
400
  sampler_params: dict = {},
 
 
401
  ) -> SparseTensor:
402
  """
403
  Sample structured latent with the given conditioning.
404
-
 
 
 
 
405
  Args:
406
  cond (dict): The conditioning information.
407
  shape_slat (SparseTensor): The structured latent for shape
408
  sampler_params (dict): Additional parameters for the sampler.
 
 
409
  """
410
  # Sample structured latent
411
  std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
@@ -413,10 +455,24 @@ class Trellis2ImageTo3DPipeline(Pipeline):
413
  shape_slat = (shape_slat - mean) / std
414
 
415
  in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels
416
- noise = shape_slat.replace(feats=torch.randn(shape_slat.coords.shape[0], in_channels - shape_slat.feats.shape[1]).to(self.device))
417
  sampler_params = {**self.tex_slat_sampler_params, **sampler_params}
418
  if self.low_vram:
419
  flow_model.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  slat = self.tex_slat_sampler.sample(
421
  flow_model,
422
  noise,
@@ -432,7 +488,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
432
  std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device)
433
  mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device)
434
  slat = slat * std + mean
435
-
436
  return slat
437
 
438
  def decode_tex_slat(
 
392
  self.models['shape_slat_decoder'].low_vram = False
393
  return ret
394
 
395
+ def _detect_degenerate_tex(
396
+ self,
397
+ flow_model,
398
+ noise: SparseTensor,
399
+ concat_cond: SparseTensor,
400
+ cond: dict,
401
+ sampler_params: dict,
402
+ threshold: float = 1.037,
403
+ ) -> bool:
404
+ """
405
+ Run 2 probe steps to detect degenerate texture flow trajectory.
406
+
407
+ Returns True if the trajectory is degenerate (will produce black texture).
408
+ """
409
+ steps = sampler_params.get('steps', 12)
410
+ rescale_t = sampler_params.get('rescale_t', 3.0)
411
+ t_seq = np.linspace(1, 0, steps + 1)
412
+ t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
413
+ t_seq = t_seq.tolist()
414
+
415
+ sample = noise
416
+ for step_idx in range(2):
417
+ t, t_prev = t_seq[step_idx], t_seq[step_idx + 1]
418
+ out = self.tex_slat_sampler.sample_once(
419
+ flow_model, sample, t, t_prev,
420
+ concat_cond=concat_cond,
421
+ **{k: v for k, v in cond.items()},
422
+ **{k: v for k, v in sampler_params.items() if k not in ('steps', 'rescale_t')},
423
+ )
424
+ sample = out.pred_x_prev
425
+
426
+ x0_std = out.pred_x_0.feats.std().item()
427
+ return x0_std > threshold
428
+
429
  def sample_tex_slat(
430
  self,
431
  cond: dict,
432
  flow_model,
433
  shape_slat: SparseTensor,
434
  sampler_params: dict = {},
435
+ max_retries: int = 3,
436
+ retry_noise_scale: float = 0.5,
437
  ) -> SparseTensor:
438
  """
439
  Sample structured latent with the given conditioning.
440
+
441
+ Includes early detection of degenerate texture flow trajectories
442
+ (black texture bug). If detected after 2 probe steps, retries with
443
+ scaled noise.
444
+
445
  Args:
446
  cond (dict): The conditioning information.
447
  shape_slat (SparseTensor): The structured latent for shape
448
  sampler_params (dict): Additional parameters for the sampler.
449
+ max_retries (int): Max retries on degenerate detection.
450
+ retry_noise_scale (float): Noise scale factor on retry.
451
  """
452
  # Sample structured latent
453
  std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
 
455
  shape_slat = (shape_slat - mean) / std
456
 
457
  in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels
458
+ n_noise_feats = in_channels - shape_slat.feats.shape[1]
459
  sampler_params = {**self.tex_slat_sampler_params, **sampler_params}
460
  if self.low_vram:
461
  flow_model.to(self.device)
462
+
463
+ noise_feats = torch.randn(shape_slat.coords.shape[0], n_noise_feats).to(self.device)
464
+ for attempt in range(max_retries + 1):
465
+ noise = shape_slat.replace(feats=noise_feats)
466
+
467
+ if self._detect_degenerate_tex(flow_model, noise, shape_slat, cond, sampler_params):
468
+ if attempt < max_retries:
469
+ noise_feats = torch.randn(shape_slat.coords.shape[0], n_noise_feats).to(self.device) * retry_noise_scale
470
+ print(f"\033[93m[tex] Degenerate detected, retry {attempt+1}/{max_retries} with noise_scale={retry_noise_scale}\033[0m")
471
+ continue
472
+ else:
473
+ print(f"\033[93m[tex] Degenerate detected but retries exhausted, proceeding anyway\033[0m")
474
+ break
475
+
476
  slat = self.tex_slat_sampler.sample(
477
  flow_model,
478
  noise,
 
488
  std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device)
489
  mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device)
490
  slat = slat * std + mean
491
+
492
  return slat
493
 
494
  def decode_tex_slat(