henribonamy commited on
Commit
c99cf4a
·
verified ·
1 Parent(s): 5808458

Fix label mismatch guard, self-loop removal, lasftm typo

Browse files
Files changed (1) hide show
  1. netfm/evaluate/pipeline.py +2 -2
netfm/evaluate/pipeline.py CHANGED
@@ -51,7 +51,7 @@ def run_full_evaluation(
51
 
52
  if data.y is not None:
53
  labels = data.y.cpu().numpy().squeeze()
54
- if labels.ndim == 1 and len(np.unique(labels)) > 1:
55
  n = data.num_nodes
56
  train_mask = np.zeros(n, dtype=bool)
57
  test_mask = np.zeros(n, dtype=bool)
@@ -90,7 +90,7 @@ def run_full_evaluation(
90
 
91
  if data.y is not None:
92
  labels = data.y.cpu().numpy().squeeze()
93
- if labels.ndim == 1 and len(np.unique(labels)) > 1:
94
  results["community_detection"] = {
95
  "netfm": evaluate_community_detection(embeddings, labels),
96
  }
 
51
 
52
  if data.y is not None:
53
  labels = data.y.cpu().numpy().squeeze()
54
+ if labels.ndim == 1 and len(labels) == data.num_nodes and len(np.unique(labels)) > 1:
55
  n = data.num_nodes
56
  train_mask = np.zeros(n, dtype=bool)
57
  test_mask = np.zeros(n, dtype=bool)
 
90
 
91
  if data.y is not None:
92
  labels = data.y.cpu().numpy().squeeze()
93
+ if labels.ndim == 1 and len(labels) == data.num_nodes and len(np.unique(labels)) > 1:
94
  results["community_detection"] = {
95
  "netfm": evaluate_community_detection(embeddings, labels),
96
  }