Danielfonseca1212 commited on
Commit
0b037bd
·
verified ·
1 Parent(s): 2b0de6d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +288 -0
app.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RelGNN — Deep Relational Learning para Detecção de Fraude
3
+ Projeto 8: Do SQL ao Graph AI sem Engenharia Manual
4
+
5
+ Hugging Face Spaces — Gradio Interface
6
+ """
7
+
8
+ import gradio as gr
9
+ import pandas as pd
10
+ import numpy as np
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib.patches as mpatches
13
+ from matplotlib.gridspec import GridSpec
14
+ import warnings
15
+ warnings.filterwarnings("ignore")
16
+
17
+ from data.tpch_generator import generate_tpch_data
18
+ from data.routes import discover_atomic_routes, RouteConfig
19
+ from relgnn.model import RelGNN, RelGNNConfig
20
+ from relgnn.trainer import Trainer
21
+ from baseline.graphsage_baseline import GraphSAGEBaseline
22
+ from baseline.xgboost_baseline import XGBoostBaseline
23
+
24
+ # ─── GLOBALS ──────────────────────────────────────────────────────────────────
25
+
26
+ RESULTS_CACHE = {}
27
+
28
+ # ─── CORE PIPELINE ────────────────────────────────────────────────────────────
29
+
30
+ def run_full_pipeline(n_customers, n_orders, fraud_rate, hidden_dim, num_epochs, max_hops, progress=gr.Progress()):
31
+ """Full pipeline: gera dados → treina RelGNN → compara baselines → retorna resultados."""
32
+
33
+ logs = []
34
+ def log(msg):
35
+ logs.append(msg)
36
+ return "\n".join(logs)
37
+
38
+ progress(0.05, desc="Gerando dataset TPC-H sintético...")
39
+ tables = generate_tpch_data(
40
+ n_customers=int(n_customers),
41
+ n_orders=int(n_orders),
42
+ fraud_rate=float(fraud_rate) / 100.0,
43
+ seed=42
44
+ )
45
+ fraud_count = tables["orders"]["is_fraud"].sum()
46
+ log(f"✅ Dataset gerado: {int(n_customers)} clientes, {int(n_orders)} pedidos, {fraud_count} fraudes ({fraud_rate:.1f}%)")
47
+
48
+ progress(0.15, desc="Descobrindo rotas atômicas...")
49
+ route_config = RouteConfig(max_hops=int(max_hops))
50
+ routes = discover_atomic_routes(tables, route_config)
51
+ log(f"✅ {len(routes)} rotas atômicas descobertas (max {max_hops} hops)")
52
+ for r in routes:
53
+ log(f" → {' → '.join(r.path)} (hop={r.n_hops})")
54
+
55
+ progress(0.30, desc="Treinando RelGNN...")
56
+ config = RelGNNConfig(
57
+ hidden_dim=int(hidden_dim),
58
+ num_epochs=int(num_epochs),
59
+ learning_rate=1e-3,
60
+ dropout=0.2,
61
+ )
62
+ relgnn = RelGNN(config)
63
+ relgnn_metrics, relgnn_history = relgnn.fit(tables, routes, log_fn=log, progress_fn=progress)
64
+ log(f"✅ RelGNN — AUC: {relgnn_metrics['auc']:.4f} F1: {relgnn_metrics['f1']:.4f} Tempo: {relgnn_metrics['train_time']:.1f}s")
65
+
66
+ progress(0.70, desc="Treinando GraphSAGE baseline...")
67
+ graphsage = GraphSAGEBaseline(hidden_dim=int(hidden_dim), num_epochs=int(num_epochs))
68
+ gs_metrics, gs_history = graphsage.fit(tables, log_fn=log)
69
+ log(f"✅ GraphSAGE — AUC: {gs_metrics['auc']:.4f} F1: {gs_metrics['f1']:.4f} Tempo: {gs_metrics['train_time']:.1f}s")
70
+
71
+ progress(0.85, desc="Treinando XGBoost baseline...")
72
+ xgb = XGBoostBaseline()
73
+ xgb_metrics = xgb.fit(tables, log_fn=log)
74
+ log(f"✅ XGBoost — AUC: {xgb_metrics['auc']:.4f} F1: {xgb_metrics['f1']:.4f} Tempo: {xgb_metrics['train_time']:.1f}s")
75
+
76
+ progress(0.93, desc="Gerando visualizações...")
77
+
78
+ fig = plot_results(relgnn_metrics, gs_metrics, xgb_metrics, relgnn_history, gs_history, routes)
79
+
80
+ metrics_df = pd.DataFrame([
81
+ {"Modelo": "🔷 RelGNN (Rotas Atômicas)", **relgnn_metrics},
82
+ {"Modelo": "🟣 GraphSAGE (Grafo Estático)", **gs_metrics},
83
+ {"Modelo": "🟡 XGBoost (Flat Features)", **xgb_metrics},
84
+ ]).rename(columns={"auc": "AUC-ROC", "f1": "F1-Score",
85
+ "precision": "Precisão", "recall": "Recall",
86
+ "train_time": "Tempo (s)"})
87
+ metrics_df = metrics_df.round(4)
88
+
89
+ routes_df = pd.DataFrame([{
90
+ "Rota": " → ".join(r.path),
91
+ "Hops": r.n_hops,
92
+ "Peso α": f"{r.attention_weight:.3f}",
93
+ "Ativa": "✅" if r.active else "—",
94
+ } for r in routes])
95
+
96
+ delta_auc = (relgnn_metrics["auc"] - gs_metrics["auc"]) * 100
97
+ delta_f1 = (relgnn_metrics["f1"] - gs_metrics["f1"]) * 100
98
+ delta_time = (1 - relgnn_metrics["train_time"] / gs_metrics["train_time"]) * 100
99
+
100
+ summary = (
101
+ f"## 🎯 Resultado Final\n\n"
102
+ f"| Métrica | RelGNN | GraphSAGE | Δ |\n"
103
+ f"|---------|--------|-----------|---|\n"
104
+ f"| AUC-ROC | **{relgnn_metrics['auc']:.4f}** | {gs_metrics['auc']:.4f} | **+{delta_auc:.1f}%** |\n"
105
+ f"| F1-Score | **{relgnn_metrics['f1']:.4f}** | {gs_metrics['f1']:.4f} | **+{delta_f1:.1f}%** |\n"
106
+ f"| Tempo Treino | **{relgnn_metrics['train_time']:.1f}s** | {gs_metrics['train_time']:.1f}s | **−{delta_time:.0f}%** |\n\n"
107
+ f"✅ RelGNN é **+{delta_auc:.1f}% mais preciso** e **{delta_time:.0f}% mais rápido** que GraphSAGE.\n"
108
+ f"🔑 **{len(routes)} rotas atômicas** aprendidas automaticamente das FKs do schema SQL.\n"
109
+ f"🚀 **Zero engenharia manual** — sem conversão explícita para grafo."
110
+ )
111
+
112
+ RESULTS_CACHE["last"] = {
113
+ "relgnn": relgnn_metrics,
114
+ "graphsage": gs_metrics,
115
+ "xgboost": xgb_metrics,
116
+ }
117
+
118
+ progress(1.0, desc="Concluído!")
119
+ log("─" * 60)
120
+ log("🏁 Pipeline completo!")
121
+
122
+ return fig, metrics_df, routes_df, summary, "\n".join(logs)
123
+
124
+
125
+ def plot_results(rm, gm, xm, rh, gh, routes):
126
+ plt.style.use("dark_background")
127
+ fig = plt.figure(figsize=(14, 9), facecolor="#0a0e1a")
128
+ gs = GridSpec(2, 3, figure=fig, hspace=0.45, wspace=0.35)
129
+
130
+ CYAN = "#00d4ff"
131
+ PURPLE = "#7c3aed"
132
+ AMBER = "#f59e0b"
133
+ GREEN = "#10b981"
134
+ PANEL = "#0f1629"
135
+
136
+ ax_curve = fig.add_subplot(gs[0, :2])
137
+ ax_bar = fig.add_subplot(gs[0, 2])
138
+ ax_route = fig.add_subplot(gs[1, 0])
139
+ ax_time = fig.add_subplot(gs[1, 1])
140
+ ax_delta = fig.add_subplot(gs[1, 2])
141
+
142
+ for ax in [ax_curve, ax_bar, ax_route, ax_time, ax_delta]:
143
+ ax.set_facecolor(PANEL)
144
+ for spine in ax.spines.values():
145
+ spine.set_color("#1e2d4a")
146
+
147
+ # 1. Training curves
148
+ epochs_r = [h["epoch"] for h in rh]
149
+ auc_r = [h["auc"] for h in rh]
150
+ epochs_g = [h["epoch"] for h in gh]
151
+ auc_g = [h["auc"] for h in gh]
152
+
153
+ ax_curve.plot(epochs_r, auc_r, color=CYAN, lw=2.5, label="RelGNN", zorder=3)
154
+ ax_curve.plot(epochs_g, auc_g, color=PURPLE, lw=2, label="GraphSAGE", linestyle="--", zorder=2)
155
+ ax_curve.fill_between(epochs_r, auc_r, alpha=0.12, color=CYAN)
156
+ ax_curve.set_title("Curva de Convergência (AUC-ROC)", color="white", fontsize=11, pad=8)
157
+ ax_curve.set_xlabel("Época", color="#64748b", fontsize=9)
158
+ ax_curve.set_ylabel("AUC-ROC", color="#64748b", fontsize=9)
159
+ ax_curve.tick_params(colors="#64748b", labelsize=8)
160
+ ax_curve.legend(facecolor="#141c33", edgecolor="#1e2d4a", labelcolor="white", fontsize=9)
161
+ ax_curve.grid(color="#1e2d4a", alpha=0.5, linewidth=0.5)
162
+ ax_curve.set_ylim(0.5, 1.0)
163
+
164
+ # 2. Bar comparison
165
+ metrics = ["AUC", "F1", "Prec", "Rec"]
166
+ relgnn_v = [rm["auc"], rm["f1"], rm["precision"], rm["recall"]]
167
+ graph_v = [gm["auc"], gm["f1"], gm["precision"], gm["recall"]]
168
+ xgb_v = [xm["auc"], xm["f1"], xm["precision"], xm["recall"]]
169
+
170
+ x = np.arange(len(metrics))
171
+ w = 0.25
172
+ ax_bar.bar(x - w, relgnn_v, w, color=CYAN, alpha=0.85, label="RelGNN")
173
+ ax_bar.bar(x, graph_v, w, color=PURPLE, alpha=0.85, label="GraphSAGE")
174
+ ax_bar.bar(x + w, xgb_v, w, color=AMBER, alpha=0.85, label="XGBoost")
175
+ ax_bar.set_title("Métricas Comparativas", color="white", fontsize=11, pad=8)
176
+ ax_bar.set_xticks(x)
177
+ ax_bar.set_xticklabels(metrics, color="#64748b", fontsize=8)
178
+ ax_bar.set_ylim(0.5, 1.05)
179
+ ax_bar.tick_params(colors="#64748b", labelsize=8)
180
+ ax_bar.legend(facecolor="#141c33", edgecolor="#1e2d4a", labelcolor="white", fontsize=7)
181
+ ax_bar.grid(axis="y", color="#1e2d4a", alpha=0.5, linewidth=0.5)
182
+
183
+ # 3. Atomic routes weights
184
+ route_labels = [" → ".join(r.path[-2:]) if len(r.path) > 2 else " → ".join(r.path)
185
+ for r in routes]
186
+ route_weights = [r.attention_weight for r in routes]
187
+ colors_r = [GREEN if r.active else "#334155" for r in routes]
188
+
189
+ bars = ax_route.barh(route_labels, route_weights, color=colors_r, alpha=0.85)
190
+ ax_route.set_title("Pesos de Atenção (α)\nRotas Atômicas", color="white", fontsize=10, pad=8)
191
+ ax_route.set_xlim(0, 1)
192
+ ax_route.tick_params(colors="#64748b", labelsize=7)
193
+ ax_route.grid(axis="x", color="#1e2d4a", alpha=0.5, linewidth=0.5)
194
+ for bar, w_ in zip(bars, route_weights):
195
+ ax_route.text(w_ + 0.02, bar.get_y() + bar.get_height()/2,
196
+ f"{w_:.2f}", va="center", color="white", fontsize=8)
197
+
198
+ # 4. Training time
199
+ models_t = ["RelGNN", "GraphSAGE", "XGBoost"]
200
+ times = [rm["train_time"], gm["train_time"], xm["train_time"]]
201
+ cols_t = [CYAN, PURPLE, AMBER]
202
+ ax_time.bar(models_t, times, color=cols_t, alpha=0.85, width=0.5)
203
+ ax_time.set_title("Tempo de Treino (s)", color="white", fontsize=11, pad=8)
204
+ ax_time.tick_params(colors="#64748b", labelsize=8)
205
+ ax_time.grid(axis="y", color="#1e2d4a", alpha=0.5, linewidth=0.5)
206
+ for i, (t, c) in enumerate(zip(times, cols_t)):
207
+ ax_time.text(i, t + 0.5, f"{t:.1f}s", ha="center", color=c, fontsize=9, fontweight="bold")
208
+
209
+ # 5. Delta vs GraphSAGE
210
+ delta_metrics = ["AUC", "F1", "Precisão", "Recall"]
211
+ deltas = [
212
+ (rm["auc"] - gm["auc"]) * 100,
213
+ (rm["f1"] - gm["f1"]) * 100,
214
+ (rm["precision"] - gm["precision"]) * 100,
215
+ (rm["recall"] - gm["recall"]) * 100,
216
+ ]
217
+ colors_d = [GREEN if d > 0 else "#ef4444" for d in deltas]
218
+ ax_delta.bar(delta_metrics, deltas, color=colors_d, alpha=0.85, width=0.5)
219
+ ax_delta.axhline(0, color="#64748b", linewidth=0.8)
220
+ ax_delta.set_title("RelGNN vs GraphSAGE\n(Δ pontos percentuais)", color="white", fontsize=10, pad=8)
221
+ ax_delta.tick_params(colors="#64748b", labelsize=8)
222
+ ax_delta.grid(axis="y", color="#1e2d4a", alpha=0.5, linewidth=0.5)
223
+ for i, (d, c) in enumerate(zip(deltas, colors_d)):
224
+ ax_delta.text(i, d + 0.1 if d >= 0 else d - 0.3,
225
+ f"+{d:.1f}%" if d >= 0 else f"{d:.1f}%",
226
+ ha="center", color=c, fontsize=9, fontweight="bold")
227
+
228
+ fig.suptitle("RelGNN — Deep Relational Learning · TPC-H Fraud Detection",
229
+ color="white", fontsize=13, fontweight="bold", y=1.01)
230
+ return fig
231
+
232
+
233
+ # ─── GRADIO UI ────────────────────────────────────────────────────────────────
234
+
235
+ CSS = """
236
+ .gradio-container { background: #0a0e1a !important; }
237
+ .gr-button-primary { background: linear-gradient(135deg, #00d4ff, #7c3aed) !important; border: none !important; }
238
+ footer { display: none !important; }
239
+ """
240
+
241
+ with gr.Blocks(css=CSS, title="RelGNN — Deep Relational Learning") as demo:
242
+
243
+ gr.Markdown("""
244
+ # ⬡ RelGNN — Deep Relational Learning
245
+ ### Do SQL ao Graph AI sem Engenharia Manual · TPC-H Fraud Detection
246
+ **Projeto 8** | Compare RelGNN (Rotas Atômicas) vs GraphSAGE vs XGBoost
247
+ """)
248
+
249
+ with gr.Row():
250
+ with gr.Column(scale=1):
251
+ gr.Markdown("### ⚙️ Configuração")
252
+ n_customers = gr.Slider(100, 2000, value=500, step=100, label="Nº de Clientes")
253
+ n_orders = gr.Slider(500, 10000, value=2000, step=500, label="Nº de Pedidos")
254
+ fraud_rate = gr.Slider(1, 20, value=5, step=1, label="Taxa de Fraude (%)")
255
+
256
+ gr.Markdown("### 🧠 Hiperparâmetros")
257
+ hidden_dim = gr.Slider(16, 128, value=64, step=16, label="Hidden Dim")
258
+ num_epochs = gr.Slider(10, 100, value=50, step=10, label="Épocas")
259
+ max_hops = gr.Slider(1, 4, value=3, step=1, label="Max Hops (Rotas Atômicas)")
260
+
261
+ btn = gr.Button("🚀 Rodar Pipeline Completo", variant="primary")
262
+
263
+ with gr.Column(scale=3):
264
+ with gr.Tabs():
265
+ with gr.Tab("📊 Visualizações"):
266
+ plot_out = gr.Plot(label="Resultados")
267
+ with gr.Tab("📋 Métricas"):
268
+ metrics_out = gr.Dataframe(label="Comparação de Modelos")
269
+ routes_out = gr.Dataframe(label="Rotas Atômicas Descobertas")
270
+ with gr.Tab("📝 Resumo"):
271
+ summary_out = gr.Markdown()
272
+ with gr.Tab("🔧 Log"):
273
+ log_out = gr.Textbox(label="Log de Execução", lines=20, max_lines=30)
274
+
275
+ btn.click(
276
+ fn=run_full_pipeline,
277
+ inputs=[n_customers, n_orders, fraud_rate, hidden_dim, num_epochs, max_hops],
278
+ outputs=[plot_out, metrics_out, routes_out, summary_out, log_out],
279
+ )
280
+
281
+ gr.Markdown("""
282
+ ---
283
+ **Referências:** [RelBench](https://relbench.stanford.edu/) · [TPC-H Benchmark](https://www.tpc.org/tpch/) · [GraphSAGE](https://arxiv.org/abs/1706.02216)
284
+ """)
285
+
286
+
287
+ if __name__ == "__main__":
288
+ demo.launch()