Danielfonseca1212 commited on
Commit
285c0e6
Β·
verified Β·
1 Parent(s): abd4459

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +296 -388
app.py CHANGED
@@ -1,27 +1,7 @@
1
- """
2
- RelGNN β€” Deep Relational Learning Β· Projeto 8
3
- Do SQL ao Graph AI sem Engenharia Manual Β· TPC-H Fraud Detection
4
- Arquivo ΓΊnico para Hugging Face Spaces (sem imports locais)
5
- """
6
 
7
- import subprocess, sys, os
8
-
9
- def _install(*args):
10
- subprocess.check_call([sys.executable, "-m", "pip", "install", *args,
11
- "-q", "--root-user-action=ignore"])
12
-
13
- # Instala dependΓͺncias que podem estar faltando no container
14
- _deps = {
15
- "torch": ["torch", "--index-url", "https://download.pytorch.org/whl/cpu"],
16
- "plotly": ["plotly"],
17
- "sklearn": ["scikit-learn"],
18
- }
19
- for _mod, _args in _deps.items():
20
- try:
21
- __import__(_mod)
22
- except ImportError:
23
- print(f"Installing {_args[0]}...", flush=True)
24
- _install(*_args)
25
 
26
  import gradio as gr
27
  import pandas as pd
@@ -38,82 +18,57 @@ from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_sco
38
  from sklearn.ensemble import GradientBoostingClassifier
39
  from collections import deque
40
  from dataclasses import dataclass
41
- from typing import List, Dict, Tuple, Optional, Callable
42
  warnings.filterwarnings("ignore")
43
 
44
  # ══════════════════════════════════════════════════════════════════════════════
45
- # DATA β€” TPC-H GENERATOR
46
  # ══════════════════════════════════════════════════════════════════════════════
47
 
48
  def generate_tpch_data(n_customers=500, n_orders=2000, fraud_rate=0.05, seed=42):
49
  rng = np.random.default_rng(seed)
