Danielfonseca1212's picture
Update app.py
449aa64 verified
import subprocess, sys
subprocess.check_call([sys.executable,"-m","pip","install","torch","--index-url","https://download.pytorch.org/whl/cpu","-q","--root-user-action=ignore"])
subprocess.check_call([sys.executable,"-m","pip","install","plotly","scikit-learn","-q","--root-user-action=ignore"])
import gradio as gr
import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import time, warnings
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score
from sklearn.ensemble import GradientBoostingClassifier
from collections import deque
from dataclasses import dataclass
from typing import List
warnings.filterwarnings("ignore")
# ══════════════════════════════════════════════════════════════════════════════
# TPC-H DATA GENERATOR
# ══════════════════════════════════════════════════════════════════════════════
def generate_tpch_data(n_customers=500, n_orders=2000, fraud_rate=0.05, seed=42):
rng = np.random.default_rng(seed)
n_nat = 25
nations = pd.DataFrame({"n_nationkey":np.arange(n_nat),"n_regionkey":rng.integers(0,5,n_nat)})
n_sup = max(10, n_customers//20)
suppliers = pd.DataFrame({"s_suppkey":np.arange(n_sup),"s_nationkey":rng.integers(0,n_nat,n_sup),
"s_acctbal":rng.uniform(-999,9999,n_sup).round(2)})
suppliers["s_risk_flag"] = (suppliers["s_acctbal"]<100).astype(int)
n_parts = max(50, n_orders//5)
parts = pd.DataFrame({"p_partkey":np.arange(n_parts),"p_retailprice":rng.uniform(5,2000,n_parts).round(2)})
customers = pd.DataFrame({"c_custkey":np.arange(n_customers),"c_nationkey":rng.integers(0,n_nat,n_customers),
"c_acctbal":rng.uniform(-999,9999,n_customers).round(2),
"c_account_age_days":rng.integers(1,3650,n_customers),
"c_num_prev_orders":rng.poisson(5,n_customers)})
ck = rng.integers(0, n_customers, n_orders)
tp = rng.exponential(5000, n_orders).round(2)
fscore = (0.4*(customers.loc[ck,"c_acctbal"].values<0).astype(float)
+ 0.3*(tp>15000).astype(float)
+ 0.2*(customers.loc[ck,"c_account_age_days"].values<30).astype(float)
+ 0.1*rng.random(n_orders))
orders = pd.DataFrame({"o_orderkey":np.arange(n_orders),"o_custkey":ck,
"o_totalprice":tp,"o_shippriority":rng.integers(0,3,n_orders),
"is_fraud":(fscore>=np.quantile(fscore,1-fraud_rate)).astype(int)})
nl = rng.integers(1,8,n_orders)
tl = nl.sum()
lineitem = pd.DataFrame({"l_orderkey":np.repeat(np.arange(n_orders),nl),
"l_partkey":rng.integers(0,n_parts,tl),
"l_suppkey":rng.integers(0,n_sup,tl),
"l_quantity":rng.integers(1,51,tl).astype(float),
"l_extendedprice":rng.uniform(10,5000,tl).round(2),
"l_discount":rng.uniform(0,0.1,tl).round(2),
"l_tax":rng.uniform(0,0.08,tl).round(2)})
return dict(customers=customers,orders=orders,lineitem=lineitem,supplier=suppliers,nation=nations,part=parts)
# ══════════════════════════════════════════════════════════════════════════════
# ATOMIC ROUTES
# ══════════════════════════════════════════════════════════════════════════════
TPCH_FK = [("orders","o_custkey","customers","c_custkey"),
("lineitem","l_orderkey","orders","o_orderkey"),
("lineitem","l_suppkey","supplier","s_suppkey"),
("lineitem","l_partkey","part","p_partkey"),
("customers","c_nationkey","nation","n_nationkey"),
("supplier","s_nationkey","nation","n_nationkey")]
@dataclass
class AtomicRoute:
path: List[str]
n_hops: int = 0
attention_weight: float = 1.0
active: bool = True
def __post_init__(self): self.n_hops = len(self.path)-1
def discover_routes(tables, max_hops=3):
adj = {}
for (s,_,d,__) in TPCH_FK:
adj.setdefault(s,[]).append(d); adj.setdefault(d,[]).append(s)
routes, q = [], deque()
q.append((["customers"],{"customers"}))
while q:
path,visited = q.popleft()
if len(path)-1>=1:
w = 1.0/((len(path)-1)**1.5)
routes.append(AtomicRoute(path=list(path),attention_weight=w,active=(len(path)-1<=2)))
if len(path)-1>=max_hops: continue
for nb in adj.get(path[-1],[]):
if nb not in visited and nb in tables:
q.append((path+[nb],visited|{nb}))
routes.sort(key=lambda r:-r.attention_weight)
ws = np.exp([r.attention_weight for r in routes]); ws/=ws.sum()
for r,w in zip(routes,ws): r.attention_weight=float(w)
return routes
# ══════════════════════════════════════════════════════════════════════════════
# FEATURE EXTRACTION
# ══════════════════════════════════════════════════════════════════════════════
def extract_features(tables):
C,O,L,S,N = tables["customers"],tables["orders"],tables["lineitem"],tables["supplier"],tables["nation"]
fraud_c = O.groupby("o_custkey")["is_fraud"].max()
labels = C["c_custkey"].map(fraud_c).fillna(0).values.astype(float)
def norm(a):
mn,mx=a.min(0,keepdims=True),a.max(0,keepdims=True)
return (a-mn)/np.where(mx-mn==0,1,mx-mn)
c_f = norm(C[["c_acctbal","c_nationkey","c_account_age_days","c_num_prev_orders"]].fillna(0).values.astype(np.float32))
om = O.groupby("o_custkey")[["o_totalprice","o_shippriority"]].mean()
ox = O.groupby("o_custkey")[["o_totalprice"]].max()
oc = O.groupby("o_custkey").size().rename("cnt")
oa = C[["c_custkey"]].set_index("c_custkey").join(om).join(ox,rsuffix="_mx").join(oc).fillna(0)
o_f = norm(oa.values.astype(np.float32))
li = L.merge(O[["o_orderkey","o_custkey"]],left_on="l_orderkey",right_on="o_orderkey",how="left")
lm = li.groupby("o_custkey")[["l_quantity","l_extendedprice","l_discount","l_tax"]].mean()
lc = li.groupby("o_custkey").size().rename("cnt")
la = C[["c_custkey"]].set_index("c_custkey").join(lm).join(lc).fillna(0)
l_f = norm(la.values.astype(np.float32))
sw = li.merge(S,left_on="l_suppkey",right_on="s_suppkey",how="left")
sm = sw.groupby("o_custkey")[["s_acctbal","s_risk_flag"]].mean()
sa = C[["c_custkey"]].set_index("c_custkey").join(sm).fillna(0)
s_f = norm(sa.values.astype(np.float32))
nj = C[["c_custkey","c_nationkey"]].merge(N,left_on="c_nationkey",right_on="n_nationkey",how="left")[["n_nationkey","n_regionkey"]].fillna(0)
n_f = norm(nj.values.astype(np.float32))
return dict(customers=c_f,orders=o_f,lineitem=l_f,supplier=s_f,nation=n_f), labels
# ══════════════════════════════════════════════════════════════════════════════
# RELGNN MODEL
# ══════════════════════════════════════════════════════════════════════════════
class TableEncoder(nn.Module):
def __init__(self,ind,hid):
super().__init__()
self.net=nn.Sequential(nn.Linear(ind,hid*2),nn.LayerNorm(hid*2),nn.ReLU(),nn.Dropout(0.2),
nn.Linear(hid*2,hid),nn.LayerNorm(hid),nn.ReLU())
def forward(self,x): return self.net(x)
class RouteAttn(nn.Module):
def __init__(self,hid):
super().__init__()
heads = max(1, min(4, hid//16))
self.attn=nn.MultiheadAttention(hid,heads,dropout=0.1,batch_first=True)
self.norm=nn.LayerNorm(hid)
self.mlp=nn.Sequential(nn.Linear(hid,hid*2),nn.ReLU(),nn.Dropout(0.1),nn.Linear(hid*2,hid))
def forward(self,hops):
out,_=self.attn(hops,hops,hops); out=self.norm(out+hops)
return out[:,0,:]+self.mlp(out[:,0,:])
class RelGNNModel(nn.Module):
def __init__(self,fdims,hid,routes):
super().__init__()
self.encs=nn.ModuleDict({t:TableEncoder(d,hid) for t,d in fdims.items()})
self.rattn=nn.ModuleList([RouteAttn(hid) for _ in routes])
self.rw=nn.Parameter(torch.ones(len(routes)))
self.head=nn.Sequential(nn.Linear(hid,hid//2),nn.ReLU(),nn.Dropout(0.2),nn.Linear(hid//2,1))
self.routes=routes
def forward(self,feats):
embs={t:enc(feats[t]) for t,enc in self.encs.items() if t in feats}
res=[]
for route,attn in zip(self.routes,self.rattn):
av=[t for t in route.path if t in embs]
if len(av)<2: res.append(list(embs.values())[0]); continue
hops=torch.stack([embs[t] for t in av],dim=1)
res.append(attn(hops))
stacked=torch.stack(res,dim=1)
w=F.softmax(self.rw,dim=0)
agg=(stacked*w.unsqueeze(0).unsqueeze(-1)).sum(1)
return self.head(agg).squeeze(-1)
def train_relgnn(tables,routes,hidden=64,epochs=50,log_fn=print,pfn=None):
t0=time.time()
fn,labels=extract_features(tables)
fdims={k:v.shape[1] for k,v in fn.items()}
model=RelGNNModel(fdims,hidden,routes)
opt=optim.AdamW(model.parameters(),lr=1e-3,weight_decay=1e-4)
sch=optim.lr_scheduler.CosineAnnealingLR(opt,T_max=epochs)
idx=np.arange(len(labels))
itr,ite=train_test_split(idx,test_size=0.2,random_state=42,stratify=(labels>0.5).astype(int))
def T(ix): return {k:torch.tensor(v[ix],dtype=torch.float32) for k,v in fn.items()}
ytr=torch.tensor(labels[itr],dtype=torch.float32)
pw=torch.tensor([(ytr==0).sum()/max((ytr==1).sum(),1)])
lossfn=nn.BCEWithLogitsLoss(pos_weight=pw)
history=[]; logi=max(1,epochs//8)
model.train()
for ep in range(1,epochs+1):
opt.zero_grad(); l=lossfn(model(T(itr)),ytr); l.backward()
nn.utils.clip_grad_norm_(model.parameters(),1.0); opt.step(); sch.step()
if ep%logi==0 or ep==epochs:
model.eval()
with torch.no_grad(): p=torch.sigmoid(model(T(ite))).numpy()
try: auc=roc_auc_score(labels[ite],p)
except: auc=0.5
history.append({"epoch":ep,"auc":auc})
log_fn(f" RelGNN ep={ep}/{epochs} loss={float(l):.4f} auc={auc:.4f}")
model.train()
if pfn: pfn(0.30+0.38*(ep/epochs),desc=f"RelGNN {ep}/{epochs}")
model.eval()
with torch.no_grad(): p=torch.sigmoid(model(T(ite))).numpy()
pred=(p>0.5).astype(int); yt=labels[ite].astype(int)
try: m=dict(auc=round(roc_auc_score(yt,p),4),f1=round(f1_score(yt,pred,zero_division=0),4),
precision=round(precision_score(yt,pred,zero_division=0),4),
recall=round(recall_score(yt,pred,zero_division=0),4),train_time=round(time.time()-t0,1))
except: m=dict(auc=0.5,f1=0.5,precision=0.5,recall=0.5,train_time=round(time.time()-t0,1))
w=F.softmax(model.rw,dim=0).detach().numpy()
for i,r in enumerate(routes):
if i<len(w): r.attention_weight=float(w[i]); r.active=float(w[i])>0.15
return m,history
# ══════════════════════════════════════════════════════════════════════════════
# GRAPHSAGE BASELINE
# ══════════════════════════════════════════════════════════════════════════════
class SAGEConv(nn.Module):
def __init__(self,i,o):
super().__init__()
self.Ws=nn.Linear(i,o,bias=False); self.Wn=nn.Linear(i,o,bias=False); self.b=nn.Parameter(torch.zeros(o))
def forward(self,h,adj): return F.relu(self.Ws(h)+self.Wn(torch.mm(adj,h))+self.b)
class GSNet(nn.Module):
def __init__(self,i,h):
super().__init__()
self.c1=SAGEConv(i,h); self.c2=SAGEConv(h,h); self.d=nn.Dropout(0.2); self.head=nn.Linear(h,1)
def forward(self,h,adj): return self.head(self.c2(self.d(self.c1(h,adj)),adj)).squeeze(-1)
def train_graphsage(tables,hidden=64,epochs=50,log_fn=print):
t0=time.time(); C,O=tables["customers"],tables["orders"]
MX=1500; nc=min(len(C),MX); no=min(len(O),MX)
cf=C[["c_acctbal","c_nationkey","c_account_age_days","c_num_prev_orders"]].iloc[:nc].fillna(0).values.astype(np.float32)
of=O[["o_totalprice","o_shippriority"]].iloc[:no].fillna(0).values.astype(np.float32)
md=max(cf.shape[1],of.shape[1])
def pad(a,t): return np.hstack([a,np.zeros((len(a),t-a.shape[1]),dtype=np.float32)]) if a.shape[1]<t else a
X=np.vstack([pad(cf,md),pad(of,md)]); N=len(X)
X=(X-X.mean(0))/np.where(X.std(0)==0,1,X.std(0))
ck=O["o_custkey"].values[:no]; oi=np.arange(no)+nc; vm=ck<nc
src=np.concatenate([ck[vm],oi[vm]]); dst=np.concatenate([oi[vm],ck[vm]])
adj=torch.zeros(N,N)
for s,d in zip(src,dst):
if s<N and d<N: adj[d,s]=1.0
adj=adj/adj.sum(1,keepdim=True).clamp(min=1)
fc=O.groupby("o_custkey")["is_fraud"].max()
labels=C["c_custkey"].iloc[:nc].map(fc).fillna(0).values.astype(np.float32)
Xt=torch.tensor(X,dtype=torch.float32); ci=np.arange(nc)
itr,ite=train_test_split(ci,test_size=0.2,random_state=42,stratify=(labels>0.5).astype(int))
ytr=torch.tensor(labels[itr],dtype=torch.float32)
pw=torch.tensor([(ytr==0).sum()/max((ytr==1).sum(),1)])
model=GSNet(md,hidden); opt=optim.AdamW(model.parameters(),lr=1e-3)
lossfn=nn.BCEWithLogitsLoss(pos_weight=pw); logi=max(1,epochs//5); history=[]
model.train()
for ep in range(1,epochs+1):
opt.zero_grad(); l=lossfn(model(Xt,adj)[itr],ytr); l.backward()
nn.utils.clip_grad_norm_(model.parameters(),1.0); opt.step()
if ep%logi==0 or ep==epochs:
model.eval()
with torch.no_grad(): p=torch.sigmoid(model(Xt,adj)[ite]).numpy()
try: auc=roc_auc_score(labels[ite],p)
except: auc=0.5
history.append({"epoch":ep,"auc":auc}); model.train()
model.eval()
with torch.no_grad(): p=torch.sigmoid(model(Xt,adj)[ite]).numpy()
pred=(p>0.5).astype(int); yt=labels[ite].astype(int)
try: m=dict(auc=round(roc_auc_score(yt,p),4),f1=round(f1_score(yt,pred,zero_division=0),4),
precision=round(precision_score(yt,pred,zero_division=0),4),
recall=round(recall_score(yt,pred,zero_division=0),4),train_time=round(time.time()-t0,1))
except: m=dict(auc=0.5,f1=0.5,precision=0.5,recall=0.5,train_time=round(time.time()-t0,1))
log_fn(f" [GraphSAGE] {N} nΓ³s Β· {len(src)} arestas Β· {m['train_time']}s")
return m,history
# ══════════════════════════════════════════════════════════════════════════════
# XGBOOST BASELINE
# ══════════════════════════════════════════════════════════════════════════════
def train_xgboost(tables,log_fn=print):
t0=time.time(); C,O,L,S=tables["customers"],tables["orders"],tables["lineitem"],tables["supplier"]
f=C[["c_custkey","c_acctbal","c_nationkey","c_account_age_days","c_num_prev_orders"]].copy()
oa=O.groupby("o_custkey").agg(oc=("o_orderkey","count"),om=("o_totalprice","mean"),ox=("o_totalprice","max")).reset_index().rename(columns={"o_custkey":"c_custkey"})
f=f.merge(oa,on="c_custkey",how="left")
li=L.merge(O[["o_orderkey","o_custkey"]],left_on="l_orderkey",right_on="o_orderkey",how="left")
la=li.groupby("o_custkey").agg(lc=("l_quantity","count"),lp=("l_extendedprice","mean"),ld=("l_discount","mean")).reset_index().rename(columns={"o_custkey":"c_custkey"})
f=f.merge(la,on="c_custkey",how="left")
sw=li.merge(S,left_on="l_suppkey",right_on="s_suppkey",how="left")
sa=sw.groupby("o_custkey").agg(sr=("s_risk_flag","sum"),sb=("s_acctbal","mean")).reset_index().rename(columns={"o_custkey":"c_custkey"})
f=f.merge(sa,on="c_custkey",how="left").drop(columns=["c_custkey"]).fillna(0)
fc=O.groupby("o_custkey")["is_fraud"].max()
y=C["c_custkey"].map(fc).fillna(0).values.astype(int)
X=f.values.astype(np.float32)
itr,ite=train_test_split(np.arange(len(y)),test_size=0.2,random_state=42,stratify=y)
model=GradientBoostingClassifier(n_estimators=80,max_depth=4,learning_rate=0.05,subsample=0.8,random_state=42)
model.fit(X[itr],y[itr]); p=model.predict_proba(X[ite])[:,1]; pred=(p>0.5).astype(int)
try: m=dict(auc=round(roc_auc_score(y[ite],p),4),f1=round(f1_score(y[ite],pred,zero_division=0),4),
precision=round(precision_score(y[ite],pred,zero_division=0),4),
recall=round(recall_score(y[ite],pred,zero_division=0),4),train_time=round(time.time()-t0,1))
except: m=dict(auc=0.5,f1=0.5,precision=0.5,recall=0.5,train_time=round(time.time()-t0,1))
log_fn(f" [XGBoost] {X.shape[1]} features Β· {m['train_time']}s")
return m
# ══════════════════════════════════════════════════════════════════════════════
# PIPELINE
# ══════════════════════════════════════════════════════════════════════════════
def _empty_fig(msg="Erro"):
fig=go.Figure(); fig.add_annotation(text=msg,xref="paper",yref="paper",x=0.5,y=0.5,showarrow=False,
font=dict(size=16,color="#ef4444")); fig.update_layout(paper_bgcolor="#0a0e1a",plot_bgcolor="#0f1629"); return fig
def _empty_df(*cols): return pd.DataFrame({c:[] for c in cols})
def run_pipeline(n_customers,n_orders,fraud_rate,hidden_dim,num_epochs,max_hops,progress=gr.Progress()):
import traceback
logs=[]; log=lambda m: logs.append(str(m))
def fail(e):
tb=traceback.format_exc(); log(f"❌ ERRO: {e}"); log(tb)
return (_empty_fig(f"Erro: {e}"),
_empty_df("Modelo","AUC","F1","PrecisΓ£o","Recall","Tempo(s)"),
_empty_df("Rota","Hops","Peso Ξ±","Ativa"),
f"## ❌ Erro\n```\n{tb}\n```",
"\n".join(logs))
try:
progress(0.05,desc="Gerando TPC-H...")
tables=generate_tpch_data(int(n_customers),int(n_orders),float(fraud_rate)/100,seed=42)
log(f"βœ… {int(n_customers)} clientes Β· {int(n_orders)} pedidos Β· {tables['orders']['is_fraud'].sum()} fraudes")
except Exception as e: return fail(e)
try:
progress(0.15,desc="Rotas atΓ΄micas...")
routes=discover_routes(tables,max_hops=int(max_hops))
log(f"βœ… {len(routes)} rotas atΓ΄micas")
for r in routes: log(f" β†’ {' β†’ '.join(r.path)} (hops={r.n_hops} Ξ±={r.attention_weight:.3f})")
except Exception as e: return fail(e)
try:
progress(0.30,desc="Treinando RelGNN...")
rm,rh=train_relgnn(tables,routes,int(hidden_dim),int(num_epochs),log,progress)
log(f"βœ… RelGNN AUC={rm['auc']} F1={rm['f1']} {rm['train_time']}s")
except Exception as e: return fail(e)
try:
progress(0.70,desc="Treinando GraphSAGE...")
gm,gh=train_graphsage(tables,int(hidden_dim),int(num_epochs),log)
log(f"βœ… GraphSAGE AUC={gm['auc']} F1={gm['f1']} {gm['train_time']}s")
except Exception as e: return fail(e)
try:
progress(0.87,desc="Treinando XGBoost...")
xm=train_xgboost(tables,log)
log(f"βœ… XGBoost AUC={xm['auc']} F1={xm['f1']} {xm['train_time']}s")
except Exception as e: return fail(e)
try:
progress(0.95,desc="Plotando...")
fig=build_figure(rm,gm,xm,rh,gh,routes)
except Exception as e: return fail(e)
try:
mdf=pd.DataFrame([{"Modelo":"πŸ”· RelGNN","AUC":rm["auc"],"F1":rm["f1"],"PrecisΓ£o":rm["precision"],"Recall":rm["recall"],"Tempo(s)":rm["train_time"]},
{"Modelo":"🟣 GraphSAGE","AUC":gm["auc"],"F1":gm["f1"],"Precisão":gm["precision"],"Recall":gm["recall"],"Tempo(s)":gm["train_time"]},
{"Modelo":"🟑 XGBoost","AUC":xm["auc"],"F1":xm["f1"],"Precisão":xm["precision"],"Recall":xm["recall"],"Tempo(s)":xm["train_time"]}]).round(4)
rdf=pd.DataFrame([{"Rota":" β†’ ".join(r.path),"Hops":r.n_hops,"Peso Ξ±":round(r.attention_weight,4),"Ativa":"βœ…" if r.active else "β€”"} for r in routes])
da=(rm["auc"]-gm["auc"])*100; dt=(1-rm["train_time"]/max(gm["train_time"],0.1))*100
summary=(f"## 🎯 Resultado Final\n\n||RelGNN|GraphSAGE|Ξ”|\n|---|---|---|---|\n"
f"|AUC|**{rm['auc']}**|{gm['auc']}|**+{da:.1f}%**|\n"
f"|F1|**{rm['f1']}**|{gm['f1']}|**+{(rm['f1']-gm['f1'])*100:.1f}%**|\n"
f"|Tempo|**{rm['train_time']}s**|{gm['train_time']}s|**βˆ’{dt:.0f}%**|\n\n"
f"πŸš€ {len(routes)} rotas Β· zero grafo estΓ‘tico Β· zero feature engineering")
progress(1.0); log("🏁 Concluído!")
return fig,mdf,rdf,summary,"\n".join(logs)
except Exception as e: return fail(e)
# ══════════════════════════════════════════════════════════════════════════════
# PLOTLY FIGURE
# ══════════════════════════════════════════════════════════════════════════════
def build_figure(rm,gm,xm,rh,gh,routes):
BG="#0a0e1a"; PANEL="#0f1629"; C="#00d4ff"; P="#7c3aed"; A="#f59e0b"; G="#10b981"; GR="#64748b"
specs=[[{"type":"xy"},{"type":"xy"},{"type":"xy"}],[{"type":"xy"},{"type":"xy"},{"type":"polar"}]]
fig=make_subplots(rows=2,cols=3,specs=specs,vertical_spacing=0.22,horizontal_spacing=0.10,
subplot_titles=["ConvergΓͺncia AUC-ROC","MΓ©tricas","Tempo Treino(s)","Pesos AtenΓ§Γ£o","Ξ” vs GraphSAGE(%)","Radar"])
fig.add_trace(go.Scatter(x=[h["epoch"] for h in rh],y=[h["auc"] for h in rh],name="RelGNN",
line=dict(color=C,width=3),fill="tozeroy",fillcolor="rgba(0,212,255,0.07)"),row=1,col=1)
fig.add_trace(go.Scatter(x=[h["epoch"] for h in gh],y=[h["auc"] for h in gh],name="GraphSAGE",
line=dict(color=P,width=2,dash="dash")),row=1,col=1)
mn=["AUC","F1","Prec","Rec"]
for vals,name,col in [([rm["auc"],rm["f1"],rm["precision"],rm["recall"]],"RelGNN",C),
([gm["auc"],gm["f1"],gm["precision"],gm["recall"]],"GraphSAGE",P),
([xm["auc"],xm["f1"],xm["precision"],xm["recall"]],"XGBoost",A)]:
fig.add_trace(go.Bar(x=mn,y=vals,name=name,marker_color=col,opacity=0.85,showlegend=False),row=1,col=2)
fig.add_trace(go.Bar(x=["RelGNN","GraphSAGE","XGBoost"],y=[rm["train_time"],gm["train_time"],xm["train_time"]],
marker_color=[C,P,A],opacity=0.85,showlegend=False,
text=[f"{v:.1f}s" for v in [rm["train_time"],gm["train_time"],xm["train_time"]]],textposition="outside"),row=1,col=3)
rl=[" β†’ ".join(r.path[-2:]) if len(r.path)>2 else " β†’ ".join(r.path) for r in routes]
fig.add_trace(go.Bar(x=[r.attention_weight for r in routes],y=rl,orientation="h",
marker_color=[G if r.active else GR for r in routes],opacity=0.85,showlegend=False,
text=[f"Ξ±={r.attention_weight:.3f}" for r in routes],textposition="outside"),row=2,col=1)
deltas=[(rm[k]-gm[k])*100 for k in ["auc","f1","precision","recall"]]
fig.add_trace(go.Bar(x=mn,y=deltas,marker_color=[G if d>=0 else "#ef4444" for d in deltas],
opacity=0.85,showlegend=False,text=[f"+{d:.1f}%" if d>=0 else f"{d:.1f}%" for d in deltas],
textposition="outside"),row=2,col=2)
fig.add_hline(y=0,line_color=GR,line_width=1,row=2,col=2)
cats=["AUC","F1","Prec","Rec","Speed"]
mt=max(rm["train_time"],gm["train_time"],xm["train_time"])
for vals,name,col in [([rm["auc"],rm["f1"],rm["precision"],rm["recall"],1-rm["train_time"]/mt],"RelGNN",C),
([gm["auc"],gm["f1"],gm["precision"],gm["recall"],1-gm["train_time"]/mt],"GraphSAGE",P),
([xm["auc"],xm["f1"],xm["precision"],xm["recall"],1-xm["train_time"]/mt],"XGBoost",A)]:
fig.add_trace(go.Scatterpolar(r=vals+[vals[0]],theta=cats+[cats[0]],name=name,fill="toself",
line_color=col,opacity=0.55,showlegend=False),row=2,col=3)
fig.update_layout(height=680,paper_bgcolor=BG,plot_bgcolor=PANEL,barmode="group",
font=dict(color="#e2e8f0",family="monospace",size=11),
title=dict(text="RelGNN Β· TPC-H Fraud Detection",font=dict(size=14,color=C),x=0.5),
legend=dict(bgcolor="#141c33",bordercolor="#1e2d4a"))
fig.update_xaxes(gridcolor="#1e2d4a"); fig.update_yaxes(gridcolor="#1e2d4a")
fig.update_yaxes(range=[0.35,1.05],row=1,col=1); fig.update_yaxes(range=[0.35,1.05],row=1,col=2)
return fig
# ══════════════════════════════════════════════════════════════════════════════
# GRADIO UI
# ══════════════════════════════════════════════════════════════════════════════
CSS=".gradio-container{max-width:1100px!important} footer{display:none!important}"
with gr.Blocks(css=CSS,title="RelGNN") as demo:
gr.Markdown("# ⬑ RelGNN β€” Deep Relational Learning\n### Do SQL ao Graph AI sem Engenharia Manual Β· TPC-H Fraud Detection")
with gr.Row():
with gr.Column(scale=1,min_width=230):
gr.Markdown("### βš™οΈ Dataset")
n_customers=gr.Slider(100,2000,value=500,step=100,label="NΒΊ Clientes")
n_orders=gr.Slider(500,10000,value=2000,step=500,label="NΒΊ Pedidos")
fraud_rate=gr.Slider(1,20,value=5,step=1,label="Fraude (%)")
gr.Markdown("### 🧠 Modelo")
hidden_dim=gr.Slider(16,128,value=64,step=16,label="Hidden Dim")
num_epochs=gr.Slider(10,100,value=50,step=10,label="Γ‰pocas")
max_hops=gr.Slider(1,4,value=3,step=1,label="Max Hops")
btn=gr.Button("πŸš€ Rodar Pipeline",variant="primary",size="lg")
with gr.Column(scale=3):
with gr.Tabs():
with gr.Tab("πŸ“Š VisualizaΓ§Γ΅es"): plot_out=gr.Plot()
with gr.Tab("πŸ“‹ MΓ©tricas"):
metrics_out=gr.Dataframe(label="ComparaΓ§Γ£o"); routes_out=gr.Dataframe(label="Rotas AtΓ΄micas")
with gr.Tab("πŸ“ Resumo"): summary_out=gr.Markdown()
with gr.Tab("πŸ”§ Log"): log_out=gr.Textbox(lines=22,max_lines=35)
btn.click(fn=run_pipeline,
inputs=[n_customers,n_orders,fraud_rate,hidden_dim,num_epochs,max_hops],
outputs=[plot_out,metrics_out,routes_out,summary_out,log_out])
if __name__=="__main__":
demo.launch()