FedGNN / data_loader.py
Danielfonseca1212's picture
Create data_loader.py
e8a515a verified
import torch
from torch_geometric.data import Data
import numpy as np
def generate_bank_data(bank_id, n_nodes=500, n_edges=1500, fraud_rate=0.05):
"""
Gera um grafo sintético representando transações bancárias.
Nós = Contas/Transações
Arestas = Fluxo de dinheiro
Features = Valor, hora, localização, histórico
Label = 0 (Legítimo), 1 (Fraude)
"""
np.random.seed(bank_id * 42) # Reprodutibilidade por banco
# Features simuladas (Valor, Hora, Score_Risco)
x = np.random.rand(n_nodes, 3).astype(np.float32)
# Criar arestas aleatórias (transações)
edge_index = torch.randint(0, n_nodes, (2, n_edges))
# Labels
y = torch.zeros(n_nodes, dtype=torch.long)
num_frauds = int(n_nodes * fraud_rate)
fraud_indices = np.random.choice(n_nodes, num_frauds, replace=False)
y[fraud_indices] = 1
# Máscara de treino/teste (80/20)
train_mask = torch.zeros(n_nodes, dtype=torch.bool)
test_mask = torch.zeros(n_nodes, dtype=torch.bool)
indices = np.random.permutation(n_nodes)
train_idx = indices[:int(n_nodes * 0.8)]
test_idx = indices[int(n_nodes * 0.8):]
train_mask[train_idx] = True
test_mask[test_idx] = True
data = Data(x=torch.tensor(x), edge_index=edge_index, y=y,
train_mask=train_mask, test_mask=test_mask)
return data
def load_datasets():
"""Carrega dados para 3 bancos fictícios"""
banks = {
"Banco A (Varejo)": generate_bank_data(1, fraud_rate=0.05),
"Banco B (Investimentos)": generate_bank_data(2, fraud_rate=0.03),
"Banco C (Fintech)": generate_bank_data(3, fraud_rate=0.08)
}
return banks