sample / 03_infer_halfedge.py
Silly98's picture
Upload 2 files
dc71d7e verified
# file: 03_infer_halfedge.py
# -*- coding: utf-8 -*-
import argparse
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, SAGEConv, GlobalAttention, JumpingKnowledge, BatchNorm
from torch_geometric.data import HeteroData
from brep_extractor_utils import load_coedge_arrays, make_heterodata
class HalfEdgeGNN(nn.Module):
def __init__(
self,
coedge_in: int,
face_in: int,
edge_in: int,
global_in: int,
hidden=256,
layers=6,
dropout=0.2,
num_classes=3,
jk_mode="cat",
gating_dim=None,
):
super().__init__()
self.convs = nn.ModuleList(); self.bns = nn.ModuleList()
self.encoders = nn.ModuleDict({
"coedge": nn.Sequential(nn.Linear(coedge_in, hidden), nn.ReLU(), nn.Dropout(dropout)),
"face": nn.Sequential(nn.Linear(face_in, hidden), nn.ReLU(), nn.Dropout(dropout)),
"edge": nn.Sequential(nn.Linear(edge_in, hidden), nn.ReLU(), nn.Dropout(dropout)),
})
for _ in range(layers):
conv = HeteroConv({
('coedge','next','coedge'): SAGEConv((hidden,hidden), hidden),
('coedge','prev','coedge'): SAGEConv((hidden,hidden), hidden),
('coedge','mate','coedge'): SAGEConv((hidden,hidden), hidden),
('coedge','to_face','face'): SAGEConv((hidden, hidden), hidden),
('face','to_coedge','coedge'): SAGEConv((hidden, hidden), hidden),
('coedge','to_edge','edge'): SAGEConv((hidden, hidden), hidden),
('edge','to_coedge','coedge'): SAGEConv((hidden, hidden), hidden),
('face','to_edge','edge'): SAGEConv((hidden, hidden), hidden),
('edge','to_face','face'): SAGEConv((hidden, hidden), hidden),
}, aggr='sum')
self.convs.append(conv)
self.bns.append(nn.ModuleDict({
"coedge": BatchNorm(hidden),
"face": BatchNorm(hidden),
"edge": BatchNorm(hidden),
}))
self.jk = JumpingKnowledge(mode=jk_mode)
self.jk_out = hidden * layers if jk_mode == "cat" else hidden
if gating_dim is None:
gating_dim = hidden
self.gating_dim = gating_dim
self.gate = nn.Sequential(
nn.Linear(self.jk_out, self.jk_out//2),
nn.ReLU(),
nn.Linear(self.jk_out//2, 1),
)
self.pool = GlobalAttention(self.gate)
self.proj = nn.Identity() if self.jk_out == gating_dim else nn.Linear(self.jk_out, gating_dim)
self.global_mlp = nn.Sequential(
nn.Linear(global_in, gating_dim),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(gating_dim, 2 * gating_dim),
)
self.head = nn.Sequential(
nn.Linear(gating_dim, hidden),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden, num_classes),
)
def forward(self, data: HeteroData):
x = {
"coedge": self.encoders["coedge"](data["coedge"].x),
"face": self.encoders["face"](data["face"].x),
"edge": self.encoders["edge"](data["edge"].x),
}
outs = []
for conv, bn in zip(self.convs, self.bns):
x_new = conv(x, data.edge_index_dict)
x = {k: F.relu(bn[k](x_new[k]) + x[k]) for k in x}
outs.append(x["coedge"])
xj = self.jk(outs)
g = self.pool(xj, data['coedge'].batch)
g0 = self.proj(g)
global_x = data["global"].x
if global_x.dim() == 1:
global_x = global_x.view(1, -1)
if global_x.size(0) != g0.size(0):
raise RuntimeError(
f"Global feature batch mismatch: {global_x.size(0)} vs {g0.size(0)}"
)
gb = self.global_mlp(global_x)
gamma, beta = gb.chunk(2, dim=-1)
gamma = torch.sigmoid(gamma)
g_mod = g0 * gamma + beta
return self.head(g_mod)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--model", required=True)
ap.add_argument("--npz", required=True, help="Path to a processed BRep extractor npz file")
ap.add_argument("--tau", type=float, default=0.0, help="Reject threshold; below this outputs random")
ap.add_argument("--min_conf", type=float, default=0.85, help="Hard minimum confidence for known classes")
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = ap.parse_args()
try:
ckpt = torch.load(args.model, map_location="cpu", weights_only=False)
except TypeError:
ckpt = torch.load(args.model, map_location="cpu")
if "global_in" not in ckpt or "gating_dim" not in ckpt:
raise RuntimeError(
"Checkpoint missing gating metadata. Please retrain with global gating enabled."
)
labels = ckpt["labels"]; inv_labels = {v:k for k,v in labels.items()}
random_id = labels.get("random")
if (args.tau > 0 or args.min_conf > 0) and random_id is None:
raise RuntimeError("Model labels do not include 'random'; retrain a 4-class model.")
stats = ckpt["stats"]
if not all(k in stats for k in ("coedge", "face", "edge")):
raise RuntimeError("Checkpoint missing heterograph stats; retrain required.")
coedge_in = ckpt.get("coedge_in", ckpt.get("node_in"))
face_in = ckpt.get("face_in")
edge_in = ckpt.get("edge_in")
if coedge_in is None or face_in is None or edge_in is None:
raise RuntimeError("Checkpoint missing heterograph input dims; retrain required.")
graph_data = load_coedge_arrays(Path(args.npz))
if int(graph_data["coedge_x"].shape[1]) != int(coedge_in):
raise RuntimeError(
f"Coedge feature dim mismatch: npz={int(graph_data['coedge_x'].shape[1])} "
f"ckpt={int(coedge_in)}"
)
if int(graph_data["face_x"].shape[1]) != int(face_in):
raise RuntimeError(
f"Face feature dim mismatch: npz={int(graph_data['face_x'].shape[1])} "
f"ckpt={int(face_in)}"
)
if int(graph_data["edge_x"].shape[1]) != int(edge_in):
raise RuntimeError(
f"Edge feature dim mismatch: npz={int(graph_data['edge_x'].shape[1])} "
f"ckpt={int(edge_in)}"
)
if int(graph_data["global_x"].shape[0]) != int(ckpt["global_in"]):
raise RuntimeError(
f"Global feature dim mismatch: npz={int(graph_data['global_x'].shape[0])} "
f"ckpt={int(ckpt['global_in'])}"
)
data = make_heterodata(
graph_data["coedge_x"],
graph_data["face_x"],
graph_data["edge_x"],
graph_data["next"],
graph_data["mate"],
graph_data["coedge_face"],
graph_data["coedge_edge"],
graph_data["global_x"],
label=None,
norm_stats=stats,
)
data['coedge'].batch = torch.zeros(data['coedge'].x.size(0), dtype=torch.long)
data["global"].batch = torch.zeros(1, dtype=torch.long)
data["face"].batch = torch.zeros(data["face"].x.size(0), dtype=torch.long)
data["edge"].batch = torch.zeros(data["edge"].x.size(0), dtype=torch.long)
global_in = ckpt["global_in"]
gating_dim = ckpt["gating_dim"]
model = HalfEdgeGNN(coedge_in=coedge_in, face_in=face_in, edge_in=edge_in, global_in=global_in,
hidden=ckpt["hp"]["hidden"],
layers=ckpt["hp"]["layers"], dropout=ckpt["hp"]["dropout"],
num_classes=len(labels), gating_dim=gating_dim).to(args.device)
model.load_state_dict(ckpt["state_dict"]); model.eval()
with torch.no_grad():
logits = model(data.to(args.device))
probs = F.softmax(logits, dim=-1).cpu().numpy()[0]
pred = int(probs.argmax())
conf = float(probs[pred])
arg_label = inv_labels[pred]
effective_tau = max(args.tau, args.min_conf)
if conf < effective_tau and random_id is not None:
final_label = "random"
else:
final_label = arg_label
print(f"Argmax: {arg_label} (conf={conf:.4f})")
print(f"Predicted: {final_label} (tau={effective_tau:.2f})")
for i, p in enumerate(probs):
print(f"{inv_labels[i]:>6s}: {p:.4f}")
if __name__ == "__main__":
main()