| """ |
| baseline/graphsage_baseline.py |
| Baseline GraphSAGE — abordagem tradicional. |
| |
| Converte as tabelas SQL em um GRAFO ESTÁTICO e roda GraphSAGE. |
| Esta é exatamente a abordagem que RelGNN evita. |
| |
| Implementado com PyTorch puro (sem DGL) para portabilidade no HF Spaces. |
| """ |
|
|
| import time |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from sklearn.model_selection import train_test_split |
| from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score |
| from typing import Dict, Callable, Tuple, List |
|
|
|
|
| |
|
|
| def tables_to_static_graph(tables: Dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: |
| """ |
| Converte tabelas SQL em grafo estático. |
| Esta etapa é cara e perde semântica relacional. |
| |
| Retorna: |
| node_features: [N, F] |
| edge_src: [E] |
| edge_dst: [E] |
| labels: [N_customers] (apenas nós de clientes têm label) |
| """ |
| customers = tables["customers"] |
| orders = tables["orders"] |
| lineitem = tables["lineitem"] |
|
|
| n_cust = len(customers) |
| n_ord = len(orders) |
|
|
| |
| ord_offset = n_cust |
|
|
| |
| cust_feats = customers[["c_acctbal", "c_nationkey", "c_account_age_days", |
| "c_num_prev_orders"]].fillna(0).values.astype(np.float32) |
| ord_feats = orders[["o_totalprice", "o_shippriority"]].fillna(0).values.astype(np.float32) |
|
|
| |
| max_dim = max(cust_feats.shape[1], ord_feats.shape[1]) |
| def pad_cols(arr, target): |
| if arr.shape[1] < target: |
| arr = np.hstack([arr, np.zeros((len(arr), target - arr.shape[1]), dtype=np.float32)]) |
| return arr |
|
|
| cust_feats = pad_cols(cust_feats, max_dim) |
| ord_feats = pad_cols(ord_feats, max_dim) |
| node_features = np.vstack([cust_feats, ord_feats]) |
|
|
| |
| col_std = node_features.std(axis=0) |
| col_std[col_std == 0] = 1 |
| node_features = (node_features - node_features.mean(axis=0)) / col_std |
|
|
| |
| cust_ids = orders["o_custkey"].values |
| ord_ids = np.arange(n_ord) + ord_offset |
|
|
| valid_mask = cust_ids < n_cust |
| src = np.concatenate([cust_ids[valid_mask], ord_ids[valid_mask]]) |
| dst = np.concatenate([ord_ids[valid_mask], cust_ids[valid_mask]]) |
|
|
| |
| fraud_by_cust = orders.groupby("o_custkey")["is_fraud"].max() |
| labels = customers["c_custkey"].map(fraud_by_cust).fillna(0).values.astype(np.float32) |
|
|
| return node_features, src, dst, labels |
|
|
|
|
| |
|
|
| class SAGEConv(nn.Module): |
| """GraphSAGE conv simplificado (mean aggregator) em PyTorch puro.""" |
| def __init__(self, in_dim: int, out_dim: int): |
| super().__init__() |
| self.W_self = nn.Linear(in_dim, out_dim, bias=False) |
| self.W_neigh = nn.Linear(in_dim, out_dim, bias=False) |
| self.bias = nn.Parameter(torch.zeros(out_dim)) |
|
|
| def forward(self, h: torch.Tensor, adj: torch.Tensor) -> torch.Tensor: |
| """ |
| h: [N, in_dim] |
| adj: [N, N] — adjacência normalizada |
| """ |
| agg = torch.mm(adj, h) |
| out = self.W_self(h) + self.W_neigh(agg) + self.bias |
| return F.relu(out) |
|
|
|
|
| class GraphSAGEModel(nn.Module): |
| def __init__(self, in_dim: int, hidden_dim: int, dropout: float = 0.2): |
| super().__init__() |
| self.conv1 = SAGEConv(in_dim, hidden_dim) |
| self.conv2 = SAGEConv(hidden_dim, hidden_dim) |
| self.dropout = nn.Dropout(dropout) |
| self.head = nn.Linear(hidden_dim, 1) |
|
|
| def forward(self, h: torch.Tensor, adj: torch.Tensor) -> torch.Tensor: |
| h = self.conv1(h, adj) |
| h = self.dropout(h) |
| h = self.conv2(h, adj) |
| return self.head(h).squeeze(-1) |
|
|
|
|
| def build_adj_matrix(n_nodes: int, src: np.ndarray, dst: np.ndarray) -> torch.Tensor: |
| """Adjacência normalizada por grau.""" |
| adj = torch.zeros(n_nodes, n_nodes) |
| for s, d in zip(src, dst): |
| if s < n_nodes and d < n_nodes: |
| adj[d, s] = 1.0 |
| |
| deg = adj.sum(dim=1, keepdim=True).clamp(min=1) |
| return adj / deg |
|
|
|
|
| |
|
|
| class GraphSAGEBaseline: |
| def __init__(self, hidden_dim: int = 64, num_epochs: int = 50): |
| self.hidden_dim = hidden_dim |
| self.num_epochs = num_epochs |
|
|
| def fit( |
| self, |
| tables: Dict, |
| log_fn: Callable = print, |
| ) -> Tuple[Dict, List[Dict]]: |
|
|
| t_start = time.time() |
| log_fn(" [GraphSAGE] Convertendo tabelas SQL → grafo estático...") |
|
|
| |
| node_features, src, dst, labels = tables_to_static_graph(tables) |
| n_nodes = len(node_features) |
| n_cust = len(tables["customers"]) |
|
|
| log_fn(f" [GraphSAGE] Grafo: {n_nodes} nós, {len(src)} arestas") |
|
|
| |
| |
| if n_nodes > 3000: |
| |
| keep = min(n_nodes, 3000) |
| node_features = node_features[:keep] |
| valid = (src < keep) & (dst < keep) |
| src, dst = src[valid], dst[valid] |
| labels_full = labels |
| labels = labels[:min(n_cust, keep)] |
| n_nodes = keep |
| n_cust = min(n_cust, keep) |
|
|
| adj = build_adj_matrix(n_nodes, src, dst) |
|
|
| X = torch.tensor(node_features, dtype=torch.float32) |
| y_all = np.zeros(n_nodes) |
| y_all[:len(labels)] = labels |
|
|
| in_dim = node_features.shape[1] |
| model = GraphSAGEModel(in_dim, self.hidden_dim) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) |
|
|
| |
| cust_idx = np.arange(n_cust) |
| idx_tr, idx_te = train_test_split( |
| cust_idx, test_size=0.2, random_state=42, |
| stratify=(labels[:n_cust] > 0.5).astype(int) |
| ) |
| y_tr = torch.tensor(labels[idx_tr], dtype=torch.float32) |
| pos_weight = torch.tensor([(y_tr == 0).sum() / max((y_tr == 1).sum(), 1)]) |
| loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight) |
|
|
| history = [] |
| log_interval = max(1, self.num_epochs // 5) |
|
|
| model.train() |
| for epoch in range(1, self.num_epochs + 1): |
| optimizer.zero_grad() |
| all_logits = model(X, adj) |
| logits_tr = all_logits[idx_tr] |
| loss = loss_fn(logits_tr, y_tr) |
| loss.backward() |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
|
|
| if epoch % log_interval == 0 or epoch == self.num_epochs: |
| model.eval() |
| with torch.no_grad(): |
| logits_te = model(X, adj)[idx_te] |
| probs_te = torch.sigmoid(logits_te).numpy() |
| try: |
| auc = roc_auc_score(labels[idx_te], probs_te) |
| except Exception: |
| auc = 0.5 |
| history.append({"epoch": epoch, "auc": auc}) |
| model.train() |
|
|
| |
| model.eval() |
| with torch.no_grad(): |
| logits_te = model(X, adj)[idx_te] |
| probs_te = torch.sigmoid(logits_te).numpy() |
|
|
| preds = (probs_te > 0.5).astype(int) |
| y_true = labels[idx_te].astype(int) |
| try: |
| auc = roc_auc_score(y_true, probs_te) |
| f1 = f1_score(y_true, preds, zero_division=0) |
| precision = precision_score(y_true, preds, zero_division=0) |
| recall = recall_score(y_true, preds, zero_division=0) |
| except Exception: |
| auc = f1 = precision = recall = 0.5 |
|
|
| train_time = round(time.time() - t_start, 1) |
| log_fn(f" [GraphSAGE] Tempo total (incl. conversão para grafo): {train_time}s") |
|
|
| metrics = { |
| "auc": round(auc, 4), "f1": round(f1, 4), |
| "precision": round(precision, 4), "recall": round(recall, 4), |
| "train_time": train_time, |
| } |
| return metrics, history |