noranisa commited on
Commit
5266d4e
·
verified ·
1 Parent(s): 6dda3d4

Update services/gnn.py

Browse files
Files changed (1) hide show
  1. services/gnn.py +18 -22
services/gnn.py CHANGED
@@ -1,23 +1,22 @@
1
- import torch
 
 
2
 
3
  def run_gnn(nodes, edges):
4
  try:
5
- from torch_geometric.data import Data
6
  from torch_geometric.nn import GCNConv
 
7
 
8
- # convert ke index
9
- node_ids = list(range(len(nodes)))
10
- edge_index = []
11
 
 
12
  for e in edges:
13
  edge_index.append([e["source"], e["target"]])
14
 
15
- if len(edge_index) == 0:
16
- return []
17
-
18
  edge_index = torch.tensor(edge_index).t().contiguous()
19
 
20
- # fitur dummy (random)
21
  x = torch.rand((len(nodes), 16))
22
 
23
  class GNN(torch.nn.Module):
@@ -35,19 +34,16 @@ def run_gnn(nodes, edges):
35
  model = GNN()
36
  out = model(x, edge_index)
37
 
38
- scores = out.detach().numpy()
39
-
40
- result = []
41
- for i, s in enumerate(scores):
42
- result.append({
43
- "node": i,
44
- "score": float(s[0])
45
- })
46
-
47
- return result
48
 
49
  except Exception as e:
50
- print("⚠️ GNN fallback:", e)
51
 
52
- # fallback sederhana
53
- return [{"node": i, "score": 0.5} for i in range(len(nodes))]
 
 
 
 
1
+ # =========================
2
+ # 🔥 SAFE GNN (NO CRASH)
3
+ # =========================
4
 
5
  def run_gnn(nodes, edges):
6
  try:
7
+ import torch
8
  from torch_geometric.nn import GCNConv
9
+ from torch_geometric.data import Data
10
 
11
+ if len(nodes) == 0 or len(edges) == 0:
12
+ return []
 
13
 
14
+ edge_index = []
15
  for e in edges:
16
  edge_index.append([e["source"], e["target"]])
17
 
 
 
 
18
  edge_index = torch.tensor(edge_index).t().contiguous()
19
 
 
20
  x = torch.rand((len(nodes), 16))
21
 
22
  class GNN(torch.nn.Module):
 
34
  model = GNN()
35
  out = model(x, edge_index)
36
 
37
+ return [
38
+ {"node": i, "score": float(out[i][0])}
39
+ for i in range(len(nodes))
40
+ ]
 
 
 
 
 
 
41
 
42
  except Exception as e:
43
+ print("⚠️ GNN fallback aktif:", e)
44
 
45
+ # 🔥 fallback TANPA torch
46
+ return [
47
+ {"node": i, "score": 0.5}
48
+ for i in range(len(nodes))
49
+ ]