xsponenta commited on
Commit
857514e
·
1 Parent(s): 8e33a89

TTA optimization: cache point fusion, vary only priority sampling

Browse files

The previous TTA commit (8e33a89) timed out because it called
build_compact_scene 3x per sample (the expensive multi-view label
voting step). Refactor:

- compute_scene(sample, cfg, rng): does build_compact_scene + group/class/
center/scale computation. Called ONCE per sample.
- sample_from_scene(scene): does priority sampling + result-dict assembly.
Cheap, called K=3 times per sample.
- fuse_and_sample is preserved as a backward-compat wrapper.

Why this still gives TTA variation: _priority_sample uses the *global*
numpy random state (np.random.shuffle), not an explicit rng arg. Each
consecutive call advances the global state and produces a different
4096-point subset of the same fused scene. The model sees different
inputs across passes despite the scene being identical.

Cost: ~10% extra wall time vs single pass (3x cheap priority sampling
+ 3x cheap model forward), instead of ~200% from the previous commit.
Should fit comfortably in the 2h budget.

Files changed (1) hide show
  1. script.py +59 -30
script.py CHANGED
@@ -60,11 +60,11 @@ MERGE_THRESH = 0.4
60
  SNAP_RADIUS = 0.5
61
 
62
 
63
- def fuse_and_sample(sample, cfg, rng):
64
- """Run point fusion + priority sampling on a raw dataset sample.
65
 
66
- Returns a dict with xyz_norm, class_id, source, mask, center, scale, etc.
67
- ready for model inference. Returns None if fusion fails.
68
  """
69
  try:
70
  scene = build_compact_scene(sample, cfg, rng)
@@ -74,21 +74,43 @@ def fuse_and_sample(sample, cfg, rng):
74
 
75
  xyz = scene["xyz"]
76
  source = scene["source"]
77
-
78
  if len(xyz) < 10:
79
  return None
80
 
81
- # Compute group_id and class_id (same as cache_scenes.py)
82
  behind_id = scene.get("behind_gest_id", np.full(len(xyz), -1, dtype=np.int16))
83
  group_id, class_id = _compute_group_and_class(
84
  scene["visible_src"], scene["visible_id"], behind_id, source)
85
-
86
- # Normalize
87
  center, scale = _compute_smart_center_scale(xyz, source)
88
 
89
- # Priority sample
90
- indices, mask = _priority_sample(source, group_id, SEQ_LEN, COLMAP_QUOTA, DEPTH_QUOTA)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
 
92
  xyz_norm = (xyz[indices] - center) / scale
93
 
94
  result = {
@@ -99,23 +121,26 @@ def fuse_and_sample(sample, cfg, rng):
99
  "center": center.astype(np.float32),
100
  "scale": np.float32(scale),
101
  }
102
-
103
- # Optional fields
104
- if "behind_gest_id" in scene:
105
  behind = np.clip(scene["behind_gest_id"][indices].astype(np.int16), 0, None)
106
  result["behind"] = behind.astype(np.int64)
107
- if "n_views_voted" in scene:
108
  result["n_views_voted"] = scene["n_views_voted"][indices].astype(np.float32)
109
- if "vote_frac" in scene:
110
  result["vote_frac"] = scene["vote_frac"][indices].astype(np.float32)
111
-
112
- # Visible src/id for snap post-processing
113
  result["visible_src"] = scene["visible_src"][indices].astype(np.int64)
114
  result["visible_id"] = scene["visible_id"][indices].astype(np.int64)
115
-
116
  return result
117
 
118
 
 
 
 
 
 
 
 
 
119
  def load_model(checkpoint_path, device):
120
  """Load model from checkpoint."""
121
  ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
@@ -455,19 +480,23 @@ if __name__ == "__main__":
455
  order_id = sample["order_id"]
456
 
457
  try:
458
- # ---- TTA: run the learned pipeline K times, union outputs
 
 
 
 
 
459
  tta_outputs = []
460
- for k in range(TTA_PASSES):
461
- rng_k = np.random.RandomState(TTA_BASE_SEED + k * 1000)
462
- fused_k = fuse_and_sample(sample, cfg, rng_k)
463
- if fused_k is None:
464
- continue
465
- try:
466
- pv_k, pe_k = predict_sample(fused_k, model, device)
467
- if isinstance(pv_k, np.ndarray) and len(pv_k) >= 2 and len(pe_k) >= 1:
468
- tta_outputs.append((pv_k, pe_k))
469
- except Exception as tta_e:
470
- print(f" TTA pass {k} failed for {order_id}: {tta_e}")
471
  if torch.cuda.is_available():
472
  torch.cuda.empty_cache()
473
 
 
60
  SNAP_RADIUS = 0.5
