gnn_wm / sampled_data /train_gps.py
EndeavourDD's picture
Add files using upload-large-folder tool
4ee0c8c verified
"""Train a GraphGPS (GPSConv + GINEConv + global attention) model on the
JSON constraint graphs produced by frame_to_graph.py.
Task (demo): graph-level regression of disassembly progress
target = frame_idx / max_frame_idx ∈ [0, 1]
The model reads a per-frame constraint graph (15 products + optional robot,
fully connected with constraint edge features) and predicts how far along
the disassembly is. This mirrors the GraphGPS tutorial (PyG docs): local
MPNN (GINEConv) + global attention, stacked for N layers.
Run:
python train_gps.py
"""
import json
import math
import random
from pathlib import Path
from typing import List
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Linear, ReLU, Sequential
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, GPSConv, global_add_pool
# ─────────────────────────────────────────────────────────────────────────────
# 1. JSON graph β†’ PyG Data
# ─────────────────────────────────────────────────────────────────────────────
TYPE_VOCAB = ["cpu_fan", "cpu_bracket", "cpu", "ram_clip", "ram",
"connector", "graphic_card", "motherboard", "robot"]
TYPE_TO_IDX = {t: i for i, t in enumerate(TYPE_VOCAB)}
NODE_FEAT_DIM = len(TYPE_VOCAB) + 3 + 2 # 9 type one-hot + 3D centroid + (mask_area, emb_norm) = 14
EDGE_FEAT_DIM = 2 # [has_constraint, is_locked]
def json_to_pyg(path: Path, target: float) -> Data:
"""Read a frame graph JSON and build a fully-connected PyG Data object."""
with open(path) as f:
gd = json.load(f)
nodes = gd["nodes"]
N = len(nodes)
id_to_idx = {n["id"]: i for i, n in enumerate(nodes)}
# Node features: [type one-hot (9), centroid_3d (3), mask_area_scaled, emb_norm] β†’ 14D
x = torch.zeros((N, NODE_FEAT_DIM), dtype=torch.float32)
for i, n in enumerate(nodes):
x[i, TYPE_TO_IDX[n["type"]]] = 1.0
x[i, 9:12] = torch.tensor(n["centroid_3d"], dtype=torch.float32)
x[i, 12] = n["mask_area"] / 1e5 # scale to ~O(1)
x[i, 13] = n["embedding_norm"] / 5.0
# Sparse constraint lookup
constraint = {} # frozenset({a, b}) -> is_locked
for e in gd["edges"]:
if e["src"] in id_to_idx and e["dst"] in id_to_idx:
constraint[frozenset([e["src"], e["dst"]])] = bool(e["is_locked"])
# Fully connected edges with 2D features [has_constraint, is_locked]
src_idx, dst_idx, edge_attr = [], [], []
for i in range(N):
for j in range(N):
if i == j:
continue
src_idx.append(i)
dst_idx.append(j)
key = frozenset([nodes[i]["id"], nodes[j]["id"]])
if key in constraint:
edge_attr.append([1.0, 1.0 if constraint[key] else 0.0])
else:
edge_attr.append([0.0, 0.0])
return Data(
x=x,
edge_index=torch.tensor([src_idx, dst_idx], dtype=torch.long),
edge_attr=torch.tensor(edge_attr, dtype=torch.float32),
y=torch.tensor([target], dtype=torch.float32),
num_nodes=N,
)
def build_dataset(json_dir: Path) -> List[Data]:
paths = sorted(json_dir.glob("frame_*_graph.json"))
frame_ids = [int(p.stem.split("_")[1]) for p in paths]
max_f = max(frame_ids)
dataset = []
for p, fid in zip(paths, frame_ids):
target = fid / max_f
dataset.append(json_to_pyg(p, target))
return dataset
# ─────────────────────────────────────────────────────────────────────────────
# 2. GraphGPS model β€” mirrors the PyG tutorial but adapted for continuous
# node / edge features (Linear instead of Embedding)
# ─────────────────────────────────────────────────────────────────────────────
class GPS(nn.Module):
def __init__(self, channels: int = 64, num_layers: int = 4, heads: int = 4):
super().__init__()
self.node_lin = Linear(NODE_FEAT_DIM, channels)
self.edge_lin = Linear(EDGE_FEAT_DIM, channels)
self.convs = nn.ModuleList()
for _ in range(num_layers):
mlp = Sequential(
Linear(channels, channels),
ReLU(),
Linear(channels, channels),
)
self.convs.append(
GPSConv(channels, GINEConv(mlp), heads=heads,
attn_type="multihead", attn_kwargs={"dropout": 0.1})
)
self.head = Sequential(
Linear(channels, channels // 2), ReLU(),
Linear(channels // 2, 1),
)
def forward(self, x, edge_index, edge_attr, batch):
x = self.node_lin(x)
e = self.edge_lin(edge_attr)
for conv in self.convs:
x = conv(x, edge_index, batch, edge_attr=e)
g = global_add_pool(x, batch)
return self.head(g).squeeze(-1)
# ─────────────────────────────────────────────────────────────────────────────
# 3. Train / eval loop
# ─────────────────────────────────────────────────────────────────────────────
def main():
torch.manual_seed(0)
random.seed(0)
json_dir = Path("graph_jsons")
dataset = build_dataset(json_dir)
print(f"Loaded {len(dataset)} frame graphs")
print(f"Example: {dataset[0]} target={dataset[0].y.item():.3f}")
# Shuffle and split 80/10/10
random.shuffle(dataset)
n = len(dataset)
n_train = int(0.8 * n)
n_val = int(0.1 * n)
train_ds = dataset[:n_train]
val_ds = dataset[n_train:n_train + n_val]
test_ds = dataset[n_train + n_val:]
print(f"Splits: train={len(train_ds)}, val={len(val_ds)}, test={len(test_ds)}")
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32)
test_loader = DataLoader(test_ds, batch_size=32)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPS(channels=64, num_layers=4, heads=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
def step(loader, train: bool):
model.train(train)
total = 0.0
count = 0
for data in loader:
data = data.to(device)
if train:
optimizer.zero_grad()
with torch.set_grad_enabled(train):
pred = model(data.x, data.edge_index, data.edge_attr, data.batch)
loss = F.l1_loss(pred, data.y)
if train:
loss.backward()
optimizer.step()
total += loss.item() * data.num_graphs
count += data.num_graphs
return total / count
for epoch in range(1, 51):
tr = step(train_loader, train=True)
vl = step(val_loader, train=False)
te = step(test_loader, train=False)
print(f"Epoch {epoch:02d} | train MAE {tr:.4f} | val MAE {vl:.4f} | test MAE {te:.4f}")
if __name__ == "__main__":
main()