50
- n_nations = 25
51
- nations = pd.DataFrame({
52
- "n_nationkey": np.arange(n_nations),
53
- "n_name": [f"NATION_{i}" for i in range(n_nations)],
54
- "n_regionkey": rng.integers(0, 5, n_nations),
55
- })
56
- n_suppliers = max(10, n_customers // 20)
57
- suppliers = pd.DataFrame({
58
- "s_suppkey": np.arange(n_suppliers),
59
- "s_nationkey": rng.integers(0, n_nations, n_suppliers),
60
- "s_acctbal": rng.uniform(-999, 9999, n_suppliers).round(2),
61
- })
62
- suppliers["s_risk_flag"] = (suppliers["s_acctbal"] < 100).astype(int)
63
- n_parts = max(50, n_orders // 5)
64
- parts = pd.DataFrame({
65
- "p_partkey": np.arange(n_parts),
66
- "p_retailprice": rng.uniform(5, 2000, n_parts).round(2),
67
- })
68
- customers = pd.DataFrame({
69
- "c_custkey": np.arange(n_customers),
70
- "c_nationkey": rng.integers(0, n_nations, n_customers),
71
- "c_acctbal": rng.uniform(-999, 9999, n_customers).round(2),
72
- "c_account_age_days": rng.integers(1, 3650, n_customers),
73
- "c_num_prev_orders": rng.poisson(5, n_customers),
74
- })
75
- customer_keys = rng.integers(0, n_customers, n_orders)
76
- totalprice = rng.exponential(scale=5000, size=n_orders).round(2)
77
- cust_acctbal = customers.loc[customer_keys, "c_acctbal"].values
78
- cust_age = customers.loc[customer_keys, "c_account_age_days"].values
79
- fraud_score = (0.4*(cust_acctbal<0).astype(float)
80
- + 0.3*(totalprice>15000).astype(float)
81
- + 0.2*(cust_age<30).astype(float)
82
- + 0.1*rng.random(n_orders))
83
- threshold = np.quantile(fraud_score, 1-fraud_rate)
84
- orders = pd.DataFrame({
85
- "o_orderkey": np.arange(n_orders),
86
- "o_custkey": customer_keys,
87
- "o_totalprice": totalprice,
88
- "o_shippriority": rng.integers(0, 3, n_orders),
89
- "is_fraud": (fraud_score >= threshold).astype(int),
90
- })
91
- n_lines = rng.integers(1, 8, n_orders)
92
- total_lines = n_lines.sum()
93
- lineitem = pd.DataFrame({
94
- "l_orderkey": np.repeat(np.arange(n_orders), n_lines),
95
- "l_partkey": rng.integers(0, n_parts, total_lines),
96
- "l_suppkey": rng.integers(0, n_suppliers, total_lines),
97
- "l_quantity": rng.integers(1, 51, total_lines).astype(float),
98
- "l_extendedprice":rng.uniform(10, 5000, total_lines).round(2),
99
- "l_discount": rng.uniform(0, 0.1, total_lines).round(2),
100
- "l_tax": rng.uniform(0, 0.08, total_lines).round(2),
101
- })
102
- return dict(customers=customers, orders=orders, lineitem=lineitem,
103
- supplier=suppliers, nation=nations, part=parts)
104
 
105
  # ══════════════════════════════════════════════════════════════════════════════
106
- # DATA β€” ATOMIC ROUTES
107
  # ══════════════════════════════════════════════════════════════════════════════
108
 
109
- TPCH_FK = [
110
- ("orders", "o_custkey", "customers","c_custkey"),
111
- ("lineitem", "l_orderkey", "orders", "o_orderkey"),
112
- ("lineitem", "l_suppkey", "supplier", "s_suppkey"),
113
- ("lineitem", "l_partkey", "part", "p_partkey"),
114
- ("customers","c_nationkey", "nation", "n_nationkey"),
115
- ("supplier", "s_nationkey", "nation", "n_nationkey"),
116
- ]
117
 
118
  @dataclass
119
  class AtomicRoute:
@@ -123,27 +78,24 @@ class AtomicRoute:
123
  active: bool = True
124
  def __post_init__(self): self.n_hops = len(self.path)-1
125
 
126
- def discover_atomic_routes(tables, max_hops=3):
127
  adj = {}
128
- for (s,sc,d,dc) in TPCH_FK:
129
- adj.setdefault(s,[]).append((d,"fwd"))
130
- adj.setdefault(d,[]).append((s,"bwd"))
131
- routes, queue = [], deque()
132
- queue.append((["customers"], {"customers"}))
133
- while queue:
134
- path, visited = queue.popleft()
135
- if len(path)-1 >= 1:
136
  w = 1.0/((len(path)-1)**1.5)
137
- routes.append(AtomicRoute(path=list(path), attention_weight=w, active=(len(path)-1<=2)))
138
- if len(path)-1 >= max_hops:
139
- continue
140
- for (nb,_) in adj.get(path[-1],[]):
141
  if nb not in visited and nb in tables:
142
- queue.append((path+[nb], visited|{nb}))
143
- routes.sort(key=lambda r: -r.attention_weight)
144
- ws = np.array([r.attention_weight for r in routes])
145
- ws = np.exp(ws)/np.exp(ws).sum()
146
- for r,w in zip(routes,ws): r.attention_weight = float(w)
147
  return routes
148
 
149
  # ══════════════════════════════════════════════════════════════════════════════
@@ -151,347 +103,305 @@ def discover_atomic_routes(tables, max_hops=3):
151
  # ══════════════════════════════════════════════════════════════════════════════
152
 
153
  def extract_features(tables):
154
- customers, orders, lineitem = tables["customers"], tables["orders"], tables["lineitem"]
155
- supplier, nation = tables["supplier"], tables["nation"]
156
- n = len(customers)
157
- fraud_by_cust = orders.groupby("o_custkey")["is_fraud"].max()
158
- labels = customers["c_custkey"].map(fraud_by_cust).fillna(0).values.astype(float)
159
-
160
- def norm(arr):
161
- mn,mx = arr.min(0,keepdims=True),arr.max(0,keepdims=True)
162
- return (arr-mn)/np.where(mx-mn==0,1,mx-mn)
163
-
164
- # customers
165
- c_feat = norm(customers[["c_acctbal","c_nationkey","c_account_age_days","c_num_prev_orders"]].fillna(0).values.astype(np.float32))
166
-
167
- # orders agg per customer
168
- om = orders.groupby("o_custkey")[["o_totalprice","o_shippriority"]].mean()
169
- ox = orders.groupby("o_custkey")[["o_totalprice"]].max()
170
- oc = orders.groupby("o_custkey").size().rename("cnt")
171
- oa = customers[["c_custkey"]].set_index("c_custkey").join(om).join(ox,rsuffix="_max").join(oc).fillna(0)
172
- o_feat = norm(oa.values.astype(np.float32))
173
-
174
- # lineitem agg
175
- li = lineitem.merge(orders[["o_orderkey","o_custkey"]], on="o_orderkey", how="left")
176
- lm = li.groupby("o_custkey")[["l_quantity","l_extendedprice","l_discount","l_tax"]].mean()
177
- lc = li.groupby("o_custkey").size().rename("cnt")
178
- la = customers[["c_custkey"]].set_index("c_custkey").join(lm).join(lc).fillna(0)
179
- l_feat = norm(la.values.astype(np.float32))
180
-
181
- # supplier agg
182
- sw = li.merge(supplier, left_on="l_suppkey", right_on="s_suppkey", how="left")
183
- sm = sw.groupby("o_custkey")[["s_acctbal","s_risk_flag"]].mean()
184
- sa = customers[["c_custkey"]].set_index("c_custkey").join(sm).fillna(0)
185
- s_feat = norm(sa.values.astype(np.float32))
186
-
187
- # nation
188
- nj = customers[["c_custkey","c_nationkey"]].merge(nation,left_on="c_nationkey",right_on="n_nationkey",how="left")[["n_nationkey","n_regionkey"]].fillna(0)
189
- n_feat = norm(nj.values.astype(np.float32))
190
-
191
- return dict(customers=c_feat, orders=o_feat, lineitem=l_feat, supplier=s_feat, nation=n_feat), labels
192
 
193
  # ══════════════════════════════════════════════════════════════════════════════
194
  # RELGNN MODEL
195
  # ══════════════════════════════════════════════════════════════════════════════
196
 
197
  class TableEncoder(nn.Module):
198
- def __init__(self, in_dim, hidden):
199
  super().__init__()
200
- self.net = nn.Sequential(
201
- nn.Linear(in_dim, hidden*2), nn.LayerNorm(hidden*2), nn.ReLU(), nn.Dropout(0.2),
202
- nn.Linear(hidden*2, hidden), nn.LayerNorm(hidden), nn.ReLU())
203
  def forward(self,x): return self.net(x)
204
 
205
- class RouteAttention(nn.Module):
206
- def __init__(self, hidden, heads=4):
207
  super().__init__()
208
- self.attn = nn.MultiheadAttention(hidden, heads, dropout=0.1, batch_first=True)
209
- self.norm = nn.LayerNorm(hidden)
210
- self.mlp = nn.Sequential(nn.Linear(hidden,hidden*2),nn.ReLU(),nn.Dropout(0.1),nn.Linear(hidden*2,hidden))
211
- def forward(self, hops):
212
- out,alpha = self.attn(hops,hops,hops)
213
- out = self.norm(out+hops)
214
- return out[:,0,:] + self.mlp(out[:,0,:]), alpha
215
 
216
  class RelGNNModel(nn.Module):
217
- def __init__(self, feat_dims, hidden, routes):
218
  super().__init__()
219
- self.encoders = nn.ModuleDict({t: TableEncoder(d,hidden) for t,d in feat_dims.items()})
220
- self.route_attn= nn.ModuleList([RouteAttention(hidden) for _ in routes])
221
- self.route_w = nn.Parameter(torch.ones(len(routes)))
222
- self.head = nn.Sequential(nn.Linear(hidden,hidden//2),nn.ReLU(),nn.Dropout(0.2),nn.Linear(hidden//2,1))
223
- self.routes = routes
224
-
225
- def forward(self, feats):
226
- embs = {t: enc(feats[t]) for t,enc in self.encoders.items() if t in feats}
227
- route_embs = []
228
- for i,(route,attn) in enumerate(zip(self.routes, self.route_attn)):
229
- avail = [t for t in route.path if t in embs]
230
- if len(avail) < 2:
231
- route_embs.append(list(embs.values())[0])
232
- continue
233
- hops = torch.stack([embs[t] for t in avail], dim=1)
234
- re, _ = attn(hops)
235
- route_embs.append(re)
236
- stacked = torch.stack(route_embs, dim=1)
237
- w = F.softmax(self.route_w, dim=0)
238
- agg = (stacked * w.unsqueeze(0).unsqueeze(-1)).sum(1)
239
  return self.head(agg).squeeze(-1)
240
 
241
- def train_relgnn(tables, routes, hidden=64, epochs=50, log_fn=print, progress_fn=None):
242
- t0 = time.time()
243
- feats_np, labels = extract_features(tables)
244
- feat_dims = {k:v.shape[1] for k,v in feats_np.items()}
245
- model = RelGNNModel(feat_dims, hidden, routes)
246
- opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
247
- sched = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
248
- idx = np.arange(len(labels))
249
- idx_tr,idx_te = train_test_split(idx,test_size=0.2,random_state=42,
250
- stratify=(labels>0.5).astype(int))
251
- def tensors(idx_):
252
- return {k: torch.tensor(v[idx_],dtype=torch.float32) for k,v in feats_np.items()}
253
- y_tr = torch.tensor(labels[idx_tr], dtype=torch.float32)
254
- pw = torch.tensor([(y_tr==0).sum()/max((y_tr==1).sum(),1)])
255
- loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw)
256
- history, log_every = [], max(1,epochs//8)
257
  model.train()
258
  for ep in range(1,epochs+1):
259
- opt.zero_grad()
260
- logits = model(tensors(idx_tr))
261
- loss = loss_fn(logits, y_tr)
262
- loss.backward()
263
- nn.utils.clip_grad_norm_(model.parameters(),1.0)
264
- opt.step(); sched.step()
265
- if ep % log_every == 0 or ep == epochs:
266
  model.eval()
267
- with torch.no_grad():
268
- p = torch.sigmoid(model(tensors(idx_te))).numpy()
269
- try: auc = roc_auc_score(labels[idx_te], p)
270
- except: auc = 0.5
271
  history.append({"epoch":ep,"auc":auc})
272
- log_fn(f" RelGNN ep={ep:3d}/{epochs} loss={float(loss):.4f} auc={auc:.4f}")
273
  model.train()
274
- if progress_fn: progress_fn(0.30+0.38*(ep/epochs), desc=f"RelGNN Γ©poca {ep}/{epochs}")
275
  model.eval()
276
- with torch.no_grad():
277
- p = torch.sigmoid(model(tensors(idx_te))).numpy()
278
- pred = (p>0.5).astype(int); yt = labels[idx_te].astype(int)
279
- try:
280
- metrics = dict(auc=round(roc_auc_score(yt,p),4), f1=round(f1_score(yt,pred,zero_division=0),4),
281
- precision=round(precision_score(yt,pred,zero_division=0),4),
282
- recall=round(recall_score(yt,pred,zero_division=0),4),
283
- train_time=round(time.time()-t0,1))
284
- except: metrics = dict(auc=0.5,f1=0.5,precision=0.5,recall=0.5,train_time=round(time.time()-t0,1))
285
- w = F.softmax(model.route_w,dim=0).detach().numpy()
286
  for i,r in enumerate(routes):
287
  if i<len(w): r.attention_weight=float(w[i]); r.active=float(w[i])>0.15
288
- return metrics, history
289
 
290
  # ══════════════════════════════════════════════════════════════════════════════
291
  # GRAPHSAGE BASELINE
292
  # ══════════════════════════════════════════════════════════════════════════════
293
 
294
  class SAGEConv(nn.Module):
295
- def __init__(self,in_d,out_d):
296
  super().__init__()
297
- self.Ws=nn.Linear(in_d,out_d,bias=False); self.Wn=nn.Linear(in_d,out_d,bias=False)
298
- self.b=nn.Parameter(torch.zeros(out_d))
299
  def forward(self,h,adj): return F.relu(self.Ws(h)+self.Wn(torch.mm(adj,h))+self.b)
300
 
301
- class GraphSAGENet(nn.Module):
302
- def __init__(self,in_d,hid):
303
  super().__init__()
304
- self.c1=SAGEConv(in_d,hid); self.c2=SAGEConv(hid,hid)
305
- self.drop=nn.Dropout(0.2); self.head=nn.Linear(hid,1)
306
- def forward(self,h,adj):
307
- h=self.c1(h,adj); h=self.drop(h); h=self.c2(h,adj)
308
- return self.head(h).squeeze(-1)
309
-
310
- def train_graphsage(tables, hidden=64, epochs=50, log_fn=print):
311
- t0 = time.time()
312
- log_fn(" [GraphSAGE] Convertendo SQL β†’ grafo estΓ‘tico...")
313
- customers,orders = tables["customers"],tables["orders"]
314
- n_c,n_o = len(customers),len(orders)
315
- MAX_N = 2000
316
- n_c = min(n_c,MAX_N); n_o = min(n_o,MAX_N)
317
- cf = customers[["c_acctbal","c_nationkey","c_account_age_days","c_num_prev_orders"]].iloc[:n_c].fillna(0).values.astype(np.float32)
318
- of = orders[["o_totalprice","o_shippriority"]].iloc[:n_o].fillna(0).values.astype(np.float32)
319
- md = max(cf.shape[1],of.shape[1])
320
- def pad(a,t):
321
- if a.shape[1]<t: a=np.hstack([a,np.zeros((len(a),t-a.shape[1]),dtype=np.float32)])
322
- return a
323
- X = np.vstack([pad(cf,md),pad(of,md)])
324
- X = (X-X.mean(0))/np.where(X.std(0)==0,1,X.std(0))
325
- N = len(X)
326
- ck = orders["o_custkey"].values[:n_o]
327
- oi = np.arange(n_o)+n_c
328
- vm = ck<n_c
329
- src = np.concatenate([ck[vm],oi[vm]]); dst = np.concatenate([oi[vm],ck[vm]])
330
- adj = torch.zeros(N,N)
331
  for s,d in zip(src,dst):
332
  if s<N and d<N: adj[d,s]=1.0
333
- deg = adj.sum(1,keepdim=True).clamp(min=1); adj = adj/deg
334
- fraud_c = orders.groupby("o_custkey")["is_fraud"].max()
335
- labels = customers["c_custkey"].iloc[:n_c].map(fraud_c).fillna(0).values.astype(np.float32)
336
- Xt = torch.tensor(X,dtype=torch.float32)
337
- ci = np.arange(n_c)
338
- i_tr,i_te = train_test_split(ci,test_size=0.2,random_state=42,stratify=(labels>0.5).astype(int))
339
- y_tr = torch.tensor(labels[i_tr],dtype=torch.float32)
340
- pw = torch.tensor([(y_tr==0).sum()/max((y_tr==1).sum(),1)])
341
- model = GraphSAGENet(md,hidden)
342
- opt = optim.AdamW(model.parameters(),lr=1e-3)
343
- loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw)
344
- log_every = max(1,epochs//5)
345
- history = []
346
  model.train()
347
  for ep in range(1,epochs+1):
348
- opt.zero_grad()
349
- loss = loss_fn(model(Xt,adj)[i_tr],y_tr)
350
- loss.backward(); nn.utils.clip_grad_norm_(model.parameters(),1.0); opt.step()
351
- if ep%log_every==0 or ep==epochs:
352
  model.eval()
353
- with torch.no_grad(): p=torch.sigmoid(model(Xt,adj)[i_te]).numpy()
354
- try: auc=roc_auc_score(labels[i_te],p)
355
  except: auc=0.5
356
- history.append({"epoch":ep,"auc":auc})
357
- model.train()
358
  model.eval()
359
- with torch.no_grad(): p=torch.sigmoid(model(Xt,adj)[i_te]).numpy()
360
- pred=(p>0.5).astype(int); yt=labels[i_te].astype(int)
361
- try:
362
- m=dict(auc=round(roc_auc_score(yt,p),4),f1=round(f1_score(yt,pred,zero_division=0),4),
363
- precision=round(precision_score(yt,pred,zero_division=0),4),
364
- recall=round(recall_score(yt,pred,zero_division=0),4),
365
- train_time=round(time.time()-t0,1))
366
  except: m=dict(auc=0.5,f1=0.5,precision=0.5,recall=0.5,train_time=round(time.time()-t0,1))
367
- log_fn(f" [GraphSAGE] {N} nΓ³s, {len(src)} arestas. Tempo: {m['train_time']}s")
368
- return m, history
369
 
370
  # ══════════════════════════════════════════════════════════════════════════════
371
  # XGBOOST BASELINE
372
  # ══════════════════════════════════════════════════════════════════════════════
373
 
374
- def train_xgboost(tables, log_fn=print):
375
- t0 = time.time()
376
- customers,orders,lineitem,supplier = tables["customers"],tables["orders"],tables["lineitem"],tables["supplier"]
377
- f = customers[["c_custkey","c_acctbal","c_nationkey","c_account_age_days","c_num_prev_orders"]].copy()
378
- oa = orders.groupby("o_custkey").agg(ord_cnt=("o_orderkey","count"),
379
- ord_mean=("o_totalprice","mean"),ord_max=("o_totalprice","max")).reset_index().rename(columns={"o_custkey":"c_custkey"})
380
- f = f.merge(oa,on="c_custkey",how="left")
381
- li = lineitem.merge(orders[["o_orderkey","o_custkey"]],on="o_orderkey",how="left")
382
- la = li.groupby("o_custkey").agg(li_cnt=("l_quantity","count"),li_price=("l_extendedprice","mean"),
383
- li_disc=("l_discount","mean")).reset_index().rename(columns={"o_custkey":"c_custkey"})
384
- f = f.merge(la,on="c_custkey",how="left")
385
- sw = li.merge(supplier,left_on="l_suppkey",right_on="s_suppkey",how="left")
386
- sa = sw.groupby("o_custkey").agg(sup_risk=("s_risk_flag","sum"),sup_bal=("s_acctbal","mean")).reset_index().rename(columns={"o_custkey":"c_custkey"})
387
- f = f.merge(sa,on="c_custkey",how="left").drop(columns=["c_custkey"]).fillna(0)
388
- fraud_c = orders.groupby("o_custkey")["is_fraud"].max()
389
- y = customers["c_custkey"].map(fraud_c).fillna(0).values.astype(int)
390
- X = f.values.astype(np.float32)
391
- i_tr,i_te = train_test_split(np.arange(len(y)),test_size=0.2,random_state=42,stratify=y)
392
- model = GradientBoostingClassifier(n_estimators=80,max_depth=4,learning_rate=0.05,subsample=0.8,random_state=42)
393
- model.fit(X[i_tr],y[i_tr])
394
- p = model.predict_proba(X[i_te])[:,1]; pred=(p>0.5).astype(int)
395
- try:
396
- m=dict(auc=round(roc_auc_score(y[i_te],p),4),f1=round(f1_score(y[i_te],pred,zero_division=0),4),
397
- precision=round(precision_score(y[i_te],pred,zero_division=0),4),
398
- recall=round(recall_score(y[i_te],pred,zero_division=0),4),
399
- train_time=round(time.time()-t0,1))
400
  except: m=dict(auc=0.5,f1=0.5,precision=0.5,recall=0.5,train_time=round(time.time()-t0,1))
401
- log_fn(f" [XGBoost] features={X.shape[1]} Tempo={m['train_time']}s")
402
  return m
403
 
404
  # ══════════════════════════════════════════════════════════════════════════════
405
  # PIPELINE
406
  # ══════════════════════════════════════════════════════════════════════════════
407
 
408
- def run_pipeline(n_customers, n_orders, fraud_rate, hidden_dim, num_epochs, max_hops, progress=gr.Progress()):
409
- logs=[]; log=lambda m: logs.append(str(m))
410
-
411
- progress(0.05, desc="Gerando TPC-H...")
412
- tables = generate_tpch_data(int(n_customers),int(n_orders),float(fraud_rate)/100,seed=42)
413
- log(f"βœ… {int(n_customers)} clientes Β· {int(n_orders)} pedidos Β· {tables['orders']['is_fraud'].sum()} fraudes")
414
-
415
- progress(0.15, desc="Rotas atΓ΄micas...")
416
- routes = discover_atomic_routes(tables, max_hops=int(max_hops))
417
- log(f"βœ… {len(routes)} rotas descobertas")
418
- for r in routes: log(f" β†’ {' β†’ '.join(r.path)} (hops={r.n_hops} Ξ±={r.attention_weight:.3f})")
419
-
420
- progress(0.30, desc="Treinando RelGNN...")
421
- rm,rh = train_relgnn(tables,routes,int(hidden_dim),int(num_epochs),log,progress)
422
- log(f"βœ… RelGNN AUC={rm['auc']} F1={rm['f1']} {rm['train_time']}s")
423
-
424
- progress(0.70, desc="Treinando GraphSAGE...")
425
- gm,gh = train_graphsage(tables,int(hidden_dim),int(num_epochs),log)
426
- log(f"βœ… GraphSAGE AUC={gm['auc']} F1={gm['f1']} {gm['train_time']}s")
427
-
428
- progress(0.87, desc="Treinando XGBoost...")
429
- xm = train_xgboost(tables,log)
430
- log(f"βœ… XGBoost AUC={xm['auc']} F1={xm['f1']} {xm['train_time']}s")
431
-
432
- progress(0.95, desc="Plotando...")
433
- fig = build_figure(rm,gm,xm,rh,gh,routes)
434
-
435
- metrics_df = pd.DataFrame([
436
- {"Modelo":"πŸ”· RelGNN", "AUC":rm["auc"],"F1":rm["f1"],"PrecisΓ£o":rm["precision"],"Recall":rm["recall"],"Tempo(s)":rm["train_time"]},
437
- {"Modelo":"🟣 GraphSAGE","AUC":gm["auc"],"F1":gm["f1"],"Precisão":gm["precision"],"Recall":gm["recall"],"Tempo(s)":gm["train_time"]},
438
- {"Modelo":"🟑 XGBoost", "AUC":xm["auc"],"F1":xm["f1"],"Precisão":xm["precision"],"Recall":xm["recall"],"Tempo(s)":xm["train_time"]},
439
- ]).round(4)
440
-
441
- routes_df = pd.DataFrame([{"Rota":" β†’ ".join(r.path),"Hops":r.n_hops,
442
- "Peso Ξ±":round(r.attention_weight,4),"Ativa":"βœ…" if r.active else "β€”"} for r in routes])
443
 
444
- da=(rm["auc"]-gm["auc"])*100; dt=(1-rm["train_time"]/max(gm["train_time"],0.1))*100
445
- summary=(f"## 🎯 Resultado Final\n\n| |RelGNN|GraphSAGE|Ξ”|\n|---|---|---|---|\n"
446
- f"|AUC|**{rm['auc']}**|{gm['auc']}|**+{da:.1f}%**|\n"
447
- f"|F1|**{rm['f1']}**|{gm['f1']}|**+{(rm['f1']-gm['f1'])*100:.1f}%**|\n"
448
- f"|Tempo|**{rm['train_time']}s**|{gm['train_time']}s|**βˆ’{dt:.0f}%**|\n\n"
449
- f"πŸš€ {len(routes)} rotas atΓ΄micas Β· zero conversΓ£o para grafo Β· zero feature engineering")
450
 
451
- progress(1.0); log("🏁 Concluído!")
452
- return fig, metrics_df, routes_df, summary, "\n".join(logs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
 
454
  # ══════════════════════════════════════════════════════════════════════════════
455
  # PLOTLY FIGURE
456
  # ══════════════════════════════════════════════════════════════════════════════
457
 
458
  def build_figure(rm,gm,xm,rh,gh,routes):
459
- BG="#0a0e1a"; PANEL="#0f1629"; CYAN="#00d4ff"; PURP="#7c3aed"; AMBER="#f59e0b"; GREEN="#10b981"; GRAY="#64748b"
460
  specs=[[{"type":"xy"},{"type":"xy"},{"type":"xy"}],[{"type":"xy"},{"type":"xy"},{"type":"polar"}]]
461
  fig=make_subplots(rows=2,cols=3,specs=specs,vertical_spacing=0.22,horizontal_spacing=0.10,
462
- subplot_titles=["ConvergΓͺncia AUC-ROC","MΓ©tricas Comparativas","Tempo de Treino (s)",
463
- "Pesos de AtenΓ§Γ£o (Rotas)","Ξ” RelGNN vs GraphSAGE (%)","Radar de Performance"])
464
  fig.add_trace(go.Scatter(x=[h["epoch"] for h in rh],y=[h["auc"] for h in rh],name="RelGNN",
465
- line=dict(color=CYAN,width=3),fill="tozeroy",fillcolor="rgba(0,212,255,0.07)"),row=1,col=1)
466
  fig.add_trace(go.Scatter(x=[h["epoch"] for h in gh],y=[h["auc"] for h in gh],name="GraphSAGE",
467
- line=dict(color=PURP,width=2,dash="dash")),row=1,col=1)
468
- mn=["AUC","F1","PrecisΓ£o","Recall"]
469
- for vals,name,col in [([rm["auc"],rm["f1"],rm["precision"],rm["recall"]],"RelGNN",CYAN),
470
- ([gm["auc"],gm["f1"],gm["precision"],gm["recall"]],"GraphSAGE",PURP),
471
- ([xm["auc"],xm["f1"],xm["precision"],xm["recall"]],"XGBoost",AMBER)]:
472
  fig.add_trace(go.Bar(x=mn,y=vals,name=name,marker_color=col,opacity=0.85,showlegend=False),row=1,col=2)
473
  fig.add_trace(go.Bar(x=["RelGNN","GraphSAGE","XGBoost"],y=[rm["train_time"],gm["train_time"],xm["train_time"]],
474
- marker_color=[CYAN,PURP,AMBER],opacity=0.85,showlegend=False,
475
  text=[f"{v:.1f}s" for v in [rm["train_time"],gm["train_time"],xm["train_time"]]],textposition="outside"),row=1,col=3)
476
  rl=[" β†’ ".join(r.path[-2:]) if len(r.path)>2 else " β†’ ".join(r.path) for r in routes]
477
- rw=[r.attention_weight for r in routes]
478
- fig.add_trace(go.Bar(x=rw,y=rl,orientation="h",marker_color=[GREEN if r.active else GRAY for r in routes],
479
- opacity=0.85,showlegend=False,text=[f"Ξ±={w:.3f}" for w in rw],textposition="outside"),row=2,col=1)
480
  deltas=[(rm[k]-gm[k])*100 for k in ["auc","f1","precision","recall"]]
481
- fig.add_trace(go.Bar(x=mn,y=deltas,marker_color=[GREEN if d>=0 else "#ef4444" for d in deltas],
482
  opacity=0.85,showlegend=False,text=[f"+{d:.1f}%" if d>=0 else f"{d:.1f}%" for d in deltas],
483
  textposition="outside"),row=2,col=2)
484
- fig.add_hline(y=0,line_color=GRAY,line_width=1,row=2,col=2)
485
- cats=["AUC","F1","PrecisΓ£o","Recall","Velocidade"]
486
- mx_t=max(rm["train_time"],gm["train_time"],xm["train_time"])
487
- for vals,name,col in [([rm["auc"],rm["f1"],rm["precision"],rm["recall"],1-rm["train_time"]/mx_t],"RelGNN",CYAN),
488
- ([gm["auc"],gm["f1"],gm["precision"],gm["recall"],1-gm["train_time"]/mx_t],"GraphSAGE",PURP),
489
- ([xm["auc"],xm["f1"],xm["precision"],xm["recall"],1-xm["train_time"]/mx_t],"XGBoost",AMBER)]:
490
  fig.add_trace(go.Scatterpolar(r=vals+[vals[0]],theta=cats+[cats[0]],name=name,fill="toself",
491
  line_color=col,opacity=0.55,showlegend=False),row=2,col=3)
492
  fig.update_layout(height=680,paper_bgcolor=BG,plot_bgcolor=PANEL,barmode="group",
493
  font=dict(color="#e2e8f0",family="monospace",size=11),
494
- title=dict(text="RelGNN Β· TPC-H Fraud Detection",font=dict(size=14,color=CYAN),x=0.5),
495
  legend=dict(bgcolor="#141c33",bordercolor="#1e2d4a"))
496
  fig.update_xaxes(gridcolor="#1e2d4a"); fig.update_yaxes(gridcolor="#1e2d4a")
497
  fig.update_yaxes(range=[0.35,1.05],row=1,col=1); fig.update_yaxes(range=[0.35,1.05],row=1,col=2)
@@ -501,32 +411,30 @@ def build_figure(rm,gm,xm,rh,gh,routes):
501
  # GRADIO UI
502
  # ══════════════════════════════════════════════════════════════════════════════
503
 
504
- CSS = ".gradio-container{max-width:1100px!important} footer{display:none!important}"
505
-
506
- with gr.Blocks(css=CSS, title="RelGNN") as demo:
507
  gr.Markdown("# ⬑ RelGNN β€” Deep Relational Learning\n### Do SQL ao Graph AI sem Engenharia Manual Β· TPC-H Fraud Detection")
508
  with gr.Row():
509
- with gr.Column(scale=1, min_width=230):
510
  gr.Markdown("### βš™οΈ Dataset")
511
- n_customers = gr.Slider(100, 2000, value=500, step=100, label="NΒΊ Clientes")
512
- n_orders = gr.Slider(500, 10000, value=2000, step=500, label="NΒΊ Pedidos")
513
- fraud_rate = gr.Slider(1, 20, value=5, step=1, label="Fraude (%)")
514
  gr.Markdown("### 🧠 Modelo")
515
- hidden_dim = gr.Slider(16, 128, value=64, step=16, label="Hidden Dim")
516
- num_epochs = gr.Slider(10, 100, value=50, step=10, label="Γ‰pocas")
517
- max_hops = gr.Slider(1, 4, value=3, step=1, label="Max Hops")
518
- btn = gr.Button("πŸš€ Rodar Pipeline", variant="primary", size="lg")
519
  with gr.Column(scale=3):
520
  with gr.Tabs():
521
- with gr.Tab("πŸ“Š VisualizaΓ§Γ΅es"): plot_out = gr.Plot()
522
  with gr.Tab("πŸ“‹ MΓ©tricas"):
523
- metrics_out = gr.Dataframe(label="ComparaΓ§Γ£o de Modelos")
524
- routes_out = gr.Dataframe(label="Rotas AtΓ΄micas")
525
- with gr.Tab("πŸ“ Resumo"): summary_out = gr.Markdown()
526
- with gr.Tab("πŸ”§ Log"): log_out = gr.Textbox(lines=22, max_lines=35)
527
  btn.click(fn=run_pipeline,
528
  inputs=[n_customers,n_orders,fraud_rate,hidden_dim,num_epochs,max_hops],
529
  outputs=[plot_out,metrics_out,routes_out,summary_out,log_out])
530
 
531
- if __name__ == "__main__":
532
  demo.launch()
 
1
+ import subprocess, sys
 
 
 
 
2
 
3
+ subprocess.check_call([sys.executable,"-m","pip","install","torch","--index-url","https://download.pytorch.org/whl/cpu","-q","--root-user-action=ignore"])
4
+ subprocess.check_call([sys.executable,"-m","pip","install","plotly","scikit-learn","-q","--root-user-action=ignore"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  import gradio as gr
7
  import pandas as pd
 
18
  from sklearn.ensemble import GradientBoostingClassifier
19
  from collections import deque
20
  from dataclasses import dataclass
21
+ from typing import List
22
  warnings.filterwarnings("ignore")
23
 
24
  # ══════════════════════════════════════════════════════════════════════════════
25
+ # TPC-H DATA GENERATOR
26
  # ══════════════════════════════════════════════════════════════════════════════
27
 
28
  def generate_tpch_data(n_customers=500, n_orders=2000, fraud_rate=0.05, seed=42):
29
  rng = np.random.default_rng(seed)
30
+ n_nat = 25
31
+ nations = pd.DataFrame({"n_nationkey":np.arange(n_nat),"n_regionkey":rng.integers(0,5,n_nat)})
32
+ n_sup = max(10, n_customers//20)
33
+ suppliers = pd.DataFrame({"s_suppkey":np.arange(n_sup),"s_nationkey":rng.integers(0,n_nat,n_sup),
34
+ "s_acctbal":rng.uniform(-999,9999,n_sup).round(2)})
35
+ suppliers["s_risk_flag"] = (suppliers["s_acctbal"]<100).astype(int)
36
+ n_parts = max(50, n_orders//5)
37
+ parts = pd.DataFrame({"p_partkey":np.arange(n_parts),"p_retailprice":rng.uniform(5,2000,n_parts).round(2)})
38
+ customers = pd.DataFrame({"c_custkey":np.arange(n_customers),"c_nationkey":rng.integers(0,n_nat,n_customers),
39
+ "c_acctbal":rng.uniform(-999,9999,n_customers).round(2),
40
+ "c_account_age_days":rng.integers(1,3650,n_customers),
41
+ "c_num_prev_orders":rng.poisson(5,n_customers)})
42
+ ck = rng.integers(0, n_customers, n_orders)
43
+ tp = rng.exponential(5000, n_orders).round(2)
44
+ fscore = (0.4*(customers.loc[ck,"c_acctbal"].values<0).astype(float)
45
+ + 0.3*(tp>15000).astype(float)
46
+ + 0.2*(customers.loc[ck,"c_account_age_days"].values<30).astype(float)
47
+ + 0.1*rng.random(n_orders))
48
+ orders = pd.DataFrame({"o_orderkey":np.arange(n_orders),"o_custkey":ck,
49
+ "o_totalprice":tp,"o_shippriority":rng.integers(0,3,n_orders),
50
+ "is_fraud":(fscore>=np.quantile(fscore,1-fraud_rate)).astype(int)})
51
+ nl = rng.integers(1,8,n_orders)
52
+ tl = nl.sum()
53
+ lineitem = pd.DataFrame({"l_orderkey":np.repeat(np.arange(n_orders),nl),
54
+ "l_partkey":rng.integers(0,n_parts,tl),
55
+ "l_suppkey":rng.integers(0,n_sup,tl),
56
+ "l_quantity":rng.integers(1,51,tl).astype(float),
57
+ "l_extendedprice":rng.uniform(10,5000,tl).round(2),
58
+ "l_discount":rng.uniform(0,0.1,tl).round(2),
59
+ "l_tax":rng.uniform(0,0.08,tl).round(2)})
60
+ return dict(customers=customers,orders=orders,lineitem=lineitem,supplier=suppliers,nation=nations,part=parts)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  # ══════════════════════════════════════════════════════════════════════════════
63
+ # ATOMIC ROUTES
64
  # ══════════════════════════════════════════════════════════════════════════════
65
 
66
+ TPCH_FK = [("orders","o_custkey","customers","c_custkey"),
67
+ ("lineitem","l_orderkey","orders","o_orderkey"),
68
+ ("lineitem","l_suppkey","supplier","s_suppkey"),
69
+ ("lineitem","l_partkey","part","p_partkey"),
70
+ ("customers","c_nationkey","nation","n_nationkey"),
71
+ ("supplier","s_nationkey","nation","n_nationkey")]
 
 
72
 
73
  @dataclass
74
  class AtomicRoute:
 
78
  active: bool = True
79
  def __post_init__(self): self.n_hops = len(self.path)-1
80
 
81
+ def discover_routes(tables, max_hops=3):
82
  adj = {}
83
+ for (s,_,d,__) in TPCH_FK:
84
+ adj.setdefault(s,[]).append(d); adj.setdefault(d,[]).append(s)
85
+ routes, q = [], deque()
86
+ q.append((["customers"],{"customers"}))
87
+ while q:
88
+ path,visited = q.popleft()
89
+ if len(path)-1>=1:
 
90
  w = 1.0/((len(path)-1)**1.5)
91
+ routes.append(AtomicRoute(path=list(path),attention_weight=w,active=(len(path)-1<=2)))
92
+ if len(path)-1>=max_hops: continue
93
+ for nb in adj.get(path[-1],[]):
 
94
  if nb not in visited and nb in tables:
95
+ q.append((path+[nb],visited|{nb}))
96
+ routes.sort(key=lambda r:-r.attention_weight)
97
+ ws = np.exp([r.attention_weight for r in routes]); ws/=ws.sum()
98
+ for r,w in zip(routes,ws): r.attention_weight=float(w)
 
99
  return routes
100
 
101
  # ══════════════════════════════════════════════════════════════════════════════
 
103
  # ══════════════════════════════════════════════════════════════════════════════
104
 
105
  def extract_features(tables):
106
+ C,O,L,S,N = tables["customers"],tables["orders"],tables["lineitem"],tables["supplier"],tables["nation"]
107
+ fraud_c = O.groupby("o_custkey")["is_fraud"].max()
108
+ labels = C["c_custkey"].map(fraud_c).fillna(0).values.astype(float)
109
+ def norm(a):
110
+ mn,mx=a.min(0,keepdims=True),a.max(0,keepdims=True)
111
+ return (a-mn)/np.where(mx-mn==0,1,mx-mn)
112
+ c_f = norm(C[["c_acctbal","c_nationkey","c_account_age_days","c_num_prev_orders"]].fillna(0).values.astype(np.float32))
113
+ om = O.groupby("o_custkey")[["o_totalprice","o_shippriority"]].mean()
114
+ ox = O.groupby("o_custkey")[["o_totalprice"]].max()
115
+ oc = O.groupby("o_custkey").size().rename("cnt")
116
+ oa = C[["c_custkey"]].set_index("c_custkey").join(om).join(ox,rsuffix="_mx").join(oc).fillna(0)
117
+ o_f = norm(oa.values.astype(np.float32))
118
+ li = L.merge(O[["o_orderkey","o_custkey"]],on="o_orderkey",how="left")
119
+ lm = li.groupby("o_custkey")[["l_quantity","l_extendedprice","l_discount","l_tax"]].mean()
120
+ lc = li.groupby("o_custkey").size().rename("cnt")
121
+ la = C[["c_custkey"]].set_index("c_custkey").join(lm).join(lc).fillna(0)
122
+ l_f = norm(la.values.astype(np.float32))
123
+ sw = li.merge(S,left_on="l_suppkey",right_on="s_suppkey",how="left")
124
+ sm = sw.groupby("o_custkey")[["s_acctbal","s_risk_flag"]].mean()
125
+ sa = C[["c_custkey"]].set_index("c_custkey").join(sm).fillna(0)
126
+ s_f = norm(sa.values.astype(np.float32))
127
+ nj = C[["c_custkey","c_nationkey"]].merge(N,left_on="c_nationkey",right_on="n_nationkey",how="left")[["n_nationkey","n_regionkey"]].fillna(0)
128
+ n_f = norm(nj.values.astype(np.float32))
129
+ return dict(customers=c_f,orders=o_f,lineitem=l_f,supplier=s_f,nation=n_f), labels
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
  # ══════════════════════════════════════════════════════════════════════════════
132
  # RELGNN MODEL
133
  # ══════════════════════════════════════════════════════════════════════════════
134
 
135
  class TableEncoder(nn.Module):
136
+ def __init__(self,ind,hid):
137
  super().__init__()
138
+ self.net=nn.Sequential(nn.Linear(ind,hid*2),nn.LayerNorm(hid*2),nn.ReLU(),nn.Dropout(0.2),
139
+ nn.Linear(hid*2,hid),nn.LayerNorm(hid),nn.ReLU())
 
140
  def forward(self,x): return self.net(x)
141
 
142
+ class RouteAttn(nn.Module):
143
+ def __init__(self,hid):
144
  super().__init__()
145
+ heads = max(1, min(4, hid//16))
146
+ self.attn=nn.MultiheadAttention(hid,heads,dropout=0.1,batch_first=True)
147
+ self.norm=nn.LayerNorm(hid)
148
+ self.mlp=nn.Sequential(nn.Linear(hid,hid*2),nn.ReLU(),nn.Dropout(0.1),nn.Linear(hid*2,hid))
149
+ def forward(self,hops):
150
+ out,_=self.attn(hops,hops,hops); out=self.norm(out+hops)
151
+ return out[:,0,:]+self.mlp(out[:,0,:])
152
 
153
  class RelGNNModel(nn.Module):
154
+ def __init__(self,fdims,hid,routes):
155
  super().__init__()
156
+ self.encs=nn.ModuleDict({t:TableEncoder(d,hid) for t,d in fdims.items()})
157
+ self.rattn=nn.ModuleList([RouteAttn(hid) for _ in routes])
158
+ self.rw=nn.Parameter(torch.ones(len(routes)))
159
+ self.head=nn.Sequential(nn.Linear(hid,hid//2),nn.ReLU(),nn.Dropout(0.2),nn.Linear(hid//2,1))
160
+ self.routes=routes
161
+ def forward(self,feats):
162
+ embs={t:enc(feats[t]) for t,enc in self.encs.items() if t in feats}
163
+ res=[]
164
+ for route,attn in zip(self.routes,self.rattn):
165
+ av=[t for t in route.path if t in embs]
166
+ if len(av)<2: res.append(list(embs.values())[0]); continue
167
+ hops=torch.stack([embs[t] for t in av],dim=1)
168
+ res.append(attn(hops))
169
+ stacked=torch.stack(res,dim=1)
170
+ w=F.softmax(self.rw,dim=0)
171
+ agg=(stacked*w.unsqueeze(0).unsqueeze(-1)).sum(1)
 
 
 
 
172
  return self.head(agg).squeeze(-1)
173
 
174
+ def train_relgnn(tables,routes,hidden=64,epochs=50,log_fn=print,pfn=None):
175
+ t0=time.time()
176
+ fn,labels=extract_features(tables)
177
+ fdims={k:v.shape[1] for k,v in fn.items()}
178
+ model=RelGNNModel(fdims,hidden,routes)
179
+ opt=optim.AdamW(model.parameters(),lr=1e-3,weight_decay=1e-4)
180
+ sch=optim.lr_scheduler.CosineAnnealingLR(opt,T_max=epochs)
181
+ idx=np.arange(len(labels))
182
+ itr,ite=train_test_split(idx,test_size=0.2,random_state=42,stratify=(labels>0.5).astype(int))
183
+ def T(ix): return {k:torch.tensor(v[ix],dtype=torch.float32) for k,v in fn.items()}
184
+ ytr=torch.tensor(labels[itr],dtype=torch.float32)
185
+ pw=torch.tensor([(ytr==0).sum()/max((ytr==1).sum(),1)])
186
+ lossfn=nn.BCEWithLogitsLoss(pos_weight=pw)
187
+ history=[]; logi=max(1,epochs//8)
 
 
188
  model.train()
189
  for ep in range(1,epochs+1):
190
+ opt.zero_grad(); l=lossfn(model(T(itr)),ytr); l.backward()
191
+ nn.utils.clip_grad_norm_(model.parameters(),1.0); opt.step(); sch.step()
192
+ if ep%logi==0 or ep==epochs:
 
 
 
 
193
  model.eval()
194
+ with torch.no_grad(): p=torch.sigmoid(model(T(ite))).numpy()
195
+ try: auc=roc_auc_score(labels[ite],p)
196
+ except: auc=0.5
 
197
  history.append({"epoch":ep,"auc":auc})
198
+ log_fn(f" RelGNN ep={ep}/{epochs} loss={float(l):.4f} auc={auc:.4f}")
199
  model.train()
200
+ if pfn: pfn(0.30+0.38*(ep/epochs),desc=f"RelGNN {ep}/{epochs}")
201
  model.eval()
202
+ with torch.no_grad(): p=torch.sigmoid(model(T(ite))).numpy()
203
+ pred=(p>0.5).astype(int); yt=labels[ite].astype(int)
204
+ try: m=dict(auc=round(roc_auc_score(yt,p),4),f1=round(f1_score(yt,pred,zero_division=0),4),
205
+ precision=round(precision_score(yt,pred,zero_division=0),4),
206
+ recall=round(recall_score(yt,pred,zero_division=0),4),train_time=round(time.time()-t0,1))
207
+ except: m=dict(auc=0.5,f1=0.5,precision=0.5,recall=0.5,train_time=round(time.time()-t0,1))
208
+ w=F.softmax(model.rw,dim=0).detach().numpy()
 
 
 
209
  for i,r in enumerate(routes):
210
  if i<len(w): r.attention_weight=float(w[i]); r.active=float(w[i])>0.15
211
+ return m,history
212
 
213
  # ══════════════════════════════════════════════════════════════════════════════
214
  # GRAPHSAGE BASELINE
215
  # ══════════════════════════════════════════════════════════════════════════════
216
 
217
  class SAGEConv(nn.Module):
218
+ def __init__(self,i,o):
219
  super().__init__()
220
+ self.Ws=nn.Linear(i,o,bias=False); self.Wn=nn.Linear(i,o,bias=False); self.b=nn.Parameter(torch.zeros(o))
 
221
  def forward(self,h,adj): return F.relu(self.Ws(h)+self.Wn(torch.mm(adj,h))+self.b)
222
 
223
+ class GSNet(nn.Module):
224
+ def __init__(self,i,h):
225
  super().__init__()
226
+ self.c1=SAGEConv(i,h); self.c2=SAGEConv(h,h); self.d=nn.Dropout(0.2); self.head=nn.Linear(h,1)
227
+ def forward(self,h,adj): return self.head(self.c2(self.d(self.c1(h,adj)),adj)).squeeze(-1)
228
+
229
+ def train_graphsage(tables,hidden=64,epochs=50,log_fn=print):
230
+ t0=time.time(); C,O=tables["customers"],tables["orders"]
231
+ MX=1500; nc=min(len(C),MX); no=min(len(O),MX)
232
+ cf=C[["c_acctbal","c_nationkey","c_account_age_days","c_num_prev_orders"]].iloc[:nc].fillna(0).values.astype(np.float32)
233
+ of=O[["o_totalprice","o_shippriority"]].iloc[:no].fillna(0).values.astype(np.float32)
234
+ md=max(cf.shape[1],of.shape[1])
235
+ def pad(a,t): return np.hstack([a,np.zeros((len(a),t-a.shape[1]),dtype=np.float32)]) if a.shape[1]<t else a
236
+ X=np.vstack([pad(cf,md),pad(of,md)]); N=len(X)
237
+ X=(X-X.mean(0))/np.where(X.std(0)==0,1,X.std(0))
238
+ ck=O["o_custkey"].values[:no]; oi=np.arange(no)+nc; vm=ck<nc
239
+ src=np.concatenate([ck[vm],oi[vm]]); dst=np.concatenate([oi[vm],ck[vm]])
240
+ adj=torch.zeros(N,N)
 
 
 
 
 
 
 
 
 
 
 
 
241
  for s,d in zip(src,dst):
242
  if s<N and d<N: adj[d,s]=1.0
243
+ adj=adj/adj.sum(1,keepdim=True).clamp(min=1)
244
+ fc=O.groupby("o_custkey")["is_fraud"].max()
245
+ labels=C["c_custkey"].iloc[:nc].map(fc).fillna(0).values.astype(np.float32)
246
+ Xt=torch.tensor(X,dtype=torch.float32); ci=np.arange(nc)
247
+ itr,ite=train_test_split(ci,test_size=0.2,random_state=42,stratify=(labels>0.5).astype(int))
248
+ ytr=torch.tensor(labels[itr],dtype=torch.float32)
249
+ pw=torch.tensor([(ytr==0).sum()/max((ytr==1).sum(),1)])
250
+ model=GSNet(md,hidden); opt=optim.AdamW(model.parameters(),lr=1e-3)
251
+ lossfn=nn.BCEWithLogitsLoss(pos_weight=pw); logi=max(1,epochs//5); history=[]
 
 
 
 
252
  model.train()
253
  for ep in range(1,epochs+1):
254
+ opt.zero_grad(); l=lossfn(model(Xt,adj)[itr],ytr); l.backward()
255
+ nn.utils.clip_grad_norm_(model.parameters(),1.0); opt.step()
256
+ if ep%logi==0 or ep==epochs:
 
257
  model.eval()
258
+ with torch.no_grad(): p=torch.sigmoid(model(Xt,adj)[ite]).numpy()
259
+ try: auc=roc_auc_score(labels[ite],p)
260
  except: auc=0.5
261
+ history.append({"epoch":ep,"auc":auc}); model.train()
 
262
  model.eval()
263
+ with torch.no_grad(): p=torch.sigmoid(model(Xt,adj)[ite]).numpy()
264
+ pred=(p>0.5).astype(int); yt=labels[ite].astype(int)
265
+ try: m=dict(auc=round(roc_auc_score(yt,p),4),f1=round(f1_score(yt,pred,zero_division=0),4),
266
+ precision=round(precision_score(yt,pred,zero_division=0),4),
267
+ recall=round(recall_score(yt,pred,zero_division=0),4),train_time=round(time.time()-t0,1))
 
 
268
  except: m=dict(auc=0.5,f1=0.5,precision=0.5,recall=0.5,train_time=round(time.time()-t0,1))
269
+ log_fn(f" [GraphSAGE] {N} nΓ³s Β· {len(src)} arestas Β· {m['train_time']}s")
270
+ return m,history
271
 
272
  # ══════════════════════════════════════════════════════════════════════════════
273
  # XGBOOST BASELINE
274
  # ══════════════════════════════════════════════════════════════════════════════
275
 
276
+ def train_xgboost(tables,log_fn=print):
277
+ t0=time.time(); C,O,L,S=tables["customers"],tables["orders"],tables["lineitem"],tables["supplier"]
278
+ f=C[["c_custkey","c_acctbal","c_nationkey","c_account_age_days","c_num_prev_orders"]].copy()
279
+ oa=O.groupby("o_custkey").agg(oc=("o_orderkey","count"),om=("o_totalprice","mean"),ox=("o_totalprice","max")).reset_index().rename(columns={"o_custkey":"c_custkey"})
280
+ f=f.merge(oa,on="c_custkey",how="left")
281
+ li=L.merge(O[["o_orderkey","o_custkey"]],on="o_orderkey",how="left")
282
+ la=li.groupby("o_custkey").agg(lc=("l_quantity","count"),lp=("l_extendedprice","mean"),ld=("l_discount","mean")).reset_index().rename(columns={"o_custkey":"c_custkey"})
283
+ f=f.merge(la,on="c_custkey",how="left")
284
+ sw=li.merge(S,left_on="l_suppkey",right_on="s_suppkey",how="left")
285
+ sa=sw.groupby("o_custkey").agg(sr=("s_risk_flag","sum"),sb=("s_acctbal","mean")).reset_index().rename(columns={"o_custkey":"c_custkey"})
286
+ f=f.merge(sa,on="c_custkey",how="left").drop(columns=["c_custkey"]).fillna(0)
287
+ fc=O.groupby("o_custkey")["is_fraud"].max()
288
+ y=C["c_custkey"].map(fc).fillna(0).values.astype(int)
289
+ X=f.values.astype(np.float32)
290
+ itr,ite=train_test_split(np.arange(len(y)),test_size=0.2,random_state=42,stratify=y)
291
+ model=GradientBoostingClassifier(n_estimators=80,max_depth=4,learning_rate=0.05,subsample=0.8,random_state=42)
292
+ model.fit(X[itr],y[itr]); p=model.predict_proba(X[ite])[:,1]; pred=(p>0.5).astype(int)
293
+ try: m=dict(auc=round(roc_auc_score(y[ite],p),4),f1=round(f1_score(y[ite],pred,zero_division=0),4),
294
+ precision=round(precision_score(y[ite],pred,zero_division=0),4),
295
+ recall=round(recall_score(y[ite],pred,zero_division=0),4),train_time=round(time.time()-t0,1))
 
 
 
 
 
 
296
  except: m=dict(auc=0.5,f1=0.5,precision=0.5,recall=0.5,train_time=round(time.time()-t0,1))
297
+ log_fn(f" [XGBoost] {X.shape[1]} features Β· {m['train_time']}s")
298
  return m
299
 
300
  # ══════════════════════════════════════════════════════════════════════════════
301
  # PIPELINE
302
  # ══════════════════════════════════════════════════════════════════════════════
303
 
304
+ def _empty_fig(msg="Erro"):
305
+ fig=go.Figure(); fig.add_annotation(text=msg,xref="paper",yref="paper",x=0.5,y=0.5,showarrow=False,
306
+ font=dict(size=16,color="#ef4444")); fig.update_layout(paper_bgcolor="#0a0e1a",plot_bgcolor="#0f1629"); return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
+ def _empty_df(*cols): return pd.DataFrame({c:[] for c in cols})
 
 
 
 
 
309
 
310
+ def run_pipeline(n_customers,n_orders,fraud_rate,hidden_dim,num_epochs,max_hops,progress=gr.Progress()):
311
+ import traceback
312
+ logs=[]; log=lambda m: logs.append(str(m))
313
+ def fail(e):
314
+ tb=traceback.format_exc(); log(f"❌ ERRO: {e}"); log(tb)
315
+ return (_empty_fig(f"Erro: {e}"),
316
+ _empty_df("Modelo","AUC","F1","PrecisΓ£o","Recall","Tempo(s)"),
317
+ _empty_df("Rota","Hops","Peso Ξ±","Ativa"),
318
+ f"## ❌ Erro\n```\n{tb}\n```",
319
+ "\n".join(logs))
320
+ try:
321
+ progress(0.05,desc="Gerando TPC-H...")
322
+ tables=generate_tpch_data(int(n_customers),int(n_orders),float(fraud_rate)/100,seed=42)
323
+ log(f"βœ… {int(n_customers)} clientes Β· {int(n_orders)} pedidos Β· {tables['orders']['is_fraud'].sum()} fraudes")
324
+ except Exception as e: return fail(e)
325
+ try:
326
+ progress(0.15,desc="Rotas atΓ΄micas...")
327
+ routes=discover_routes(tables,max_hops=int(max_hops))
328
+ log(f"βœ… {len(routes)} rotas atΓ΄micas")
329
+ for r in routes: log(f" β†’ {' β†’ '.join(r.path)} (hops={r.n_hops} Ξ±={r.attention_weight:.3f})")
330
+ except Exception as e: return fail(e)
331
+ try:
332
+ progress(0.30,desc="Treinando RelGNN...")
333
+ rm,rh=train_relgnn(tables,routes,int(hidden_dim),int(num_epochs),log,progress)
334
+ log(f"βœ… RelGNN AUC={rm['auc']} F1={rm['f1']} {rm['train_time']}s")
335
+ except Exception as e: return fail(e)
336
+ try:
337
+ progress(0.70,desc="Treinando GraphSAGE...")
338
+ gm,gh=train_graphsage(tables,int(hidden_dim),int(num_epochs),log)
339
+ log(f"βœ… GraphSAGE AUC={gm['auc']} F1={gm['f1']} {gm['train_time']}s")
340
+ except Exception as e: return fail(e)
341
+ try:
342
+ progress(0.87,desc="Treinando XGBoost...")
343
+ xm=train_xgboost(tables,log)
344
+ log(f"βœ… XGBoost AUC={xm['auc']} F1={xm['f1']} {xm['train_time']}s")
345
+ except Exception as e: return fail(e)
346
+ try:
347
+ progress(0.95,desc="Plotando...")
348
+ fig=build_figure(rm,gm,xm,rh,gh,routes)
349
+ except Exception as e: return fail(e)
350
+ try:
351
+ mdf=pd.DataFrame([{"Modelo":"πŸ”· RelGNN","AUC":rm["auc"],"F1":rm["f1"],"PrecisΓ£o":rm["precision"],"Recall":rm["recall"],"Tempo(s)":rm["train_time"]},
352
+ {"Modelo":"🟣 GraphSAGE","AUC":gm["auc"],"F1":gm["f1"],"Precisão":gm["precision"],"Recall":gm["recall"],"Tempo(s)":gm["train_time"]},
353
+ {"Modelo":"🟑 XGBoost","AUC":xm["auc"],"F1":xm["f1"],"Precisão":xm["precision"],"Recall":xm["recall"],"Tempo(s)":xm["train_time"]}]).round(4)
354
+ rdf=pd.DataFrame([{"Rota":" β†’ ".join(r.path),"Hops":r.n_hops,"Peso Ξ±":round(r.attention_weight,4),"Ativa":"βœ…" if r.active else "β€”"} for r in routes])
355
+ da=(rm["auc"]-gm["auc"])*100; dt=(1-rm["train_time"]/max(gm["train_time"],0.1))*100
356
+ summary=(f"## 🎯 Resultado Final\n\n||RelGNN|GraphSAGE|Ξ”|\n|---|---|---|---|\n"
357
+ f"|AUC|**{rm['auc']}**|{gm['auc']}|**+{da:.1f}%**|\n"
358
+ f"|F1|**{rm['f1']}**|{gm['f1']}|**+{(rm['f1']-gm['f1'])*100:.1f}%**|\n"
359
+ f"|Tempo|**{rm['train_time']}s**|{gm['train_time']}s|**βˆ’{dt:.0f}%**|\n\n"
360
+ f"πŸš€ {len(routes)} rotas Β· zero grafo estΓ‘tico Β· zero feature engineering")
361
+ progress(1.0); log("🏁 Concluído!")
362
+ return fig,mdf,rdf,summary,"\n".join(logs)
363
+ except Exception as e: return fail(e)
364
 
365
  # ══════════════════════════════════════════════════════════════════════════════
366
  # PLOTLY FIGURE
367
  # ══════════════════════════════════════════════════════════════════════════════
368
 
369
  def build_figure(rm,gm,xm,rh,gh,routes):
370
+ BG="#0a0e1a"; PANEL="#0f1629"; C="#00d4ff"; P="#7c3aed"; A="#f59e0b"; G="#10b981"; GR="#64748b"
371
  specs=[[{"type":"xy"},{"type":"xy"},{"type":"xy"}],[{"type":"xy"},{"type":"xy"},{"type":"polar"}]]
372
  fig=make_subplots(rows=2,cols=3,specs=specs,vertical_spacing=0.22,horizontal_spacing=0.10,
373
+ subplot_titles=["ConvergΓͺncia AUC-ROC","MΓ©tricas","Tempo Treino(s)","Pesos AtenΓ§Γ£o","Ξ” vs GraphSAGE(%)","Radar"])
 
374
  fig.add_trace(go.Scatter(x=[h["epoch"] for h in rh],y=[h["auc"] for h in rh],name="RelGNN",
375
+ line=dict(color=C,width=3),fill="tozeroy",fillcolor="rgba(0,212,255,0.07)"),row=1,col=1)
376
  fig.add_trace(go.Scatter(x=[h["epoch"] for h in gh],y=[h["auc"] for h in gh],name="GraphSAGE",
377
+ line=dict(color=P,width=2,dash="dash")),row=1,col=1)
378
+ mn=["AUC","F1","Prec","Rec"]
379
+ for vals,name,col in [([rm["auc"],rm["f1"],rm["precision"],rm["recall"]],"RelGNN",C),
380
+ ([gm["auc"],gm["f1"],gm["precision"],gm["recall"]],"GraphSAGE",P),
381
+ ([xm["auc"],xm["f1"],xm["precision"],xm["recall"]],"XGBoost",A)]:
382
  fig.add_trace(go.Bar(x=mn,y=vals,name=name,marker_color=col,opacity=0.85,showlegend=False),row=1,col=2)
383
  fig.add_trace(go.Bar(x=["RelGNN","GraphSAGE","XGBoost"],y=[rm["train_time"],gm["train_time"],xm["train_time"]],
384
+ marker_color=[C,P,A],opacity=0.85,showlegend=False,
385
  text=[f"{v:.1f}s" for v in [rm["train_time"],gm["train_time"],xm["train_time"]]],textposition="outside"),row=1,col=3)
386
  rl=[" β†’ ".join(r.path[-2:]) if len(r.path)>2 else " β†’ ".join(r.path) for r in routes]
387
+ fig.add_trace(go.Bar(x=[r.attention_weight for r in routes],y=rl,orientation="h",
388
+ marker_color=[G if r.active else GR for r in routes],opacity=0.85,showlegend=False,
389
+ text=[f"Ξ±={r.attention_weight:.3f}" for r in routes],textposition="outside"),row=2,col=1)
390
  deltas=[(rm[k]-gm[k])*100 for k in ["auc","f1","precision","recall"]]
391
+ fig.add_trace(go.Bar(x=mn,y=deltas,marker_color=[G if d>=0 else "#ef4444" for d in deltas],
392
  opacity=0.85,showlegend=False,text=[f"+{d:.1f}%" if d>=0 else f"{d:.1f}%" for d in deltas],
393
  textposition="outside"),row=2,col=2)
394
+ fig.add_hline(y=0,line_color=GR,line_width=1,row=2,col=2)
395
+ cats=["AUC","F1","Prec","Rec","Speed"]
396
+ mt=max(rm["train_time"],gm["train_time"],xm["train_time"])
397
+ for vals,name,col in [([rm["auc"],rm["f1"],rm["precision"],rm["recall"],1-rm["train_time"]/mt],"RelGNN",C),
398
+ ([gm["auc"],gm["f1"],gm["precision"],gm["recall"],1-gm["train_time"]/mt],"GraphSAGE",P),
399
+ ([xm["auc"],xm["f1"],xm["precision"],xm["recall"],1-xm["train_time"]/mt],"XGBoost",A)]:
400
  fig.add_trace(go.Scatterpolar(r=vals+[vals[0]],theta=cats+[cats[0]],name=name,fill="toself",
401
  line_color=col,opacity=0.55,showlegend=False),row=2,col=3)
402
  fig.update_layout(height=680,paper_bgcolor=BG,plot_bgcolor=PANEL,barmode="group",
403
  font=dict(color="#e2e8f0",family="monospace",size=11),
404
+ title=dict(text="RelGNN Β· TPC-H Fraud Detection",font=dict(size=14,color=C),x=0.5),
405
  legend=dict(bgcolor="#141c33",bordercolor="#1e2d4a"))
406
  fig.update_xaxes(gridcolor="#1e2d4a"); fig.update_yaxes(gridcolor="#1e2d4a")
407
  fig.update_yaxes(range=[0.35,1.05],row=1,col=1); fig.update_yaxes(range=[0.35,1.05],row=1,col=2)
 
411
  # GRADIO UI
412
  # ══════════════════════════════════════════════════════════════════════════════
413
 
414
+ CSS=".gradio-container{max-width:1100px!important} footer{display:none!important}"
415
+ with gr.Blocks(css=CSS,title="RelGNN") as demo:
 
416
  gr.Markdown("# ⬑ RelGNN β€” Deep Relational Learning\n### Do SQL ao Graph AI sem Engenharia Manual Β· TPC-H Fraud Detection")
417
  with gr.Row():
418
+ with gr.Column(scale=1,min_width=230):
419
  gr.Markdown("### βš™οΈ Dataset")
420
+ n_customers=gr.Slider(100,2000,value=500,step=100,label="NΒΊ Clientes")
421
+ n_orders=gr.Slider(500,10000,value=2000,step=500,label="NΒΊ Pedidos")
422
+ fraud_rate=gr.Slider(1,20,value=5,step=1,label="Fraude (%)")
423
  gr.Markdown("### 🧠 Modelo")
424
+ hidden_dim=gr.Slider(16,128,value=64,step=16,label="Hidden Dim")
425
+ num_epochs=gr.Slider(10,100,value=50,step=10,label="Γ‰pocas")
426
+ max_hops=gr.Slider(1,4,value=3,step=1,label="Max Hops")
427
+ btn=gr.Button("πŸš€ Rodar Pipeline",variant="primary",size="lg")
428
  with gr.Column(scale=3):
429
  with gr.Tabs():
430
+ with gr.Tab("πŸ“Š VisualizaΓ§Γ΅es"): plot_out=gr.Plot()
431
  with gr.Tab("πŸ“‹ MΓ©tricas"):
432
+ metrics_out=gr.Dataframe(label="ComparaΓ§Γ£o"); routes_out=gr.Dataframe(label="Rotas AtΓ΄micas")
433
+ with gr.Tab("πŸ“ Resumo"): summary_out=gr.Markdown()
434
+ with gr.Tab("πŸ”§ Log"): log_out=gr.Textbox(lines=22,max_lines=35)
 
435
  btn.click(fn=run_pipeline,
436
  inputs=[n_customers,n_orders,fraud_rate,hidden_dim,num_epochs,max_hops],
437
  outputs=[plot_out,metrics_out,routes_out,summary_out,log_out])
438
 
439
+ if __name__=="__main__":
440
  demo.launch()