Instructions to use EndeavourDD/gnn_wm with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use EndeavourDD/gnn_wm with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("EndeavourDD/gnn_wm", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| """Train a GraphGPS (GPSConv + GINEConv + global attention) model on the | |
| JSON constraint graphs produced by frame_to_graph.py. | |
| Task (demo): graph-level regression of disassembly progress | |
| target = frame_idx / max_frame_idx β [0, 1] | |
| The model reads a per-frame constraint graph (15 products + optional robot, | |
| fully connected with constraint edge features) and predicts how far along | |
| the disassembly is. This mirrors the GraphGPS tutorial (PyG docs): local | |
| MPNN (GINEConv) + global attention, stacked for N layers. | |
| Run: | |
| python train_gps.py | |
| """ | |
| import json | |
| import math | |
| import random | |
| from pathlib import Path | |
| from typing import List | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from torch.nn import Linear, ReLU, Sequential | |
| from torch_geometric.data import Data | |
| from torch_geometric.loader import DataLoader | |
| from torch_geometric.nn import GINEConv, GPSConv, global_add_pool | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. JSON graph β PyG Data | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| TYPE_VOCAB = ["cpu_fan", "cpu_bracket", "cpu", "ram_clip", "ram", | |
| "connector", "graphic_card", "motherboard", "robot"] | |
| TYPE_TO_IDX = {t: i for i, t in enumerate(TYPE_VOCAB)} | |
| NODE_FEAT_DIM = len(TYPE_VOCAB) + 3 + 2 # 9 type one-hot + 3D centroid + (mask_area, emb_norm) = 14 | |
| EDGE_FEAT_DIM = 2 # [has_constraint, is_locked] | |
| def json_to_pyg(path: Path, target: float) -> Data: | |
| """Read a frame graph JSON and build a fully-connected PyG Data object.""" | |
| with open(path) as f: | |
| gd = json.load(f) | |
| nodes = gd["nodes"] | |
| N = len(nodes) | |
| id_to_idx = {n["id"]: i for i, n in enumerate(nodes)} | |
| # Node features: [type one-hot (9), centroid_3d (3), mask_area_scaled, emb_norm] β 14D | |
| x = torch.zeros((N, NODE_FEAT_DIM), dtype=torch.float32) | |
| for i, n in enumerate(nodes): | |
| x[i, TYPE_TO_IDX[n["type"]]] = 1.0 | |
| x[i, 9:12] = torch.tensor(n["centroid_3d"], dtype=torch.float32) | |
| x[i, 12] = n["mask_area"] / 1e5 # scale to ~O(1) | |
| x[i, 13] = n["embedding_norm"] / 5.0 | |
| # Sparse constraint lookup | |
| constraint = {} # frozenset({a, b}) -> is_locked | |
| for e in gd["edges"]: | |
| if e["src"] in id_to_idx and e["dst"] in id_to_idx: | |
| constraint[frozenset([e["src"], e["dst"]])] = bool(e["is_locked"]) | |
| # Fully connected edges with 2D features [has_constraint, is_locked] | |
| src_idx, dst_idx, edge_attr = [], [], [] | |
| for i in range(N): | |
| for j in range(N): | |
| if i == j: | |
| continue | |
| src_idx.append(i) | |
| dst_idx.append(j) | |
| key = frozenset([nodes[i]["id"], nodes[j]["id"]]) | |
| if key in constraint: | |
| edge_attr.append([1.0, 1.0 if constraint[key] else 0.0]) | |
| else: | |
| edge_attr.append([0.0, 0.0]) | |
| return Data( | |
| x=x, | |
| edge_index=torch.tensor([src_idx, dst_idx], dtype=torch.long), | |
| edge_attr=torch.tensor(edge_attr, dtype=torch.float32), | |
| y=torch.tensor([target], dtype=torch.float32), | |
| num_nodes=N, | |
| ) | |
| def build_dataset(json_dir: Path) -> List[Data]: | |
| paths = sorted(json_dir.glob("frame_*_graph.json")) | |
| frame_ids = [int(p.stem.split("_")[1]) for p in paths] | |
| max_f = max(frame_ids) | |
| dataset = [] | |
| for p, fid in zip(paths, frame_ids): | |
| target = fid / max_f | |
| dataset.append(json_to_pyg(p, target)) | |
| return dataset | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. GraphGPS model β mirrors the PyG tutorial but adapted for continuous | |
| # node / edge features (Linear instead of Embedding) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class GPS(nn.Module): | |
| def __init__(self, channels: int = 64, num_layers: int = 4, heads: int = 4): | |
| super().__init__() | |
| self.node_lin = Linear(NODE_FEAT_DIM, channels) | |
| self.edge_lin = Linear(EDGE_FEAT_DIM, channels) | |
| self.convs = nn.ModuleList() | |
| for _ in range(num_layers): | |
| mlp = Sequential( | |
| Linear(channels, channels), | |
| ReLU(), | |
| Linear(channels, channels), | |
| ) | |
| self.convs.append( | |
| GPSConv(channels, GINEConv(mlp), heads=heads, | |
| attn_type="multihead", attn_kwargs={"dropout": 0.1}) | |
| ) | |
| self.head = Sequential( | |
| Linear(channels, channels // 2), ReLU(), | |
| Linear(channels // 2, 1), | |
| ) | |
| def forward(self, x, edge_index, edge_attr, batch): | |
| x = self.node_lin(x) | |
| e = self.edge_lin(edge_attr) | |
| for conv in self.convs: | |
| x = conv(x, edge_index, batch, edge_attr=e) | |
| g = global_add_pool(x, batch) | |
| return self.head(g).squeeze(-1) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. Train / eval loop | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| torch.manual_seed(0) | |
| random.seed(0) | |
| json_dir = Path("graph_jsons") | |
| dataset = build_dataset(json_dir) | |
| print(f"Loaded {len(dataset)} frame graphs") | |
| print(f"Example: {dataset[0]} target={dataset[0].y.item():.3f}") | |
| # Shuffle and split 80/10/10 | |
| random.shuffle(dataset) | |
| n = len(dataset) | |
| n_train = int(0.8 * n) | |
| n_val = int(0.1 * n) | |
| train_ds = dataset[:n_train] | |
| val_ds = dataset[n_train:n_train + n_val] | |
| test_ds = dataset[n_train + n_val:] | |
| print(f"Splits: train={len(train_ds)}, val={len(val_ds)}, test={len(test_ds)}") | |
| train_loader = DataLoader(train_ds, batch_size=16, shuffle=True) | |
| val_loader = DataLoader(val_ds, batch_size=32) | |
| test_loader = DataLoader(test_ds, batch_size=32) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = GPS(channels=64, num_layers=4, heads=4).to(device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5) | |
| def step(loader, train: bool): | |
| model.train(train) | |
| total = 0.0 | |
| count = 0 | |
| for data in loader: | |
| data = data.to(device) | |
| if train: | |
| optimizer.zero_grad() | |
| with torch.set_grad_enabled(train): | |
| pred = model(data.x, data.edge_index, data.edge_attr, data.batch) | |
| loss = F.l1_loss(pred, data.y) | |
| if train: | |
| loss.backward() | |
| optimizer.step() | |
| total += loss.item() * data.num_graphs | |
| count += data.num_graphs | |
| return total / count | |
| for epoch in range(1, 51): | |
| tr = step(train_loader, train=True) | |
| vl = step(val_loader, train=False) | |
| te = step(test_loader, train=False) | |
| print(f"Epoch {epoch:02d} | train MAE {tr:.4f} | val MAE {vl:.4f} | test MAE {te:.4f}") | |
| if __name__ == "__main__": | |
| main() | |