"""Train a GraphGPS (GPSConv + GINEConv + global attention) model on the JSON constraint graphs produced by frame_to_graph.py. Task (demo): graph-level regression of disassembly progress target = frame_idx / max_frame_idx ∈ [0, 1] The model reads a per-frame constraint graph (15 products + optional robot, fully connected with constraint edge features) and predicts how far along the disassembly is. This mirrors the GraphGPS tutorial (PyG docs): local MPNN (GINEConv) + global attention, stacked for N layers. Run: python train_gps.py """ import json import math import random from pathlib import Path from typing import List import torch import torch.nn.functional as F from torch import nn from torch.nn import Linear, ReLU, Sequential from torch_geometric.data import Data from torch_geometric.loader import DataLoader from torch_geometric.nn import GINEConv, GPSConv, global_add_pool # ───────────────────────────────────────────────────────────────────────────── # 1. JSON graph → PyG Data # ───────────────────────────────────────────────────────────────────────────── TYPE_VOCAB = ["cpu_fan", "cpu_bracket", "cpu", "ram_clip", "ram", "connector", "graphic_card", "motherboard", "robot"] TYPE_TO_IDX = {t: i for i, t in enumerate(TYPE_VOCAB)} NODE_FEAT_DIM = len(TYPE_VOCAB) + 3 + 2 # 9 type one-hot + 3D centroid + (mask_area, emb_norm) = 14 EDGE_FEAT_DIM = 2 # [has_constraint, is_locked] def json_to_pyg(path: Path, target: float) -> Data: """Read a frame graph JSON and build a fully-connected PyG Data object.""" with open(path) as f: gd = json.load(f) nodes = gd["nodes"] N = len(nodes) id_to_idx = {n["id"]: i for i, n in enumerate(nodes)} # Node features: [type one-hot (9), centroid_3d (3), mask_area_scaled, emb_norm] → 14D x = torch.zeros((N, NODE_FEAT_DIM), dtype=torch.float32) for i, n in enumerate(nodes): x[i, TYPE_TO_IDX[n["type"]]] = 1.0 x[i, 9:12] = torch.tensor(n["centroid_3d"], dtype=torch.float32) x[i, 12] = n["mask_area"] / 1e5 # scale to ~O(1) x[i, 13] = n["embedding_norm"] / 5.0 # Sparse constraint lookup constraint = {} # frozenset({a, b}) -> is_locked for e in gd["edges"]: if e["src"] in id_to_idx and e["dst"] in id_to_idx: constraint[frozenset([e["src"], e["dst"]])] = bool(e["is_locked"]) # Fully connected edges with 2D features [has_constraint, is_locked] src_idx, dst_idx, edge_attr = [], [], [] for i in range(N): for j in range(N): if i == j: continue src_idx.append(i) dst_idx.append(j) key = frozenset([nodes[i]["id"], nodes[j]["id"]]) if key in constraint: edge_attr.append([1.0, 1.0 if constraint[key] else 0.0]) else: edge_attr.append([0.0, 0.0]) return Data( x=x, edge_index=torch.tensor([src_idx, dst_idx], dtype=torch.long), edge_attr=torch.tensor(edge_attr, dtype=torch.float32), y=torch.tensor([target], dtype=torch.float32), num_nodes=N, ) def build_dataset(json_dir: Path) -> List[Data]: paths = sorted(json_dir.glob("frame_*_graph.json")) frame_ids = [int(p.stem.split("_")[1]) for p in paths] max_f = max(frame_ids) dataset = [] for p, fid in zip(paths, frame_ids): target = fid / max_f dataset.append(json_to_pyg(p, target)) return dataset # ───────────────────────────────────────────────────────────────────────────── # 2. GraphGPS model — mirrors the PyG tutorial but adapted for continuous # node / edge features (Linear instead of Embedding) # ───────────────────────────────────────────────────────────────────────────── class GPS(nn.Module): def __init__(self, channels: int = 64, num_layers: int = 4, heads: int = 4): super().__init__() self.node_lin = Linear(NODE_FEAT_DIM, channels) self.edge_lin = Linear(EDGE_FEAT_DIM, channels) self.convs = nn.ModuleList() for _ in range(num_layers): mlp = Sequential( Linear(channels, channels), ReLU(), Linear(channels, channels), ) self.convs.append( GPSConv(channels, GINEConv(mlp), heads=heads, attn_type="multihead", attn_kwargs={"dropout": 0.1}) ) self.head = Sequential( Linear(channels, channels // 2), ReLU(), Linear(channels // 2, 1), ) def forward(self, x, edge_index, edge_attr, batch): x = self.node_lin(x) e = self.edge_lin(edge_attr) for conv in self.convs: x = conv(x, edge_index, batch, edge_attr=e) g = global_add_pool(x, batch) return self.head(g).squeeze(-1) # ───────────────────────────────────────────────────────────────────────────── # 3. Train / eval loop # ───────────────────────────────────────────────────────────────────────────── def main(): torch.manual_seed(0) random.seed(0) json_dir = Path("graph_jsons") dataset = build_dataset(json_dir) print(f"Loaded {len(dataset)} frame graphs") print(f"Example: {dataset[0]} target={dataset[0].y.item():.3f}") # Shuffle and split 80/10/10 random.shuffle(dataset) n = len(dataset) n_train = int(0.8 * n) n_val = int(0.1 * n) train_ds = dataset[:n_train] val_ds = dataset[n_train:n_train + n_val] test_ds = dataset[n_train + n_val:] print(f"Splits: train={len(train_ds)}, val={len(val_ds)}, test={len(test_ds)}") train_loader = DataLoader(train_ds, batch_size=16, shuffle=True) val_loader = DataLoader(val_ds, batch_size=32) test_loader = DataLoader(test_ds, batch_size=32) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = GPS(channels=64, num_layers=4, heads=4).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5) def step(loader, train: bool): model.train(train) total = 0.0 count = 0 for data in loader: data = data.to(device) if train: optimizer.zero_grad() with torch.set_grad_enabled(train): pred = model(data.x, data.edge_index, data.edge_attr, data.batch) loss = F.l1_loss(pred, data.y) if train: loss.backward() optimizer.step() total += loss.item() * data.num_graphs count += data.num_graphs return total / count for epoch in range(1, 51): tr = step(train_loader, train=True) vl = step(val_loader, train=False) te = step(test_loader, train=False) print(f"Epoch {epoch:02d} | train MAE {tr:.4f} | val MAE {vl:.4f} | test MAE {te:.4f}") if __name__ == "__main__": main()