""" baseline/graphsage_baseline.py Baseline GraphSAGE — abordagem tradicional. Converte as tabelas SQL em um GRAFO ESTÁTICO e roda GraphSAGE. Esta é exatamente a abordagem que RelGNN evita. Implementado com PyTorch puro (sem DGL) para portabilidade no HF Spaces. """ import time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from sklearn.model_selection import train_test_split from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score from typing import Dict, Callable, Tuple, List # ─── GRAPH CONSTRUCTION (o que RelGNN EVITA) ────────────────────────────────── def tables_to_static_graph(tables: Dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Converte tabelas SQL em grafo estático. Esta etapa é cara e perde semântica relacional. Retorna: node_features: [N, F] edge_src: [E] edge_dst: [E] labels: [N_customers] (apenas nós de clientes têm label) """ customers = tables["customers"] orders = tables["orders"] lineitem = tables["lineitem"] n_cust = len(customers) n_ord = len(orders) # Offset: clientes = [0, n_cust), pedidos = [n_cust, n_cust+n_ord) ord_offset = n_cust # Features dos nós cust_feats = customers[["c_acctbal", "c_nationkey", "c_account_age_days", "c_num_prev_orders"]].fillna(0).values.astype(np.float32) ord_feats = orders[["o_totalprice", "o_shippriority"]].fillna(0).values.astype(np.float32) # Padeia para mesma dim max_dim = max(cust_feats.shape[1], ord_feats.shape[1]) def pad_cols(arr, target): if arr.shape[1] < target: arr = np.hstack([arr, np.zeros((len(arr), target - arr.shape[1]), dtype=np.float32)]) return arr cust_feats = pad_cols(cust_feats, max_dim) ord_feats = pad_cols(ord_feats, max_dim) node_features = np.vstack([cust_feats, ord_feats]) # Normalização col_std = node_features.std(axis=0) col_std[col_std == 0] = 1 node_features = (node_features - node_features.mean(axis=0)) / col_std # Arestas: customer ↔ order cust_ids = orders["o_custkey"].values ord_ids = np.arange(n_ord) + ord_offset valid_mask = cust_ids < n_cust src = np.concatenate([cust_ids[valid_mask], ord_ids[valid_mask]]) dst = np.concatenate([ord_ids[valid_mask], cust_ids[valid_mask]]) # Labels para nós de clientes fraud_by_cust = orders.groupby("o_custkey")["is_fraud"].max() labels = customers["c_custkey"].map(fraud_by_cust).fillna(0).values.astype(np.float32) return node_features, src, dst, labels # ─── GRAPHSAGE LAYER ────────────────────────────────────────────────────────── class SAGEConv(nn.Module): """GraphSAGE conv simplificado (mean aggregator) em PyTorch puro.""" def __init__(self, in_dim: int, out_dim: int): super().__init__() self.W_self = nn.Linear(in_dim, out_dim, bias=False) self.W_neigh = nn.Linear(in_dim, out_dim, bias=False) self.bias = nn.Parameter(torch.zeros(out_dim)) def forward(self, h: torch.Tensor, adj: torch.Tensor) -> torch.Tensor: """ h: [N, in_dim] adj: [N, N] — adjacência normalizada """ agg = torch.mm(adj, h) # Mean neighbor aggregation out = self.W_self(h) + self.W_neigh(agg) + self.bias return F.relu(out) class GraphSAGEModel(nn.Module): def __init__(self, in_dim: int, hidden_dim: int, dropout: float = 0.2): super().__init__() self.conv1 = SAGEConv(in_dim, hidden_dim) self.conv2 = SAGEConv(hidden_dim, hidden_dim) self.dropout = nn.Dropout(dropout) self.head = nn.Linear(hidden_dim, 1) def forward(self, h: torch.Tensor, adj: torch.Tensor) -> torch.Tensor: h = self.conv1(h, adj) h = self.dropout(h) h = self.conv2(h, adj) return self.head(h).squeeze(-1) def build_adj_matrix(n_nodes: int, src: np.ndarray, dst: np.ndarray) -> torch.Tensor: """Adjacência normalizada por grau.""" adj = torch.zeros(n_nodes, n_nodes) for s, d in zip(src, dst): if s < n_nodes and d < n_nodes: adj[d, s] = 1.0 # Normaliza por grau deg = adj.sum(dim=1, keepdim=True).clamp(min=1) return adj / deg # ─── GRAPHSAGE BASELINE ─────────────────────────────────────────────────────── class GraphSAGEBaseline: def __init__(self, hidden_dim: int = 64, num_epochs: int = 50): self.hidden_dim = hidden_dim self.num_epochs = num_epochs def fit( self, tables: Dict, log_fn: Callable = print, ) -> Tuple[Dict, List[Dict]]: t_start = time.time() log_fn(" [GraphSAGE] Convertendo tabelas SQL → grafo estático...") # Passo custoso: conversão para grafo node_features, src, dst, labels = tables_to_static_graph(tables) n_nodes = len(node_features) n_cust = len(tables["customers"]) log_fn(f" [GraphSAGE] Grafo: {n_nodes} nós, {len(src)} arestas") # Adj matrix (limitada a n_nodes pequeno para HF Spaces) # Para grafos grandes, usaríamos sparse; aqui simplificamos if n_nodes > 3000: # Subsample para caber em memória keep = min(n_nodes, 3000) node_features = node_features[:keep] valid = (src < keep) & (dst < keep) src, dst = src[valid], dst[valid] labels_full = labels labels = labels[:min(n_cust, keep)] n_nodes = keep n_cust = min(n_cust, keep) adj = build_adj_matrix(n_nodes, src, dst) X = torch.tensor(node_features, dtype=torch.float32) y_all = np.zeros(n_nodes) y_all[:len(labels)] = labels in_dim = node_features.shape[1] model = GraphSAGEModel(in_dim, self.hidden_dim) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) # Índices de treino/teste (apenas nós de clientes têm label) cust_idx = np.arange(n_cust) idx_tr, idx_te = train_test_split( cust_idx, test_size=0.2, random_state=42, stratify=(labels[:n_cust] > 0.5).astype(int) ) y_tr = torch.tensor(labels[idx_tr], dtype=torch.float32) pos_weight = torch.tensor([(y_tr == 0).sum() / max((y_tr == 1).sum(), 1)]) loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight) history = [] log_interval = max(1, self.num_epochs // 5) model.train() for epoch in range(1, self.num_epochs + 1): optimizer.zero_grad() all_logits = model(X, adj) logits_tr = all_logits[idx_tr] loss = loss_fn(logits_tr, y_tr) loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() if epoch % log_interval == 0 or epoch == self.num_epochs: model.eval() with torch.no_grad(): logits_te = model(X, adj)[idx_te] probs_te = torch.sigmoid(logits_te).numpy() try: auc = roc_auc_score(labels[idx_te], probs_te) except Exception: auc = 0.5 history.append({"epoch": epoch, "auc": auc}) model.train() # Métricas finais model.eval() with torch.no_grad(): logits_te = model(X, adj)[idx_te] probs_te = torch.sigmoid(logits_te).numpy() preds = (probs_te > 0.5).astype(int) y_true = labels[idx_te].astype(int) try: auc = roc_auc_score(y_true, probs_te) f1 = f1_score(y_true, preds, zero_division=0) precision = precision_score(y_true, preds, zero_division=0) recall = recall_score(y_true, preds, zero_division=0) except Exception: auc = f1 = precision = recall = 0.5 train_time = round(time.time() - t_start, 1) log_fn(f" [GraphSAGE] Tempo total (incl. conversão para grafo): {train_time}s") metrics = { "auc": round(auc, 4), "f1": round(f1, 4), "precision": round(precision, 4), "recall": round(recall, 4), "train_time": train_time, } return metrics, history