Commit ·
acf2939
1
Parent(s): a68383d
fixing datatype issue in graph creating
Browse files
root_gnn_dgl/root_gnn_base/dataset.py
CHANGED
|
@@ -185,7 +185,9 @@ class RootDataset(DGLDataset):
|
|
| 185 |
globals.append(torch.zeros(len(self.global_features)))
|
| 186 |
for i_ti, tr_branch in enumerate(self.tracking_info):
|
| 187 |
if isinstance(tr_branch, str):
|
| 188 |
-
tracking[-1]
|
|
|
|
|
|
|
| 189 |
else:
|
| 190 |
tracking[-1][i_ti] = tr_branch
|
| 191 |
for i_gl, gl_branch in enumerate(self.global_features):
|
|
|
|
| 185 |
globals.append(torch.zeros(len(self.global_features)))
|
| 186 |
for i_ti, tr_branch in enumerate(self.tracking_info):
|
| 187 |
if isinstance(tr_branch, str):
|
| 188 |
+
dtype = tracking[-1].dtype
|
| 189 |
+
tracking[-1][i_ti] = torch.as_tensor(ch[tr_branch], dtype=dtype)
|
| 190 |
+
# tracking[-1][i_ti] = ch[tr_branch]
|
| 191 |
else:
|
| 192 |
tracking[-1][i_ti] = tr_branch
|
| 193 |
for i_gl, gl_branch in enumerate(self.global_features):
|