File size: 2,974 Bytes
d646e7f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | import torch
import torch.nn as nn
import dgl
from models import utils
class Edge_Network(nn.Module):
def __init__(self, cfg):
super().__init__()
hid_size = cfg.hid_size
n_layers = cfg.n_layers
self.n_proc_steps = cfg.n_proc_steps
#encoder
self.node_encoder = utils.Make_MLP(cfg.input_dim_nodes, hid_size, hid_size, n_layers)
self.edge_encoder = utils.Make_MLP(cfg.input_dim_edges, hid_size, hid_size, n_layers)
self.global_encoder = utils.Make_MLP(cfg.input_dim_globals, hid_size, hid_size, n_layers)
#GNN
self.node_update = utils.Make_MLP(3*hid_size, hid_size, hid_size, n_layers)
self.edge_update = utils.Make_MLP(4*hid_size, hid_size, hid_size, n_layers)
self.global_update = utils.Make_MLP(3*hid_size, hid_size, hid_size, n_layers)
#decoder
self.global_decoder = utils.Make_MLP(hid_size, hid_size, hid_size, n_layers)
self.classify = nn.Linear(hid_size, cfg.out_dim)
def forward(self, node_feats, edge_feats, global_feats, batched_graph, metadata={}):
# encoders
batched_graph.ndata['h'] = self.node_encoder(node_feats)
batched_graph.edata['e'] = self.edge_encoder(edge_feats)
if global_feats.ndim == 3:
global_feats = global_feats.view(-1, global_feats.shape[-1])
h_global = self.global_encoder(global_feats)
# message passing
for _ in range(self.n_proc_steps):
batched_graph.apply_edges(dgl.function.copy_u('h', 'm_u'))
batched_graph.apply_edges(utils.copy_v)
# edge update
edge_inputs = torch.cat([
batched_graph.edata['e'],
batched_graph.edata['m_u'],
batched_graph.edata['m_v'],
utils.broadcast_global_to_edges(h_global, edge_split=metadata.get("batch_num_edges", None))
], dim=1)
batched_graph.edata['e'] = self.edge_update(edge_inputs)
# node update
batched_graph.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e'))
node_inputs = torch.cat([
batched_graph.ndata['h'],
batched_graph.ndata['h_e'],
utils.broadcast_global_to_nodes(h_global, node_split=metadata.get("batch_num_nodes", None))
], dim=1)
batched_graph.ndata['h'] = self.node_update(node_inputs)
# global update
graph_node_feat = utils.mean_nodes(
batched_graph, 'h', node_split=metadata.get("batch_num_nodes", None)
)
graph_edge_feat = utils.mean_edges(
batched_graph, 'e', edge_split=metadata.get("batch_num_edges", None)
)
h_global = self.global_update(torch.cat([h_global, graph_node_feat, graph_edge_feat], dim=1))
h_global = self.global_decoder(h_global)
out = self.classify(h_global)
return out
|