xsponenta Claude Opus 4.7 commited on
Commit
b6bc99a
·
1 Parent(s): 2df06c6

TTA: Hungarian-matched multi-seed inference with strict 3-pass agreement

Browse files

Three priority-sample seeds (2718, 31415, 42), Hungarian-match segments
across passes via flip-invariant endpoint distance, drop anchor segments
that don't appear in BOTH supporting passes (min_passes_for_keep=2).
Surviving segments are orientation-aligned and averaged across the 2-3
passes that saw them.

The previous TTA (857514e, reverted) picked one pass's output - that's
why it failed. This version aggregates and filters.

Local 100-sample A/B vs orphan_refine baseline:
baseline (orphan_refine): mean=0.3856 q5=0.0620 q50=0.3946
TTA simple concat-and-merge: mean=0.3756 q5=0.0842 (-0.010, REJECTED)
TTA Hungarian min_passes=1: mean=0.3853 q5=0.0848 (-0.000, neutral)
TTA Hungarian min_passes=2: mean=0.3888 q5=0.0907 q50=0.4032 (+0.003 mean,
+0.029 q5)
The strict variant wins primarily on hard scenes - q5 jumps 47%
(0.062 -> 0.091), q25 +0.006, q50 +0.009. Strict filter throws out
hallucinations on difficult scenes where the model is uncertain across
sampling seeds.

Cost: 3x inference time. Feature-flagged via USE_TTA in script.py for
easy revert if HF Space hits time limits.

Also includes (not used in production yet):
- edge_fill.py: attempted edge filling from 2D mask evidence
(rejected in A/B testing; was net -0.004 even with capped fills)

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

Files changed (4) hide show
  1. edge_fill.py +161 -0
  2. local_eval.py +75 -12
  3. script.py +47 -9
  4. tta.py +245 -0
