opsiclear-admin commited on
Commit
9c08519
·
verified ·
1 Parent(s): 40e9737

Best-of-N noise selection: probe 3 candidates, pick lowest std

Browse files
trellis2/pipelines/trellis2_image_to_3d.py CHANGED
@@ -392,22 +392,20 @@ class Trellis2ImageTo3DPipeline(Pipeline):
392
  self.models['shape_slat_decoder'].low_vram = False
393
  return ret
394
 
395
- def _detect_degenerate_tex(
396
  self,
397
  flow_model,
398
  noise: SparseTensor,
399
  concat_cond: SparseTensor,
400
  cond: dict,
401
  sampler_params: dict,
402
- threshold: float = 1.037,
403
- ) -> bool:
404
  """
405
- Run 2 probe steps to detect degenerate texture flow trajectory.
406
 
 
407
  Uses single-image inference (bypassing multi-image patching) so
408
- the threshold stays calibrated regardless of fusion mode.
409
-
410
- Returns True if the trajectory is degenerate (will produce black texture).
411
  """
412
  steps = sampler_params.get('steps', 12)
413
  rescale_t = sampler_params.get('rescale_t', 3.0)
@@ -415,7 +413,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
415
  t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
416
  t_seq = t_seq.tolist()
417
 
418
- # Use single-image cond for probe (threshold was calibrated on single-image)
419
  probe_cond = {}
420
  for k, v in cond.items():
421
  probe_cond[k] = v[:1] if torch.is_tensor(v) and v.ndim >= 1 else v
@@ -442,8 +440,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
442
  if patched:
443
  sampler._inference_model = patched_fn
444
 
445
- x0_std = out.pred_x_0.feats.std().item()
446
- return x0_std > threshold
447
 
448
  def sample_tex_slat(
449
  self,
@@ -451,22 +448,19 @@ class Trellis2ImageTo3DPipeline(Pipeline):
451
  flow_model,
452
  shape_slat: SparseTensor,
453
  sampler_params: dict = {},
454
- max_retries: int = 3,
455
- retry_noise_scale: float = 0.5,
456
  ) -> SparseTensor:
457
  """
458
  Sample structured latent with the given conditioning.
459
 
460
- Includes early detection of degenerate texture flow trajectories
461
- (black texture bug). If detected after 2 probe steps, retries with
462
- scaled noise.
463
 
464
  Args:
465
  cond (dict): The conditioning information.
466
  shape_slat (SparseTensor): The structured latent for shape
467
  sampler_params (dict): Additional parameters for the sampler.
468
- max_retries (int): Max retries on degenerate detection.
469
- retry_noise_scale (float): Noise scale factor on retry.
470
  """
471
  # Sample structured latent
472
  std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
@@ -479,18 +473,19 @@ class Trellis2ImageTo3DPipeline(Pipeline):
479
  if self.low_vram:
480
  flow_model.to(self.device)
481
 
482
- noise_feats = torch.randn(shape_slat.coords.shape[0], n_noise_feats).to(self.device)
483
- for attempt in range(max_retries + 1):
 
 
 
484
  noise = shape_slat.replace(feats=noise_feats)
485
-
486
- if self._detect_degenerate_tex(flow_model, noise, shape_slat, cond, sampler_params):
487
- if attempt < max_retries:
488
- noise_feats = torch.randn(shape_slat.coords.shape[0], n_noise_feats).to(self.device) * retry_noise_scale
489
- print(f"\033[93m[tex] Degenerate detected, retry {attempt+1}/{max_retries} with noise_scale={retry_noise_scale}\033[0m")
490
- continue
491
- else:
492
- print(f"\033[93m[tex] Degenerate detected but retries exhausted, proceeding anyway\033[0m")
493
- break
494
 
495
  slat = self.tex_slat_sampler.sample(
496
  flow_model,
 
392
  self.models['shape_slat_decoder'].low_vram = False
393
  return ret
394
 
395
+ def _probe_tex_noise(
396
  self,
397
  flow_model,
398
  noise: SparseTensor,
399
  concat_cond: SparseTensor,
400
  cond: dict,
401
  sampler_params: dict,
402
+ ) -> float:
 
403
  """
404
+ Run 2 probe steps and return pred_x0 std as a quality score.
405
 
406
+ Lower std = better trajectory (further from degenerate attractor).
407
  Uses single-image inference (bypassing multi-image patching) so
408
+ scores are comparable regardless of fusion mode.
 
 
409
  """
410
  steps = sampler_params.get('steps', 12)
411
  rescale_t = sampler_params.get('rescale_t', 3.0)
 
413
  t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
414
  t_seq = t_seq.tolist()
415
 
416
+ # Use single-image cond for probe (scores calibrated on single-image)
417
  probe_cond = {}
418
  for k, v in cond.items():
419
  probe_cond[k] = v[:1] if torch.is_tensor(v) and v.ndim >= 1 else v
 
440
  if patched:
441
  sampler._inference_model = patched_fn
442
 
443
+ return out.pred_x_0.feats.std().item()
 
444
 
445
  def sample_tex_slat(
446
  self,
 
448
  flow_model,
449
  shape_slat: SparseTensor,
450
  sampler_params: dict = {},
451
+ num_candidates: int = 3,
 
452
  ) -> SparseTensor:
453
  """
454
  Sample structured latent with the given conditioning.
455
 
456
+ Probes multiple noise candidates and selects the one with the
457
+ lowest pred_x0 std (furthest from degenerate attractor).
 
458
 
459
  Args:
460
  cond (dict): The conditioning information.
461
  shape_slat (SparseTensor): The structured latent for shape
462
  sampler_params (dict): Additional parameters for the sampler.
463
+ num_candidates (int): Number of noise candidates to probe.
 
464
  """
465
  # Sample structured latent
466
  std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
 
473
  if self.low_vram:
474
  flow_model.to(self.device)
475
 
476
+ # Probe multiple noise candidates, pick the best
477
+ best_noise_feats = None
478
+ best_score = float('inf')
479
+ for i in range(num_candidates):
480
+ noise_feats = torch.randn(shape_slat.coords.shape[0], n_noise_feats).to(self.device)
481
  noise = shape_slat.replace(feats=noise_feats)
482
+ score = self._probe_tex_noise(flow_model, noise, shape_slat, cond, sampler_params)
483
+ print(f"\033[93m[tex] Candidate {i+1}/{num_candidates}: pred_x0_std={score:.4f}\033[0m")
484
+ if score < best_score:
485
+ best_score = score
486
+ best_noise_feats = noise_feats
487
+ noise = shape_slat.replace(feats=best_noise_feats)
488
+ print(f"\033[93m[tex] Selected candidate with pred_x0_std={best_score:.4f}\033[0m")
 
 
489
 
490
  slat = self.tex_slat_sampler.sample(
491
  flow_model,