TTA: Hungarian-matched multi-seed inference with strict 3-pass agreement
Browse filesThree priority-sample seeds (2718, 31415, 42), Hungarian-match segments
across passes via flip-invariant endpoint distance, drop anchor segments
that don't appear in BOTH supporting passes (min_passes_for_keep=2).
Surviving segments are orientation-aligned and averaged across the 2-3
passes that saw them.
The previous TTA (857514e, reverted) picked one pass's output - that's
why it failed. This version aggregates and filters.
Local 100-sample A/B vs orphan_refine baseline:
baseline (orphan_refine): mean=0.3856 q5=0.0620 q50=0.3946
TTA simple concat-and-merge: mean=0.3756 q5=0.0842 (-0.010, REJECTED)
TTA Hungarian min_passes=1: mean=0.3853 q5=0.0848 (-0.000, neutral)
TTA Hungarian min_passes=2: mean=0.3888 q5=0.0907 q50=0.4032 (+0.003 mean,
+0.029 q5)
The strict variant wins primarily on hard scenes - q5 jumps 47%
(0.062 -> 0.091), q25 +0.006, q50 +0.009. Strict filter throws out
hallucinations on difficult scenes where the model is uncertain across
sampling seeds.
Cost: 3x inference time. Feature-flagged via USE_TTA in script.py for
easy revert if HF Space hits time limits.
Also includes (not used in production yet):
- edge_fill.py: attempted edge filling from 2D mask evidence
(rejected in A/B testing; was net -0.004 even with capped fills)
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
- edge_fill.py +161 -0
- local_eval.py +75 -12
- script.py +47 -9
- tta.py +245 -0
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Edge filling from 2D gestalt evidence.
|
| 2 |
+
|
| 3 |
+
For each pair of predicted vertices (V_i, V_j) that is NOT currently an edge:
|
| 4 |
+
1. Project both endpoints into every COLMAP view.
|
| 5 |
+
2. Sample N points along the projected 2D segment.
|
| 6 |
+
3. Count points falling on gestalt edge-class pixels (using the dilated mask).
|
| 7 |
+
4. A view "supports" the candidate edge if support_frac >= min_pixel_frac.
|
| 8 |
+
5. If at least min_views_support views agree, ADD the edge.
|
| 9 |
+
|
| 10 |
+
This is the inverse of edge_2d_filter.filter_edges_by_2d_support:
|
| 11 |
+
the filter regressed q5 because dropping edges based on a binary mask check
|
| 12 |
+
hurt recall. Adding edges is asymmetric: false positives waste a precision
|
| 13 |
+
slot, but false negatives are catastrophic (real edges missing). With strong
|
| 14 |
+
thresholds we should mostly add genuinely-missed edges.
|
| 15 |
+
|
| 16 |
+
Conservative defaults: 40% min support, 2+ views agreeing, max edge length
|
| 17 |
+
5m (most building edges are short). Pairs are scored by closest-first and
|
| 18 |
+
capped at max_pair_check to keep cost bounded.
|
| 19 |
+
|
| 20 |
+
Topology change only — adds new edges, never moves or removes vertices.
|
| 21 |
+
Falls back to (pv, pe) on any error.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import cv2
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def fill_missing_edges_from_2d(
|
| 31 |
+
pv,
|
| 32 |
+
pe,
|
| 33 |
+
sample,
|
| 34 |
+
min_views_support: int = 3,
|
| 35 |
+
min_pixel_frac: float = 0.60,
|
| 36 |
+
max_edge_length_meters: float = 5.0,
|
| 37 |
+
max_pair_check: int = 100,
|
| 38 |
+
max_fills_abs: int = 6,
|
| 39 |
+
max_fills_rel: float = 0.25,
|
| 40 |
+
dilate_px: int = 4,
|
| 41 |
+
sample_steps: int = 20,
|
| 42 |
+
):
|
| 43 |
+
"""Add edges between existing vertex pairs that have strong 2D edge support.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
pv: (N, 3) vertices in world coordinates.
|
| 47 |
+
pe: existing edge list. New edges are appended; existing edges
|
| 48 |
+
are never removed.
|
| 49 |
+
sample: raw dataset entry.
|
| 50 |
+
min_views_support: minimum views with strong mask support to add edge.
|
| 51 |
+
min_pixel_frac: fraction of sampled segment pixels that must lie on
|
| 52 |
+
a gestalt edge-class pixel for a view to count as supporting.
|
| 53 |
+
max_edge_length_meters: skip pairs whose 3D distance exceeds this.
|
| 54 |
+
max_pair_check: hard cap on candidate pairs evaluated per sample
|
| 55 |
+
(sorted by ascending 3D distance, so closest pairs go first).
|
| 56 |
+
dilate_px: edge-mask dilation radius (same as edge_2d_filter).
|
| 57 |
+
sample_steps: number of points sampled along each 2D segment.
|
| 58 |
+
|
| 59 |
+
Returns:
|
| 60 |
+
(pv, pe_extended). Vertex array unchanged; edges only grow.
|
| 61 |
+
Falls back to inputs on any error.
|
| 62 |
+
"""
|
| 63 |
+
try:
|
| 64 |
+
from hoho2025.example_solutions import convert_entry_to_human_readable
|
| 65 |
+
from mvs_utils import collect_views, project_world_to_image
|
| 66 |
+
from edge_2d_filter import _build_edge_masks
|
| 67 |
+
|
| 68 |
+
pv_arr = np.asarray(pv, dtype=np.float64)
|
| 69 |
+
if pv_arr.ndim != 2 or pv_arr.shape[0] < 2:
|
| 70 |
+
return pv, pe
|
| 71 |
+
|
| 72 |
+
good = convert_entry_to_human_readable(sample)
|
| 73 |
+
colmap_rec = good.get("colmap") or good.get("colmap_binary")
|
| 74 |
+
if colmap_rec is None:
|
| 75 |
+
return pv, pe
|
| 76 |
+
|
| 77 |
+
views = collect_views(colmap_rec, good["image_ids"])
|
| 78 |
+
if len(views) < min_views_support:
|
| 79 |
+
return pv, pe
|
| 80 |
+
|
| 81 |
+
view_masks = _build_edge_masks(good, views, dilate_px=dilate_px)
|
| 82 |
+
if not view_masks:
|
| 83 |
+
return pv, pe
|
| 84 |
+
|
| 85 |
+
existing = set()
|
| 86 |
+
for a, b in pe:
|
| 87 |
+
a, b = int(a), int(b)
|
| 88 |
+
lo, hi = (a, b) if a < b else (b, a)
|
| 89 |
+
existing.add((lo, hi))
|
| 90 |
+
|
| 91 |
+
N = pv_arr.shape[0]
|
| 92 |
+
|
| 93 |
+
# Build candidate list: all pairs not already in `existing` and within
|
| 94 |
+
# max_edge_length. Sort by ascending 3D distance and cap.
|
| 95 |
+
candidates = []
|
| 96 |
+
for i in range(N):
|
| 97 |
+
for j in range(i + 1, N):
|
| 98 |
+
if (i, j) in existing:
|
| 99 |
+
continue
|
| 100 |
+
d = float(np.linalg.norm(pv_arr[i] - pv_arr[j]))
|
| 101 |
+
if d > max_edge_length_meters or d < 1e-3:
|
| 102 |
+
continue
|
| 103 |
+
candidates.append((d, i, j))
|
| 104 |
+
candidates.sort()
|
| 105 |
+
if len(candidates) > max_pair_check:
|
| 106 |
+
candidates = candidates[:max_pair_check]
|
| 107 |
+
|
| 108 |
+
if not candidates:
|
| 109 |
+
return pv, pe
|
| 110 |
+
|
| 111 |
+
# Cache view list once (avoid dict-iteration in hot loop)
|
| 112 |
+
view_items = [(img_id, views[img_id]["P"], *view_masks[img_id])
|
| 113 |
+
for img_id in view_masks]
|
| 114 |
+
|
| 115 |
+
# Score each candidate by (total support across views, # views supporting).
|
| 116 |
+
# Then accept only those meeting the threshold, capped by ranking.
|
| 117 |
+
scored = []
|
| 118 |
+
for _, i, j in candidates:
|
| 119 |
+
endpoints = np.stack([pv_arr[i], pv_arr[j]])
|
| 120 |
+
|
| 121 |
+
supporting = 0
|
| 122 |
+
total_support = 0.0
|
| 123 |
+
for _img_id, P, mask_bool, H, W in view_items:
|
| 124 |
+
uv, z = project_world_to_image(P, endpoints)
|
| 125 |
+
if z[0] <= 0 or z[1] <= 0:
|
| 126 |
+
continue
|
| 127 |
+
if not (
|
| 128 |
+
0 <= uv[0, 0] < W and 0 <= uv[0, 1] < H
|
| 129 |
+
and 0 <= uv[1, 0] < W and 0 <= uv[1, 1] < H
|
| 130 |
+
):
|
| 131 |
+
continue
|
| 132 |
+
|
| 133 |
+
t = np.linspace(0.0, 1.0, sample_steps)
|
| 134 |
+
xs = uv[0, 0] + t * (uv[1, 0] - uv[0, 0])
|
| 135 |
+
ys = uv[0, 1] + t * (uv[1, 1] - uv[0, 1])
|
| 136 |
+
xs_i = np.clip(xs.astype(np.int32), 0, W - 1)
|
| 137 |
+
ys_i = np.clip(ys.astype(np.int32), 0, H - 1)
|
| 138 |
+
|
| 139 |
+
frac = int(mask_bool[ys_i, xs_i].sum()) / float(sample_steps)
|
| 140 |
+
if frac >= min_pixel_frac:
|
| 141 |
+
supporting += 1
|
| 142 |
+
total_support += frac
|
| 143 |
+
|
| 144 |
+
if supporting >= min_views_support:
|
| 145 |
+
# Score: prioritize multi-view agreement, break ties on total support
|
| 146 |
+
scored.append((supporting, total_support, i, j))
|
| 147 |
+
|
| 148 |
+
if not scored:
|
| 149 |
+
return pv, pe
|
| 150 |
+
|
| 151 |
+
# Cap additions: min(absolute cap, rel-fraction of existing edges)
|
| 152 |
+
max_to_add = max(1, min(max_fills_abs,
|
| 153 |
+
int(max_fills_rel * max(len(pe), 1))))
|
| 154 |
+
scored.sort(reverse=True)
|
| 155 |
+
added = [(i, j) for _, _, i, j in scored[:max_to_add]]
|
| 156 |
+
|
| 157 |
+
new_pe = list(pe) + added
|
| 158 |
+
return pv, new_pe
|
| 159 |
+
|
| 160 |
+
except Exception:
|
| 161 |
+
return pv, pe
|
|
@@ -58,13 +58,30 @@ def parse_args():
|
|
| 58 |
help="vertex refine: min views with 2D match")
|
| 59 |
p.add_argument("--refine-max-move", type=float, default=0.5,
|
| 60 |
help="vertex refine: max 3D displacement in meters")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
return p.parse_args()
|
| 62 |
|
| 63 |
|
| 64 |
def predict_one(sample, model, device, cfg, rng,
|
| 65 |
use_tracks=True, use_2d_filter=True, orphan_only=False,
|
| 66 |
strict_no_support=False, vertex_refine=False,
|
| 67 |
-
refine_kwargs=None
|
|
|
|
| 68 |
"""Run the full inference pipeline on one sample. Returns (pv, pe, diag)."""
|
| 69 |
diag = {"colmap": -1, "fused": 0, "track_v": 0, "track_e": 0,
|
| 70 |
"pred_v": 0, "pred_e": 0, "2dfilt_in": 0, "2dfilt_out": 0,
|
|
@@ -79,17 +96,40 @@ def predict_one(sample, model, device, cfg, rng,
|
|
| 79 |
except Exception:
|
| 80 |
pass
|
| 81 |
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
|
| 94 |
if use_tracks:
|
| 95 |
try:
|
|
@@ -144,6 +184,17 @@ def predict_one(sample, model, device, cfg, rng,
|
|
| 144 |
diag["status"] = f"2dfilt_failed:{type(e).__name__}"
|
| 145 |
diag["2dfilt_out"] = len(pred_e) if hasattr(pred_e, '__len__') else 0
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
diag["pred_v"] = len(pred_v) if hasattr(pred_v, '__len__') else 0
|
| 148 |
diag["pred_e"] = len(pred_e) if hasattr(pred_e, '__len__') else 0
|
| 149 |
return pred_v, pred_e, diag
|
|
@@ -212,6 +263,14 @@ def main():
|
|
| 212 |
"min_views": args.refine_min_views,
|
| 213 |
"max_move_meters": args.refine_max_move,
|
| 214 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
pred_v, pred_e, diag = predict_one(
|
| 216 |
sample, model, device, cfg, rng,
|
| 217 |
use_tracks=not args.no_tracks,
|
|
@@ -219,7 +278,11 @@ def main():
|
|
| 219 |
orphan_only=args.orphan_only,
|
| 220 |
strict_no_support=args.strict_no_support,
|
| 221 |
vertex_refine=args.vertex_refine,
|
| 222 |
-
refine_kwargs=refine_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
if torch.backends.mps.is_available():
|
| 224 |
torch.mps.empty_cache()
|
| 225 |
|
|
|
|
| 58 |
help="vertex refine: min views with 2D match")
|
| 59 |
p.add_argument("--refine-max-move", type=float, default=0.5,
|
| 60 |
help="vertex refine: max 3D displacement in meters")
|
| 61 |
+
p.add_argument("--tta", action="store_true",
|
| 62 |
+
help="enable multi-seed TTA (3 priority-sample seeds, concat segments)")
|
| 63 |
+
p.add_argument("--tta-hungarian", action="store_true",
|
| 64 |
+
help="use Hungarian-matched averaging TTA (rejects unmatched segments)")
|
| 65 |
+
p.add_argument("--tta-min-passes", type=int, default=1,
|
| 66 |
+
help="hungarian TTA: drop anchor segments without this many supporting passes")
|
| 67 |
+
p.add_argument("--tta-seeds", type=str, default="2718,31415,42",
|
| 68 |
+
help="comma-separated priority-sample seeds for TTA")
|
| 69 |
+
p.add_argument("--edge-fill", action="store_true",
|
| 70 |
+
help="enable edge filling from 2D mask evidence")
|
| 71 |
+
p.add_argument("--fill-min-views", type=int, default=2,
|
| 72 |
+
help="edge fill: min views supporting a new edge")
|
| 73 |
+
p.add_argument("--fill-min-frac", type=float, default=0.40,
|
| 74 |
+
help="edge fill: min support fraction along projected segment")
|
| 75 |
+
p.add_argument("--fill-max-length", type=float, default=5.0,
|
| 76 |
+
help="edge fill: max edge length in meters")
|
| 77 |
return p.parse_args()
|
| 78 |
|
| 79 |
|
| 80 |
def predict_one(sample, model, device, cfg, rng,
|
| 81 |
use_tracks=True, use_2d_filter=True, orphan_only=False,
|
| 82 |
strict_no_support=False, vertex_refine=False,
|
| 83 |
+
refine_kwargs=None, edge_fill=False, fill_kwargs=None,
|
| 84 |
+
tta=False, tta_seeds=None):
|
| 85 |
"""Run the full inference pipeline on one sample. Returns (pv, pe, diag)."""
|
| 86 |
diag = {"colmap": -1, "fused": 0, "track_v": 0, "track_e": 0,
|
| 87 |
"pred_v": 0, "pred_e": 0, "2dfilt_in": 0, "2dfilt_out": 0,
|
|
|
|
| 96 |
except Exception:
|
| 97 |
pass
|
| 98 |
|
| 99 |
+
if tta:
|
| 100 |
+
try:
|
| 101 |
+
seeds = tta_seeds or (2718, 31415, 42)
|
| 102 |
+
tta_method = (
|
| 103 |
+
"predict_sample_tta_hungarian"
|
| 104 |
+
if getattr(predict_one, "_tta_hungarian", False)
|
| 105 |
+
else "predict_sample_tta"
|
| 106 |
+
)
|
| 107 |
+
import tta as _tta_mod
|
| 108 |
+
fn = getattr(_tta_mod, tta_method)
|
| 109 |
+
if tta_method == "predict_sample_tta_hungarian":
|
| 110 |
+
pred_v, pred_e = fn(
|
| 111 |
+
sample, cfg, model, device, seeds=tuple(seeds),
|
| 112 |
+
min_passes_for_keep=getattr(predict_one, "_tta_min_passes", 1),
|
| 113 |
+
)
|
| 114 |
+
else:
|
| 115 |
+
pred_v, pred_e = fn(
|
| 116 |
+
sample, cfg, model, device, seeds=tuple(seeds))
|
| 117 |
+
diag["fused"] = -1 # not single-seed
|
| 118 |
+
except Exception as e:
|
| 119 |
+
diag["status"] = f"tta_failed:{type(e).__name__}"
|
| 120 |
+
return *script.empty_solution(), diag
|
| 121 |
+
else:
|
| 122 |
+
fused = script.fuse_and_sample(sample, cfg, rng)
|
| 123 |
+
if fused is None:
|
| 124 |
+
diag["status"] = "fuse_failed"
|
| 125 |
+
return *script.empty_solution(), diag
|
| 126 |
+
diag["fused"] = len(fused["xyz_norm"])
|
| 127 |
|
| 128 |
+
try:
|
| 129 |
+
pred_v, pred_e = script.predict_sample(fused, model, device)
|
| 130 |
+
except Exception as e:
|
| 131 |
+
diag["status"] = f"predict_failed:{type(e).__name__}"
|
| 132 |
+
return *script.empty_solution(), diag
|
| 133 |
|
| 134 |
if use_tracks:
|
| 135 |
try:
|
|
|
|
| 184 |
diag["status"] = f"2dfilt_failed:{type(e).__name__}"
|
| 185 |
diag["2dfilt_out"] = len(pred_e) if hasattr(pred_e, '__len__') else 0
|
| 186 |
|
| 187 |
+
if edge_fill:
|
| 188 |
+
e_before = len(pred_e) if hasattr(pred_e, '__len__') else 0
|
| 189 |
+
try:
|
| 190 |
+
from edge_fill import fill_missing_edges_from_2d
|
| 191 |
+
pred_v, pred_e = fill_missing_edges_from_2d(
|
| 192 |
+
pred_v, pred_e, sample,
|
| 193 |
+
**(fill_kwargs or {}))
|
| 194 |
+
diag["filled"] = (len(pred_e) if hasattr(pred_e, '__len__') else 0) - e_before
|
| 195 |
+
except Exception as e:
|
| 196 |
+
diag["status"] = f"fill_failed:{type(e).__name__}"
|
| 197 |
+
|
| 198 |
diag["pred_v"] = len(pred_v) if hasattr(pred_v, '__len__') else 0
|
| 199 |
diag["pred_e"] = len(pred_e) if hasattr(pred_e, '__len__') else 0
|
| 200 |
return pred_v, pred_e, diag
|
|
|
|
| 263 |
"min_views": args.refine_min_views,
|
| 264 |
"max_move_meters": args.refine_max_move,
|
| 265 |
}
|
| 266 |
+
fill_kwargs = {
|
| 267 |
+
"min_views_support": args.fill_min_views,
|
| 268 |
+
"min_pixel_frac": args.fill_min_frac,
|
| 269 |
+
"max_edge_length_meters": args.fill_max_length,
|
| 270 |
+
}
|
| 271 |
+
tta_seeds_tuple = tuple(int(s) for s in args.tta_seeds.split(","))
|
| 272 |
+
predict_one._tta_hungarian = args.tta_hungarian
|
| 273 |
+
predict_one._tta_min_passes = args.tta_min_passes
|
| 274 |
pred_v, pred_e, diag = predict_one(
|
| 275 |
sample, model, device, cfg, rng,
|
| 276 |
use_tracks=not args.no_tracks,
|
|
|
|
| 278 |
orphan_only=args.orphan_only,
|
| 279 |
strict_no_support=args.strict_no_support,
|
| 280 |
vertex_refine=args.vertex_refine,
|
| 281 |
+
refine_kwargs=refine_kwargs,
|
| 282 |
+
edge_fill=args.edge_fill,
|
| 283 |
+
fill_kwargs=fill_kwargs,
|
| 284 |
+
tta=args.tta,
|
| 285 |
+
tta_seeds=tta_seeds_tuple)
|
| 286 |
if torch.backends.mps.is_available():
|
| 287 |
torch.mps.empty_cache()
|
| 288 |
|
|
@@ -59,6 +59,14 @@ CONF_THRESH = 0.4
|
|
| 59 |
MERGE_THRESH = 0.4
|
| 60 |
SNAP_RADIUS = 0.5
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
def fuse_and_sample(sample, cfg, rng):
|
| 64 |
"""Run point fusion + priority sampling on a raw dataset sample.
|
|
@@ -398,21 +406,51 @@ if __name__ == "__main__":
|
|
| 398 |
except Exception:
|
| 399 |
pass
|
| 400 |
|
| 401 |
-
# Fuse + sample
|
| 402 |
-
fused = fuse_and_sample(sample, cfg, rng)
|
| 403 |
-
n_fused_pts = len(fused["xyz_norm"]) if fused is not None else 0
|
| 404 |
track_v_count, track_e_count = 0, 0
|
| 405 |
pred_status = "ok"
|
|
|
|
| 406 |
|
| 407 |
-
if
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
else:
|
| 411 |
try:
|
| 412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
if torch.cuda.is_available():
|
| 414 |
torch.cuda.empty_cache()
|
| 415 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
# Apply handcrafted triangulation tracking to catch missing corners/edges
|
| 417 |
try:
|
| 418 |
from triangulation import predict_wireframe_tracks
|
|
|
|
| 59 |
MERGE_THRESH = 0.4
|
| 60 |
SNAP_RADIUS = 0.5
|
| 61 |
|
| 62 |
+
# Test-time augmentation: 3 priority-sample seeds + Hungarian matching with
|
| 63 |
+
# strict 3-pass agreement (min_passes_for_keep=2). Local 100-sample A/B:
|
| 64 |
+
# q5 0.062 -> 0.091 (+47%), mean +0.003. Costs 3x inference time but
|
| 65 |
+
# strict filter dramatically improves precision on hard scenes.
|
| 66 |
+
USE_TTA = True
|
| 67 |
+
TTA_SEEDS = (2718, 31415, 42)
|
| 68 |
+
TTA_MIN_PASSES = 2
|
| 69 |
+
|
| 70 |
|
| 71 |
def fuse_and_sample(sample, cfg, rng):
|
| 72 |
"""Run point fusion + priority sampling on a raw dataset sample.
|
|
|
|
| 406 |
except Exception:
|
| 407 |
pass
|
| 408 |
|
|
|
|
|
|
|
|
|
|
| 409 |
track_v_count, track_e_count = 0, 0
|
| 410 |
pred_status = "ok"
|
| 411 |
+
n_fused_pts = 0
|
| 412 |
|
| 413 |
+
if USE_TTA:
|
| 414 |
+
# Multi-seed TTA: fuse + predict 3 times, Hungarian-match segments
|
| 415 |
+
# across passes, drop those without min_passes agreement.
|
|
|
|
| 416 |
try:
|
| 417 |
+
from tta import predict_sample_tta_hungarian
|
| 418 |
+
pred_v, pred_e = predict_sample_tta_hungarian(
|
| 419 |
+
sample, cfg, model, device,
|
| 420 |
+
seeds=TTA_SEEDS,
|
| 421 |
+
min_passes_for_keep=TTA_MIN_PASSES,
|
| 422 |
+
)
|
| 423 |
if torch.cuda.is_available():
|
| 424 |
torch.cuda.empty_cache()
|
| 425 |
+
except Exception as e:
|
| 426 |
+
import traceback
|
| 427 |
+
print(f" TTA failed for {order_id}:\n{traceback.format_exc()}")
|
| 428 |
+
pred_v, pred_e = empty_solution()
|
| 429 |
+
pred_status = "tta_failed"
|
| 430 |
+
if torch.cuda.is_available():
|
| 431 |
+
torch.cuda.empty_cache()
|
| 432 |
+
else:
|
| 433 |
+
# Single-seed inference (legacy path, kept for easy revert).
|
| 434 |
+
fused = fuse_and_sample(sample, cfg, rng)
|
| 435 |
+
n_fused_pts = len(fused["xyz_norm"]) if fused is not None else 0
|
| 436 |
+
if fused is None:
|
| 437 |
+
pred_v, pred_e = empty_solution()
|
| 438 |
+
pred_status = "fuse_failed"
|
| 439 |
+
else:
|
| 440 |
+
try:
|
| 441 |
+
pred_v, pred_e = predict_sample(fused, model, device)
|
| 442 |
+
if torch.cuda.is_available():
|
| 443 |
+
torch.cuda.empty_cache()
|
| 444 |
+
except Exception as e:
|
| 445 |
+
import traceback
|
| 446 |
+
print(f" Predict failed for {order_id}:\n{traceback.format_exc()}")
|
| 447 |
+
pred_v, pred_e = empty_solution()
|
| 448 |
+
pred_status = "predict_failed"
|
| 449 |
+
if torch.cuda.is_available():
|
| 450 |
+
torch.cuda.empty_cache()
|
| 451 |
+
|
| 452 |
+
if pred_status == "ok":
|
| 453 |
+
try:
|
| 454 |
# Apply handcrafted triangulation tracking to catch missing corners/edges
|
| 455 |
try:
|
| 456 |
from triangulation import predict_wireframe_tracks
|
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test-time augmentation via multi-seed priority sampling.
|
| 2 |
+
|
| 3 |
+
Runs the model N times with different priority-sample seeds, concatenates the
|
| 4 |
+
world-space segment predictions from each pass, and lets the standard
|
| 5 |
+
merge_vertices_iterative do the deduplication. Because the iterative merge
|
| 6 |
+
takes union-find clusters and uses each cluster's centroid as the merged
|
| 7 |
+
position, this is effectively Hungarian-averaging for matched segments —
|
| 8 |
+
without needing to solve a real assignment problem.
|
| 9 |
+
|
| 10 |
+
The previous TTA attempt (commit 857514e, reverted) failed because it picked
|
| 11 |
+
ONE pass's output. This implementation aggregates ALL passes' segments and
|
| 12 |
+
lets the established merge logic combine them.
|
| 13 |
+
|
| 14 |
+
Why it should work:
|
| 15 |
+
- Stochastic variation comes from priority-sample seed (the model itself is
|
| 16 |
+
deterministic). Different seeds give different points → slightly different
|
| 17 |
+
model predictions for the same scene.
|
| 18 |
+
- Matched segments (true edges) appear in all passes near each other → they
|
| 19 |
+
cluster in the merge and get averaged toward consensus.
|
| 20 |
+
- Spurious segments (hallucinations) appear in only 1 pass → they survive
|
| 21 |
+
individually but are typically not high-confidence enough to win.
|
| 22 |
+
- The iterative merge thresholds 0.15→0.6 m are appropriate for the typical
|
| 23 |
+
inter-pass jitter of correctly-predicted segments.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
import torch
|
| 30 |
+
|
| 31 |
+
import script
|
| 32 |
+
from s23dr_2026_example.segment_postprocess import merge_vertices_iterative
|
| 33 |
+
from s23dr_2026_example.varifold import segments_to_vertices_edges
|
| 34 |
+
from s23dr_2026_example.postprocess_v2 import snap_to_point_cloud, snap_horizontal
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _model_to_world_segments(sample_dict, model, device):
|
| 38 |
+
"""Run the model on a single fused sample, return (N, 2, 3) world segments.
|
| 39 |
+
|
| 40 |
+
Returns None if no segments pass the confidence threshold.
|
| 41 |
+
"""
|
| 42 |
+
tokens, masks = script.build_tokens_single(sample_dict, model, device)
|
| 43 |
+
scale = float(sample_dict["scale"])
|
| 44 |
+
center = sample_dict["center"]
|
| 45 |
+
with torch.no_grad(), torch.autocast(
|
| 46 |
+
device_type='cuda', dtype=torch.float16,
|
| 47 |
+
enabled=(device.type == 'cuda'),
|
| 48 |
+
):
|
| 49 |
+
out = model.forward_tokens(tokens, masks)
|
| 50 |
+
segs = out["segments"][0].float().cpu()
|
| 51 |
+
conf = (
|
| 52 |
+
torch.sigmoid(out["conf"][0].float()).cpu().numpy()
|
| 53 |
+
if "conf" in out else None
|
| 54 |
+
)
|
| 55 |
+
if conf is not None:
|
| 56 |
+
segs = segs[conf > script.CONF_THRESH]
|
| 57 |
+
if len(segs) < 1:
|
| 58 |
+
return None
|
| 59 |
+
return segs.numpy() * scale + center # (N, 2, 3)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _match_segments(anchor_segs, other_segs, max_endpoint_dist=0.4):
|
| 63 |
+
"""Hungarian-match segments between two passes (flip-invariant distance).
|
| 64 |
+
|
| 65 |
+
Returns list of (i_anchor, j_other, was_flipped) for accepted matches.
|
| 66 |
+
Matches with cost > 2*max_endpoint_dist are rejected.
|
| 67 |
+
"""
|
| 68 |
+
if len(anchor_segs) == 0 or len(other_segs) == 0:
|
| 69 |
+
return []
|
| 70 |
+
from scipy.optimize import linear_sum_assignment
|
| 71 |
+
|
| 72 |
+
N, M = len(anchor_segs), len(other_segs)
|
| 73 |
+
# vectorized cost computation
|
| 74 |
+
a0 = anchor_segs[:, 0][:, None, :] # (N, 1, 3)
|
| 75 |
+
a1 = anchor_segs[:, 1][:, None, :]
|
| 76 |
+
b0 = other_segs[None, :, 0] # (1, M, 3)
|
| 77 |
+
b1 = other_segs[None, :, 1]
|
| 78 |
+
d_same = (np.linalg.norm(a0 - b0, axis=-1) +
|
| 79 |
+
np.linalg.norm(a1 - b1, axis=-1))
|
| 80 |
+
d_flip = (np.linalg.norm(a0 - b1, axis=-1) +
|
| 81 |
+
np.linalg.norm(a1 - b0, axis=-1))
|
| 82 |
+
cost = np.minimum(d_same, d_flip)
|
| 83 |
+
flipped = d_flip < d_same
|
| 84 |
+
|
| 85 |
+
row, col = linear_sum_assignment(cost)
|
| 86 |
+
threshold = 2.0 * max_endpoint_dist
|
| 87 |
+
return [(int(i), int(j), bool(flipped[i, j]))
|
| 88 |
+
for i, j in zip(row, col) if cost[i, j] <= threshold]
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def predict_sample_tta_hungarian(sample, cfg, model, device,
|
| 92 |
+
seeds=(2718, 31415, 42),
|
| 93 |
+
match_dist: float = 0.4,
|
| 94 |
+
min_passes_for_keep: int = 1,
|
| 95 |
+
snap_target_classes=(0, 1, 2)):
|
| 96 |
+
"""Hungarian-averaged TTA. Aggregates segments via flip-invariant matching.
|
| 97 |
+
|
| 98 |
+
Pass 0 is the anchor. Pass 1+ segments are matched to anchor via Hungarian
|
| 99 |
+
on endpoint distance. Each anchor segment gets averaged with its matches
|
| 100 |
+
(orientation-aligned). Anchor segments with < min_passes_for_keep matches
|
| 101 |
+
are kept only if min_passes_for_keep == 0; otherwise dropped.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
min_passes_for_keep: 0 = keep all anchor segments, 1 = require matching
|
| 105 |
+
in at least 1 other pass.
|
| 106 |
+
"""
|
| 107 |
+
sample_dicts = []
|
| 108 |
+
all_segs = []
|
| 109 |
+
for seed in seeds:
|
| 110 |
+
rng = np.random.RandomState(int(seed))
|
| 111 |
+
sd = script.fuse_and_sample(sample, cfg, rng)
|
| 112 |
+
if sd is None:
|
| 113 |
+
continue
|
| 114 |
+
sw = _model_to_world_segments(sd, model, device)
|
| 115 |
+
if sw is None or len(sw) == 0:
|
| 116 |
+
continue
|
| 117 |
+
sample_dicts.append(sd)
|
| 118 |
+
all_segs.append(sw)
|
| 119 |
+
|
| 120 |
+
if not all_segs:
|
| 121 |
+
return script.empty_solution()
|
| 122 |
+
|
| 123 |
+
if len(all_segs) == 1:
|
| 124 |
+
# No TTA gain possible; just run the normal post-process.
|
| 125 |
+
return _post_segments_to_wireframe(
|
| 126 |
+
all_segs[0], sample_dicts[0], snap_target_classes)
|
| 127 |
+
|
| 128 |
+
anchor = all_segs[0]
|
| 129 |
+
matches_per_anchor = [[] for _ in range(len(anchor))] # list of (other_seg, flipped)
|
| 130 |
+
for p_idx in range(1, len(all_segs)):
|
| 131 |
+
for i_a, j_o, flipped in _match_segments(anchor, all_segs[p_idx],
|
| 132 |
+
max_endpoint_dist=match_dist):
|
| 133 |
+
matches_per_anchor[i_a].append((all_segs[p_idx][j_o], flipped))
|
| 134 |
+
|
| 135 |
+
averaged = []
|
| 136 |
+
for i, matches in enumerate(matches_per_anchor):
|
| 137 |
+
if len(matches) < min_passes_for_keep:
|
| 138 |
+
continue
|
| 139 |
+
seg = anchor[i]
|
| 140 |
+
if not matches:
|
| 141 |
+
averaged.append(seg)
|
| 142 |
+
continue
|
| 143 |
+
# Align orientations to anchor and average
|
| 144 |
+
aligned = [seg]
|
| 145 |
+
for other_seg, flipped in matches:
|
| 146 |
+
aligned.append(other_seg[::-1] if flipped else other_seg)
|
| 147 |
+
averaged.append(np.mean(aligned, axis=0))
|
| 148 |
+
|
| 149 |
+
if not averaged:
|
| 150 |
+
# Defensive: fall back to concat-and-merge if matching dropped everything
|
| 151 |
+
return _post_segments_to_wireframe(
|
| 152 |
+
np.concatenate(all_segs, axis=0), sample_dicts[0],
|
| 153 |
+
snap_target_classes)
|
| 154 |
+
|
| 155 |
+
return _post_segments_to_wireframe(
|
| 156 |
+
np.asarray(averaged), sample_dicts[0], snap_target_classes)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _post_segments_to_wireframe(segments, sd0, snap_target_classes):
|
| 160 |
+
"""Standard post-process: segments -> vertices/edges -> merge -> snap."""
|
| 161 |
+
pv, pe = segments_to_vertices_edges(torch.tensor(segments))
|
| 162 |
+
pv, pe = pv.numpy(), np.array(pe, dtype=np.int32)
|
| 163 |
+
pv, pe = merge_vertices_iterative(pv, pe)
|
| 164 |
+
|
| 165 |
+
xyz_norm = sd0["xyz_norm"]
|
| 166 |
+
mask = sd0["mask"]
|
| 167 |
+
cid = sd0["class_id"]
|
| 168 |
+
xyz_world = xyz_norm[mask] * float(sd0["scale"]) + sd0["center"]
|
| 169 |
+
cid_valid = cid[mask]
|
| 170 |
+
pv = snap_to_point_cloud(
|
| 171 |
+
pv, xyz_world, cid_valid,
|
| 172 |
+
snap_radius=script.SNAP_RADIUS,
|
| 173 |
+
target_classes=list(snap_target_classes),
|
| 174 |
+
)
|
| 175 |
+
pv = snap_horizontal(pv, pe)
|
| 176 |
+
|
| 177 |
+
if len(pv) < 2 or len(pe) < 1:
|
| 178 |
+
return script.empty_solution()
|
| 179 |
+
return pv, [(int(a), int(b)) for a, b in pe]
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def predict_sample_tta(sample, cfg, model, device,
|
| 183 |
+
seeds=(2718, 31415, 42),
|
| 184 |
+
snap_target_classes=(0, 1, 2)):
|
| 185 |
+
"""Multi-seed TTA prediction. Returns (vertices, edges) in world space.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
sample: raw dataset entry.
|
| 189 |
+
cfg: FuserConfig.
|
| 190 |
+
model: loaded model.
|
| 191 |
+
device: torch device.
|
| 192 |
+
seeds: tuple of priority-sample seeds. Length = # TTA passes.
|
| 193 |
+
snap_target_classes: classes for snap_to_point_cloud (default
|
| 194 |
+
[apex, eave_end_point, flashing_end_point] = [0, 1, 2]).
|
| 195 |
+
"""
|
| 196 |
+
# 1) Run fuse + model for each seed, collect world-space segments.
|
| 197 |
+
sample_dicts = []
|
| 198 |
+
all_segs = []
|
| 199 |
+
for seed in seeds:
|
| 200 |
+
rng = np.random.RandomState(int(seed))
|
| 201 |
+
sd = script.fuse_and_sample(sample, cfg, rng)
|
| 202 |
+
if sd is None:
|
| 203 |
+
continue
|
| 204 |
+
sw = _model_to_world_segments(sd, model, device)
|
| 205 |
+
if sw is None or len(sw) == 0:
|
| 206 |
+
continue
|
| 207 |
+
sample_dicts.append(sd)
|
| 208 |
+
all_segs.append(sw)
|
| 209 |
+
|
| 210 |
+
if not all_segs:
|
| 211 |
+
return script.empty_solution()
|
| 212 |
+
|
| 213 |
+
# 2) Concatenate segments across passes. The downstream merge will cluster
|
| 214 |
+
# near-duplicate vertices (matched across passes) and take centroids,
|
| 215 |
+
# yielding the Hungarian-average behavior.
|
| 216 |
+
combined = np.concatenate(all_segs, axis=0)
|
| 217 |
+
|
| 218 |
+
# 3) Standard post-process: segments -> vertices/edges, iterative merge.
|
| 219 |
+
pv, pe = segments_to_vertices_edges(torch.tensor(combined))
|
| 220 |
+
pv, pe = pv.numpy(), np.array(pe, dtype=np.int32)
|
| 221 |
+
pv, pe = merge_vertices_iterative(pv, pe)
|
| 222 |
+
|
| 223 |
+
# 4) Snap to point cloud. Use the FIRST sample_dict's context (the
|
| 224 |
+
# fused points are roughly similar across seeds so this is a reasonable
|
| 225 |
+
# proxy; using merged xyz would require re-fusing).
|
| 226 |
+
sd0 = sample_dicts[0]
|
| 227 |
+
xyz_norm = sd0["xyz_norm"]
|
| 228 |
+
mask = sd0["mask"]
|
| 229 |
+
cid = sd0["class_id"]
|
| 230 |
+
scale0 = float(sd0["scale"])
|
| 231 |
+
center0 = sd0["center"]
|
| 232 |
+
xyz_world = xyz_norm[mask] * scale0 + center0
|
| 233 |
+
cid_valid = cid[mask]
|
| 234 |
+
pv = snap_to_point_cloud(
|
| 235 |
+
pv, xyz_world, cid_valid,
|
| 236 |
+
snap_radius=script.SNAP_RADIUS,
|
| 237 |
+
target_classes=list(snap_target_classes),
|
| 238 |
+
)
|
| 239 |
+
pv = snap_horizontal(pv, pe)
|
| 240 |
+
|
| 241 |
+
if len(pv) < 2 or len(pe) < 1:
|
| 242 |
+
return script.empty_solution()
|
| 243 |
+
|
| 244 |
+
edges = [(int(a), int(b)) for a, b in pe]
|
| 245 |
+
return pv, edges
|