| import dgl |
| import torch |
| from dataclasses import dataclass, field |
| from typing import List, Dict |
|
|
| @dataclass |
| class Graphs: |
| graphs: List[dgl.DGLGraph] |
| metadata: Dict[str, torch.Tensor] |
|
|
| def __len__(self): |
| return len(self.graphs) |
|
|
| def __getitem__(self, idx): |
| meta = {k: v[idx] for k, v in self.metadata.items()} |
| return self.graphs[idx], meta |
|
|
| def shuffle(self): |
| idx = torch.randperm(len(self.graphs)) |
| self.graphs = [self.graphs[i] for i in idx] |
| for k in self.metadata: |
| self.metadata[k] = self.metadata[k][idx] |
|
|
| def batch(self, batch_size, node_feature_dim=None, dtype=None): |
| """ |
| In-place batching: after this, self.graphs is a list of batched DGLGraphs, |
| and self.metadata[k] is a tensor of shape [num_batches, batch_size, ...]. |
| """ |
| batched_graphs = [] |
| batched_meta = {k: [] for k in self.metadata} |
| N = len(self.graphs) |
|
|
| |
| if node_feature_dim is None and N > 0: |
| feats = self.graphs[0].ndata['features'] |
| node_feature_dim = feats.shape[1] if feats.ndim > 1 else 1 |
| if dtype is None and N > 0: |
| dtype = self.graphs[0].ndata['features'].dtype |
|
|
| for start in range(0, N, batch_size): |
| end = start + batch_size |
| batch_graphs = self.graphs[start:end] |
| batch_meta = {k: v[start:end] for k, v in self.metadata.items()} |
|
|
| |
| pad_count = batch_size - len(batch_graphs) |
| if pad_count > 0: |
| dummy_graph = dgl.graph(([], [])) |
| dummy_graph.ndata['features'] = torch.empty((0, node_feature_dim), dtype=dtype) |
| dummy_graph.edata['features'] = torch.empty((0, 3), dtype=dtype) |
| batch_graphs += [dummy_graph] * pad_count |
|
|
| |
| for k, v in batch_meta.items(): |
| shape = list(v[0].shape) if len(v) > 0 else [] |
| pad_tensor = torch.zeros([pad_count] + shape, dtype=v.dtype, device=v.device) |
| batch_meta[k] = torch.cat([v, pad_tensor], dim=0) |
| else: |
| for k, v in batch_meta.items(): |
| batch_meta[k] = torch.stack(v, dim=0) if isinstance(v, list) else v |
|
|
| batched_graphs.append(dgl.batch(batch_graphs)) |
| for k in batched_meta: |
| batched_meta[k].append(batch_meta[k]) |
|
|
| |
| for k in batched_meta: |
| self.metadata[k] = torch.stack(batched_meta[k], dim=0) |
|
|
| self.graphs = batched_graphs |
|
|
| def normalize(self, stats): |
| node_mean, node_std, _ = stats['node'] |
| edge_mean, edge_std, _ = stats['edge'] |
| for g in self.graphs: |
| g.ndata['features'] = (g.ndata['features'] - node_mean) / node_std |
| g.edata['features'] = (g.edata['features'] - edge_mean) / edge_std |
|
|
| def save_graphs(graphs: Graphs, f: str): |
| meta_to_save = {k: v for k, v in graphs.metadata.items()} |
| dgl.save_graphs(f, graphs.graphs, meta_to_save) |
|
|
| def load_graphs(f: str) -> Graphs: |
| g, meta = dgl.load_graphs(f) |
| for k in meta: |
| if not isinstance(meta[k], torch.Tensor): |
| meta[k] = torch.stack(meta[k]) |
| return Graphs(graphs=g, metadata=meta) |