opsiclear-admin commited on
Commit
666c821
·
verified ·
1 Parent(s): e3cd099

Fix degenerate detection: bypass multi-image patching during probe

Browse files
trellis2/pipelines/trellis2_image_to_3d.py CHANGED
@@ -404,6 +404,9 @@ class Trellis2ImageTo3DPipeline(Pipeline):
404
  """
405
  Run 2 probe steps to detect degenerate texture flow trajectory.
406
 
 
 
 
407
  Returns True if the trajectory is degenerate (will produce black texture).
408
  """
409
  steps = sampler_params.get('steps', 12)
@@ -412,16 +415,32 @@ class Trellis2ImageTo3DPipeline(Pipeline):
412
  t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
413
  t_seq = t_seq.tolist()
414
 
415
- sample = noise
416
- for step_idx in range(2):
417
- t, t_prev = t_seq[step_idx], t_seq[step_idx + 1]
418
- out = self.tex_slat_sampler.sample_once(
419
- flow_model, sample, t, t_prev,
420
- concat_cond=concat_cond,
421
- **{k: v for k, v in cond.items()},
422
- **{k: v for k, v in sampler_params.items() if k not in ('steps', 'rescale_t')},
423
- )
424
- sample = out.pred_x_prev
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
 
426
  x0_std = out.pred_x_0.feats.std().item()
427
  return x0_std > threshold
 
404
  """
405
  Run 2 probe steps to detect degenerate texture flow trajectory.
406
 
407
+ Uses single-image inference (bypassing multi-image patching) so
408
+ the threshold stays calibrated regardless of fusion mode.
409
+
410
  Returns True if the trajectory is degenerate (will produce black texture).
411
  """
412
  steps = sampler_params.get('steps', 12)
 
415
  t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
416
  t_seq = t_seq.tolist()
417
 
418
+ # Use single-image cond for probe (threshold was calibrated on single-image)
419
+ probe_cond = {}
420
+ for k, v in cond.items():
421
+ probe_cond[k] = v[:1] if torch.is_tensor(v) and v.ndim >= 1 else v
422
+
423
+ # Bypass multi-image patching during probe
424
+ sampler = self.tex_slat_sampler
425
+ patched = hasattr(sampler, '_old_inference_model')
426
+ if patched:
427
+ patched_fn = sampler._inference_model
428
+ sampler._inference_model = sampler._old_inference_model
429
+
430
+ try:
431
+ sample = noise
432
+ for step_idx in range(2):
433
+ t, t_prev = t_seq[step_idx], t_seq[step_idx + 1]
434
+ out = sampler.sample_once(
435
+ flow_model, sample, t, t_prev,
436
+ concat_cond=concat_cond,
437
+ **{k: v for k, v in probe_cond.items()},
438
+ **{k: v for k, v in sampler_params.items() if k not in ('steps', 'rescale_t')},
439
+ )
440
+ sample = out.pred_x_prev
441
+ finally:
442
+ if patched:
443
+ sampler._inference_model = patched_fn
444
 
445
  x0_std = out.pred_x_0.feats.std().item()
446
  return x0_std > threshold