Instructions to use EndeavourDD/gnn_wm with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use EndeavourDD/gnn_wm with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("EndeavourDD/gnn_wm", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| from __future__ import annotations | |
| from typing import Literal | |
| import torch | |
| import torch.nn as nn | |
| from torch_geometric.nn import GINEConv, GPSConv | |
| class GraphSpatialEncoder(nn.Module): | |
| """Per-frame spatial graph encoder built on top of PyG.""" | |
| def __init__( | |
| self, | |
| node_in_dim: int, | |
| edge_in_dim: int, | |
| hidden_dim: int = 256, | |
| num_layers: int = 4, | |
| dropout: float = 0.1, | |
| backbone: Literal["gine", "gps"] = "gine", | |
| num_heads: int = 8, | |
| ): | |
| super().__init__() | |
| self.backbone = backbone | |
| self.node_proj = nn.Linear(node_in_dim, hidden_dim) | |
| self.edge_proj = nn.Linear(edge_in_dim, hidden_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.layers = nn.ModuleList() | |
| for _ in range(num_layers): | |
| mlp = nn.Sequential( | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.SiLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| ) | |
| if backbone == "gine": | |
| conv = GINEConv(mlp, edge_dim=hidden_dim) | |
| elif backbone == "gps": | |
| conv = GPSConv( | |
| channels=hidden_dim, | |
| conv=GINEConv(mlp, edge_dim=hidden_dim), | |
| heads=num_heads, | |
| attn_type="multihead", | |
| attn_kwargs={"dropout": dropout}, | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported graph backbone: {backbone}") | |
| self.layers.append(conv) | |
| self.norm = nn.LayerNorm(hidden_dim) | |
| def forward(self, data) -> torch.Tensor: | |
| x = self.node_proj(data.x) | |
| edge_attr = self.edge_proj(data.edge_attr) | |
| for layer in self.layers: | |
| residual = x | |
| if self.backbone == "gps": | |
| x = layer(x, data.edge_index, data.batch, edge_attr=edge_attr) | |
| else: | |
| x = layer(x, data.edge_index, edge_attr=edge_attr) | |
| x = self.norm(x + self.dropout(x) + residual) | |
| return x | |