Commit
·
3b61c0f
1
Parent(s):
17c40e5
fixing bug with attention edge network
Browse files
root_gnn_dgl/models/GCN.py
CHANGED
|
@@ -1269,7 +1269,8 @@ class Attention_Edge_Network(nn.Module):
|
|
| 1269 |
h_original_shape = h.shape
|
| 1270 |
num_graphs = len(dgl.unbatch(g))
|
| 1271 |
num_nodes = g.batch_num_nodes()[0].item()
|
| 1272 |
-
padding_mask = g.ndata['padding_mask']
|
|
|
|
| 1273 |
padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
|
| 1274 |
|
| 1275 |
for i in range(self.n_proc_steps):
|
|
|
|
| 1269 |
h_original_shape = h.shape
|
| 1270 |
num_graphs = len(dgl.unbatch(g))
|
| 1271 |
num_nodes = g.batch_num_nodes()[0].item()
|
| 1272 |
+
padding_mask = g.ndata['padding_mask'].bool()
|
| 1273 |
+
|
| 1274 |
padding_mask = torch.reshape(padding_mask, (num_graphs, num_nodes))
|
| 1275 |
|
| 1276 |
for i in range(self.n_proc_steps):
|