TTA: 3-pass test-time augmentation with priority-sample seed variation
Browse filesReplace the single learned-pipeline pass + losing classical sklearn
ensemble (commit 3118106 regressed -0.018) with K=3 learned passes
that share the same model checkpoint but use different priority-sample
RNG seeds. Each pass produces a different mix of depth-unprojected
points, so the model sees a slightly different input — model averaging
on a fixed checkpoint is the standard inference-time technique for
1-3% F1 gain.
Aggregation: union the 3 (vertices, edges) outputs via ensemble_merge
with a tight 0.3m merge radius (same vertex predicted by multiple
passes consolidates; pass-specific noise stays separate).
The classical predict_wireframe_tracks step (DLT triangulation) is
kept on top of the unioned learned output via hybrid_merge — this part
worked in the 0.4584 baseline and is unchanged.
Removed: the sklearn-classical ensemble (USE_TRACK_ENSEMBLE-disabled
predict_wireframe_sklearn). It added noise from 2D semantic vertex
detection that the learned pipeline had not produced.
Runtime: ~3x learned forward passes per sample. Each was ~1s, so we
expect ~3s/sample x 100 samples = 300s extra. Well within the 2h budget.
|
@@ -434,7 +434,14 @@ if __name__ == "__main__":
|
|
| 434 |
|
| 435 |
# Point fusion config
|
| 436 |
cfg = FuserConfig()
|
| 437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
# Process all samples
|
| 440 |
solution = []
|
|
@@ -447,70 +454,47 @@ if __name__ == "__main__":
|
|
| 447 |
for sample in tqdm(dataset[subset_name], desc=subset_name):
|
| 448 |
order_id = sample["order_id"]
|
| 449 |
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 461 |
try:
|
| 462 |
from triangulation import predict_wireframe_tracks
|
| 463 |
-
# Use min_views=3 for highly precise, conservative geometric tracks
|
| 464 |
track_v, track_e = predict_wireframe_tracks(sample, min_views=3)
|
| 465 |
-
|
| 466 |
pred_v, pred_e = hybrid_merge(pred_v, pred_e, track_v, track_e, merge_radius=0.8)
|
| 467 |
except Exception as track_e_err:
|
| 468 |
print(f" Track ensemble failed for {order_id}: {track_e_err}")
|
| 469 |
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
# signal: 2D semantic vertex detection + RANSAC depth fit +
|
| 477 |
-
# 3D unprojection + sklearn edge classifier.
|
| 478 |
-
try:
|
| 479 |
-
import sklearn_submission as _skl
|
| 480 |
-
_saved_flags = (
|
| 481 |
-
_skl.USE_TRACK_ENSEMBLE,
|
| 482 |
-
_skl.USE_TRACKS_AS_VERTICES,
|
| 483 |
-
_skl.USE_WINNER_CANDIDATES,
|
| 484 |
-
_skl.USE_LINE_EDGES,
|
| 485 |
-
)
|
| 486 |
-
_skl.USE_TRACK_ENSEMBLE = False
|
| 487 |
-
_skl.USE_TRACKS_AS_VERTICES = False
|
| 488 |
-
_skl.USE_WINNER_CANDIDATES = False
|
| 489 |
-
_skl.USE_LINE_EDGES = False
|
| 490 |
-
try:
|
| 491 |
-
skl_v, skl_e = _skl.predict_wireframe_sklearn(sample)
|
| 492 |
-
finally:
|
| 493 |
-
(
|
| 494 |
-
_skl.USE_TRACK_ENSEMBLE,
|
| 495 |
-
_skl.USE_TRACKS_AS_VERTICES,
|
| 496 |
-
_skl.USE_WINNER_CANDIDATES,
|
| 497 |
-
_skl.USE_LINE_EDGES,
|
| 498 |
-
) = _saved_flags
|
| 499 |
-
if isinstance(pred_v, np.ndarray) and len(pred_v) >= 1 and \
|
| 500 |
-
skl_v is not None and len(skl_v) >= 1:
|
| 501 |
-
pred_v, pred_e = ensemble_merge(
|
| 502 |
-
pred_v, pred_e, skl_v, skl_e,
|
| 503 |
-
vertex_merge_radius=0.4,
|
| 504 |
-
)
|
| 505 |
-
except Exception as ens_err:
|
| 506 |
-
print(f" Ensemble merge failed for {order_id}: {ens_err}")
|
| 507 |
-
|
| 508 |
-
except Exception as e:
|
| 509 |
-
import traceback
|
| 510 |
-
print(f" Predict failed for {order_id}:\n{traceback.format_exc()}")
|
| 511 |
-
pred_v, pred_e = empty_solution()
|
| 512 |
-
if torch.cuda.is_available():
|
| 513 |
-
torch.cuda.empty_cache()
|
| 514 |
|
| 515 |
solution.append({
|
| 516 |
"order_id": order_id,
|
|
|
|
| 434 |
|
| 435 |
# Point fusion config
|
| 436 |
cfg = FuserConfig()
|
| 437 |
+
|
| 438 |
+
# Test-time augmentation: how many learned-pipeline passes per sample.
|
| 439 |
+
# Each pass uses a different priority-sample seed so the input point
|
| 440 |
+
# cloud (especially the depth-unprojected portion) varies. We then
|
| 441 |
+
# union the segment predictions across passes via ensemble_merge.
|
| 442 |
+
TTA_PASSES = 3
|
| 443 |
+
TTA_BASE_SEED = 2718
|
| 444 |
+
TTA_MERGE_RADIUS = 0.3 # tight: same vertex predicted by multiple passes
|
| 445 |
|
| 446 |
# Process all samples
|
| 447 |
solution = []
|
|
|
|
| 454 |
for sample in tqdm(dataset[subset_name], desc=subset_name):
|
| 455 |
order_id = sample["order_id"]
|
| 456 |
|
| 457 |
+
try:
|
| 458 |
+
# ---- TTA: run the learned pipeline K times, union outputs
|
| 459 |
+
tta_outputs = []
|
| 460 |
+
for k in range(TTA_PASSES):
|
| 461 |
+
rng_k = np.random.RandomState(TTA_BASE_SEED + k * 1000)
|
| 462 |
+
fused_k = fuse_and_sample(sample, cfg, rng_k)
|
| 463 |
+
if fused_k is None:
|
| 464 |
+
continue
|
| 465 |
+
try:
|
| 466 |
+
pv_k, pe_k = predict_sample(fused_k, model, device)
|
| 467 |
+
if isinstance(pv_k, np.ndarray) and len(pv_k) >= 2 and len(pe_k) >= 1:
|
| 468 |
+
tta_outputs.append((pv_k, pe_k))
|
| 469 |
+
except Exception as tta_e:
|
| 470 |
+
print(f" TTA pass {k} failed for {order_id}: {tta_e}")
|
| 471 |
+
if torch.cuda.is_available():
|
| 472 |
+
torch.cuda.empty_cache()
|
| 473 |
+
|
| 474 |
+
if not tta_outputs:
|
| 475 |
+
pred_v, pred_e = empty_solution()
|
| 476 |
+
else:
|
| 477 |
+
pred_v, pred_e = tta_outputs[0]
|
| 478 |
+
for pv_k, pe_k in tta_outputs[1:]:
|
| 479 |
+
pred_v, pred_e = ensemble_merge(
|
| 480 |
+
pred_v, pred_e, pv_k, pe_k,
|
| 481 |
+
vertex_merge_radius=TTA_MERGE_RADIUS,
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
# ---- Classical track ensemble (precise DLT triangulation)
|
| 485 |
try:
|
| 486 |
from triangulation import predict_wireframe_tracks
|
|
|
|
| 487 |
track_v, track_e = predict_wireframe_tracks(sample, min_views=3)
|
|
|
|
| 488 |
pred_v, pred_e = hybrid_merge(pred_v, pred_e, track_v, track_e, merge_radius=0.8)
|
| 489 |
except Exception as track_e_err:
|
| 490 |
print(f" Track ensemble failed for {order_id}: {track_e_err}")
|
| 491 |
|
| 492 |
+
except Exception as e:
|
| 493 |
+
import traceback
|
| 494 |
+
print(f" Predict failed for {order_id}:\n{traceback.format_exc()}")
|
| 495 |
+
pred_v, pred_e = empty_solution()
|
| 496 |
+
if torch.cuda.is_available():
|
| 497 |
+
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
|
| 499 |
solution.append({
|
| 500 |
"order_id": order_id,
|