# file: 03_infer_halfedge.py # -*- coding: utf-8 -*- import argparse from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import HeteroConv, SAGEConv, GlobalAttention, JumpingKnowledge, BatchNorm from torch_geometric.data import HeteroData from brep_extractor_utils import load_coedge_arrays, make_heterodata class HalfEdgeGNN(nn.Module): def __init__( self, coedge_in: int, face_in: int, edge_in: int, global_in: int, hidden=256, layers=6, dropout=0.2, num_classes=3, jk_mode="cat", gating_dim=None, ): super().__init__() self.convs = nn.ModuleList(); self.bns = nn.ModuleList() self.encoders = nn.ModuleDict({ "coedge": nn.Sequential(nn.Linear(coedge_in, hidden), nn.ReLU(), nn.Dropout(dropout)), "face": nn.Sequential(nn.Linear(face_in, hidden), nn.ReLU(), nn.Dropout(dropout)), "edge": nn.Sequential(nn.Linear(edge_in, hidden), nn.ReLU(), nn.Dropout(dropout)), }) for _ in range(layers): conv = HeteroConv({ ('coedge','next','coedge'): SAGEConv((hidden,hidden), hidden), ('coedge','prev','coedge'): SAGEConv((hidden,hidden), hidden), ('coedge','mate','coedge'): SAGEConv((hidden,hidden), hidden), ('coedge','to_face','face'): SAGEConv((hidden, hidden), hidden), ('face','to_coedge','coedge'): SAGEConv((hidden, hidden), hidden), ('coedge','to_edge','edge'): SAGEConv((hidden, hidden), hidden), ('edge','to_coedge','coedge'): SAGEConv((hidden, hidden), hidden), ('face','to_edge','edge'): SAGEConv((hidden, hidden), hidden), ('edge','to_face','face'): SAGEConv((hidden, hidden), hidden), }, aggr='sum') self.convs.append(conv) self.bns.append(nn.ModuleDict({ "coedge": BatchNorm(hidden), "face": BatchNorm(hidden), "edge": BatchNorm(hidden), })) self.jk = JumpingKnowledge(mode=jk_mode) self.jk_out = hidden * layers if jk_mode == "cat" else hidden if gating_dim is None: gating_dim = hidden self.gating_dim = gating_dim self.gate = nn.Sequential( nn.Linear(self.jk_out, self.jk_out//2), nn.ReLU(), nn.Linear(self.jk_out//2, 1), ) self.pool = GlobalAttention(self.gate) self.proj = nn.Identity() if self.jk_out == gating_dim else nn.Linear(self.jk_out, gating_dim) self.global_mlp = nn.Sequential( nn.Linear(global_in, gating_dim), nn.ReLU(), nn.Dropout(0.3), nn.Linear(gating_dim, 2 * gating_dim), ) self.head = nn.Sequential( nn.Linear(gating_dim, hidden), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, num_classes), ) def forward(self, data: HeteroData): x = { "coedge": self.encoders["coedge"](data["coedge"].x), "face": self.encoders["face"](data["face"].x), "edge": self.encoders["edge"](data["edge"].x), } outs = [] for conv, bn in zip(self.convs, self.bns): x_new = conv(x, data.edge_index_dict) x = {k: F.relu(bn[k](x_new[k]) + x[k]) for k in x} outs.append(x["coedge"]) xj = self.jk(outs) g = self.pool(xj, data['coedge'].batch) g0 = self.proj(g) global_x = data["global"].x if global_x.dim() == 1: global_x = global_x.view(1, -1) if global_x.size(0) != g0.size(0): raise RuntimeError( f"Global feature batch mismatch: {global_x.size(0)} vs {g0.size(0)}" ) gb = self.global_mlp(global_x) gamma, beta = gb.chunk(2, dim=-1) gamma = torch.sigmoid(gamma) g_mod = g0 * gamma + beta return self.head(g_mod) def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", required=True) ap.add_argument("--npz", required=True, help="Path to a processed BRep extractor npz file") ap.add_argument("--tau", type=float, default=0.0, help="Reject threshold; below this outputs random") ap.add_argument("--min_conf", type=float, default=0.85, help="Hard minimum confidence for known classes") ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu") args = ap.parse_args() try: ckpt = torch.load(args.model, map_location="cpu", weights_only=False) except TypeError: ckpt = torch.load(args.model, map_location="cpu") if "global_in" not in ckpt or "gating_dim" not in ckpt: raise RuntimeError( "Checkpoint missing gating metadata. Please retrain with global gating enabled." ) labels = ckpt["labels"]; inv_labels = {v:k for k,v in labels.items()} random_id = labels.get("random") if (args.tau > 0 or args.min_conf > 0) and random_id is None: raise RuntimeError("Model labels do not include 'random'; retrain a 4-class model.") stats = ckpt["stats"] if not all(k in stats for k in ("coedge", "face", "edge")): raise RuntimeError("Checkpoint missing heterograph stats; retrain required.") coedge_in = ckpt.get("coedge_in", ckpt.get("node_in")) face_in = ckpt.get("face_in") edge_in = ckpt.get("edge_in") if coedge_in is None or face_in is None or edge_in is None: raise RuntimeError("Checkpoint missing heterograph input dims; retrain required.") graph_data = load_coedge_arrays(Path(args.npz)) if int(graph_data["coedge_x"].shape[1]) != int(coedge_in): raise RuntimeError( f"Coedge feature dim mismatch: npz={int(graph_data['coedge_x'].shape[1])} " f"ckpt={int(coedge_in)}" ) if int(graph_data["face_x"].shape[1]) != int(face_in): raise RuntimeError( f"Face feature dim mismatch: npz={int(graph_data['face_x'].shape[1])} " f"ckpt={int(face_in)}" ) if int(graph_data["edge_x"].shape[1]) != int(edge_in): raise RuntimeError( f"Edge feature dim mismatch: npz={int(graph_data['edge_x'].shape[1])} " f"ckpt={int(edge_in)}" ) if int(graph_data["global_x"].shape[0]) != int(ckpt["global_in"]): raise RuntimeError( f"Global feature dim mismatch: npz={int(graph_data['global_x'].shape[0])} " f"ckpt={int(ckpt['global_in'])}" ) data = make_heterodata( graph_data["coedge_x"], graph_data["face_x"], graph_data["edge_x"], graph_data["next"], graph_data["mate"], graph_data["coedge_face"], graph_data["coedge_edge"], graph_data["global_x"], label=None, norm_stats=stats, ) data['coedge'].batch = torch.zeros(data['coedge'].x.size(0), dtype=torch.long) data["global"].batch = torch.zeros(1, dtype=torch.long) data["face"].batch = torch.zeros(data["face"].x.size(0), dtype=torch.long) data["edge"].batch = torch.zeros(data["edge"].x.size(0), dtype=torch.long) global_in = ckpt["global_in"] gating_dim = ckpt["gating_dim"] model = HalfEdgeGNN(coedge_in=coedge_in, face_in=face_in, edge_in=edge_in, global_in=global_in, hidden=ckpt["hp"]["hidden"], layers=ckpt["hp"]["layers"], dropout=ckpt["hp"]["dropout"], num_classes=len(labels), gating_dim=gating_dim).to(args.device) model.load_state_dict(ckpt["state_dict"]); model.eval() with torch.no_grad(): logits = model(data.to(args.device)) probs = F.softmax(logits, dim=-1).cpu().numpy()[0] pred = int(probs.argmax()) conf = float(probs[pred]) arg_label = inv_labels[pred] effective_tau = max(args.tau, args.min_conf) if conf < effective_tau and random_id is not None: final_label = "random" else: final_label = arg_label print(f"Argmax: {arg_label} (conf={conf:.4f})") print(f"Predicted: {final_label} (tau={effective_tau:.2f})") for i, p in enumerate(probs): print(f"{inv_labels[i]:>6s}: {p:.4f}") if __name__ == "__main__": main()