Instructions to use EndeavourDD/gnn_wm with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use EndeavourDD/gnn_wm with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("EndeavourDD/gnn_wm", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
File size: 2,049 Bytes
da7bf91 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | from __future__ import annotations
from typing import Literal
import torch
import torch.nn as nn
from torch_geometric.nn import GINEConv, GPSConv
class GraphSpatialEncoder(nn.Module):
"""Per-frame spatial graph encoder built on top of PyG."""
def __init__(
self,
node_in_dim: int,
edge_in_dim: int,
hidden_dim: int = 256,
num_layers: int = 4,
dropout: float = 0.1,
backbone: Literal["gine", "gps"] = "gine",
num_heads: int = 8,
):
super().__init__()
self.backbone = backbone
self.node_proj = nn.Linear(node_in_dim, hidden_dim)
self.edge_proj = nn.Linear(edge_in_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.layers = nn.ModuleList()
for _ in range(num_layers):
mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
)
if backbone == "gine":
conv = GINEConv(mlp, edge_dim=hidden_dim)
elif backbone == "gps":
conv = GPSConv(
channels=hidden_dim,
conv=GINEConv(mlp, edge_dim=hidden_dim),
heads=num_heads,
attn_type="multihead",
attn_kwargs={"dropout": dropout},
)
else:
raise ValueError(f"Unsupported graph backbone: {backbone}")
self.layers.append(conv)
self.norm = nn.LayerNorm(hidden_dim)
def forward(self, data) -> torch.Tensor:
x = self.node_proj(data.x)
edge_attr = self.edge_proj(data.edge_attr)
for layer in self.layers:
residual = x
if self.backbone == "gps":
x = layer(x, data.edge_index, data.batch, edge_attr=edge_attr)
else:
x = layer(x, data.edge_index, edge_attr=edge_attr)
x = self.norm(x + self.dropout(x) + residual)
return x
|