Spaces:
Running on Zero
Running on Zero
Best-of-N noise selection: probe 3 candidates, pick lowest std
Browse files
trellis2/pipelines/trellis2_image_to_3d.py
CHANGED
|
@@ -392,22 +392,20 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 392 |
self.models['shape_slat_decoder'].low_vram = False
|
| 393 |
return ret
|
| 394 |
|
| 395 |
-
def
|
| 396 |
self,
|
| 397 |
flow_model,
|
| 398 |
noise: SparseTensor,
|
| 399 |
concat_cond: SparseTensor,
|
| 400 |
cond: dict,
|
| 401 |
sampler_params: dict,
|
| 402 |
-
|
| 403 |
-
) -> bool:
|
| 404 |
"""
|
| 405 |
-
Run 2 probe steps
|
| 406 |
|
|
|
|
| 407 |
Uses single-image inference (bypassing multi-image patching) so
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
Returns True if the trajectory is degenerate (will produce black texture).
|
| 411 |
"""
|
| 412 |
steps = sampler_params.get('steps', 12)
|
| 413 |
rescale_t = sampler_params.get('rescale_t', 3.0)
|
|
@@ -415,7 +413,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 415 |
t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
|
| 416 |
t_seq = t_seq.tolist()
|
| 417 |
|
| 418 |
-
# Use single-image cond for probe (
|
| 419 |
probe_cond = {}
|
| 420 |
for k, v in cond.items():
|
| 421 |
probe_cond[k] = v[:1] if torch.is_tensor(v) and v.ndim >= 1 else v
|
|
@@ -442,8 +440,7 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 442 |
if patched:
|
| 443 |
sampler._inference_model = patched_fn
|
| 444 |
|
| 445 |
-
|
| 446 |
-
return x0_std > threshold
|
| 447 |
|
| 448 |
def sample_tex_slat(
|
| 449 |
self,
|
|
@@ -451,22 +448,19 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 451 |
flow_model,
|
| 452 |
shape_slat: SparseTensor,
|
| 453 |
sampler_params: dict = {},
|
| 454 |
-
|
| 455 |
-
retry_noise_scale: float = 0.5,
|
| 456 |
) -> SparseTensor:
|
| 457 |
"""
|
| 458 |
Sample structured latent with the given conditioning.
|
| 459 |
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
scaled noise.
|
| 463 |
|
| 464 |
Args:
|
| 465 |
cond (dict): The conditioning information.
|
| 466 |
shape_slat (SparseTensor): The structured latent for shape
|
| 467 |
sampler_params (dict): Additional parameters for the sampler.
|
| 468 |
-
|
| 469 |
-
retry_noise_scale (float): Noise scale factor on retry.
|
| 470 |
"""
|
| 471 |
# Sample structured latent
|
| 472 |
std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
|
|
@@ -479,18 +473,19 @@ class Trellis2ImageTo3DPipeline(Pipeline):
|
|
| 479 |
if self.low_vram:
|
| 480 |
flow_model.to(self.device)
|
| 481 |
|
| 482 |
-
|
| 483 |
-
|
|
|
|
|
|
|
|
|
|
| 484 |
noise = shape_slat.replace(feats=noise_feats)
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
print(f"\033[93m[tex] Degenerate detected but retries exhausted, proceeding anyway\033[0m")
|
| 493 |
-
break
|
| 494 |
|
| 495 |
slat = self.tex_slat_sampler.sample(
|
| 496 |
flow_model,
|
|
|
|
| 392 |
self.models['shape_slat_decoder'].low_vram = False
|
| 393 |
return ret
|
| 394 |
|
| 395 |
+
def _probe_tex_noise(
|
| 396 |
self,
|
| 397 |
flow_model,
|
| 398 |
noise: SparseTensor,
|
| 399 |
concat_cond: SparseTensor,
|
| 400 |
cond: dict,
|
| 401 |
sampler_params: dict,
|
| 402 |
+
) -> float:
|
|
|
|
| 403 |
"""
|
| 404 |
+
Run 2 probe steps and return pred_x0 std as a quality score.
|
| 405 |
|
| 406 |
+
Lower std = better trajectory (further from degenerate attractor).
|
| 407 |
Uses single-image inference (bypassing multi-image patching) so
|
| 408 |
+
scores are comparable regardless of fusion mode.
|
|
|
|
|
|
|
| 409 |
"""
|
| 410 |
steps = sampler_params.get('steps', 12)
|
| 411 |
rescale_t = sampler_params.get('rescale_t', 3.0)
|
|
|
|
| 413 |
t_seq = rescale_t * t_seq / (1 + (rescale_t - 1) * t_seq)
|
| 414 |
t_seq = t_seq.tolist()
|
| 415 |
|
| 416 |
+
# Use single-image cond for probe (scores calibrated on single-image)
|
| 417 |
probe_cond = {}
|
| 418 |
for k, v in cond.items():
|
| 419 |
probe_cond[k] = v[:1] if torch.is_tensor(v) and v.ndim >= 1 else v
|
|
|
|
| 440 |
if patched:
|
| 441 |
sampler._inference_model = patched_fn
|
| 442 |
|
| 443 |
+
return out.pred_x_0.feats.std().item()
|
|
|
|
| 444 |
|
| 445 |
def sample_tex_slat(
|
| 446 |
self,
|
|
|
|
| 448 |
flow_model,
|
| 449 |
shape_slat: SparseTensor,
|
| 450 |
sampler_params: dict = {},
|
| 451 |
+
num_candidates: int = 3,
|
|
|
|
| 452 |
) -> SparseTensor:
|
| 453 |
"""
|
| 454 |
Sample structured latent with the given conditioning.
|
| 455 |
|
| 456 |
+
Probes multiple noise candidates and selects the one with the
|
| 457 |
+
lowest pred_x0 std (furthest from degenerate attractor).
|
|
|
|
| 458 |
|
| 459 |
Args:
|
| 460 |
cond (dict): The conditioning information.
|
| 461 |
shape_slat (SparseTensor): The structured latent for shape
|
| 462 |
sampler_params (dict): Additional parameters for the sampler.
|
| 463 |
+
num_candidates (int): Number of noise candidates to probe.
|
|
|
|
| 464 |
"""
|
| 465 |
# Sample structured latent
|
| 466 |
std = torch.tensor(self.shape_slat_normalization['std'])[None].to(shape_slat.device)
|
|
|
|
| 473 |
if self.low_vram:
|
| 474 |
flow_model.to(self.device)
|
| 475 |
|
| 476 |
+
# Probe multiple noise candidates, pick the best
|
| 477 |
+
best_noise_feats = None
|
| 478 |
+
best_score = float('inf')
|
| 479 |
+
for i in range(num_candidates):
|
| 480 |
+
noise_feats = torch.randn(shape_slat.coords.shape[0], n_noise_feats).to(self.device)
|
| 481 |
noise = shape_slat.replace(feats=noise_feats)
|
| 482 |
+
score = self._probe_tex_noise(flow_model, noise, shape_slat, cond, sampler_params)
|
| 483 |
+
print(f"\033[93m[tex] Candidate {i+1}/{num_candidates}: pred_x0_std={score:.4f}\033[0m")
|
| 484 |
+
if score < best_score:
|
| 485 |
+
best_score = score
|
| 486 |
+
best_noise_feats = noise_feats
|
| 487 |
+
noise = shape_slat.replace(feats=best_noise_feats)
|
| 488 |
+
print(f"\033[93m[tex] Selected candidate with pred_x0_std={best_score:.4f}\033[0m")
|
|
|
|
|
|
|
| 489 |
|
| 490 |
slat = self.tex_slat_sampler.sample(
|
| 491 |
flow_model,
|