Practice Pulling (#4)
Browse files- Added error message to training script. (a39024d691cfc9ca696261913815367f357cdd02)
root_gnn_dgl/scripts/training_script.py
CHANGED
|
@@ -496,6 +496,8 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
|
|
| 496 |
|
| 497 |
try:
|
| 498 |
#test_auc = roc_auc_score(labels[wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
|
|
|
|
|
|
|
| 499 |
test_auc = roc_auc_score(labels_onehot[wgt_mask], scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
|
| 500 |
except ValueError:
|
| 501 |
test_auc = np.nan
|
|
|
|
| 496 |
|
| 497 |
try:
|
| 498 |
#test_auc = roc_auc_score(labels[wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
|
| 499 |
+
if (len(scores[0]) != config["Model"]["args"]["out_size"]):
|
| 500 |
+
print("ERROR: The out_size and the number of class labels don't match! Please check config.")
|
| 501 |
test_auc = roc_auc_score(labels_onehot[wgt_mask], scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
|
| 502 |
except ValueError:
|
| 503 |
test_auc = np.nan
|