| import torch |
| import torch.nn as nn |
| import dgl |
|
|
| from models import utils |
|
|
| class Edge_Network(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| hid_size = cfg.hid_size |
| n_layers = cfg.n_layers |
| self.n_proc_steps = cfg.n_proc_steps |
|
|
| |
| self.node_encoder = utils.Make_MLP(cfg.input_dim_nodes, hid_size, hid_size, n_layers) |
| self.edge_encoder = utils.Make_MLP(cfg.input_dim_edges, hid_size, hid_size, n_layers) |
| self.global_encoder = utils.Make_MLP(cfg.input_dim_globals, hid_size, hid_size, n_layers) |
|
|
| |
| self.node_update = utils.Make_MLP(3*hid_size, hid_size, hid_size, n_layers) |
| self.edge_update = utils.Make_MLP(4*hid_size, hid_size, hid_size, n_layers) |
| self.global_update = utils.Make_MLP(3*hid_size, hid_size, hid_size, n_layers) |
|
|
| |
| self.global_decoder = utils.Make_MLP(hid_size, hid_size, hid_size, n_layers) |
| self.classify = nn.Linear(hid_size, cfg.out_dim) |
|
|
| def forward(self, node_feats, edge_feats, global_feats, batched_graph, metadata={}): |
| |
| batched_graph.ndata['h'] = self.node_encoder(node_feats) |
| batched_graph.edata['e'] = self.edge_encoder(edge_feats) |
|
|
| if global_feats.ndim == 3: |
| global_feats = global_feats.view(-1, global_feats.shape[-1]) |
| h_global = self.global_encoder(global_feats) |
|
|
| |
| for _ in range(self.n_proc_steps): |
| batched_graph.apply_edges(dgl.function.copy_u('h', 'm_u')) |
| batched_graph.apply_edges(utils.copy_v) |
|
|
| |
| edge_inputs = torch.cat([ |
| batched_graph.edata['e'], |
| batched_graph.edata['m_u'], |
| batched_graph.edata['m_v'], |
| utils.broadcast_global_to_edges(h_global, edge_split=metadata.get("batch_num_edges", None)) |
| ], dim=1) |
| batched_graph.edata['e'] = self.edge_update(edge_inputs) |
|
|
| |
| batched_graph.update_all(dgl.function.copy_e('e', 'm'), dgl.function.sum('m', 'h_e')) |
| node_inputs = torch.cat([ |
| batched_graph.ndata['h'], |
| batched_graph.ndata['h_e'], |
| utils.broadcast_global_to_nodes(h_global, node_split=metadata.get("batch_num_nodes", None)) |
| ], dim=1) |
| batched_graph.ndata['h'] = self.node_update(node_inputs) |
|
|
| |
| graph_node_feat = utils.mean_nodes( |
| batched_graph, 'h', node_split=metadata.get("batch_num_nodes", None) |
| ) |
| graph_edge_feat = utils.mean_edges( |
| batched_graph, 'e', edge_split=metadata.get("batch_num_edges", None) |
| ) |
| h_global = self.global_update(torch.cat([h_global, graph_node_feat, graph_edge_feat], dim=1)) |
|
|
| h_global = self.global_decoder(h_global) |
| out = self.classify(h_global) |
| return out |
|
|
|
|