edge_fill.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Edge filling from 2D gestalt evidence.
2
+
3
+ For each pair of predicted vertices (V_i, V_j) that is NOT currently an edge:
4
+ 1. Project both endpoints into every COLMAP view.
5
+ 2. Sample N points along the projected 2D segment.
6
+ 3. Count points falling on gestalt edge-class pixels (using the dilated mask).
7
+ 4. A view "supports" the candidate edge if support_frac >= min_pixel_frac.
8
+ 5. If at least min_views_support views agree, ADD the edge.
9
+
10
+ This is the inverse of edge_2d_filter.filter_edges_by_2d_support:
11
+ the filter regressed q5 because dropping edges based on a binary mask check
12
+ hurt recall. Adding edges is asymmetric: false positives waste a precision
13
+ slot, but false negatives are catastrophic (real edges missing). With strong
14
+ thresholds we should mostly add genuinely-missed edges.
15
+
16
+ Conservative defaults: 40% min support, 2+ views agreeing, max edge length
17
+ 5m (most building edges are short). Pairs are scored by closest-first and
18
+ capped at max_pair_check to keep cost bounded.
19
+
20
+ Topology change only — adds new edges, never moves or removes vertices.
21
+ Falls back to (pv, pe) on any error.
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import numpy as np
27
+ import cv2
28
+
29
+
30
+ def fill_missing_edges_from_2d(
31
+ pv,
32
+ pe,
33
+ sample,
34
+ min_views_support: int = 3,
35
+ min_pixel_frac: float = 0.60,
36
+ max_edge_length_meters: float = 5.0,
37
+ max_pair_check: int = 100,
38
+ max_fills_abs: int = 6,
39
+ max_fills_rel: float = 0.25,
40
+ dilate_px: int = 4,
41
+ sample_steps: int = 20,
42
+ ):
43
+ """Add edges between existing vertex pairs that have strong 2D edge support.
44
+
45
+ Args:
46
+ pv: (N, 3) vertices in world coordinates.
47
+ pe: existing edge list. New edges are appended; existing edges
48
+ are never removed.
49
+ sample: raw dataset entry.
50
+ min_views_support: minimum views with strong mask support to add edge.
51
+ min_pixel_frac: fraction of sampled segment pixels that must lie on
52
+ a gestalt edge-class pixel for a view to count as supporting.
53
+ max_edge_length_meters: skip pairs whose 3D distance exceeds this.
54
+ max_pair_check: hard cap on candidate pairs evaluated per sample
55
+ (sorted by ascending 3D distance, so closest pairs go first).
56
+ dilate_px: edge-mask dilation radius (same as edge_2d_filter).
57
+ sample_steps: number of points sampled along each 2D segment.
58
+
59
+ Returns:
60
+ (pv, pe_extended). Vertex array unchanged; edges only grow.
61
+ Falls back to inputs on any error.
62
+ """
63
+ try:
64
+ from hoho2025.example_solutions import convert_entry_to_human_readable
65
+ from mvs_utils import collect_views, project_world_to_image
66
+ from edge_2d_filter import _build_edge_masks
67
+
68
+ pv_arr = np.asarray(pv, dtype=np.float64)
69
+ if pv_arr.ndim != 2 or pv_arr.shape[0] < 2:
70
+ return pv, pe
71
+
72
+ good = convert_entry_to_human_readable(sample)
73
+ colmap_rec = good.get("colmap") or good.get("colmap_binary")
74
+ if colmap_rec is None:
75
+ return pv, pe
76
+
77
+ views = collect_views(colmap_rec, good["image_ids"])
78
+ if len(views) < min_views_support:
79
+ return pv, pe
80
+
81
+ view_masks = _build_edge_masks(good, views, dilate_px=dilate_px)
82
+ if not view_masks:
83
+ return pv, pe
84
+
85
+ existing = set()
86
+ for a, b in pe:
87
+ a, b = int(a), int(b)
88
+ lo, hi = (a, b) if a < b else (b, a)
89
+ existing.add((lo, hi))
90
+
91
+ N = pv_arr.shape[0]
92
+
93
+ # Build candidate list: all pairs not already in `existing` and within
94
+ # max_edge_length. Sort by ascending 3D distance and cap.
95
+ candidates = []
96
+ for i in range(N):
97
+ for j in range(i + 1, N):
98
+ if (i, j) in existing:
99
+ continue
100
+ d = float(np.linalg.norm(pv_arr[i] - pv_arr[j]))
101
+ if d > max_edge_length_meters or d < 1e-3:
102
+ continue
103
+ candidates.append((d, i, j))
104
+ candidates.sort()
105
+ if len(candidates) > max_pair_check:
106
+ candidates = candidates[:max_pair_check]
107
+
108
+ if not candidates:
109
+ return pv, pe
110
+
111
+ # Cache view list once (avoid dict-iteration in hot loop)
112
+ view_items = [(img_id, views[img_id]["P"], *view_masks[img_id])
113
+ for img_id in view_masks]
114
+
115
+ # Score each candidate by (total support across views, # views supporting).
116
+ # Then accept only those meeting the threshold, capped by ranking.
117
+ scored = []
118
+ for _, i, j in candidates:
119
+ endpoints = np.stack([pv_arr[i], pv_arr[j]])
120
+
121
+ supporting = 0
122
+ total_support = 0.0
123
+ for _img_id, P, mask_bool, H, W in view_items:
124
+ uv, z = project_world_to_image(P, endpoints)
125
+ if z[0] <= 0 or z[1] <= 0:
126
+ continue
127
+ if not (
128
+ 0 <= uv[0, 0] < W and 0 <= uv[0, 1] < H
129
+ and 0 <= uv[1, 0] < W and 0 <= uv[1, 1] < H
130
+ ):
131
+ continue
132
+
133
+ t = np.linspace(0.0, 1.0, sample_steps)
134
+ xs = uv[0, 0] + t * (uv[1, 0] - uv[0, 0])
135
+ ys = uv[0, 1] + t * (uv[1, 1] - uv[0, 1])
136
+ xs_i = np.clip(xs.astype(np.int32), 0, W - 1)
137
+ ys_i = np.clip(ys.astype(np.int32), 0, H - 1)
138
+
139
+ frac = int(mask_bool[ys_i, xs_i].sum()) / float(sample_steps)
140
+ if frac >= min_pixel_frac:
141
+ supporting += 1
142
+ total_support += frac
143
+
144
+ if supporting >= min_views_support:
145
+ # Score: prioritize multi-view agreement, break ties on total support
146
+ scored.append((supporting, total_support, i, j))
147
+
148
+ if not scored:
149
+ return pv, pe
150
+
151
+ # Cap additions: min(absolute cap, rel-fraction of existing edges)
152
+ max_to_add = max(1, min(max_fills_abs,
153
+ int(max_fills_rel * max(len(pe), 1))))
154
+ scored.sort(reverse=True)
155
+ added = [(i, j) for _, _, i, j in scored[:max_to_add]]
156
+
157
+ new_pe = list(pe) + added
158
+ return pv, new_pe
159
+
160
+ except Exception:
161
+ return pv, pe
local_eval.py CHANGED
@@ -58,13 +58,30 @@ def parse_args():
58
  help="vertex refine: min views with 2D match")
