xsponenta commited on
Commit
8e33a89
·
1 Parent(s): 3118106

TTA: 3-pass test-time augmentation with priority-sample seed variation

Browse files

Replace the single learned-pipeline pass + losing classical sklearn
ensemble (commit 3118106 regressed -0.018) with K=3 learned passes
that share the same model checkpoint but use different priority-sample
RNG seeds. Each pass produces a different mix of depth-unprojected
points, so the model sees a slightly different input — model averaging
on a fixed checkpoint is the standard inference-time technique for
1-3% F1 gain.

Aggregation: union the 3 (vertices, edges) outputs via ensemble_merge
with a tight 0.3m merge radius (same vertex predicted by multiple
passes consolidates; pass-specific noise stays separate).

The classical predict_wireframe_tracks step (DLT triangulation) is
kept on top of the unioned learned output via hybrid_merge — this part
worked in the 0.4584 baseline and is unchanged.

Removed: the sklearn-classical ensemble (USE_TRACK_ENSEMBLE-disabled
predict_wireframe_sklearn). It added noise from 2D semantic vertex
detection that the learned pipeline had not produced.

Runtime: ~3x learned forward passes per sample. Each was ~1s, so we
expect ~3s/sample x 100 samples = 300s extra. Well within the 2h budget.

Files changed (1) hide show
  1. script.py +42 -58
script.py CHANGED
@@ -434,7 +434,14 @@ if __name__ == "__main__":
434
 
435
  # Point fusion config
436
  cfg = FuserConfig()
437
- rng = np.random.RandomState(2718)
 
 
 
 
 
 
 
438
 
439
  # Process all samples
440
  solution = []
@@ -447,70 +454,47 @@ if __name__ == "__main__":
447
  for sample in tqdm(dataset[subset_name], desc=subset_name):
448
  order_id = sample["order_id"]
449
 
450
- # Fuse + sample
451
- fused = fuse_and_sample(sample, cfg, rng)
452
- if fused is None:
453
- pred_v, pred_e = empty_solution()
454
- else:
455
- try:
456
- pred_v, pred_e = predict_sample(fused, model, device)
457
- if torch.cuda.is_available():
458
- torch.cuda.empty_cache()
459
-
460
- # Apply handcrafted triangulation tracking to catch missing corners/edges
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  try:
462
  from triangulation import predict_wireframe_tracks
463
- # Use min_views=3 for highly precise, conservative geometric tracks
464
  track_v, track_e = predict_wireframe_tracks(sample, min_views=3)
465
-
466
  pred_v, pred_e = hybrid_merge(pred_v, pred_e, track_v, track_e, merge_radius=0.8)
467
  except Exception as track_e_err:
468
  print(f" Track ensemble failed for {order_id}: {track_e_err}")
469
 
470
- # Pipeline ensemble: also run the classical sklearn pipeline
471
- # but disable its heaviest features. predict_wireframe_tracks
472
- # is already called above by hybrid_merge — calling it again
473
- # via USE_TRACK_ENSEMBLE/USE_TRACKS_AS_VERTICES doubles the
474
- # slowest step and causes timeout. Winner candidates and line
475
- # cloud are also expensive. We keep just the core classical
476
- # signal: 2D semantic vertex detection + RANSAC depth fit +
477
- # 3D unprojection + sklearn edge classifier.
478
- try:
479
- import sklearn_submission as _skl
480
- _saved_flags = (
481
- _skl.USE_TRACK_ENSEMBLE,
482
- _skl.USE_TRACKS_AS_VERTICES,
483
- _skl.USE_WINNER_CANDIDATES,
484
- _skl.USE_LINE_EDGES,
485
- )
486
- _skl.USE_TRACK_ENSEMBLE = False
487
- _skl.USE_TRACKS_AS_VERTICES = False
488
- _skl.USE_WINNER_CANDIDATES = False
489
- _skl.USE_LINE_EDGES = False
490
- try:
491
- skl_v, skl_e = _skl.predict_wireframe_sklearn(sample)
492
- finally:
493
- (
494
- _skl.USE_TRACK_ENSEMBLE,
495
- _skl.USE_TRACKS_AS_VERTICES,
496
- _skl.USE_WINNER_CANDIDATES,
497
- _skl.USE_LINE_EDGES,
498
- ) = _saved_flags
499
- if isinstance(pred_v, np.ndarray) and len(pred_v) >= 1 and \
500
- skl_v is not None and len(skl_v) >= 1:
501
- pred_v, pred_e = ensemble_merge(
502
- pred_v, pred_e, skl_v, skl_e,
503
- vertex_merge_radius=0.4,
504
- )
505
- except Exception as ens_err:
506
- print(f" Ensemble merge failed for {order_id}: {ens_err}")
507
-
508
- except Exception as e:
509
- import traceback
510
- print(f" Predict failed for {order_id}:\n{traceback.format_exc()}")
511
- pred_v, pred_e = empty_solution()
512
- if torch.cuda.is_available():
513
- torch.cuda.empty_cache()
514
 
