Add S23DR 2026 learned baseline (mirrored from usm3d/learned-baseline-2026, repro_runs/submitted_2048 omitted)
e87ec4d verified | """S23DR 2026 submission: learned wireframe prediction from fused point clouds. | |
| Pipeline: raw sample -> point fusion -> priority sample 2048 -> model -> post-process -> wireframe | |
| """ | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import json | |
| import os | |
| import sys | |
| import time | |
| import numpy as np | |
| import torch | |
| def empty_solution(): | |
| return np.zeros((2, 3)), [(0, 1)] | |
| # --------------------------------------------------------------------------- | |
| # Point fusion + sampling (from cache_scenes.py / make_sampled_cache.py) | |
| # --------------------------------------------------------------------------- | |
| # Add our package to path | |
| SCRIPT_DIR = Path(__file__).resolve().parent | |
| sys.path.insert(0, str(SCRIPT_DIR)) | |
| from s23dr_2026_example.point_fusion import build_compact_scene, FuserConfig | |
| from s23dr_2026_example.cache_scenes import ( | |
| _compute_group_and_class, _compute_smart_center_scale, | |
| ) | |
| from s23dr_2026_example.make_sampled_cache import _priority_sample | |
| # Tokenizer / model imports | |
| from s23dr_2026_example.tokenizer import EdgeDepthSequenceConfig | |
| from s23dr_2026_example.model import EdgeDepthSegmentsModel | |
| from s23dr_2026_example.segment_postprocess import merge_vertices_iterative | |
| from s23dr_2026_example.varifold import segments_to_vertices_edges | |
| from s23dr_2026_example.postprocess_v2 import snap_to_point_cloud, snap_horizontal | |
| SEQ_LEN = 4096 | |
| COLMAP_QUOTA = 3072 | |
| DEPTH_QUOTA = 1024 | |
| CONF_THRESH = 0.5 | |
| MERGE_THRESH = 0.4 | |
| SNAP_RADIUS = 0.5 | |
| def fuse_and_sample(sample, cfg, rng): | |
| """Run point fusion + priority sampling on a raw dataset sample. | |
| Returns a dict with xyz_norm, class_id, source, mask, center, scale, etc. | |
| ready for model inference. Returns None if fusion fails. | |
| """ | |
| try: | |
| scene = build_compact_scene(sample, cfg, rng) | |
| except Exception as e: | |
| print(f" Fusion failed: {e}") | |
| return None | |
| xyz = scene["xyz"] | |
| source = scene["source"] | |
| if len(xyz) < 10: | |
| return None | |
| # Compute group_id and class_id (same as cache_scenes.py) | |
| behind_id = scene.get("behind_gest_id", np.full(len(xyz), -1, dtype=np.int16)) | |
| group_id, class_id = _compute_group_and_class( | |
| scene["visible_src"], scene["visible_id"], behind_id, source) | |
| # Normalize | |
| center, scale = _compute_smart_center_scale(xyz, source) | |
| # Priority sample | |
| indices, mask = _priority_sample(source, group_id, SEQ_LEN, COLMAP_QUOTA, DEPTH_QUOTA) | |
| xyz_norm = (xyz[indices] - center) / scale | |
| result = { | |
| "xyz_norm": xyz_norm.astype(np.float32), | |
| "class_id": class_id[indices].astype(np.int64), | |
| "source": source[indices].astype(np.int64), | |
| "mask": mask, | |
| "center": center.astype(np.float32), | |
| "scale": np.float32(scale), | |
| } | |
| # Optional fields | |
| if "behind_gest_id" in scene: | |
| behind = np.clip(scene["behind_gest_id"][indices].astype(np.int16), 0, None) | |
| result["behind"] = behind.astype(np.int64) | |
| if "n_views_voted" in scene: | |
| result["n_views_voted"] = scene["n_views_voted"][indices].astype(np.float32) | |
| if "vote_frac" in scene: | |
| result["vote_frac"] = scene["vote_frac"][indices].astype(np.float32) | |
| # Visible src/id for snap post-processing | |
| result["visible_src"] = scene["visible_src"][indices].astype(np.int64) | |
| result["visible_id"] = scene["visible_id"][indices].astype(np.int64) | |
| return result | |
| def load_model(checkpoint_path, device): | |
| """Load model from checkpoint.""" | |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| args = ckpt.get("args", {}) | |
| norm_class = torch.nn.RMSNorm if args.get("rms_norm") else None | |
| seq_cfg = EdgeDepthSequenceConfig( | |
| seq_len=SEQ_LEN, colmap_points=COLMAP_QUOTA, depth_points=DEPTH_QUOTA) | |
| model = EdgeDepthSegmentsModel( | |
| seq_cfg=seq_cfg, | |
| segments=args.get("segments", 64), | |
| hidden=args.get("hidden", 256), | |
| num_heads=args.get("num_heads", 4), | |
| kv_heads_cross=args.get("kv_heads_cross", 2), | |
| kv_heads_self=args.get("kv_heads_self", 2), | |
| dim_feedforward=args.get("ff", 1024), | |
| dropout=args.get("dropout", 0.1), | |
| latent_tokens=args.get("latent_tokens", 256), | |
| latent_layers=args.get("latent_layers", 7), | |
| decoder_layers=args.get("decoder_layers", 3), | |
| cross_attn_interval=args.get("cross_attn_interval", 4), | |
| norm_class=norm_class, | |
| activation=args.get("activation", "gelu"), | |
| segment_conf=args.get("segment_conf", True), | |
| behind_emb_dim=args.get("behind_emb_dim", 8), | |
| use_vote_features=args.get("vote_features", True), | |
| arch=args.get("arch", "perceiver"), | |
| encoder_layers=args.get("encoder_layers", 4), | |
| pre_encoder_layers=args.get("pre_encoder_layers", 0), | |
| segment_param=args.get("segment_param", "midpoint_dir_len"), | |
| qk_norm=args.get("qk_norm", True), | |
| ).to(device) | |
| # Handle torch.compile _orig_mod prefix | |
| state = ckpt["model"] | |
| fixed = {k.replace("segmenter._orig_mod.", "segmenter."): v | |
| for k, v in state.items()} | |
| model.load_state_dict(fixed, strict=True) | |
| model.eval() | |
| return model | |
| def build_tokens_single(sample_dict, model, device): | |
| """Build token tensor for a single sample (no DataLoader).""" | |
| xyz = torch.as_tensor(sample_dict["xyz_norm"], dtype=torch.float32).unsqueeze(0).to(device) | |
| cid = torch.as_tensor(sample_dict["class_id"], dtype=torch.long).unsqueeze(0).to(device) | |
| src = torch.as_tensor(sample_dict["source"], dtype=torch.long).unsqueeze(0).to(device) | |
| masks = torch.as_tensor(sample_dict["mask"], dtype=torch.bool).unsqueeze(0).to(device) | |
| B, T, _ = xyz.shape | |
| tok = model.tokenizer | |
| fourier = tok.pos_enc(xyz.reshape(-1, 3)).reshape(B, T, -1) \ | |
| if tok.pos_enc is not None else xyz.new_zeros(B, T, 0) | |
| parts = [xyz, fourier, tok.label_emb(cid), tok.src_emb(src.clamp(0, 1))] | |
| if tok.behind_emb_dim > 0: | |
| if "behind" in sample_dict: | |
| beh = torch.as_tensor(sample_dict["behind"], dtype=torch.long).unsqueeze(0).to(device) | |
| else: | |
| beh = xyz.new_zeros(B, T, dtype=torch.long) | |
| parts.append(tok.behind_emb(beh)) | |
| if tok.use_vote_features: | |
| if "n_views_voted" in sample_dict and "vote_frac" in sample_dict: | |
| nv = ((torch.as_tensor(sample_dict["n_views_voted"], dtype=torch.float32).unsqueeze(0).to(device) - 2.7) / 1.0).unsqueeze(-1) | |
| vf = ((torch.as_tensor(sample_dict["vote_frac"], dtype=torch.float32).unsqueeze(0).to(device) - 0.5) / 0.25).unsqueeze(-1) | |
| parts.extend([nv, vf]) | |
| else: | |
| parts.extend([xyz.new_zeros(B, T, 1), xyz.new_zeros(B, T, 1)]) | |
| tokens = torch.cat(parts, dim=-1) | |
| return tokens, masks | |
| def predict_sample(sample_dict, model, device): | |
| """Run model inference + post-processing on a fused sample. | |
| Returns (vertices, edges) in world space. | |
| """ | |
| tokens, masks = build_tokens_single(sample_dict, model, device) | |
| scale = float(sample_dict["scale"]) | |
| center = sample_dict["center"] | |
| with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.float16, | |
| enabled=(device.type == 'cuda')): | |
| out = model.forward_tokens(tokens, masks) | |
| segs = out["segments"][0].float().cpu() | |
| conf = torch.sigmoid(out["conf"][0].float()).cpu().numpy() if "conf" in out else None | |
| # Confidence filter | |
| if conf is not None: | |
| keep = conf > CONF_THRESH | |
| segs = segs[keep] | |
| if len(segs) < 1: | |
| return empty_solution() | |
| # To world space | |
| segs_world = segs.numpy() * scale + center | |
| # Vertices + edges from segments | |
| pv, pe = segments_to_vertices_edges(torch.tensor(segs_world)) | |
| pv, pe = pv.numpy(), np.array(pe, dtype=np.int32) | |
| # Merge | |
| pv, pe = merge_vertices_iterative(pv, pe) | |
| # Snap to point cloud | |
| xyz_norm = sample_dict["xyz_norm"] | |
| mask = sample_dict["mask"] | |
| cid = sample_dict["class_id"] | |
| xyz_world = xyz_norm[mask] * scale + center | |
| cid_valid = cid[mask] | |
| pv = snap_to_point_cloud(pv, xyz_world, cid_valid, snap_radius=SNAP_RADIUS) | |
| # Horizontal snap | |
| pv = snap_horizontal(pv, pe) | |
| if len(pv) < 2 or len(pe) < 1: | |
| return empty_solution() | |
| edges = [(int(a), int(b)) for a, b in pe] | |
| return pv, edges | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| t_start = time.time() | |
| # Load params | |
| param_path = Path("params.json") | |
| with param_path.open() as f: | |
| params = json.load(f) | |
| print(f"Competition: {params.get('competition_id', '?')}") | |
| print(f"Dataset: {params.get('dataset', '?')}") | |
| # Load test data | |
| data_path = Path("/tmp/data") | |
| if not data_path.exists(): | |
| from huggingface_hub import snapshot_download | |
| snapshot_download( | |
| repo_id=params["dataset"], | |
| local_dir="/tmp/data", | |
| repo_type="dataset", | |
| ) | |
| from datasets import load_dataset | |
| data_files = { | |
| "validation": [str(p) for p in data_path.rglob("*public*/**/*.tar")], | |
| "test": [str(p) for p in data_path.rglob("*private*/**/*.tar")], | |
| } | |
| print(f"Data files: {data_files}") | |
| dataset = load_dataset( | |
| str(data_path / "hoho22k_2026_test_x_anon.py"), | |
| data_files=data_files, | |
| trust_remote_code=True, | |
| writer_batch_size=100, | |
| ) | |
| print(f"Loaded: {dataset}") | |
| # Load model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Device: {device}") | |
| checkpoint_path = SCRIPT_DIR / "checkpoint.pt" | |
| model = load_model(checkpoint_path, device) | |
| print(f"Model loaded: {sum(p.numel() for p in model.parameters()):,} params") | |
| # Point fusion config | |
| cfg = FuserConfig() | |
| rng = np.random.RandomState(2718) | |
| # Process all samples | |
| solution = [] | |
| total_samples = sum(len(dataset[s]) for s in dataset) | |
| processed = 0 | |
| for subset_name in dataset: | |
| print(f"\nProcessing {subset_name} ({len(dataset[subset_name])} samples)...") | |
| for sample in tqdm(dataset[subset_name], desc=subset_name): | |
| order_id = sample["order_id"] | |
| # Fuse + sample | |
| fused = fuse_and_sample(sample, cfg, rng) | |
| if fused is None: | |
| pred_v, pred_e = empty_solution() | |
| else: | |
| try: | |
| pred_v, pred_e = predict_sample(fused, model, device) | |
| except Exception as e: | |
| print(f" Predict failed for {order_id}: {e}") | |
| pred_v, pred_e = empty_solution() | |
| solution.append({ | |
| "order_id": order_id, | |
| "wf_vertices": pred_v.tolist() if isinstance(pred_v, np.ndarray) else pred_v, | |
| "wf_edges": [(int(a), int(b)) for a, b in pred_e], | |
| }) | |
| processed += 1 | |
| if processed % 50 == 0: | |
| elapsed = time.time() - t_start | |
| rate = elapsed / processed | |
| remaining = (total_samples - processed) * rate | |
| print(f" [{processed}/{total_samples}] " | |
| f"{elapsed:.0f}s elapsed, ~{remaining:.0f}s remaining") | |
| # Save | |
| with open("submission.json", "w") as f: | |
| json.dump(solution, f) | |
| elapsed = time.time() - t_start | |
| print(f"\nDone. {processed} samples in {elapsed:.0f}s ({elapsed/max(processed,1):.1f}s/sample)") | |
| print(f"Saved submission.json ({len(solution)} entries)") | |