Instructions to use EndeavourDD/gnn_wm with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use EndeavourDD/gnn_wm with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("EndeavourDD/gnn_wm", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
File size: 7,864 Bytes
4ee0c8c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | """Train a GraphGPS (GPSConv + GINEConv + global attention) model on the
JSON constraint graphs produced by frame_to_graph.py.
Task (demo): graph-level regression of disassembly progress
target = frame_idx / max_frame_idx β [0, 1]
The model reads a per-frame constraint graph (15 products + optional robot,
fully connected with constraint edge features) and predicts how far along
the disassembly is. This mirrors the GraphGPS tutorial (PyG docs): local
MPNN (GINEConv) + global attention, stacked for N layers.
Run:
python train_gps.py
"""
import json
import math
import random
from pathlib import Path
from typing import List
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import Linear, ReLU, Sequential
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GINEConv, GPSConv, global_add_pool
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 1. JSON graph β PyG Data
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
TYPE_VOCAB = ["cpu_fan", "cpu_bracket", "cpu", "ram_clip", "ram",
"connector", "graphic_card", "motherboard", "robot"]
TYPE_TO_IDX = {t: i for i, t in enumerate(TYPE_VOCAB)}
NODE_FEAT_DIM = len(TYPE_VOCAB) + 3 + 2 # 9 type one-hot + 3D centroid + (mask_area, emb_norm) = 14
EDGE_FEAT_DIM = 2 # [has_constraint, is_locked]
def json_to_pyg(path: Path, target: float) -> Data:
"""Read a frame graph JSON and build a fully-connected PyG Data object."""
with open(path) as f:
gd = json.load(f)
nodes = gd["nodes"]
N = len(nodes)
id_to_idx = {n["id"]: i for i, n in enumerate(nodes)}
# Node features: [type one-hot (9), centroid_3d (3), mask_area_scaled, emb_norm] β 14D
x = torch.zeros((N, NODE_FEAT_DIM), dtype=torch.float32)
for i, n in enumerate(nodes):
x[i, TYPE_TO_IDX[n["type"]]] = 1.0
x[i, 9:12] = torch.tensor(n["centroid_3d"], dtype=torch.float32)
x[i, 12] = n["mask_area"] / 1e5 # scale to ~O(1)
x[i, 13] = n["embedding_norm"] / 5.0
# Sparse constraint lookup
constraint = {} # frozenset({a, b}) -> is_locked
for e in gd["edges"]:
if e["src"] in id_to_idx and e["dst"] in id_to_idx:
constraint[frozenset([e["src"], e["dst"]])] = bool(e["is_locked"])
# Fully connected edges with 2D features [has_constraint, is_locked]
src_idx, dst_idx, edge_attr = [], [], []
for i in range(N):
for j in range(N):
if i == j:
continue
src_idx.append(i)
dst_idx.append(j)
key = frozenset([nodes[i]["id"], nodes[j]["id"]])
if key in constraint:
edge_attr.append([1.0, 1.0 if constraint[key] else 0.0])
else:
edge_attr.append([0.0, 0.0])
return Data(
x=x,
edge_index=torch.tensor([src_idx, dst_idx], dtype=torch.long),
edge_attr=torch.tensor(edge_attr, dtype=torch.float32),
y=torch.tensor([target], dtype=torch.float32),
num_nodes=N,
)
def build_dataset(json_dir: Path) -> List[Data]:
paths = sorted(json_dir.glob("frame_*_graph.json"))
frame_ids = [int(p.stem.split("_")[1]) for p in paths]
max_f = max(frame_ids)
dataset = []
for p, fid in zip(paths, frame_ids):
target = fid / max_f
dataset.append(json_to_pyg(p, target))
return dataset
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2. GraphGPS model β mirrors the PyG tutorial but adapted for continuous
# node / edge features (Linear instead of Embedding)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
class GPS(nn.Module):
def __init__(self, channels: int = 64, num_layers: int = 4, heads: int = 4):
super().__init__()
self.node_lin = Linear(NODE_FEAT_DIM, channels)
self.edge_lin = Linear(EDGE_FEAT_DIM, channels)
self.convs = nn.ModuleList()
for _ in range(num_layers):
mlp = Sequential(
Linear(channels, channels),
ReLU(),
Linear(channels, channels),
)
self.convs.append(
GPSConv(channels, GINEConv(mlp), heads=heads,
attn_type="multihead", attn_kwargs={"dropout": 0.1})
)
self.head = Sequential(
Linear(channels, channels // 2), ReLU(),
Linear(channels // 2, 1),
)
def forward(self, x, edge_index, edge_attr, batch):
x = self.node_lin(x)
e = self.edge_lin(edge_attr)
for conv in self.convs:
x = conv(x, edge_index, batch, edge_attr=e)
g = global_add_pool(x, batch)
return self.head(g).squeeze(-1)
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 3. Train / eval loop
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def main():
torch.manual_seed(0)
random.seed(0)
json_dir = Path("graph_jsons")
dataset = build_dataset(json_dir)
print(f"Loaded {len(dataset)} frame graphs")
print(f"Example: {dataset[0]} target={dataset[0].y.item():.3f}")
# Shuffle and split 80/10/10
random.shuffle(dataset)
n = len(dataset)
n_train = int(0.8 * n)
n_val = int(0.1 * n)
train_ds = dataset[:n_train]
val_ds = dataset[n_train:n_train + n_val]
test_ds = dataset[n_train + n_val:]
print(f"Splits: train={len(train_ds)}, val={len(val_ds)}, test={len(test_ds)}")
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=32)
test_loader = DataLoader(test_ds, batch_size=32)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GPS(channels=64, num_layers=4, heads=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
def step(loader, train: bool):
model.train(train)
total = 0.0
count = 0
for data in loader:
data = data.to(device)
if train:
optimizer.zero_grad()
with torch.set_grad_enabled(train):
pred = model(data.x, data.edge_index, data.edge_attr, data.batch)
loss = F.l1_loss(pred, data.y)
if train:
loss.backward()
optimizer.step()
total += loss.item() * data.num_graphs
count += data.num_graphs
return total / count
for epoch in range(1, 51):
tr = step(train_loader, train=True)
vl = step(val_loader, train=False)
te = step(test_loader, train=False)
print(f"Epoch {epoch:02d} | train MAE {tr:.4f} | val MAE {vl:.4f} | test MAE {te:.4f}")
if __name__ == "__main__":
main()
|