61
 
62
 
63
+ def compute_scene(sample, cfg, rng):
64
+ """Expensive: multi-view label voting + smart normalization. Call once per sample.
65
 
66
+ Returns a dict with the full pre-priority-sampling fused scene, ready to
67
+ feed into ``sample_from_scene`` repeatedly for TTA. Returns None on failure.
68
  """
69
  try:
70
  scene = build_compact_scene(sample, cfg, rng)
 
74
 
75
  xyz = scene["xyz"]
76
  source = scene["source"]
 
77
  if len(xyz) < 10:
78
  return None
79
 
 
80
  behind_id = scene.get("behind_gest_id", np.full(len(xyz), -1, dtype=np.int16))
81
  group_id, class_id = _compute_group_and_class(
82
  scene["visible_src"], scene["visible_id"], behind_id, source)
 
 
83
  center, scale = _compute_smart_center_scale(xyz, source)
84
 
85
+ return {
86
+ "xyz": xyz,
87
+ "source": source,
88
+ "group_id": group_id,
89
+ "class_id": class_id,
90
+ "center": center,
91
+ "scale": scale,
92
+ "behind_gest_id": scene.get("behind_gest_id"),
93
+ "n_views_voted": scene.get("n_views_voted"),
94
+ "vote_frac": scene.get("vote_frac"),
95
+ "visible_src": scene["visible_src"],
96
+ "visible_id": scene["visible_id"],
97
+ }
98
+
99
+
100
+ def sample_from_scene(scene):
101
+ """Cheap: priority-sample 4096 points from a fused scene.
102
+
103
+ Uses the global numpy random state (advanced internally by ``_priority_sample``),
104
+ so consecutive calls yield different 4096-subsets — perfect for TTA.
105
+ """
106
+ xyz = scene["xyz"]
107
+ source = scene["source"]
108
+ group_id = scene["group_id"]
109
+ class_id = scene["class_id"]
110
+ center = scene["center"]
111
+ scale = scene["scale"]
112
 
113
+ indices, mask = _priority_sample(source, group_id, SEQ_LEN, COLMAP_QUOTA, DEPTH_QUOTA)
114
  xyz_norm = (xyz[indices] - center) / scale
115
 
116
  result = {
 
121
  "center": center.astype(np.float32),
122
  "scale": np.float32(scale),
123
  }
124
+ if scene.get("behind_gest_id") is not None:
 
 
125
  behind = np.clip(scene["behind_gest_id"][indices].astype(np.int16), 0, None)
126
  result["behind"] = behind.astype(np.int64)
127
+ if scene.get("n_views_voted") is not None:
128
  result["n_views_voted"] = scene["n_views_voted"][indices].astype(np.float32)
129
+ if scene.get("vote_frac") is not None:
130
  result["vote_frac"] = scene["vote_frac"][indices].astype(np.float32)
 
 
131
  result["visible_src"] = scene["visible_src"][indices].astype(np.int64)
132
  result["visible_id"] = scene["visible_id"][indices].astype(np.int64)
 
133
  return result
134
 
135
 
136
+ def fuse_and_sample(sample, cfg, rng):
137
+ """Backward-compatible wrapper: compute scene + one priority sample."""
138
+ scene = compute_scene(sample, cfg, rng)
139
+ if scene is None:
140
+ return None
141
+ return sample_from_scene(scene)
142
+
143
+
144
  def load_model(checkpoint_path, device):
145
  """Load model from checkpoint."""
146
  ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
 
480
  order_id = sample["order_id"]
481
 
482
  try:
483
+ # ---- Build the fused scene ONCE (the expensive multi-view
484
+ # label voting); then run priority sampling + model K times
485
+ # for TTA. _priority_sample uses the global numpy RNG which
486
+ # advances on each call, giving genuine variation cheaply.
487
+ scene_rng = np.random.RandomState(TTA_BASE_SEED)
488
+ scene = compute_scene(sample, cfg, scene_rng)
489
  tta_outputs = []
490
+ if scene is not None:
491
+ np.random.seed(TTA_BASE_SEED) # reset global RNG for reproducibility
492
+ for k in range(TTA_PASSES):
493
+ try:
494
+ fused_k = sample_from_scene(scene)
495
+ pv_k, pe_k = predict_sample(fused_k, model, device)
496
+ if isinstance(pv_k, np.ndarray) and len(pv_k) >= 2 and len(pe_k) >= 1:
497
+ tta_outputs.append((pv_k, pe_k))
498
+ except Exception as tta_e:
499
+ print(f" TTA pass {k} failed for {order_id}: {tta_e}")
 
500
  if torch.cuda.is_available():
501
  torch.cuda.empty_cache()
502