Danielfonseca1212 commited on
Commit
18b6b0c
·
verified ·
1 Parent(s): 2672a04

Update elliptic_data.py

Browse files
Files changed (1) hide show
  1. elliptic_data.py +50 -54
elliptic_data.py CHANGED
@@ -1,30 +1,11 @@
1
  # elliptic_data.py — Loader do Elliptic Bitcoin Dataset via PyG
2
  import torch
3
  import numpy as np
4
- import pandas as pd
5
  from torch_geometric.datasets import EllipticBitcoinDataset
6
- from torch_geometric.loader import NeighborLoader
7
  from torch_geometric.transforms import NormalizeFeatures
8
- import os
9
 
10
  def carregar_elliptic(root='/tmp/elliptic', normalize=True):
11
- """
12
- Carrega o Elliptic Bitcoin Dataset via PyG.
13
-
14
- Estatísticas reais:
15
- - 203,769 nós (transações Bitcoin)
16
- - 234,355 arestas (fluxo de Bitcoin)
17
- - 166 features por nó (94 locais + 72 agregadas)
18
- - 2 classes: ilícito (lavagem) / lícito
19
- - 49 timesteps (jan 2017 - set 2018)
20
- - ~21% rotulados, ~79% desconhecidos
21
-
22
- Split temporal (como no paper):
23
- - Treino: timesteps 1-34
24
- - Teste: timesteps 35-49
25
- """
26
  transform = NormalizeFeatures() if normalize else None
27
-
28
  try:
29
  dataset = EllipticBitcoinDataset(root=root, transform=transform)
30
  data = dataset[0]
@@ -32,25 +13,14 @@ def carregar_elliptic(root='/tmp/elliptic', normalize=True):
32
  except Exception as e:
33
  return None, str(e)
34
 
35
-
36
  def preparar_splits(data):
37
- """
38
- Split temporal como descrito no paper original:
39
- Treino nos primeiros timesteps, teste nos últimos.
40
- Máscara 'unknown' (classe 2) excluída do treino/teste.
41
- """
42
- # PyG já fornece máscaras train/test no Elliptic
43
- # Classe 0 = ilícito, 1 = lícito, 2 = desconhecido
44
-
45
- # Filtra apenas nós rotulados
46
  labeled_mask = data.y != 2
47
  train_mask = data.train_mask & labeled_mask
48
  test_mask = data.test_mask & labeled_mask
49
-
50
- # Estatísticas
51
  y_train = data.y[train_mask]
52
  y_test = data.y[test_mask]
53
-
54
  stats = {
55
  'n_nos': data.x.shape[0],
56
  'n_arestas': data.edge_index.shape[1],
@@ -62,33 +32,59 @@ def preparar_splits(data):
62
  'n_licito_train': int((y_train == 1).sum()),
63
  'n_ilicito_test': int((y_test == 0).sum()),
64
  'n_licito_test': int((y_test == 1).sum()),
65
- 'taxa_fraude_train': float((y_train==0).sum()/len(y_train)),
66
- 'taxa_fraude_test': float((y_test ==0).sum()/len(y_test)),
67
  }
68
-
69
  data.train_mask_labeled = train_mask
70
  data.test_mask_labeled = test_mask
71
-
72
  return data, stats
73
 
74
 
75
- def criar_loaders(data, num_neighbors=[10, 5], batch_size=512):
76
  """
77
- Mini-batch com NeighborLoader para GraphSAGE inductive.
78
- Amostra vizinhos em vez de usar o grafo completo.
79
  """
80
- train_loader = NeighborLoader(
81
- data,
82
- num_neighbors=num_neighbors,
83
- batch_size=batch_size,
84
- input_nodes=data.train_mask_labeled,
85
- shuffle=True,
86
- )
87
- test_loader = NeighborLoader(
88
- data,
89
- num_neighbors=num_neighbors,
90
- batch_size=batch_size,
91
- input_nodes=data.test_mask_labeled,
92
- shuffle=False,
93
- )
94
- return train_loader, test_loader
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # elliptic_data.py — Loader do Elliptic Bitcoin Dataset via PyG
2
  import torch
