ypreetham commited on
Commit
14a33e8
·
1 Parent(s): f251d7d

Practice Pulling (#4)

Browse files

- Added error message to training script. (a39024d691cfc9ca696261913815367f357cdd02)

root_gnn_dgl/scripts/training_script.py CHANGED
@@ -496,6 +496,8 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
496
 
497
  try:
498
  #test_auc = roc_auc_score(labels[wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
 
 
499
  test_auc = roc_auc_score(labels_onehot[wgt_mask], scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
500
  except ValueError:
501
  test_auc = np.nan
 
496
 
497
  try:
498
  #test_auc = roc_auc_score(labels[wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
499
+ if (len(scores[0]) != config["Model"]["args"]["out_size"]):
500
+ print("ERROR: The out_size and the number of class labels don't match! Please check config.")
501
  test_auc = roc_auc_score(labels_onehot[wgt_mask], scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
502
  except ValueError:
503
  test_auc = np.nan