File size: 3,453 Bytes
5ceead6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import dgl
import torch
from dataclasses import dataclass, field
from typing import List, Dict

@dataclass
class Graphs:
    graphs: List[dgl.DGLGraph]
    metadata: Dict[str, torch.Tensor]

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        meta = {k: v[idx] for k, v in self.metadata.items()}
        return self.graphs[idx], meta

    def shuffle(self):
        idx = torch.randperm(len(self.graphs))
        self.graphs = [self.graphs[i] for i in idx]
        for k in self.metadata:
            self.metadata[k] = self.metadata[k][idx]

    def batch(self, batch_size, node_feature_dim=None, dtype=None):
        """
        In-place batching: after this, self.graphs is a list of batched DGLGraphs,
        and self.metadata[k] is a tensor of shape [num_batches, batch_size, ...].
        """
        batched_graphs = []
        batched_meta = {k: [] for k in self.metadata}
        N = len(self.graphs)

        # Infer node_feature_dim and dtype if not specified
        if node_feature_dim is None and N > 0:
            feats = self.graphs[0].ndata['features']
            node_feature_dim = feats.shape[1] if feats.ndim > 1 else 1
        if dtype is None and N > 0:
            dtype = self.graphs[0].ndata['features'].dtype

        for start in range(0, N, batch_size):
            end = start + batch_size
            batch_graphs = self.graphs[start:end]
            batch_meta = {k: v[start:end] for k, v in self.metadata.items()}

            # Padding if needed
            pad_count = batch_size - len(batch_graphs)
            if pad_count > 0:
                dummy_graph = dgl.graph(([], []))
                dummy_graph.ndata['features'] = torch.empty((0, node_feature_dim), dtype=dtype)
                dummy_graph.edata['features'] = torch.empty((0, 3), dtype=dtype)  # assuming 3 edge features
                batch_graphs += [dummy_graph] * pad_count

                # Pad metadata with zeros
                for k, v in batch_meta.items():
                    shape = list(v[0].shape) if len(v) > 0 else []
                    pad_tensor = torch.zeros([pad_count] + shape, dtype=v.dtype, device=v.device)
                    batch_meta[k] = torch.cat([v, pad_tensor], dim=0)
            else:
                for k, v in batch_meta.items():
                    batch_meta[k] = torch.stack(v, dim=0) if isinstance(v, list) else v

            batched_graphs.append(dgl.batch(batch_graphs))
            for k in batched_meta:
                batched_meta[k].append(batch_meta[k])

        # Now stack along a new axis: [num_batches, batch_size, ...]
        for k in batched_meta:
            self.metadata[k] = torch.stack(batched_meta[k], dim=0)

        self.graphs = batched_graphs

    def normalize(self, stats):
        node_mean, node_std, _ = stats['node']
        edge_mean, edge_std, _ = stats['edge']
        for g in self.graphs:
            g.ndata['features'] = (g.ndata['features'] - node_mean) / node_std
            g.edata['features'] = (g.edata['features'] - edge_mean) / edge_std

def save_graphs(graphs: Graphs, f: str):
    meta_to_save = {k: v for k, v in graphs.metadata.items()}
    dgl.save_graphs(f, graphs.graphs, meta_to_save)

def load_graphs(f: str) -> Graphs:
    g, meta = dgl.load_graphs(f)
    for k in meta:
        if not isinstance(meta[k], torch.Tensor):
            meta[k] = torch.stack(meta[k])
    return Graphs(graphs=g, metadata=meta)