| """V4 edge classifier with DINOv2-pretrained patch features. |
| |
| DINOv2-small (22M params, frozen) encodes each gestalt view to a 16x16 grid |
| of 384-dim semantic features. For each predicted edge, we bilinearly sample |
| features at the projected edge midpoint in each view, then mean+max pool |
| across views to get a 768-dim feature per edge. Concatenated with v2's |
| 40-dim geometric+mask features = 808-dim input to MLP. |
| |
| DINOv2 patch size is 14, input 224x224 → 16x16 patches. We feed gestalt RGB |
| resized to 224x224. |
| |
| Falls back to no-op on any error. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from edge_classifier_v2 import ( |
| NUM_FEATURES as V2_NUM, extract_features_v2, |
| ) |
|
|
| DINO_FEAT_DIM = 384 |
| DINO_PATCH_SIZE = 14 |
| DINO_INPUT = 224 |
| DINO_GRID = DINO_INPUT // DINO_PATCH_SIZE |
| EDGE_FEAT_DIM = DINO_FEAT_DIM * 2 |
|
|
|
|
| def get_dino_model(device="cpu"): |
| """Load DINOv2-small (frozen). Cached after first call.""" |
| if not hasattr(get_dino_model, "_cache"): |
| get_dino_model._cache = {} |
| cache_key = str(device) |
| if cache_key not in get_dino_model._cache: |
| model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14', verbose=False) |
| model = model.to(device).eval() |
| for p in model.parameters(): |
| p.requires_grad = False |
| get_dino_model._cache[cache_key] = model |
| return get_dino_model._cache[cache_key] |
|
|
|
|
| class EdgeClassifierV4(nn.Module): |
| """Head: takes pre-pooled DINOv2 edge features + v2 geom features.""" |
| def __init__(self, geom_dim: int = V2_NUM, edge_feat_dim: int = EDGE_FEAT_DIM, |
| hidden: int = 128): |
| super().__init__() |
| in_dim = geom_dim + edge_feat_dim |
| self.net = nn.Sequential( |
| nn.Linear(in_dim, hidden), |
| nn.GELU(), |
| nn.Dropout(0.2), |
| nn.Linear(hidden, hidden), |
| nn.GELU(), |
| nn.Dropout(0.2), |
| nn.Linear(hidden, hidden // 2), |
| nn.GELU(), |
| nn.Linear(hidden // 2, 1), |
| ) |
|
|
| def forward(self, geom_feats, edge_feats): |
| x = torch.cat([geom_feats, edge_feats], dim=1) |
| return self.net(x).squeeze(-1) |
|
|
|
|
| @torch.no_grad() |
| def _encode_views_with_dino(good, views, dino, device): |
| """Encode each view's gestalt image with DINOv2. Returns dict img_id -> (16,16,384).""" |
| out = {} |
| imgs = [] |
| img_ids = [] |
| Hs, Ws = {}, {} |
| for gest_pil, depth_pil, img_id in zip( |
| good["gestalt"], good["depth"], good["image_ids"] |
| ): |
| if img_id not in views: |
| continue |
| depth_np = np.array(depth_pil) |
| H, W = depth_np.shape[:2] |
| |
| gest_resized = np.array(gest_pil.resize((DINO_INPUT, DINO_INPUT))).astype(np.float32) / 255.0 |
| if gest_resized.ndim == 2: |
| gest_resized = np.stack([gest_resized]*3, axis=-1) |
| else: |
| gest_resized = gest_resized[..., :3] |
| |
| gest_resized = (gest_resized - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225]) |
| imgs.append(gest_resized.transpose(2, 0, 1)) |
| img_ids.append(img_id) |
| Hs[img_id] = H |
| Ws[img_id] = W |
| |
| views[img_id]["H"] = H |
| views[img_id]["W"] = W |
|
|
| if not imgs: |
| return out, Hs, Ws |
|
|
| batch = torch.tensor(np.stack(imgs), dtype=torch.float32, device=device) |
| |
| feats = dino.forward_features(batch) |
| patches = feats["x_norm_patchtokens"] |
| patches = patches.reshape(len(imgs), DINO_GRID, DINO_GRID, DINO_FEAT_DIM).cpu().numpy() |
| for i, img_id in enumerate(img_ids): |
| out[img_id] = patches[i] |
| return out, Hs, Ws |
|
|
|
|
| def _bilinear_sample_grid(grid, u_norm, v_norm): |
| """Bilinearly sample (G, G, D) grid at (u, v) in [0, 1] coords. Returns (D,).""" |
| G = grid.shape[0] |
| u = u_norm * (G - 1) |
| v = v_norm * (G - 1) |
| x0 = int(np.floor(u)); x1 = min(x0 + 1, G - 1) |
| y0 = int(np.floor(v)); y1 = min(y0 + 1, G - 1) |
| x0 = max(0, x0); y0 = max(0, y0) |
| dx = u - x0; dy = v - y0 |
| f00 = grid[y0, x0] |
| f01 = grid[y0, x1] |
| f10 = grid[y1, x0] |
| f11 = grid[y1, x1] |
| f0 = f00 * (1 - dx) + f01 * dx |
| f1 = f10 * (1 - dx) + f11 * dx |
| return f0 * (1 - dy) + f1 * dy |
|
|
|
|
| def extract_features_v4(pv, pe, sample, dino, device="cpu"): |
| """Return (E, geom_dim), (E, edge_feat_dim) tensors.""" |
| pv_arr = np.asarray(pv) |
| E = len(pe) |
| geom = extract_features_v2(pv, pe, sample) |
| edge_feats = np.zeros((E, EDGE_FEAT_DIM), dtype=np.float32) |
| if E == 0: |
| return geom, edge_feats |
|
|
| try: |
| from hoho2025.example_solutions import convert_entry_to_human_readable |
| from mvs_utils import collect_views, project_world_to_image |
| good = convert_entry_to_human_readable(sample) |
| colmap_rec = good.get("colmap") or good.get("colmap_binary") |
| if colmap_rec is None: |
| return geom, edge_feats |
| views = collect_views(colmap_rec, good["image_ids"]) |
| if not views: |
| return geom, edge_feats |
| per_view, Hs, Ws = _encode_views_with_dino(good, views, dino, device) |
| if not per_view: |
| return geom, edge_feats |
|
|
| for i, (u, vv) in enumerate(pe): |
| u, vv = int(u), int(vv) |
| if u >= len(pv_arr) or vv >= len(pv_arr) or u == vv: |
| continue |
| endpoints = np.stack([pv_arr[u], pv_arr[vv]]) |
| per_view_feats = [] |
| for img_id, view in views.items(): |
| if img_id not in per_view: |
| continue |
| H, W = Hs[img_id], Ws[img_id] |
| uv, z = project_world_to_image(view["P"], endpoints) |
| if z[0] <= 0 or z[1] <= 0: |
| continue |
| |
| mx, my = 0.5 * (uv[0, 0] + uv[1, 0]), 0.5 * (uv[0, 1] + uv[1, 1]) |
| if not (0 <= mx < W and 0 <= my < H): |
| continue |
| |
| feat = _bilinear_sample_grid(per_view[img_id], mx / max(W - 1, 1), my / max(H - 1, 1)) |
| per_view_feats.append(feat) |
| if per_view_feats: |
| arr = np.asarray(per_view_feats) |
| mean_f = arr.mean(axis=0) |
| max_f = arr.max(axis=0) |
| edge_feats[i] = np.concatenate([mean_f, max_f]) |
| except Exception: |
| pass |
|
|
| return geom, edge_feats |
|
|
|
|
| def label_edges_vs_gt(pv, pe, gt_v, gt_e, match_radius: float = 0.4): |
| from edge_classifier_v2 import label_edges_vs_gt as _f |
| return _f(pv, pe, gt_v, gt_e, match_radius) |
|
|
|
|
| def classify_edges_v4(pv, pe, sample, classifier, dino, device="cpu", |
| threshold: float = 0.5, |
| feature_mean=None, feature_std=None, |
| edge_feat_mean=None, edge_feat_std=None, |
| min_keep_frac: float = 0.7): |
| try: |
| if len(pe) == 0: |
| return pv, pe |
| geom, edge_feats = extract_features_v4(pv, pe, sample, dino, device=device) |
| if feature_mean is not None and feature_std is not None: |
| geom = (geom - feature_mean) / (feature_std + 1e-6) |
| if edge_feat_mean is not None and edge_feat_std is not None: |
| edge_feats = (edge_feats - edge_feat_mean) / (edge_feat_std + 1e-6) |
| with torch.no_grad(): |
| g = torch.tensor(geom, dtype=torch.float32, device=device) |
| e = torch.tensor(edge_feats, dtype=torch.float32, device=device) |
| scores = torch.sigmoid(classifier(g, e)).cpu().numpy() |
|
|
| keep_mask = scores >= threshold |
| min_keep = max(1, int(np.ceil(min_keep_frac * len(pe)))) |
| if keep_mask.sum() < min_keep: |
| top_idx = np.argsort(-scores)[:min_keep] |
| keep_mask = np.zeros_like(keep_mask) |
| keep_mask[top_idx] = True |
|
|
| keep_edges = [pe[i] for i in range(len(pe)) if keep_mask[i]] |
| if len(keep_edges) < 1: |
| return pv, pe |
|
|
| from edge_2d_filter import drop_orphan_vertices |
| return drop_orphan_vertices(np.asarray(pv), keep_edges) |
| except Exception: |
| return pv, pe |
|
|
|
|
| def load_classifier_v4(path: str, device: str = "cpu"): |
| blob = torch.load(path, map_location=device, weights_only=False) |
| m = EdgeClassifierV4(geom_dim=V2_NUM, edge_feat_dim=EDGE_FEAT_DIM, |
| hidden=blob.get("hidden", 128)) |
| m.load_state_dict(blob["model"]) |
| m.to(device).eval() |
| return (m, blob.get("feature_mean"), blob.get("feature_std"), |
| blob.get("edge_feat_mean"), blob.get("edge_feat_std")) |
|
|