Update app.py
Browse files
app.py
CHANGED
|
@@ -69,8 +69,8 @@ def train():
|
|
| 69 |
optimizer.zero_grad()
|
| 70 |
z = encoder(data.x, data.edge_index)
|
| 71 |
|
| 72 |
-
pos_edge = data.
|
| 73 |
-
neg_edge = data.
|
| 74 |
|
| 75 |
pos_pred = predictor(z[pos_edge[0]], z[pos_edge[1]])
|
| 76 |
neg_pred = predictor(z[neg_edge[0]], z[neg_edge[1]])
|
|
@@ -86,8 +86,10 @@ def train():
|
|
| 86 |
encoder.eval()
|
| 87 |
predictor.eval()
|
| 88 |
z = encoder(test_data.x, test_data.edge_index)
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
| 91 |
preds = torch.cat([pos_pred, neg_pred])
|
| 92 |
labels = torch.cat([
|
| 93 |
torch.ones(pos_pred.size(0)),
|
|
@@ -99,7 +101,7 @@ def train():
|
|
| 99 |
|
| 100 |
def visualize_graph():
|
| 101 |
temp_data = data.clone()
|
| 102 |
-
temp_data.edge_index =
|
| 103 |
G = to_networkx(temp_data, to_undirected=True)
|
| 104 |
plt.figure(figsize=(4, 4))
|
| 105 |
nx.draw(G, with_labels=True, node_color='lightblue', edge_color='gray')
|
|
@@ -155,7 +157,7 @@ def predict_link(node_i, node_j):
|
|
| 155 |
return []
|
| 156 |
|
| 157 |
temp_data = data.clone()
|
| 158 |
-
temp_data.edge_index =
|
| 159 |
G = to_networkx(temp_data, to_undirected=True)
|
| 160 |
neighbors = {node: list(G.neighbors(node)) for node in G.nodes}
|
| 161 |
dfs = dfs_path(neighbors, int(node_i), int(node_j))
|
|
|
|
| 69 |
optimizer.zero_grad()
|
| 70 |
z = encoder(data.x, data.edge_index)
|
| 71 |
|
| 72 |
+
pos_edge = data.edge_label_index[:, data.edge_label == 1]
|
| 73 |
+
neg_edge = data.edge_label_index[:, data.edge_label == 0]
|
| 74 |
|
| 75 |
pos_pred = predictor(z[pos_edge[0]], z[pos_edge[1]])
|
| 76 |
neg_pred = predictor(z[neg_edge[0]], z[neg_edge[1]])
|
|
|
|
| 86 |
encoder.eval()
|
| 87 |
predictor.eval()
|
| 88 |
z = encoder(test_data.x, test_data.edge_index)
|
| 89 |
+
pos_edge = test_data.edge_label_index[:, test_data.edge_label == 1]
|
| 90 |
+
neg_edge = test_data.edge_label_index[:, test_data.edge_label == 0]
|
| 91 |
+
pos_pred = predictor(z[pos_edge[0]], z[pos_edge[1]])
|
| 92 |
+
neg_pred = predictor(z[neg_edge[0]], z[neg_edge[1]])
|
| 93 |
preds = torch.cat([pos_pred, neg_pred])
|
| 94 |
labels = torch.cat([
|
| 95 |
torch.ones(pos_pred.size(0)),
|
|
|
|
| 101 |
|
| 102 |
def visualize_graph():
|
| 103 |
temp_data = data.clone()
|
| 104 |
+
temp_data.edge_index = data.edge_index # Use actual edge index for visualization
|
| 105 |
G = to_networkx(temp_data, to_undirected=True)
|
| 106 |
plt.figure(figsize=(4, 4))
|
| 107 |
nx.draw(G, with_labels=True, node_color='lightblue', edge_color='gray')
|
|
|
|
| 157 |
return []
|
| 158 |
|
| 159 |
temp_data = data.clone()
|
| 160 |
+
temp_data.edge_index = data.edge_index
|
| 161 |
G = to_networkx(temp_data, to_undirected=True)
|
| 162 |
neighbors = {node: list(G.neighbors(node)) for node in G.nodes}
|
| 163 |
dfs = dfs_path(neighbors, int(node_i), int(node_j))
|