Spaces:
Running on Zero
Running on Zero
Fix black texture bug: detect degenerate sampler trajectories and retry
Browse files
trellis2/pipelines/trellis2_image_to_3d.py
CHANGED
|
@@ -392,20 +392,62 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 392 |
self.models['shape_slat_decoder'].low_vram = False
|
| 393 |
return ret
|
| 394 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
def sample_tex_slat(
|
| 396 |
self,
|
| 397 |
cond: dict,
|
| 398 |
flow_model,
|
| 399 |
shape_slat: SparseTensor,
|
| 400 |
sampler_params: dict = {},
|
|
|
|
|
|
|
| 401 |
) -> SparseTensor:
|
| 402 |
"""
|
| 403 |
Sample structured latent with the given conditioning.
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
Args:
|
| 406 |
cond (dict): The conditioning information.
|
| 407 |
shape_slat (SparseTensor): The structured latent for shape
|
| 408 |
sampler_params (dict): Additional parameters for the sampler.
|
|
|
|
|
|
|
| 409 |
"""
|
| 410 |
# Sample structured latent
|
| 411 |
std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
|
|
@@ -413,10 +455,24 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 413 |
shape_slat = (shape_slat - mean) / std
|
| 414 |
|
| 415 |
in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels
|
| 416 |
-
|
| 417 |
sampler_params = {**self.tex_slat_sampler_params, **sampler_params}
|
| 418 |
if self.low_vram:
|
| 419 |
flow_model.to(self.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
slat = self.tex_slat_sampler.sample(
|
| 421 |
flow_model,
|
| 422 |
noise,
|
|
@@ -432,7 +488,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 432 |
std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device)
|
| 433 |
mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device)
|
| 434 |
slat = slat * std + mean
|
| 435 |
-
|
| 436 |
return slat
|
| 437 |
|
| 438 |
def decode_tex_slat(
|
|
|
|
| 392 |
self.models['shape_slat_decoder'].low_vram = False
|
| 393 |
return ret
|
| 394 |
|
| 395 |
+
def _detect_degenerate_tex(
|
| 396 |
+
self,
|
| 397 |
+
flow_model,
|
| 398 |
+
noise: SparseTensor,
|
| 399 |
+
concat_cond: SparseTensor,
|
| 400 |
+
cond: dict,
|
| 401 |
+
sampler_params: dict,
|
| 402 |
+
threshold: float = 1.037,
|
| 403 |
+
) -> bool:
|
| 404 |
+
"""
|
| 405 |
+
Run 2 probe steps to detect degenerate texture flow trajectory.
|
| 406 |
+
|
| 407 |
+
Returns True if the trajectory is degenerate (will produce black texture).
|
| 408 |
+
"""
|
| 409 |
+
steps = sampler_params.get('steps', 12)
|
| 410 |
+
rescale_t = sampler_params.get('rescale_t', 3.0)
|
| 411 |
+
t_seq = np.linspace(1, 0, steps + 1)
|
| 412 |
+
t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
|
| 413 |
+
t_seq = t_seq.tolist()
|
| 414 |
+
|
| 415 |
+
sample = noise
|
| 416 |
+
for step_idx in range(2):
|
| 417 |
+
t, t_prev = t_seq[step_idx], t_seq[step_idx + 1]
|
| 418 |
+
out = self.tex_slat_sampler.sample_once(
|
| 419 |
+
flow_model, sample, t, t_prev,
|
| 420 |
+
concat_cond=concat_cond,
|
| 421 |
+
**{k: v for k, v in cond.items()},
|
| 422 |
+
**{k: v for k, v in sampler_params.items() if k not in ('steps', 'rescale_t')},
|
| 423 |
+
)
|
| 424 |
+
sample = out.pred_x_prev
|
| 425 |
+
|
| 426 |
+
x0_std = out.pred_x_0.feats.std().item()
|
| 427 |
+
return x0_std > threshold
|
| 428 |
+
|
| 429 |
def sample_tex_slat(
|
| 430 |
self,
|
| 431 |
cond: dict,
|
| 432 |
flow_model,
|
| 433 |
shape_slat: SparseTensor,
|
| 434 |
sampler_params: dict = {},
|
| 435 |
+
max_retries: int = 3,
|
| 436 |
+
retry_noise_scale: float = 0.5,
|
| 437 |
) -> SparseTensor:
|
| 438 |
"""
|
| 439 |
Sample structured latent with the given conditioning.
|
| 440 |
+
|
| 441 |
+
Includes early detection of degenerate texture flow trajectories
|
| 442 |
+
(black texture bug). If detected after 2 probe steps, retries with
|
| 443 |
+
scaled noise.
|
| 444 |
+
|
| 445 |
Args:
|
| 446 |
cond (dict): The conditioning information.
|
| 447 |
shape_slat (SparseTensor): The structured latent for shape
|
| 448 |
sampler_params (dict): Additional parameters for the sampler.
|
| 449 |
+
max_retries (int): Max retries on degenerate detection.
|
| 450 |
+
retry_noise_scale (float): Noise scale factor on retry.
|
| 451 |
"""
|
| 452 |
# Sample structured latent
|
| 453 |
std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
|
|
|
|
| 455 |
shape_slat = (shape_slat - mean) / std
|
| 456 |
|
| 457 |
in_channels = flow_model.in_channels if isinstance(flow_model, nn.Module) else flow_model[0].in_channels
|
| 458 |
+
n_noise_feats = in_channels - shape_slat.feats.shape[1]
|
| 459 |
sampler_params = {**self.tex_slat_sampler_params, **sampler_params}
|
| 460 |
if self.low_vram:
|
| 461 |
flow_model.to(self.device)
|
| 462 |
+
|
| 463 |
+
noise_feats = torch.randn(shape_slat.coords.shape[0], n_noise_feats).to(self.device)
|
| 464 |
+
for attempt in range(max_retries + 1):
|
| 465 |
+
noise = shape_slat.replace(feats=noise_feats)
|
| 466 |
+
|
| 467 |
+
if self._detect_degenerate_tex(flow_model, noise, shape_slat, cond, sampler_params):
|
| 468 |
+
if attempt < max_retries:
|
| 469 |
+
noise_feats = torch.randn(shape_slat.coords.shape[0], n_noise_feats).to(self.device) * retry_noise_scale
|
| 470 |
+
print(f"\033[93m[tex] Degenerate detected, retry {attempt+1}/{max_retries} with noise_scale={retry_noise_scale}\033[0m")
|
| 471 |
+
continue
|
| 472 |
+
else:
|
| 473 |
+
print(f"\033[93m[tex] Degenerate detected but retries exhausted, proceeding anyway\033[0m")
|
| 474 |
+
break
|
| 475 |
+
|
| 476 |
slat = self.tex_slat_sampler.sample(
|
| 477 |
flow_model,
|
| 478 |
noise,
|
|
|
|
| 488 |
std = torch.tensor(self.tex_slat_normalization['std'])[None].to(slat.device)
|
| 489 |
mean = torch.tensor(self.tex_slat_normalization['mean'])[None].to(slat.device)
|
| 490 |
slat = slat * std + mean
|
| 491 |
+
|
| 492 |
return slat
|
| 493 |
|
| 494 |
def decode_tex_slat(
|