EllipticBitcoin / elliptic_data.py
Danielfonseca1212's picture
Update elliptic_data.py
18b6b0c verified
# elliptic_data.py — Loader do Elliptic Bitcoin Dataset via PyG
import torch
import numpy as np
from torch_geometric.datasets import EllipticBitcoinDataset
from torch_geometric.transforms import NormalizeFeatures
def carregar_elliptic(root='/tmp/elliptic', normalize=True):
transform = NormalizeFeatures() if normalize else None
try:
dataset = EllipticBitcoinDataset(root=root, transform=transform)
data = dataset[0]
return data, True
except Exception as e:
return None, str(e)
def preparar_splits(data):
labeled_mask = data.y != 2
train_mask = data.train_mask & labeled_mask
test_mask = data.test_mask & labeled_mask
y_train = data.y[train_mask]
y_test = data.y[test_mask]
stats = {
'n_nos': data.x.shape[0],
'n_arestas': data.edge_index.shape[1],
'n_features': data.x.shape[1],
'n_rotulados': int(labeled_mask.sum()),
'n_train': int(train_mask.sum()),
'n_test': int(test_mask.sum()),
'n_ilicito_train': int((y_train == 0).sum()),
'n_licito_train': int((y_train == 1).sum()),
'n_ilicito_test': int((y_test == 0).sum()),
'n_licito_test': int((y_test == 1).sum()),
'taxa_fraude_train': float((y_train==0).sum()/max(len(y_train),1)),
'taxa_fraude_test': float((y_test ==0).sum()/max(len(y_test),1)),
}
data.train_mask_labeled = train_mask
data.test_mask_labeled = test_mask
return data, stats
def criar_mini_batches(data, batch_size=512, split='train'):
"""
Mini-batches sem NeighborLoader (não precisa de torch-sparse).
Retorna lista de (x, edge_index_local, y, mask_seed) por batch.
"""
mask = data.train_mask_labeled if split == 'train' else data.test_mask_labeled
indices = mask.nonzero(as_tuple=True)[0]
# Shuffle para treino
if split == 'train':
perm = torch.randperm(len(indices))
indices = indices[perm]
batches = []
ei = data.edge_index
src, dst = ei[0], ei[1]
for i in range(0, len(indices), batch_size):
seed = indices[i:i+batch_size]
seed_set = set(seed.tolist())
# Inclui vizinhos de 1-hop dos seeds
mask_edge = torch.isin(src, seed)
vizinhos = dst[mask_edge].unique()
nos_batch = torch.cat([seed, vizinhos]).unique()
nos_set = set(nos_batch.tolist())
# Remapeia índices locais
nos_sorted = nos_batch.sort()[0]
global2local = {int(g): l for l, g in enumerate(nos_sorted.tolist())}
# Arestas internas ao batch
mask_int = (torch.isin(src, nos_sorted) & torch.isin(dst, nos_sorted))
ei_batch = ei[:, mask_int]
ei_local = torch.stack([
torch.tensor([global2local[int(n)] for n in ei_batch[0].tolist()]),
torch.tensor([global2local[int(n)] for n in ei_batch[1].tolist()])
])
x_batch = data.x[nos_sorted]
y_batch = data.y[nos_sorted]
# Mask dos seeds dentro do batch local
seed_local = torch.tensor([global2local[int(s)] for s in seed.tolist()])
batches.append((x_batch, ei_local, y_batch, seed_local))
return batches