EzekielMW commited on
Commit
04cd1d6
·
verified ·
1 Parent(s): 16a0610

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -6
app.py CHANGED
@@ -69,8 +69,8 @@ def train():
69
  optimizer.zero_grad()
70
  z = encoder(data.x, data.edge_index)
71
 
72
- pos_edge = data.pos_edge_label_index
73
- neg_edge = data.neg_edge_label_index
74
 
75
  pos_pred = predictor(z[pos_edge[0]], z[pos_edge[1]])
76
  neg_pred = predictor(z[neg_edge[0]], z[neg_edge[1]])
@@ -86,8 +86,10 @@ def train():
86
  encoder.eval()
87
  predictor.eval()
88
  z = encoder(test_data.x, test_data.edge_index)
89
- pos_pred = predictor(z[test_data.pos_edge_label_index[0]], z[test_data.pos_edge_label_index[1]])
90
- neg_pred = predictor(z[test_data.neg_edge_label_index[0]], z[test_data.neg_edge_label_index[1]])
 
 
91
  preds = torch.cat([pos_pred, neg_pred])
92
  labels = torch.cat([
93
  torch.ones(pos_pred.size(0)),
@@ -99,7 +101,7 @@ def train():
99
 
100
  def visualize_graph():
101
  temp_data = data.clone()
102
- temp_data.edge_index = temp_data.pos_edge_label_index
103
  G = to_networkx(temp_data, to_undirected=True)
104
  plt.figure(figsize=(4, 4))
105
  nx.draw(G, with_labels=True, node_color='lightblue', edge_color='gray')
@@ -155,7 +157,7 @@ def predict_link(node_i, node_j):
155
  return []
156
 
157
  temp_data = data.clone()
158
- temp_data.edge_index = temp_data.pos_edge_label_index
159
  G = to_networkx(temp_data, to_undirected=True)
160
  neighbors = {node: list(G.neighbors(node)) for node in G.nodes}
161
  dfs = dfs_path(neighbors, int(node_i), int(node_j))
 
69
  optimizer.zero_grad()
70
  z = encoder(data.x, data.edge_index)
71
 
72
+ pos_edge = data.edge_label_index[:, data.edge_label == 1]
73
+ neg_edge = data.edge_label_index[:, data.edge_label == 0]
74
 
75
  pos_pred = predictor(z[pos_edge[0]], z[pos_edge[1]])
76
  neg_pred = predictor(z[neg_edge[0]], z[neg_edge[1]])
 
86
  encoder.eval()
87
  predictor.eval()
88
  z = encoder(test_data.x, test_data.edge_index)
89
+ pos_edge = test_data.edge_label_index[:, test_data.edge_label == 1]
90
+ neg_edge = test_data.edge_label_index[:, test_data.edge_label == 0]
91
+ pos_pred = predictor(z[pos_edge[0]], z[pos_edge[1]])
92
+ neg_pred = predictor(z[neg_edge[0]], z[neg_edge[1]])
93
  preds = torch.cat([pos_pred, neg_pred])
94
  labels = torch.cat([
95
  torch.ones(pos_pred.size(0)),
 
101
 
102
  def visualize_graph():
103
  temp_data = data.clone()
104
+ temp_data.edge_index = data.edge_index # Use actual edge index for visualization
105
  G = to_networkx(temp_data, to_undirected=True)
106
  plt.figure(figsize=(4, 4))
107
  nx.draw(G, with_labels=True, node_color='lightblue', edge_color='gray')
 
157
  return []
158
 
159
  temp_data = data.clone()
160
+ temp_data.edge_index = data.edge_index
161
  G = to_networkx(temp_data, to_undirected=True)
162
  neighbors = {node: list(G.neighbors(node)) for node in G.nodes}
163
  dfs = dfs_path(neighbors, int(node_i), int(node_j))