Practice Pulling
#4
by
ypreetham
- opened
root_gnn_dgl/scripts/training_script.py
CHANGED
|
@@ -500,6 +500,8 @@ def train(train_loaders, test_loaders, model, device, config, args, rank):
|
|
| 500 |
|
| 501 |
try:
|
| 502 |
#test_auc = roc_auc_score(labels[wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
|
|
|
|
|
|
|
| 503 |
test_auc = roc_auc_score(labels_onehot[wgt_mask], scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
|
| 504 |
except ValueError:
|
| 505 |
test_auc = np.nan
|
|
|
|
| 500 |
|
| 501 |
try:
|
| 502 |
#test_auc = roc_auc_score(labels[wgt_mask].to("cpu") == 1, scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
|
| 503 |
+
if (len(scores[0]) != config["Model"]["args"]["out_size"]):
|
| 504 |
+
print("ERROR: The out_size and the number of class labels don't match! Please check config.")
|
| 505 |
test_auc = roc_auc_score(labels_onehot[wgt_mask], scores[wgt_mask].to("cpu"), multi_class='ovr', sample_weight=weights[wgt_mask].to("cpu"))
|
| 506 |
except ValueError:
|
| 507 |
test_auc = np.nan
|