ho22joshua commited on
Commit
3b61c0f
·
1 Parent(s): 17c40e5

fixing bug with attention edge network

Browse files
Files changed (1) hide show
  1. root_gnn_dgl/models/GCN.py +2 -1
root_gnn_dgl/models/GCN.py CHANGED
@@ -1269,7 +1269,8 @@ class Attention_Edge_Network(nn.Module):
1269
  h_original_shape = h.shape
1270
  num_graphs = len(dgl.unbatch(g))
1271
  num_nodes = g.batch_num_nodes()[0].item()
1272
- padding_mask = g.ndata['padding_mask'] > 0
 
1273
  padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
1274
 
1275
  for i in range(self.n_proc_steps):
 
1269
  h_original_shape = h.shape
1270
  num_graphs = len(dgl.unbatch(g))
1271
  num_nodes = g.batch_num_nodes()[0].item()
1272
+ padding_mask = g.ndata['padding_mask'].bool()
1273
+
1274
  padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
1275
 
1276
  for i in range(self.n_proc_steps):