3
  import numpy as np
 
4
  from torch_geometric.datasets import EllipticBitcoinDataset
 
5
  from torch_geometric.transforms import NormalizeFeatures
 
6
 
7
  def carregar_elliptic(root='/tmp/elliptic', normalize=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  transform = NormalizeFeatures() if normalize else None
 
9
  try:
10
  dataset = EllipticBitcoinDataset(root=root, transform=transform)
11
  data = dataset[0]
 
13
  except Exception as e:
14
  return None, str(e)
15
 
 
16
  def preparar_splits(data):
 
 
 
 
 
 
 
 
 
17
  labeled_mask = data.y != 2
18
  train_mask = data.train_mask & labeled_mask
19
  test_mask = data.test_mask & labeled_mask
20
+
 
21
  y_train = data.y[train_mask]
22
  y_test = data.y[test_mask]
23
+
24
  stats = {
25
  'n_nos': data.x.shape[0],
26
  'n_arestas': data.edge_index.shape[1],
 
32
  'n_licito_train': int((y_train == 1).sum()),
33
  'n_ilicito_test': int((y_test == 0).sum()),
34
  'n_licito_test': int((y_test == 1).sum()),
35
+ 'taxa_fraude_train': float((y_train==0).sum()/max(len(y_train),1)),
36
+ 'taxa_fraude_test': float((y_test ==0).sum()/max(len(y_test),1)),
37
  }
38
+
39
  data.train_mask_labeled = train_mask
40
  data.test_mask_labeled = test_mask
 
41
  return data, stats
42
 
43
 
44
+ def criar_mini_batches(data, batch_size=512, split='train'):
45
  """
46
+ Mini-batches sem NeighborLoader (não precisa de torch-sparse).
47
+ Retorna lista de (x, edge_index_local, y, mask_seed) por batch.
48
  """
49
+ mask = data.train_mask_labeled if split == 'train' else data.test_mask_labeled
50
+ indices = mask.nonzero(as_tuple=True)[0]
51
+ # Shuffle para treino
52
+ if split == 'train':
53
+ perm = torch.randperm(len(indices))
54
+ indices = indices[perm]
55
+
56
+ batches = []
57
+ ei = data.edge_index
58
+ src, dst = ei[0], ei[1]
59
+
60
+ for i in range(0, len(indices), batch_size):
61
+ seed = indices[i:i+batch_size]
62
+ seed_set = set(seed.tolist())
63
+
64
+ # Inclui vizinhos de 1-hop dos seeds
65
+ mask_edge = torch.isin(src, seed)
66
+ vizinhos = dst[mask_edge].unique()
67
+ nos_batch = torch.cat([seed, vizinhos]).unique()
68
+ nos_set = set(nos_batch.tolist())
69
+
70
+ # Remapeia índices locais
71
+ nos_sorted = nos_batch.sort()[0]
72
+ global2local = {int(g): l for l, g in enumerate(nos_sorted.tolist())}
73
+
74
+ # Arestas internas ao batch
75
+ mask_int = (torch.isin(src, nos_sorted) & torch.isin(dst, nos_sorted))
76
+ ei_batch = ei[:, mask_int]
77
+ ei_local = torch.stack([
78
+ torch.tensor([global2local[int(n)] for n in ei_batch[0].tolist()]),
79
+ torch.tensor([global2local[int(n)] for n in ei_batch[1].tolist()])
80
+ ])
81
+
82
+ x_batch = data.x[nos_sorted]
83
+ y_batch = data.y[nos_sorted]
84
+
85
+ # Mask dos seeds dentro do batch local
86
+ seed_local = torch.tensor([global2local[int(s)] for s in seed.tolist()])
87
+
88
+ batches.append((x_batch, ei_local, y_batch, seed_local))
89
+
90
+ return batches