59
  p.add_argument("--refine-max-move", type=float, default=0.5,
60
  help="vertex refine: max 3D displacement in meters")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  return p.parse_args()
62
 
63
 
64
  def predict_one(sample, model, device, cfg, rng,
65
  use_tracks=True, use_2d_filter=True, orphan_only=False,
66
  strict_no_support=False, vertex_refine=False,
67
- refine_kwargs=None):
 
68
  """Run the full inference pipeline on one sample. Returns (pv, pe, diag)."""
69
  diag = {"colmap": -1, "fused": 0, "track_v": 0, "track_e": 0,
70
  "pred_v": 0, "pred_e": 0, "2dfilt_in": 0, "2dfilt_out": 0,
@@ -79,17 +96,40 @@ def predict_one(sample, model, device, cfg, rng,
79
  except Exception:
80
  pass
81
 
82
- fused = script.fuse_and_sample(sample, cfg, rng)
83
- if fused is None:
84
- diag["status"] = "fuse_failed"
85
- return *script.empty_solution(), diag
86
- diag["fused"] = len(fused["xyz_norm"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
- try:
89
- pred_v, pred_e = script.predict_sample(fused, model, device)
90
- except Exception as e:
91
- diag["status"] = f"predict_failed:{type(e).__name__}"
92
- return *script.empty_solution(), diag
93
 
94
  if use_tracks:
95
  try:
@@ -144,6 +184,17 @@ def predict_one(sample, model, device, cfg, rng,
144
  diag["status"] = f"2dfilt_failed:{type(e).__name__}"
145
  diag["2dfilt_out"] = len(pred_e) if hasattr(pred_e, '__len__') else 0
146
 
 
 
 
 
 
 
 
 
 
 
 
147
  diag["pred_v"] = len(pred_v) if hasattr(pred_v, '__len__') else 0
148
  diag["pred_e"] = len(pred_e) if hasattr(pred_e, '__len__') else 0
149
  return pred_v, pred_e, diag
@@ -212,6 +263,14 @@ def main():
212
  "min_views": args.refine_min_views,
213
  "max_move_meters": args.refine_max_move,
214
  }
 
 
 
 
 
 
 
 
215
  pred_v, pred_e, diag = predict_one(
216
  sample, model, device, cfg, rng,
217
  use_tracks=not args.no_tracks,
@@ -219,7 +278,11 @@ def main():
219
  orphan_only=args.orphan_only,
220
  strict_no_support=args.strict_no_support,
221
  vertex_refine=args.vertex_refine,
222
- refine_kwargs=refine_kwargs)
 
 
 
 
223
  if torch.backends.mps.is_available():
224
  torch.mps.empty_cache()
225
 
 
58
  help="vertex refine: min views with 2D match")
59
  p.add_argument("--refine-max-move", type=float, default=0.5,
60
  help="vertex refine: max 3D displacement in meters")
61
+ p.add_argument("--tta", action="store_true",
62
+ help="enable multi-seed TTA (3 priority-sample seeds, concat segments)")
63
+ p.add_argument("--tta-hungarian", action="store_true",
64
+ help="use Hungarian-matched averaging TTA (rejects unmatched segments)")
65
+ p.add_argument("--tta-min-passes", type=int, default=1,
66
+ help="hungarian TTA: drop anchor segments without this many supporting passes")
67
+ p.add_argument("--tta-seeds", type=str, default="2718,31415,42",
68
+ help="comma-separated priority-sample seeds for TTA")
69
+ p.add_argument("--edge-fill", action="store_true",
70
+ help="enable edge filling from 2D mask evidence")
71
+ p.add_argument("--fill-min-views", type=int, default=2,
72
+ help="edge fill: min views supporting a new edge")
73
+ p.add_argument("--fill-min-frac", type=float, default=0.40,
74
+ help="edge fill: min support fraction along projected segment")
75
+ p.add_argument("--fill-max-length", type=float, default=5.0,
76
+ help="edge fill: max edge length in meters")
77
  return p.parse_args()
78
 
79
 
80
  def predict_one(sample, model, device, cfg, rng,
81
  use_tracks=True, use_2d_filter=True, orphan_only=False,
82
  strict_no_support=False, vertex_refine=False,
83
+ refine_kwargs=None, edge_fill=False, fill_kwargs=None,
84
+ tta=False, tta_seeds=None):
85
  """Run the full inference pipeline on one sample. Returns (pv, pe, diag)."""
86
  diag = {"colmap": -1, "fused": 0, "track_v": 0, "track_e": 0,
87
  "pred_v": 0, "pred_e": 0, "2dfilt_in": 0, "2dfilt_out": 0,
 
96
  except Exception:
97
  pass
98
 
99
+ if tta:
100
+ try:
101
+ seeds = tta_seeds or (2718, 31415, 42)
102
+ tta_method = (
103
+ "predict_sample_tta_hungarian"
104
+ if getattr(predict_one, "_tta_hungarian", False)
105
+ else "predict_sample_tta"
106
+ )
107
+ import tta as _tta_mod
108
+ fn = getattr(_tta_mod, tta_method)
109
+ if tta_method == "predict_sample_tta_hungarian":
110
+ pred_v, pred_e = fn(
111
+ sample, cfg, model, device, seeds=tuple(seeds),
112
+ min_passes_for_keep=getattr(predict_one, "_tta_min_passes", 1),
113
+ )
114
+ else:
115
+ pred_v, pred_e = fn(
116
+ sample, cfg, model, device, seeds=tuple(seeds))
117
+ diag["fused"] = -1 # not single-seed
118
+ except Exception as e:
119
+ diag["status"] = f"tta_failed:{type(e).__name__}"
120
+ return *script.empty_solution(), diag
121
+ else:
122
+ fused = script.fuse_and_sample(sample, cfg, rng)
123
+ if fused is None:
124
+ diag["status"] = "fuse_failed"
125
+ return *script.empty_solution(), diag
126
+ diag["fused"] = len(fused["xyz_norm"])
127
 
128
+ try:
129
+ pred_v, pred_e = script.predict_sample(fused, model, device)
130
+ except Exception as e:
131
+ diag["status"] = f"predict_failed:{type(e).__name__}"
132
+ return *script.empty_solution(), diag
133
 
134
  if use_tracks:
135
  try:
 
184
  diag["status"] = f"2dfilt_failed:{type(e).__name__}"
185
  diag["2dfilt_out"] = len(pred_e) if hasattr(pred_e, '__len__') else 0
186
 
187
+ if edge_fill:
188
+ e_before = len(pred_e) if hasattr(pred_e, '__len__') else 0
189
+ try:
190
+ from edge_fill import fill_missing_edges_from_2d
191
+ pred_v, pred_e = fill_missing_edges_from_2d(
192
+ pred_v, pred_e, sample,
193
+ **(fill_kwargs or {}))
194
+ diag["filled"] = (len(pred_e) if hasattr(pred_e, '__len__') else 0) - e_before
195
+ except Exception as e:
196
+ diag["status"] = f"fill_failed:{type(e).__name__}"
197
+
198
  diag["pred_v"] = len(pred_v) if hasattr(pred_v, '__len__') else 0
199
  diag["pred_e"] = len(pred_e) if hasattr(pred_e, '__len__') else 0
200
  return pred_v, pred_e, diag
 
263
  "min_views": args.refine_min_views,
264
  "max_move_meters": args.refine_max_move,
265
  }
266
+ fill_kwargs = {
267
+ "min_views_support": args.fill_min_views,
268
+ "min_pixel_frac": args.fill_min_frac,
269
+ "max_edge_length_meters": args.fill_max_length,
270
+ }
271
+ tta_seeds_tuple = tuple(int(s) for s in args.tta_seeds.split(","))
272
+ predict_one._tta_hungarian = args.tta_hungarian
273
+ predict_one._tta_min_passes = args.tta_min_passes
274
  pred_v, pred_e, diag = predict_one(
275
  sample, model, device, cfg, rng,
276
  use_tracks=not args.no_tracks,
 
278
  orphan_only=args.orphan_only,
279
  strict_no_support=args.strict_no_support,
280
  vertex_refine=args.vertex_refine,
281
+ refine_kwargs=refine_kwargs,
282
+ edge_fill=args.edge_fill,
283
+ fill_kwargs=fill_kwargs,
284
+ tta=args.tta,
285
+ tta_seeds=tta_seeds_tuple)
286
  if torch.backends.mps.is_available():
287
  torch.mps.empty_cache()
288
 
script.py CHANGED
@@ -59,6 +59,14 @@ CONF_THRESH = 0.4
59
  MERGE_THRESH = 0.4
60
  SNAP_RADIUS = 0.5
61
 
 
 
 
 
 
 
 
 
62
 
63
  def fuse_and_sample(sample, cfg, rng):
64
  """Run point fusion + priority sampling on a raw dataset sample.
@@ -398,21 +406,51 @@ if __name__ == "__main__":
398
  except Exception:
399
  pass
400
 
401
- # Fuse + sample
402
- fused = fuse_and_sample(sample, cfg, rng)
403
- n_fused_pts = len(fused["xyz_norm"]) if fused is not None else 0
404
  track_v_count, track_e_count = 0, 0
405
  pred_status = "ok"
 
406
 
407
- if fused is None:
408
- pred_v, pred_e = empty_solution()
409
- pred_status = "fuse_failed"
410
- else:
411
  try:
412
- pred_v, pred_e = predict_sample(fused, model, device)
 
 
 
 
 
413
  if torch.cuda.is_available():
414
  torch.cuda.empty_cache()
415
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
  # Apply handcrafted triangulation tracking to catch missing corners/edges
417
  try:
418
  from triangulation import predict_wireframe_tracks
 
59
  MERGE_THRESH = 0.4
60
  SNAP_RADIUS = 0.5
61
 
62
+ # Test-time augmentation: 3 priority-sample seeds + Hungarian matching with
63
+ # strict 3-pass agreement (min_passes_for_keep=2). Local 100-sample A/B:
64
+ # q5 0.062 -> 0.091 (+47%), mean +0.003. Costs 3x inference time but
65
+ # strict filter dramatically improves precision on hard scenes.
66
+ USE_TTA = True
67
+ TTA_SEEDS = (2718, 31415, 42)
68
+ TTA_MIN_PASSES = 2
69
+
70
 
71
  def fuse_and_sample(sample, cfg, rng):
72
  """Run point fusion + priority sampling on a raw dataset sample.
 
406
  except Exception:
407
  pass
408
 
 
 
 
409
  track_v_count, track_e_count = 0, 0
410
  pred_status = "ok"
411
+ n_fused_pts = 0
412
 
413
+ if USE_TTA:
414
+ # Multi-seed TTA: fuse + predict 3 times, Hungarian-match segments
415
+ # across passes, drop those without min_passes agreement.
 
416
  try:
417
+ from tta import predict_sample_tta_hungarian
418
+ pred_v, pred_e = predict_sample_tta_hungarian(
419
+ sample, cfg, model, device,
420
+ seeds=TTA_SEEDS,
421
+ min_passes_for_keep=TTA_MIN_PASSES,
422
+ )
423
  if torch.cuda.is_available():
424
  torch.cuda.empty_cache()
425
+ except Exception as e:
426
+ import traceback
427
+ print(f" TTA failed for {order_id}:\n{traceback.format_exc()}")
428
+ pred_v, pred_e = empty_solution()
429
+ pred_status = "tta_failed"
430
+ if torch.cuda.is_available():
431
+ torch.cuda.empty_cache()
432
+ else:
433
+ # Single-seed inference (legacy path, kept for easy revert).
434
+ fused = fuse_and_sample(sample, cfg, rng)
435
+ n_fused_pts = len(fused["xyz_norm"]) if fused is not None else 0
436
+ if fused is None:
437
+ pred_v, pred_e = empty_solution()
438
+ pred_status = "fuse_failed"
439
+ else:
440
+ try:
441
+ pred_v, pred_e = predict_sample(fused, model, device)
442
+ if torch.cuda.is_available():
443
+ torch.cuda.empty_cache()
444
+ except Exception as e:
445
+ import traceback
446
+ print(f" Predict failed for {order_id}:\n{traceback.format_exc()}")
447
+ pred_v, pred_e = empty_solution()
448
+ pred_status = "predict_failed"
449
+ if torch.cuda.is_available():
450
+ torch.cuda.empty_cache()
451
+
452
+ if pred_status == "ok":
453
+ try:
454
  # Apply handcrafted triangulation tracking to catch missing corners/edges
455
  try:
456
  from triangulation import predict_wireframe_tracks
tta.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test-time augmentation via multi-seed priority sampling.
2
+
3
+ Runs the model N times with different priority-sample seeds, concatenates the
4
+ world-space segment predictions from each pass, and lets the standard
5
+ merge_vertices_iterative do the deduplication. Because the iterative merge
6
+ takes union-find clusters and uses each cluster's centroid as the merged
7
+ position, this is effectively Hungarian-averaging for matched segments —
8
+ without needing to solve a real assignment problem.
9
+
10
+ The previous TTA attempt (commit 857514e, reverted) failed because it picked
11
+ ONE pass's output. This implementation aggregates ALL passes' segments and
12
+ lets the established merge logic combine them.
13
+
14
+ Why it should work:
15
+ - Stochastic variation comes from priority-sample seed (the model itself is
16
+ deterministic). Different seeds give different points → slightly different
17
+ model predictions for the same scene.
18
+ - Matched segments (true edges) appear in all passes near each other → they
19
+ cluster in the merge and get averaged toward consensus.
20
+ - Spurious segments (hallucinations) appear in only 1 pass → they survive
21
+ individually but are typically not high-confidence enough to win.
22
+ - The iterative merge thresholds 0.15→0.6 m are appropriate for the typical
23
+ inter-pass jitter of correctly-predicted segments.
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import numpy as np
29
+ import torch
30
+
31
+ import script
32
+ from s23dr_2026_example.segment_postprocess import merge_vertices_iterative
33
+ from s23dr_2026_example.varifold import segments_to_vertices_edges
34
+ from s23dr_2026_example.postprocess_v2 import snap_to_point_cloud, snap_horizontal
35
+
36
+
37
+ def _model_to_world_segments(sample_dict, model, device):
38
+ """Run the model on a single fused sample, return (N, 2, 3) world segments.
39
+
40
+ Returns None if no segments pass the confidence threshold.
41
+ """
42
+ tokens, masks = script.build_tokens_single(sample_dict, model, device)
43
+ scale = float(sample_dict["scale"])
44
+ center = sample_dict["center"]
45
+ with torch.no_grad(), torch.autocast(
46
+ device_type='cuda', dtype=torch.float16,
47
+ enabled=(device.type == 'cuda'),
48
+ ):
49
+ out = model.forward_tokens(tokens, masks)
50
+ segs = out["segments"][0].float().cpu()
51
+ conf = (
52
+ torch.sigmoid(out["conf"][0].float()).cpu().numpy()
53
+ if "conf" in out else None
54
+ )
55
+ if conf is not None:
56
+ segs = segs[conf > script.CONF_THRESH]
57
+ if len(segs) < 1:
58
+ return None
59
+ return segs.numpy() * scale + center # (N, 2, 3)
60
+
61
+
62
+ def _match_segments(anchor_segs, other_segs, max_endpoint_dist=0.4):
63
+ """Hungarian-match segments between two passes (flip-invariant distance).
64
+
65
+ Returns list of (i_anchor, j_other, was_flipped) for accepted matches.
66
+ Matches with cost > 2*max_endpoint_dist are rejected.
67
+ """
68
+ if len(anchor_segs) == 0 or len(other_segs) == 0:
69
+ return []
70
+ from scipy.optimize import linear_sum_assignment
71
+
72
+ N, M = len(anchor_segs), len(other_segs)
73
+ # vectorized cost computation
74
+ a0 = anchor_segs[:, 0][:, None, :] # (N, 1, 3)
75
+ a1 = anchor_segs[:, 1][:, None, :]
76
+ b0 = other_segs[None, :, 0] # (1, M, 3)
77
+ b1 = other_segs[None, :, 1]
78
+ d_same = (np.linalg.norm(a0 - b0, axis=-1) +
79
+ np.linalg.norm(a1 - b1, axis=-1))
80
+ d_flip = (np.linalg.norm(a0 - b1, axis=-1) +
81
+ np.linalg.norm(a1 - b0, axis=-1))
82
+ cost = np.minimum(d_same, d_flip)
83
+ flipped = d_flip < d_same
84
+
85
+ row, col = linear_sum_assignment(cost)
86
+ threshold = 2.0 * max_endpoint_dist
87
+ return [(int(i), int(j), bool(flipped[i, j]))
88
+ for i, j in zip(row, col) if cost[i, j] <= threshold]
89
+
90
+
91
+ def predict_sample_tta_hungarian(sample, cfg, model, device,
92
+ seeds=(2718, 31415, 42),
93
+ match_dist: float = 0.4,
94
+ min_passes_for_keep: int = 1,
95
+ snap_target_classes=(0, 1, 2)):
96
+ """Hungarian-averaged TTA. Aggregates segments via flip-invariant matching.
97
+
98
+ Pass 0 is the anchor. Pass 1+ segments are matched to anchor via Hungarian
99
+ on endpoint distance. Each anchor segment gets averaged with its matches
100
+ (orientation-aligned). Anchor segments with < min_passes_for_keep matches
101
+ are kept only if min_passes_for_keep == 0; otherwise dropped.
102
+
103
+ Args:
104
+ min_passes_for_keep: 0 = keep all anchor segments, 1 = require matching
105
+ in at least 1 other pass.
106
+ """
107
+ sample_dicts = []
108
+ all_segs = []
109
+ for seed in seeds:
110
+ rng = np.random.RandomState(int(seed))
111
+ sd = script.fuse_and_sample(sample, cfg, rng)
112
+ if sd is None:
113
+ continue
114
+ sw = _model_to_world_segments(sd, model, device)
115
+ if sw is None or len(sw) == 0:
116
+ continue
117
+ sample_dicts.append(sd)
118
+ all_segs.append(sw)
119
+
120
+ if not all_segs:
121
+ return script.empty_solution()
122
+
123
+ if len(all_segs) == 1:
124
+ # No TTA gain possible; just run the normal post-process.
125
+ return _post_segments_to_wireframe(
126
+ all_segs[0], sample_dicts[0], snap_target_classes)
127
+
128
+ anchor = all_segs[0]
129
+ matches_per_anchor = [[] for _ in range(len(anchor))] # list of (other_seg, flipped)
130
+ for p_idx in range(1, len(all_segs)):
131
+ for i_a, j_o, flipped in _match_segments(anchor, all_segs[p_idx],
132
+ max_endpoint_dist=match_dist):
133
+ matches_per_anchor[i_a].append((all_segs[p_idx][j_o], flipped))
134
+
135
+ averaged = []
136
+ for i, matches in enumerate(matches_per_anchor):
137
+ if len(matches) < min_passes_for_keep:
138
+ continue
139
+ seg = anchor[i]
140
+ if not matches:
141
+ averaged.append(seg)
142
+ continue
143
+ # Align orientations to anchor and average
144
+ aligned = [seg]
145
+ for other_seg, flipped in matches:
146
+ aligned.append(other_seg[::-1] if flipped else other_seg)
147
+ averaged.append(np.mean(aligned, axis=0))
148
+
149
+ if not averaged:
150
+ # Defensive: fall back to concat-and-merge if matching dropped everything
151
+ return _post_segments_to_wireframe(
152
+ np.concatenate(all_segs, axis=0), sample_dicts[0],
153
+ snap_target_classes)
154
+
155
+ return _post_segments_to_wireframe(
156
+ np.asarray(averaged), sample_dicts[0], snap_target_classes)
157
+
158
+
159
+ def _post_segments_to_wireframe(segments, sd0, snap_target_classes):
160
+ """Standard post-process: segments -> vertices/edges -> merge -> snap."""
161
+ pv, pe = segments_to_vertices_edges(torch.tensor(segments))
162
+ pv, pe = pv.numpy(), np.array(pe, dtype=np.int32)
163
+ pv, pe = merge_vertices_iterative(pv, pe)
164
+
165
+ xyz_norm = sd0["xyz_norm"]
166
+ mask = sd0["mask"]
167
+ cid = sd0["class_id"]
168
+ xyz_world = xyz_norm[mask] * float(sd0["scale"]) + sd0["center"]
169
+ cid_valid = cid[mask]
170
+ pv = snap_to_point_cloud(
171
+ pv, xyz_world, cid_valid,
172
+ snap_radius=script.SNAP_RADIUS,
173
+ target_classes=list(snap_target_classes),
174
+ )
175
+ pv = snap_horizontal(pv, pe)
176
+
177
+ if len(pv) < 2 or len(pe) < 1:
178
+ return script.empty_solution()
179
+ return pv, [(int(a), int(b)) for a, b in pe]
180
+
181
+
182
+ def predict_sample_tta(sample, cfg, model, device,
183
+ seeds=(2718, 31415, 42),
184
+ snap_target_classes=(0, 1, 2)):
185
+ """Multi-seed TTA prediction. Returns (vertices, edges) in world space.
186
+
187
+ Args:
188
+ sample: raw dataset entry.
189
+ cfg: FuserConfig.
190
+ model: loaded model.
191
+ device: torch device.
192
+ seeds: tuple of priority-sample seeds. Length = # TTA passes.
193
+ snap_target_classes: classes for snap_to_point_cloud (default
194
+ [apex, eave_end_point, flashing_end_point] = [0, 1, 2]).
195
+ """
196
+ # 1) Run fuse + model for each seed, collect world-space segments.
197
+ sample_dicts = []
198
+ all_segs = []
199
+ for seed in seeds:
200
+ rng = np.random.RandomState(int(seed))
201
+ sd = script.fuse_and_sample(sample, cfg, rng)
202
+ if sd is None:
203
+ continue
204
+ sw = _model_to_world_segments(sd, model, device)
205
+ if sw is None or len(sw) == 0:
206
+ continue
207
+ sample_dicts.append(sd)
208
+ all_segs.append(sw)
209
+
210
+ if not all_segs:
211
+ return script.empty_solution()
212
+
213
+ # 2) Concatenate segments across passes. The downstream merge will cluster
214
+ # near-duplicate vertices (matched across passes) and take centroids,
215
+ # yielding the Hungarian-average behavior.
216
+ combined = np.concatenate(all_segs, axis=0)
217
+
218
+ # 3) Standard post-process: segments -> vertices/edges, iterative merge.
219
+ pv, pe = segments_to_vertices_edges(torch.tensor(combined))
220
+ pv, pe = pv.numpy(), np.array(pe, dtype=np.int32)
221
+ pv, pe = merge_vertices_iterative(pv, pe)
222
+
223
+ # 4) Snap to point cloud. Use the FIRST sample_dict's context (the
224
+ # fused points are roughly similar across seeds so this is a reasonable
225
+ # proxy; using merged xyz would require re-fusing).
226
+ sd0 = sample_dicts[0]
227
+ xyz_norm = sd0["xyz_norm"]
228
+ mask = sd0["mask"]
229
+ cid = sd0["class_id"]
230
+ scale0 = float(sd0["scale"])
231
+ center0 = sd0["center"]
232
+ xyz_world = xyz_norm[mask] * scale0 + center0
233
+ cid_valid = cid[mask]
234
+ pv = snap_to_point_cloud(
235
+ pv, xyz_world, cid_valid,
236
+ snap_radius=script.SNAP_RADIUS,
237
+ target_classes=list(snap_target_classes),
238
+ )
239
+ pv = snap_horizontal(pv, pe)
240
+
241
+ if len(pv) < 2 or len(pe) < 1:
242
+ return script.empty_solution()
243
+
244
+ edges = [(int(a), int(b)) for a, b in pe]
245
+ return pv, edges