File size: 2,974 Bytes
d646e7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn
import dgl

from models import utils

class Edge_Network(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        hid_size = cfg.hid_size
        n_layers = cfg.n_layers
        self.n_proc_steps = cfg.n_proc_steps

        #encoder
        self.node_encoder = utils.Make_MLP(cfg.input_dim_nodes, hid_size, hid_size, n_layers)
        self.edge_encoder = utils.Make_MLP(cfg.input_dim_edges, hid_size, hid_size, n_layers)
        self.global_encoder = utils.Make_MLP(cfg.input_dim_globals, hid_size, hid_size, n_layers)

        #GNN
        self.node_update = utils.Make_MLP(3*hid_size, hid_size, hid_size, n_layers)
        self.edge_update = utils.Make_MLP(4*hid_size, hid_size, hid_size, n_layers)
        self.global_update = utils.Make_MLP(3*hid_size, hid_size, hid_size, n_layers)

        #decoder
        self.global_decoder = utils.Make_MLP(hid_size, hid_size, hid_size, n_layers)
        self.classify = nn.Linear(hid_size, cfg.out_dim)

    def forward(self, node_feats, edge_feats, global_feats, batched_graph, metadata={}):
        # encoders
        batched_graph.ndata['h'] = self.node_encoder(node_feats)
        batched_graph.edata['e'] = self.edge_encoder(edge_feats)

        if global_feats.ndim == 3:
            global_feats = global_feats.view(-1, global_feats.shape[-1])
        h_global = self.global_encoder(global_feats)

        # message passing
        for _ in range(self.n_proc_steps):
            batched_graph.apply_edges(dgl.function.copy_u('h', 'm_u'))
            batched_graph.apply_edges(utils.copy_v)

            # edge update
            edge_inputs = torch.cat([
                batched_graph.edata['e'],
                batched_graph.edata['m_u'],
                batched_graph.edata['m_v'],
                utils.broadcast_global_to_edges(h_global, edge_split=metadata.get("batch_num_edges", None))
            ], dim=1)
            batched_graph.edata['e'] = self.edge_update(edge_inputs)

            # node update
            batched_graph.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
            node_inputs = torch.cat([
                batched_graph.ndata['h'],
                batched_graph.ndata['h_e'],
                utils.broadcast_global_to_nodes(h_global, node_split=metadata.get("batch_num_nodes", None))
            ], dim=1)
            batched_graph.ndata['h'] = self.node_update(node_inputs)

            # global update
            graph_node_feat = utils.mean_nodes(
                batched_graph, 'h', node_split=metadata.get("batch_num_nodes", None)
            )
            graph_edge_feat = utils.mean_edges(
                batched_graph, 'e', edge_split=metadata.get("batch_num_edges", None)
            )
            h_global = self.global_update(torch.cat([h_global, graph_node_feat, graph_edge_feat], dim=1))

        h_global = self.global_decoder(h_global)
        out = self.classify(h_global)
        return out