Danielfonseca1212 commited on
Commit
d27f646
Β·
verified Β·
1 Parent(s): 6864b79

Create graphsage baseline .py

Browse files
Files changed (1) hide show
  1. graphsage baseline .py +231 -0
graphsage baseline .py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ baseline/graphsage_baseline.py
3
+ Baseline GraphSAGE β€” abordagem tradicional.
4
+
5
+ Converte as tabelas SQL em um GRAFO ESTÁTICO e roda GraphSAGE.
6
+ Esta Γ© exatamente a abordagem que RelGNN evita.
7
+
8
+ Implementado com PyTorch puro (sem DGL) para portabilidade no HF Spaces.
9
+ """
10
+
11
+ import time
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from sklearn.model_selection import train_test_split
17
+ from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
18
+ from typing import Dict, Callable, Tuple, List
19
+
20
+
21
+ # ─── GRAPH CONSTRUCTION (o que RelGNN EVITA) ──────────────────────────────────
22
+
23
+ def tables_to_static_graph(tables: Dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
24
+ """
25
+ Converte tabelas SQL em grafo estΓ‘tico.
26
+ Esta etapa Γ© cara e perde semΓ’ntica relacional.
27
+
28
+ Retorna:
29
+ node_features: [N, F]
30
+ edge_src: [E]
31
+ edge_dst: [E]
32
+ labels: [N_customers] (apenas nΓ³s de clientes tΓͺm label)
33
+ """
34
+ customers = tables["customers"]
35
+ orders = tables["orders"]
36
+ lineitem = tables["lineitem"]
37
+
38
+ n_cust = len(customers)
39
+ n_ord = len(orders)
40
+
41
+ # Offset: clientes = [0, n_cust), pedidos = [n_cust, n_cust+n_ord)
42
+ ord_offset = n_cust
43
+
44
+ # Features dos nΓ³s
45
+ cust_feats = customers[["c_acctbal", "c_nationkey", "c_account_age_days",
46
+ "c_num_prev_orders"]].fillna(0).values.astype(np.float32)
47
+ ord_feats = orders[["o_totalprice", "o_shippriority"]].fillna(0).values.astype(np.float32)
48
+
49
+ # Padeia para mesma dim
50
+ max_dim = max(cust_feats.shape[1], ord_feats.shape[1])
51
+ def pad_cols(arr, target):
52
+ if arr.shape[1] < target:
53
+ arr = np.hstack([arr, np.zeros((len(arr), target - arr.shape[1]), dtype=np.float32)])
54
+ return arr
55
+
56
+ cust_feats = pad_cols(cust_feats, max_dim)
57
+ ord_feats = pad_cols(ord_feats, max_dim)
58
+ node_features = np.vstack([cust_feats, ord_feats])
59
+
60
+ # NormalizaΓ§Γ£o
61
+ col_std = node_features.std(axis=0)
62
+ col_std[col_std == 0] = 1
63
+ node_features = (node_features - node_features.mean(axis=0)) / col_std
64
+
65
+ # Arestas: customer ↔ order
66
+ cust_ids = orders["o_custkey"].values
67
+ ord_ids = np.arange(n_ord) + ord_offset
68
+
69
+ valid_mask = cust_ids < n_cust
70
+ src = np.concatenate([cust_ids[valid_mask], ord_ids[valid_mask]])
71
+ dst = np.concatenate([ord_ids[valid_mask], cust_ids[valid_mask]])
72
+
73
+ # Labels para nΓ³s de clientes
74
+ fraud_by_cust = orders.groupby("o_custkey")["is_fraud"].max()
75
+ labels = customers["c_custkey"].map(fraud_by_cust).fillna(0).values.astype(np.float32)
76
+
77
+ return node_features, src, dst, labels
78
+
79
+
80
+ # ─── GRAPHSAGE LAYER ──────────────────────────────────────────────────────────
81
+
82
+ class SAGEConv(nn.Module):
83
+ """GraphSAGE conv simplificado (mean aggregator) em PyTorch puro."""
84
+ def __init__(self, in_dim: int, out_dim: int):
85
+ super().__init__()
86
+ self.W_self = nn.Linear(in_dim, out_dim, bias=False)
87
+ self.W_neigh = nn.Linear(in_dim, out_dim, bias=False)
88
+ self.bias = nn.Parameter(torch.zeros(out_dim))
89
+
90
+ def forward(self, h: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
91
+ """
92
+ h: [N, in_dim]
93
+ adj: [N, N] β€” adjacΓͺncia normalizada
94
+ """
95
+ agg = torch.mm(adj, h) # Mean neighbor aggregation
96
+ out = self.W_self(h) + self.W_neigh(agg) + self.bias
97
+ return F.relu(out)
98
+
99
+
100
+ class GraphSAGEModel(nn.Module):
101
+ def __init__(self, in_dim: int, hidden_dim: int, dropout: float = 0.2):
102
+ super().__init__()
103
+ self.conv1 = SAGEConv(in_dim, hidden_dim)
104
+ self.conv2 = SAGEConv(hidden_dim, hidden_dim)
105
+ self.dropout = nn.Dropout(dropout)
106
+ self.head = nn.Linear(hidden_dim, 1)
107
+
108
+ def forward(self, h: torch.Tensor, adj: torch.Tensor) -> torch.Tensor:
109
+ h = self.conv1(h, adj)
110
+ h = self.dropout(h)
111
+ h = self.conv2(h, adj)
112
+ return self.head(h).squeeze(-1)
113
+
114
+
115
+ def build_adj_matrix(n_nodes: int, src: np.ndarray, dst: np.ndarray) -> torch.Tensor:
116
+ """AdjacΓͺncia normalizada por grau."""
117
+ adj = torch.zeros(n_nodes, n_nodes)
118
+ for s, d in zip(src, dst):
119
+ if s < n_nodes and d < n_nodes:
120
+ adj[d, s] = 1.0
121
+ # Normaliza por grau
122
+ deg = adj.sum(dim=1, keepdim=True).clamp(min=1)
123
+ return adj / deg
124
+
125
+
126
+ # ─── GRAPHSAGE BASELINE ───────────────────────────────────────────────────────
127
+
128
+ class GraphSAGEBaseline:
129
+ def __init__(self, hidden_dim: int = 64, num_epochs: int = 50):
130
+ self.hidden_dim = hidden_dim
131
+ self.num_epochs = num_epochs
132
+
133
+ def fit(
134
+ self,
135
+ tables: Dict,
136
+ log_fn: Callable = print,
137
+ ) -> Tuple[Dict, List[Dict]]:
138
+
139
+ t_start = time.time()
140
+ log_fn(" [GraphSAGE] Convertendo tabelas SQL β†’ grafo estΓ‘tico...")
141
+
142
+ # Passo custoso: conversΓ£o para grafo
143
+ node_features, src, dst, labels = tables_to_static_graph(tables)
144
+ n_nodes = len(node_features)
145
+ n_cust = len(tables["customers"])
146
+
147
+ log_fn(f" [GraphSAGE] Grafo: {n_nodes} nΓ³s, {len(src)} arestas")
148
+
149
+ # Adj matrix (limitada a n_nodes pequeno para HF Spaces)
150
+ # Para grafos grandes, usarΓ­amos sparse; aqui simplificamos
151
+ if n_nodes > 3000:
152
+ # Subsample para caber em memΓ³ria
153
+ keep = min(n_nodes, 3000)
154
+ node_features = node_features[:keep]
155
+ valid = (src < keep) & (dst < keep)
156
+ src, dst = src[valid], dst[valid]
157
+ labels_full = labels
158
+ labels = labels[:min(n_cust, keep)]
159
+ n_nodes = keep
160
+ n_cust = min(n_cust, keep)
161
+
162
+ adj = build_adj_matrix(n_nodes, src, dst)
163
+
164
+ X = torch.tensor(node_features, dtype=torch.float32)
165
+ y_all = np.zeros(n_nodes)
166
+ y_all[:len(labels)] = labels
167
+
168
+ in_dim = node_features.shape[1]
169
+ model = GraphSAGEModel(in_dim, self.hidden_dim)
170
+ optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
171
+
172
+ # Índices de treino/teste (apenas nΓ³s de clientes tΓͺm label)
173
+ cust_idx = np.arange(n_cust)
174
+ idx_tr, idx_te = train_test_split(
175
+ cust_idx, test_size=0.2, random_state=42,
176
+ stratify=(labels[:n_cust] > 0.5).astype(int)
177
+ )
178
+ y_tr = torch.tensor(labels[idx_tr], dtype=torch.float32)
179
+ pos_weight = torch.tensor([(y_tr == 0).sum() / max((y_tr == 1).sum(), 1)])
180
+ loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
181
+
182
+ history = []
183
+ log_interval = max(1, self.num_epochs // 5)
184
+
185
+ model.train()
186
+ for epoch in range(1, self.num_epochs + 1):
187
+ optimizer.zero_grad()
188
+ all_logits = model(X, adj)
189
+ logits_tr = all_logits[idx_tr]
190
+ loss = loss_fn(logits_tr, y_tr)
191
+ loss.backward()
192
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
193
+ optimizer.step()
194
+
195
+ if epoch % log_interval == 0 or epoch == self.num_epochs:
196
+ model.eval()
197
+ with torch.no_grad():
198
+ logits_te = model(X, adj)[idx_te]
199
+ probs_te = torch.sigmoid(logits_te).numpy()
200
+ try:
201
+ auc = roc_auc_score(labels[idx_te], probs_te)
202
+ except Exception:
203
+ auc = 0.5
204
+ history.append({"epoch": epoch, "auc": auc})
205
+ model.train()
206
+
207
+ # MΓ©tricas finais
208
+ model.eval()
209
+ with torch.no_grad():
210
+ logits_te = model(X, adj)[idx_te]
211
+ probs_te = torch.sigmoid(logits_te).numpy()
212
+
213
+ preds = (probs_te > 0.5).astype(int)
214
+ y_true = labels[idx_te].astype(int)
215
+ try:
216
+ auc = roc_auc_score(y_true, probs_te)
217
+ f1 = f1_score(y_true, preds, zero_division=0)
218
+ precision = precision_score(y_true, preds, zero_division=0)
219
+ recall = recall_score(y_true, preds, zero_division=0)
220
+ except Exception:
221
+ auc = f1 = precision = recall = 0.5
222
+
223
+ train_time = round(time.time() - t_start, 1)
224
+ log_fn(f" [GraphSAGE] Tempo total (incl. conversΓ£o para grafo): {train_time}s")
225
+
226
+ metrics = {
227
+ "auc": round(auc, 4), "f1": round(f1, 4),
228
+ "precision": round(precision, 4), "recall": round(recall, 4),
229
+ "train_time": train_time,
230
+ }
231
+ return metrics, history