RelGNNDeepRelationalLearning / graphsage baseline .py
Danielfonseca1212's picture
Create graphsage baseline .py
d27f646 verified
"""
baseline/graphsage_baseline.py
Baseline GraphSAGE — abordagem tradicional.
Converte as tabelas SQL em um GRAFO ESTÁTICO e roda GraphSAGE.
Esta é exatamente a abordagem que RelGNN evita.
Implementado com PyTorch puro (sem DGL) para portabilidade no HF Spaces.
"""
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
from typing import Dict, Callable, Tuple, List
# ─── GRAPH CONSTRUCTION (o que RelGNN EVITA) ──────────────────────────────────
def tables_to_static_graph(tables: Dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Converte tabelas SQL em grafo estático.
Esta etapa é cara e perde semântica relacional.
Retorna:
node_features: [N, F]
edge_src: [E]
edge_dst: [E]
labels: [N_customers] (apenas nós de clientes têm label)
"""
customers = tables["customers"]
orders = tables["orders"]
lineitem = tables["lineitem"]
n_cust = len(customers)
n_ord = len(orders)
# Offset: clientes = [0, n_cust), pedidos = [n_cust, n_cust+n_ord)
ord_offset = n_cust
# Features dos nós
cust_feats = customers[["c_acctbal", "c_nationkey", "c_account_age_days",
"c_num_prev_orders"]].fillna(0).values.astype(np.float32)
ord_feats = orders[["o_totalprice", "o_shippriority"]].fillna(0).values.astype(np.float32)
# Padeia para mesma dim
max_dim = max(cust_feats.shape[1], ord_feats.shape[1])
def pad_cols(arr, target):
if arr.shape[1] < target:
arr = np.hstack([arr, np.zeros((len(arr), target - arr.shape[1]), dtype=np.float32)])
return arr
cust_feats = pad_cols(cust_feats, max_dim)
ord_feats = pad_cols(ord_feats, max_dim)
node_features = np.vstack([cust_feats, ord_feats])
# Normalização
col_std = node_features.std(axis=0)
col_std[col_std == 0] = 1
node_features = (node_features - node_features.mean(axis=0)) / col_std
# Arestas: customer ↔ order
cust_ids = orders["o_custkey"].values
ord_ids = np.arange(n_ord) + ord_offset
valid_mask = cust_ids < n_cust
src = np.concatenate([cust_ids[valid_mask], ord_ids[valid_mask]])
dst = np.concatenate([ord_ids[valid_mask], cust_ids[valid_mask]])
# Labels para nós de clientes
fraud_by_cust = orders.groupby("o_custkey")["is_fraud"].max()
labels = customers["c_custkey"].map(fraud_by_cust).fillna(0).values.astype(np.float32)
return node_features, src, dst, labels
# ─── GRAPHSAGE LAYER ──────────────────────────────────────────────────────────
class SAGEConv(nn.Module):
"""GraphSAGE conv simplificado (mean aggregator) em PyTorch puro."""
def __init__(self, in_dim: int, out_dim: int):
super().__init__()
self.W_self = nn.Linear(in_dim, out_dim, bias=False)
self.W_neigh = nn.Linear(in_dim, out_dim, bias=False)
self.bias = nn.Parameter(torch.zeros(out_dim))
def forward(self, h: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
"""
h: [N, in_dim]
adj: [N, N] — adjacência normalizada
"""
agg = torch.mm(adj, h) # Mean neighbor aggregation
out = self.W_self(h) + self.W_neigh(agg) + self.bias
return F.relu(out)
class GraphSAGEModel(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, dropout: float = 0.2):
super().__init__()
self.conv1 = SAGEConv(in_dim, hidden_dim)
self.conv2 = SAGEConv(hidden_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
self.head = nn.Linear(hidden_dim, 1)
def forward(self, h: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
h = self.conv1(h, adj)
h = self.dropout(h)
h = self.conv2(h, adj)
return self.head(h).squeeze(-1)
def build_adj_matrix(n_nodes: int, src: np.ndarray, dst: np.ndarray) -> torch.Tensor:
"""Adjacência normalizada por grau."""
adj = torch.zeros(n_nodes, n_nodes)
for s, d in zip(src, dst):
if s < n_nodes and d < n_nodes:
adj[d, s] = 1.0
# Normaliza por grau
deg = adj.sum(dim=1, keepdim=True).clamp(min=1)
return adj / deg
# ─── GRAPHSAGE BASELINE ───────────────────────────────────────────────────────
class GraphSAGEBaseline:
def __init__(self, hidden_dim: int = 64, num_epochs: int = 50):
self.hidden_dim = hidden_dim
self.num_epochs = num_epochs
def fit(
self,
tables: Dict,
log_fn: Callable = print,
) -> Tuple[Dict, List[Dict]]:
t_start = time.time()
log_fn(" [GraphSAGE] Convertendo tabelas SQL → grafo estático...")
# Passo custoso: conversão para grafo
node_features, src, dst, labels = tables_to_static_graph(tables)
n_nodes = len(node_features)
n_cust = len(tables["customers"])
log_fn(f" [GraphSAGE] Grafo: {n_nodes} nós, {len(src)} arestas")
# Adj matrix (limitada a n_nodes pequeno para HF Spaces)
# Para grafos grandes, usaríamos sparse; aqui simplificamos
if n_nodes > 3000:
# Subsample para caber em memória
keep = min(n_nodes, 3000)
node_features = node_features[:keep]
valid = (src < keep) & (dst < keep)
src, dst = src[valid], dst[valid]
labels_full = labels
labels = labels[:min(n_cust, keep)]
n_nodes = keep
n_cust = min(n_cust, keep)
adj = build_adj_matrix(n_nodes, src, dst)
X = torch.tensor(node_features, dtype=torch.float32)
y_all = np.zeros(n_nodes)
y_all[:len(labels)] = labels
in_dim = node_features.shape[1]
model = GraphSAGEModel(in_dim, self.hidden_dim)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
# Índices de treino/teste (apenas nós de clientes têm label)
cust_idx = np.arange(n_cust)
idx_tr, idx_te = train_test_split(
cust_idx, test_size=0.2, random_state=42,
stratify=(labels[:n_cust] > 0.5).astype(int)
)
y_tr = torch.tensor(labels[idx_tr], dtype=torch.float32)
pos_weight = torch.tensor([(y_tr == 0).sum() / max((y_tr == 1).sum(), 1)])
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
history = []
log_interval = max(1, self.num_epochs // 5)
model.train()
for epoch in range(1, self.num_epochs + 1):
optimizer.zero_grad()
all_logits = model(X, adj)
logits_tr = all_logits[idx_tr]
loss = loss_fn(logits_tr, y_tr)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
if epoch % log_interval == 0 or epoch == self.num_epochs:
model.eval()
with torch.no_grad():
logits_te = model(X, adj)[idx_te]
probs_te = torch.sigmoid(logits_te).numpy()
try:
auc = roc_auc_score(labels[idx_te], probs_te)
except Exception:
auc = 0.5
history.append({"epoch": epoch, "auc": auc})
model.train()
# Métricas finais
model.eval()
with torch.no_grad():
logits_te = model(X, adj)[idx_te]
probs_te = torch.sigmoid(logits_te).numpy()
preds = (probs_te > 0.5).astype(int)
y_true = labels[idx_te].astype(int)
try:
auc = roc_auc_score(y_true, probs_te)
f1 = f1_score(y_true, preds, zero_division=0)
precision = precision_score(y_true, preds, zero_division=0)
recall = recall_score(y_true, preds, zero_division=0)
except Exception:
auc = f1 = precision = recall = 0.5
train_time = round(time.time() - t_start, 1)
log_fn(f" [GraphSAGE] Tempo total (incl. conversão para grafo): {train_time}s")
metrics = {
"auc": round(auc, 4), "f1": round(f1, 4),
"precision": round(precision, 4), "recall": round(recall, 4),
"train_time": train_time,
}
return metrics, history