File size: 5,353 Bytes
d646e7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import torch.nn as nn
import dgl

def mean_nodes(batched_graph, feat_key='h', op='mean', node_split=None):
    """
    Aggregates node features per disjoint graph in a batched DGLGraph.

    Args:
        batched_graph: DGLGraph
        feat_key: str, node feature key
        op: 'mean', 'sum', or 'max'
        node_split: 1D tensor or list of ints (num nodes per graph)

    Returns:
        Tensor of shape [num_graphs, node_feat_dim]
    """
    h = batched_graph.ndata[feat_key]
    if node_split is None or len(node_split) == 0:
        if op == 'mean':
            return dgl.mean_nodes(batched_graph, feat_key)
        elif op == 'sum':
            return dgl.sum_nodes(batched_graph, feat_key)
        elif op == 'max':
            return dgl.max_nodes(batched_graph, feat_key)
        else:
            raise ValueError(f"Unknown op: {op}")
    else:
        # Ensure node_split is a flat list of ints
        if isinstance(node_split, torch.Tensor):
            splits = node_split.view(-1).tolist()
        else:
            splits = [int(x) for x in node_split]
        chunks = torch.split(h, splits, dim=0)
        if op == 'mean':
            out = torch.stack([chunk.mean(0) if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
        elif op == 'sum':
            out = torch.stack([chunk.sum(0) if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
        elif op == 'max':
            out = torch.stack([chunk.max(0).values if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks])
        else:
            raise ValueError(f"Unknown op: {op}")
        return out
    
def mean_edges(batched_graph, feat_key='e', op='mean', edge_split=None):
    """
    Aggregates edge features per disjoint graph in a batched DGLGraph.

    Args:
        batched_graph: DGLGraph
        feat_key: str, edge feature key
        op: 'mean', 'sum', or 'max'
        edge_split: 1D tensor or list of ints (num edges per graph)

    Returns:
        Tensor of shape [num_graphs, edge_feat_dim]
    """
    e = batched_graph.edata[feat_key]
    if edge_split is None or len(edge_split) == 0:
        if op == 'mean':
            return dgl.mean_edges(batched_graph, feat_key)
        elif op == 'sum':
            return dgl.sum_edges(batched_graph, feat_key)
        elif op == 'max':
            return dgl.max_edges(batched_graph, feat_key)
        else:
            raise ValueError(f"Unknown op: {op}")
    else:
        # Ensure edge_split is a flat list of ints
        if isinstance(edge_split, torch.Tensor):
            splits = edge_split.view(-1).tolist()
        else:
            splits = [int(x) for x in edge_split]
        chunks = torch.split(e, splits, dim=0)
        if op == 'mean':
            out = torch.stack([chunk.mean(0) if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
        elif op == 'sum':
            out = torch.stack([chunk.sum(0) if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
        elif op == 'max':
            out = torch.stack([chunk.max(0).values if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks])
        else:
            raise ValueError(f"Unknown op: {op}")
        return out
    
def Make_SLP(in_size, out_size, activation = nn.ReLU, dropout = 0):
    layers = []
    layers.append(nn.Linear(in_size, out_size))
    layers.append(activation())
    layers.append(nn.Dropout(dropout))
    return layers

def Make_MLP(in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0):
    layers = []
    if n_layers > 1:
        layers += Make_SLP(in_size, hid_size, activation, dropout)
        for i in range(n_layers-2):
            layers += Make_SLP(hid_size, hid_size, activation, dropout)
        layers += Make_SLP(hid_size, out_size, activation, dropout)
    else:
        layers += Make_SLP(in_size, out_size, activation, dropout)
    layers.append(torch.nn.LayerNorm(out_size))
    return nn.Sequential(*layers)

def broadcast_global_to_nodes(globals, node_split):
    """
    globals: [num_graphs, global_dim]
    node_split: list/1D tensor of length num_graphs, number of nodes per graph
    Returns: [total_num_nodes, global_dim]
    """
    if node_split is None:
        raise ValueError("node_split must be provided")
    if not torch.is_tensor(node_split):
        node_split = torch.tensor(node_split, dtype=torch.long, device=globals.device)
    else:
        node_split = node_split.to(device=globals.device, dtype=torch.long)
    node_split = node_split.flatten()
    return torch.repeat_interleave(globals, node_split, dim=0)

def broadcast_global_to_edges(globals, edge_split):
    """
    globals: [num_graphs, global_dim] (on CUDA or CPU)
    edge_split: list/1D tensor of length num_graphs, number of edges per graph (CPU or CUDA)
    Returns: [total_num_edges, global_dim]
    """
    if edge_split is None:
        raise ValueError("edge_split must be provided")
    if not torch.is_tensor(edge_split):
        edge_split = torch.tensor(edge_split, dtype=torch.long, device=globals.device)
    else:
        edge_split = edge_split.to(device=globals.device, dtype=torch.long)
    edge_split = edge_split.flatten()
    return torch.repeat_interleave(globals, edge_split, dim=0)

def copy_v(edges):
    return {'m_v': edges.dst['h']}