| |
| |
| import argparse |
| from pathlib import Path |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch_geometric.nn import HeteroConv, SAGEConv, GlobalAttention, JumpingKnowledge, BatchNorm |
| from torch_geometric.data import HeteroData |
|
|
| from brep_extractor_utils import load_coedge_arrays, make_heterodata |
|
|
| class HalfEdgeGNN(nn.Module): |
| def __init__( |
| self, |
| coedge_in: int, |
| face_in: int, |
| edge_in: int, |
| global_in: int, |
| hidden=256, |
| layers=6, |
| dropout=0.2, |
| num_classes=3, |
| jk_mode="cat", |
| gating_dim=None, |
| ): |
| super().__init__() |
| self.convs = nn.ModuleList(); self.bns = nn.ModuleList() |
| self.encoders = nn.ModuleDict({ |
| "coedge": nn.Sequential(nn.Linear(coedge_in, hidden), nn.ReLU(), nn.Dropout(dropout)), |
| "face": nn.Sequential(nn.Linear(face_in, hidden), nn.ReLU(), nn.Dropout(dropout)), |
| "edge": nn.Sequential(nn.Linear(edge_in, hidden), nn.ReLU(), nn.Dropout(dropout)), |
| }) |
| for _ in range(layers): |
| conv = HeteroConv({ |
| ('coedge','next','coedge'): SAGEConv((hidden,hidden), hidden), |
| ('coedge','prev','coedge'): SAGEConv((hidden,hidden), hidden), |
| ('coedge','mate','coedge'): SAGEConv((hidden,hidden), hidden), |
| ('coedge','to_face','face'): SAGEConv((hidden, hidden), hidden), |
| ('face','to_coedge','coedge'): SAGEConv((hidden, hidden), hidden), |
| ('coedge','to_edge','edge'): SAGEConv((hidden, hidden), hidden), |
| ('edge','to_coedge','coedge'): SAGEConv((hidden, hidden), hidden), |
| ('face','to_edge','edge'): SAGEConv((hidden, hidden), hidden), |
| ('edge','to_face','face'): SAGEConv((hidden, hidden), hidden), |
| }, aggr='sum') |
| self.convs.append(conv) |
| self.bns.append(nn.ModuleDict({ |
| "coedge": BatchNorm(hidden), |
| "face": BatchNorm(hidden), |
| "edge": BatchNorm(hidden), |
| })) |
| self.jk = JumpingKnowledge(mode=jk_mode) |
| self.jk_out = hidden * layers if jk_mode == "cat" else hidden |
| if gating_dim is None: |
| gating_dim = hidden |
| self.gating_dim = gating_dim |
| self.gate = nn.Sequential( |
| nn.Linear(self.jk_out, self.jk_out//2), |
| nn.ReLU(), |
| nn.Linear(self.jk_out//2, 1), |
| ) |
| self.pool = GlobalAttention(self.gate) |
| self.proj = nn.Identity() if self.jk_out == gating_dim else nn.Linear(self.jk_out, gating_dim) |
| self.global_mlp = nn.Sequential( |
| nn.Linear(global_in, gating_dim), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(gating_dim, 2 * gating_dim), |
| ) |
| self.head = nn.Sequential( |
| nn.Linear(gating_dim, hidden), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden, num_classes), |
| ) |
|
|
| def forward(self, data: HeteroData): |
| x = { |
| "coedge": self.encoders["coedge"](data["coedge"].x), |
| "face": self.encoders["face"](data["face"].x), |
| "edge": self.encoders["edge"](data["edge"].x), |
| } |
| outs = [] |
| for conv, bn in zip(self.convs, self.bns): |
| x_new = conv(x, data.edge_index_dict) |
| x = {k: F.relu(bn[k](x_new[k]) + x[k]) for k in x} |
| outs.append(x["coedge"]) |
| xj = self.jk(outs) |
| g = self.pool(xj, data['coedge'].batch) |
| g0 = self.proj(g) |
| global_x = data["global"].x |
| if global_x.dim() == 1: |
| global_x = global_x.view(1, -1) |
| if global_x.size(0) != g0.size(0): |
| raise RuntimeError( |
| f"Global feature batch mismatch: {global_x.size(0)} vs {g0.size(0)}" |
| ) |
| gb = self.global_mlp(global_x) |
| gamma, beta = gb.chunk(2, dim=-1) |
| gamma = torch.sigmoid(gamma) |
| g_mod = g0 * gamma + beta |
| return self.head(g_mod) |
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--model", required=True) |
| ap.add_argument("--npz", required=True, help="Path to a processed BRep extractor npz file") |
| ap.add_argument("--tau", type=float, default=0.0, help="Reject threshold; below this outputs random") |
| ap.add_argument("--min_conf", type=float, default=0.85, help="Hard minimum confidence for known classes") |
| ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") |
| args = ap.parse_args() |
|
|
| try: |
| ckpt = torch.load(args.model, map_location="cpu", weights_only=False) |
| except TypeError: |
| ckpt = torch.load(args.model, map_location="cpu") |
| if "global_in" not in ckpt or "gating_dim" not in ckpt: |
| raise RuntimeError( |
| "Checkpoint missing gating metadata. Please retrain with global gating enabled." |
| ) |
| labels = ckpt["labels"]; inv_labels = {v:k for k,v in labels.items()} |
| random_id = labels.get("random") |
| if (args.tau > 0 or args.min_conf > 0) and random_id is None: |
| raise RuntimeError("Model labels do not include 'random'; retrain a 4-class model.") |
| stats = ckpt["stats"] |
| if not all(k in stats for k in ("coedge", "face", "edge")): |
| raise RuntimeError("Checkpoint missing heterograph stats; retrain required.") |
|
|
| coedge_in = ckpt.get("coedge_in", ckpt.get("node_in")) |
| face_in = ckpt.get("face_in") |
| edge_in = ckpt.get("edge_in") |
| if coedge_in is None or face_in is None or edge_in is None: |
| raise RuntimeError("Checkpoint missing heterograph input dims; retrain required.") |
|
|
| graph_data = load_coedge_arrays(Path(args.npz)) |
| if int(graph_data["coedge_x"].shape[1]) != int(coedge_in): |
| raise RuntimeError( |
| f"Coedge feature dim mismatch: npz={int(graph_data['coedge_x'].shape[1])} " |
| f"ckpt={int(coedge_in)}" |
| ) |
| if int(graph_data["face_x"].shape[1]) != int(face_in): |
| raise RuntimeError( |
| f"Face feature dim mismatch: npz={int(graph_data['face_x'].shape[1])} " |
| f"ckpt={int(face_in)}" |
| ) |
| if int(graph_data["edge_x"].shape[1]) != int(edge_in): |
| raise RuntimeError( |
| f"Edge feature dim mismatch: npz={int(graph_data['edge_x'].shape[1])} " |
| f"ckpt={int(edge_in)}" |
| ) |
| if int(graph_data["global_x"].shape[0]) != int(ckpt["global_in"]): |
| raise RuntimeError( |
| f"Global feature dim mismatch: npz={int(graph_data['global_x'].shape[0])} " |
| f"ckpt={int(ckpt['global_in'])}" |
| ) |
| data = make_heterodata( |
| graph_data["coedge_x"], |
| graph_data["face_x"], |
| graph_data["edge_x"], |
| graph_data["next"], |
| graph_data["mate"], |
| graph_data["coedge_face"], |
| graph_data["coedge_edge"], |
| graph_data["global_x"], |
| label=None, |
| norm_stats=stats, |
| ) |
| data['coedge'].batch = torch.zeros(data['coedge'].x.size(0), dtype=torch.long) |
| data["global"].batch = torch.zeros(1, dtype=torch.long) |
| data["face"].batch = torch.zeros(data["face"].x.size(0), dtype=torch.long) |
| data["edge"].batch = torch.zeros(data["edge"].x.size(0), dtype=torch.long) |
|
|
| global_in = ckpt["global_in"] |
| gating_dim = ckpt["gating_dim"] |
| model = HalfEdgeGNN(coedge_in=coedge_in, face_in=face_in, edge_in=edge_in, global_in=global_in, |
| hidden=ckpt["hp"]["hidden"], |
| layers=ckpt["hp"]["layers"], dropout=ckpt["hp"]["dropout"], |
| num_classes=len(labels), gating_dim=gating_dim).to(args.device) |
| model.load_state_dict(ckpt["state_dict"]); model.eval() |
|
|
| with torch.no_grad(): |
| logits = model(data.to(args.device)) |
| probs = F.softmax(logits, dim=-1).cpu().numpy()[0] |
| pred = int(probs.argmax()) |
| conf = float(probs[pred]) |
| arg_label = inv_labels[pred] |
| effective_tau = max(args.tau, args.min_conf) |
| if conf < effective_tau and random_id is not None: |
| final_label = "random" |
| else: |
| final_label = arg_label |
| print(f"Argmax: {arg_label} (conf={conf:.4f})") |
| print(f"Predicted: {final_label} (tau={effective_tau:.2f})") |
| for i, p in enumerate(probs): |
| print(f"{inv_labels[i]:>6s}: {p:.4f}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|