ho22joshua commited on
Commit
acf2939
·
1 Parent(s): a68383d

fixing datatype issue in graph creating

Browse files
root_gnn_dgl/root_gnn_base/dataset.py CHANGED
@@ -185,7 +185,9 @@ class RootDataset(DGLDataset):
185
  globals.append(torch.zeros(len(self.global_features)))
186
  for i_ti, tr_branch in enumerate(self.tracking_info):
187
  if isinstance(tr_branch, str):
188
- tracking[-1][i_ti] = ch[tr_branch]
 
 
189
  else:
190
  tracking[-1][i_ti] = tr_branch
191
  for i_gl, gl_branch in enumerate(self.global_features):
 
185
  globals.append(torch.zeros(len(self.global_features)))
186
  for i_ti, tr_branch in enumerate(self.tracking_info):
187
  if isinstance(tr_branch, str):
188
+ dtype = tracking[-1].dtype
189
+ tracking[-1][i_ti] = torch.as_tensor(ch[tr_branch], dtype=dtype)
190
+ # tracking[-1][i_ti] = ch[tr_branch]
191
  else:
192
  tracking[-1][i_ti] = tr_branch
193
  for i_gl, gl_branch in enumerate(self.global_features):