515
  solution.append({
516
  "order_id": order_id,
 
434
 
435
  # Point fusion config
436
  cfg = FuserConfig()
437
+
438
+ # Test-time augmentation: how many learned-pipeline passes per sample.
439
+ # Each pass uses a different priority-sample seed so the input point
440
+ # cloud (especially the depth-unprojected portion) varies. We then
441
+ # union the segment predictions across passes via ensemble_merge.
442
+ TTA_PASSES = 3
443
+ TTA_BASE_SEED = 2718
444
+ TTA_MERGE_RADIUS = 0.3 # tight: same vertex predicted by multiple passes
445
 
446
  # Process all samples
447
  solution = []
 
454
  for sample in tqdm(dataset[subset_name], desc=subset_name):
455
  order_id = sample["order_id"]
456
 
457
+ try:
458
+ # ---- TTA: run the learned pipeline K times, union outputs
459
+ tta_outputs = []
460
+ for k in range(TTA_PASSES):
461
+ rng_k = np.random.RandomState(TTA_BASE_SEED + k * 1000)
462
+ fused_k = fuse_and_sample(sample, cfg, rng_k)
463
+ if fused_k is None:
464
+ continue
465
+ try:
466
+ pv_k, pe_k = predict_sample(fused_k, model, device)
467
+ if isinstance(pv_k, np.ndarray) and len(pv_k) >= 2 and len(pe_k) >= 1:
468
+ tta_outputs.append((pv_k, pe_k))
469
+ except Exception as tta_e:
470
+ print(f" TTA pass {k} failed for {order_id}: {tta_e}")
471
+ if torch.cuda.is_available():
472
+ torch.cuda.empty_cache()
473
+
474
+ if not tta_outputs:
475
+ pred_v, pred_e = empty_solution()
476
+ else:
477
+ pred_v, pred_e = tta_outputs[0]
478
+ for pv_k, pe_k in tta_outputs[1:]:
479
+ pred_v, pred_e = ensemble_merge(
480
+ pred_v, pred_e, pv_k, pe_k,
481
+ vertex_merge_radius=TTA_MERGE_RADIUS,
482
+ )
483
+
484
+ # ---- Classical track ensemble (precise DLT triangulation)
485
  try:
486
  from triangulation import predict_wireframe_tracks
 
487
  track_v, track_e = predict_wireframe_tracks(sample, min_views=3)
 
488
  pred_v, pred_e = hybrid_merge(pred_v, pred_e, track_v, track_e, merge_radius=0.8)
489
  except Exception as track_e_err:
490
  print(f" Track ensemble failed for {order_id}: {track_e_err}")
491
 
492
+ except Exception as e:
493
+ import traceback
494
+ print(f" Predict failed for {order_id}:\n{traceback.format_exc()}")
495
+ pred_v, pred_e = empty_solution()
496
+ if torch.cuda.is_available():
497
+ torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498
 
499
  solution.append({
500
  "order_id": order_id,