Spaces:
Sleeping
Sleeping
| # app.py — GraphSAGE Inductive | Elliptic Bitcoin Dataset Real | |
| import streamlit as st | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| import os | |
| from datetime import datetime | |
| st.set_page_config( | |
| page_title="GraphSAGE — Elliptic Bitcoin", | |
| page_icon="₿", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| st.markdown(""" | |
| <style> | |
| @import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;600;700;800&family=JetBrains+Mono:wght@400;600&display=swap'); | |
| html, body, [class*="css"] { | |
| font-family: 'Outfit', sans-serif; | |
| background: #030711; color: #e2e8f0; | |
| } | |
| h1,h2,h3 { font-weight: 800; } | |
| code, pre { font-family: 'JetBrains Mono', monospace !important; } | |
| .card { background:#0d1117; border:1px solid #1e2938; border-radius:12px; padding:18px; } | |
| .metric-val { font-size:2rem; font-weight:800; font-family:'JetBrains Mono'; } | |
| .metric-lbl { font-size:.68rem; color:#64748b; text-transform:uppercase; letter-spacing:2px; } | |
| .model-card { | |
| border-radius:12px; padding:16px; margin:6px 0; | |
| border:1px solid transparent; transition:border-color .2s; | |
| } | |
| .model-sage { background:#0f1f14; border-color:#22c55e; } | |
| .model-gcn { background:#0f1528; border-color:#3b82f6; } | |
| .model-mlp { background:#1a0f0f; border-color:#f59e0b; } | |
| .benchmark-row { | |
| display:flex; align-items:center; gap:8px; | |
| padding:10px 14px; border-radius:8px; margin:4px 0; | |
| font-family:'JetBrains Mono',monospace; font-size:.82rem; | |
| } | |
| .best-row { background:#0d1f10; border:1px solid #22c55e33; } | |
| .stProgress > div > div { background:linear-gradient(90deg,#22c55e,#16a34a) !important; } | |
| .bitcoin-badge { | |
| background:linear-gradient(90deg,#f7931a,#fbbf24); | |
| color:#000; padding:2px 10px; border-radius:20px; | |
| font-size:.75rem; font-weight:700; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # ── SESSION STATE ───────────────────────────────────────────── | |
| for k, v in { | |
| 'data': None, 'stats': None, 'loaded': False, | |
| 'trainers': {}, 'resultados': {}, | |
| 'treinando': None, 'neo4j': None, 'neo4j_ok': False, | |
| }.items(): | |
| if k not in st.session_state: | |
| st.session_state[k] = v | |
| # ── NEO4J ───────────────────────────────────────────────────── | |
| def get_neo4j_config(): | |
| cfg = {} | |
| try: | |
| s = st.secrets | |
| if 'NEO4J_URI' in s: | |
| cfg = {'uri': s['NEO4J_URI'], 'username': s['NEO4J_USERNAME'], | |
| 'password': s['NEO4J_PASSWORD'], | |
| 'database': s.get('NEO4J_DATABASE', 'neo4j')} | |
| elif 'neo4j' in s: | |
| n = s['neo4j'] | |
| cfg = {'uri': n.get('uri',''), 'username': n.get('username',''), | |
| 'password': n.get('password',''), 'database': n.get('database','neo4j')} | |
| except Exception: | |
| pass | |
| if not cfg.get('uri'): | |
| cfg = {'uri': os.getenv('NEO4J_URI',''), 'username': os.getenv('NEO4J_USERNAME',''), | |
| 'password': os.getenv('NEO4J_PASSWORD',''), 'database': os.getenv('NEO4J_DATABASE','neo4j')} | |
| return cfg | |
| def conectar_neo4j(): | |
| try: | |
| from neo4j import GraphDatabase | |
| cfg = get_neo4j_config() | |
| if not all([cfg['uri'], cfg['username'], cfg['password']]): | |
| return None | |
| driver = GraphDatabase.driver(cfg['uri'], auth=(cfg['username'], cfg['password'])) | |
| with driver.session(database=cfg['database']) as s: | |
| s.run('RETURN 1') | |
| return driver, cfg['database'] | |
| except Exception: | |
| return None | |
| def carregar_libs(): | |
| try: | |
| from elliptic_data import carregar_elliptic, preparar_splits, criar_mini_batches | |
| from elliptic_model import GraphSAGE, GCNBaseline, MLPBaseline, TrainerElliptic | |
| return carregar_elliptic, preparar_splits, criar_mini_batches, \ | |
| GraphSAGE, GCNBaseline, MLPBaseline, TrainerElliptic | |
| except Exception as e: | |
| return str(e), None, None, None, None, None, None | |
| # ── CHARTS ──────────────────────────────────────────────────── | |
| def curves_svg(hist, title='', color='#22c55e'): | |
| loss = hist.get('loss_train', []) | |
| auc = hist.get('auc_val', []) | |
| ep = len(loss) | |
| if ep == 0: return '' | |
| def pts(vals, H=100): | |
| mn,mx = min(vals),max(vals); r=mx-mn or 1 | |
| return ' '.join(f'{i*420/max(ep-1,1):.1f},{H-(v-mn)/r*H:.1f}' | |
| for i,v in enumerate(vals)) | |
| return f"""<div class="card" style="margin-top:8px"> | |
| <div style="font-size:11px;color:#64748b;margin-bottom:4px">{title} | |
| <span style="color:#ef4444;margin-left:8px">— Loss</span> | |
| <span style="color:{color};margin-left:8px">— AUC Test</span> | |
| </div> | |
| <svg viewBox="0 0 435 110" style="width:100%"> | |
| <polyline points="{pts(loss)}" fill="none" stroke="#ef4444" stroke-width="1.8"/> | |
| <polyline points="{pts(auc)}" fill="none" stroke="{color}" stroke-width="2"/> | |
| <line x1="0" y1="100" x2="420" y2="100" stroke="#1e2938"/> | |
| </svg></div>""" | |
| def roc_svg(resultados): | |
| """ROC de todos os modelos juntos.""" | |
| from sklearn.metrics import roc_curve, auc as sk_auc | |
| CORES = {'GraphSAGE':'#22c55e', 'GCN':'#3b82f6', 'MLP':'#f59e0b'} | |
| curvas = '' | |
| for nome, res in resultados.items(): | |
| if 'y_true' not in res: continue | |
| fpr,tpr,_ = roc_curve(res['y_true'], res['probs']) | |
| ra = sk_auc(fpr,tpr) | |
| pts = ' '.join(f'{f*400:.1f},{180-t*180:.1f}' for f,t in zip(fpr,tpr)) | |
| cor = CORES.get(nome,'#888') | |
| curvas += (f'<polyline points="{pts}" fill="none" stroke="{cor}" stroke-width="2.5"/>' | |
| f'<text x="410" y="{list(resultados.keys()).index(nome)*16+30}" ' | |
| f'fill="{cor}" font-size="10">{nome} {ra:.3f}</text>') | |
| return f"""<div class="card"> | |
| <div style="font-size:11px;color:#64748b;margin-bottom:4px">ROC — TODOS OS MODELOS</div> | |
| <svg viewBox="0 0 520 195" style="width:100%"> | |
| <line x1="0" y1="0" x2="400" y2="180" stroke="#1e2938" stroke-dasharray="4"/> | |
| {curvas} | |
| <line x1="0" y1="180" x2="400" y2="180" stroke="#1e2938"/> | |
| <line x1="0" y1="0" x2="0" y2="180" stroke="#1e2938"/> | |
| </svg></div>""" | |
| def benchmark_html(resultados): | |
| CORES = {'GraphSAGE':'#22c55e', 'GCN':'#3b82f6', 'MLP':'#f59e0b'} | |
| melhor_auc = max((r.get('auc',0) for r in resultados.values()), default=0) | |
| html = '' | |
| # Ordena por AUC | |
| items = sorted(resultados.items(), key=lambda x: x[1].get('auc',0), reverse=True) | |
| for nome, res in items: | |
| if 'auc' not in res: continue | |
| cor = CORES.get(nome, '#888') | |
| is_best = res['auc'] == melhor_auc | |
| cls = 'best-row' if is_best else '' | |
| crown = '👑 ' if is_best else '' | |
| bar_auc = int(res['auc']*100) | |
| bar_f1 = int(res['f1']*100) | |
| html += f"""<div class="benchmark-row {cls}"> | |
| <span style="color:{cor};min-width:110px;font-weight:600">{crown}{nome}</span> | |
| <span style="min-width:80px"> | |
| <div style="background:#1e2938;border-radius:3px;height:6px;width:80px"> | |
| <div style="width:{bar_auc}%;height:6px;background:{cor};border-radius:3px"></div> | |
| </div> | |
| <span style="font-size:.75rem;color:{cor}">AUC {res['auc']:.4f}</span> | |
| </span> | |
| <span style="min-width:80px"> | |
| <div style="background:#1e2938;border-radius:3px;height:6px;width:80px"> | |
| <div style="width:{bar_f1}%;height:6px;background:{cor}88;border-radius:3px"></div> | |
| </div> | |
| <span style="font-size:.75rem;color:{cor}88">F1 {res['f1']:.4f}</span> | |
| </span> | |
| <span style="color:#64748b;font-size:.75rem"> | |
| P:{res['precision']:.3f} R:{res['recall']:.3f} | |
| </span> | |
| </div>""" | |
| return html | |
| def cm_html(cm, nome): | |
| CORES = {'GraphSAGE':'#22c55e', 'GCN':'#3b82f6', 'MLP':'#f59e0b'} | |
| cor = CORES.get(nome,'#888') | |
| tn,fp,fn,tp = cm.ravel() | |
| items = [(cor,tn,'TN','Lícitas\ncorretas'),('#ef4444',fp,'FP','Falsos\nalarmes'), | |
| ('#f59e0b',fn,'FN','Ilícitas\nperdidas'),(cor,tp,'TP','Ilícitas\ncapturadas')] | |
| html = f'<div style="font-size:.8rem;font-weight:700;color:{cor};margin-bottom:8px">{nome}</div>' | |
| html += '<div style="display:grid;grid-template-columns:1fr 1fr;gap:6px">' | |
| for c,v,a,d in items: | |
| html += (f'<div style="background:{c}15;border:1px solid {c}40;border-radius:8px;' | |
| f'padding:10px;text-align:center">' | |
| f'<div style="font-size:1.4rem;font-weight:800;color:{c}">{v}</div>' | |
| f'<div style="color:{c};font-size:.8rem;font-weight:600">{a}</div>' | |
| f'<div style="color:#64748b;font-size:.65rem;white-space:pre-line">{d}</div></div>') | |
| return html + '</div>' | |
| def tsne_svg(embeddings, y_true, titulo=''): | |
| try: | |
| from sklearn.manifold import TSNE | |
| n = min(2000, len(embeddings)) | |
| idx = np.random.choice(len(embeddings), n, replace=False) | |
| emb = embeddings[idx]; yt = y_true[idx] | |
| tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, n//3)) | |
| coords = tsne.fit_transform(emb) | |
| cx,cy = coords[:,0], coords[:,1] | |
| mn_x,mx_x = cx.min(),cx.max(); mn_y,mx_y = cy.min(),cy.max() | |
| def sc(v,mn,mx,W): return (v-mn)/(mx-mn+1e-8)*W | |
| circles = ''.join( | |
| f'<circle cx="{sc(x,mn_x,mx_x,440):.1f}" cy="{sc(y,mn_y,mx_y,260):.1f}" ' | |
| f'r="{5 if yt[i]==1 else 3}" ' | |
| f'fill="{"#ef4444" if yt[i]==1 else "#22c55e44"}" opacity=".85"/>' | |
| for i,(x,y) in enumerate(zip(cx,cy)) | |
| ) | |
| return (f'<div class="card"><div style="font-size:11px;color:#64748b;margin-bottom:4px">' | |
| f't-SNE {titulo} <span style="color:#22c55e">● Lícita</span>' | |
| f'<span style="color:#ef4444;margin-left:8px">● Ilícita</span></div>' | |
| f'<svg viewBox="0 0 460 270" style="width:100%;background:#070d14;border-radius:6px">' | |
| f'{circles}</svg></div>') | |
| except Exception as e: | |
| return f'<p style="color:#64748b">t-SNE: {e}</p>' | |
| # ── SIDEBAR ─────────────────────────────────────────────────── | |
| def sidebar(): | |
| st.sidebar.markdown('## ₿ GraphSAGE Config') | |
| st.sidebar.markdown('### Dataset') | |
| norm = st.sidebar.toggle('Normalizar features', True) | |
| st.sidebar.markdown('---') | |
| st.sidebar.markdown('### Modelos a treinar') | |
| treinar_sage = st.sidebar.checkbox('GraphSAGE', True) | |
| treinar_gcn = st.sidebar.checkbox('GCN Baseline', True) | |
| treinar_mlp = st.sidebar.checkbox('MLP Baseline', True) | |
| st.sidebar.markdown('---') | |
| st.sidebar.markdown('### Hiperparâmetros') | |
| hidden = st.sidebar.select_slider('Hidden dim', [64,128,256], 128) | |
| layers = st.sidebar.select_slider('Camadas GNN', [1,2,3], 2) | |
| lr = st.sidebar.select_slider('LR', [0.0005,0.001,0.003], 0.001) | |
| epocas = st.sidebar.slider('Épocas', 5, 50, 20, 5) | |
| dropout = st.sidebar.slider('Dropout', 0.1, 0.5, 0.3, 0.05) | |
| batch = st.sidebar.select_slider('Batch size', [256,512,1024], 512) | |
| if st.session_state.neo4j_ok: | |
| st.sidebar.success('🗄️ Neo4j Conectado') | |
| else: | |
| st.sidebar.warning('⚠️ Neo4j Offline') | |
| return dict(norm=norm, treinar_sage=treinar_sage, treinar_gcn=treinar_gcn, | |
| treinar_mlp=treinar_mlp, hidden=hidden, layers=layers, | |
| lr=lr, epocas=epocas, dropout=dropout, batch=batch) | |
| # ── MAIN ────────────────────────────────────────────────────── | |
| def main(): | |
| if st.session_state.neo4j is None: | |
| conn = conectar_neo4j() | |
| st.session_state.neo4j = conn | |
| st.session_state.neo4j_ok = conn is not None | |
| cfg = sidebar() | |
| res_libs = carregar_libs() | |
| if isinstance(res_libs[0], str): | |
| st.error(f'Erro de importação: {res_libs[0]}') | |
| st.stop() | |
| carregar_elliptic, preparar_splits, criar_mini_batches, \ | |
| GraphSAGE, GCNBaseline, MLPBaseline, TrainerElliptic = res_libs | |
| st.markdown(""" | |
| <div style="margin-bottom:28px"> | |
| <h1 style="font-size:2.2rem;margin:0; | |
| background:linear-gradient(90deg,#f7931a,#22c55e,#3b82f6); | |
| -webkit-background-clip:text;-webkit-text-fill-color:transparent"> | |
| GraphSAGE — Elliptic Bitcoin | |
| </h1> | |
| <p style="color:#64748b;margin:3px 0 0 2px;font-size:.9rem"> | |
| Dataset real · 203k transações Bitcoin · GraphSAGE Inductive vs GCN vs MLP | |
| </p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| tabs = st.tabs(['₿ Dataset', '🧠 Treinar', '📊 Benchmark', '🔍 Inductive Demo', '🗄️ Neo4j']) | |
| # ── TAB 0: DATASET ──────────────────────────────────────── | |
| with tabs[0]: | |
| c1, c2 = st.columns([1, 2]) | |
| with c1: | |
| st.markdown('### Elliptic Bitcoin Dataset') | |
| st.markdown(""" | |
| **Dataset real** coletado pelo MIT Media Lab: | |
| - **203,769** transações Bitcoin reais | |
| - **234,355** arestas de fluxo de Bitcoin | |
| - **166 features** por transação | |
| - **49 timesteps** (jan 2017 — set 2018) | |
| - Labels: `ilícito` (lavagem) / `lícito` / `desconhecido` | |
| **Split temporal** — como no paper: | |
| - Treino: timesteps **1 – 34** | |
| - Teste: timesteps **35 – 49** | |
| *(evita data leakage temporal)* | |
| **Por que GraphSAGE?** | |
| > Transações novas chegam todo segundo. | |
| > GCN precisaria retreinar o grafo inteiro. | |
| > GraphSAGE aprende agregadores — prediz | |
| > para nós novos sem retreinar. | |
| """) | |
| if st.button('📥 Carregar Dataset', type='primary', use_container_width=True): | |
| with st.spinner('Baixando Elliptic Bitcoin Dataset via PyG...'): | |
| data, ok = carregar_elliptic(normalize=cfg['norm']) | |
| if ok is True: | |
| data, stats = preparar_splits(data) | |
| st.session_state.data = data | |
| st.session_state.stats = stats | |
| st.session_state.loaded = True | |
| st.session_state.trainers = {} | |
| st.session_state.resultados = {} | |
| st.success('✅ Dataset carregado!') | |
| else: | |
| st.error(f'Erro ao carregar: {ok}') | |
| with c2: | |
| if st.session_state.loaded and st.session_state.stats: | |
| s = st.session_state.stats | |
| m1,m2,m3,m4 = st.columns(4) | |
| for col, v, l in [ | |
| (m1, f"{s['n_nos']:,}", 'Nós'), | |
| (m2, f"{s['n_arestas']:,}", 'Arestas'), | |
| (m3, f"{s['n_features']}", 'Features'), | |
| (m4, f"{s['taxa_fraude_train']:.1%}", 'Taxa ilícito'), | |
| ]: | |
| col.markdown( | |
| f'<div class="card" style="text-align:center">' | |
| f'<div class="metric-val" style="color:#f7931a">{v}</div>' | |
| f'<div class="metric-lbl">{l}</div></div>', | |
| unsafe_allow_html=True) | |
| st.markdown('<br>', unsafe_allow_html=True) | |
| c_a, c_b = st.columns(2) | |
| with c_a: | |
| st.markdown('#### Split Treino') | |
| st.markdown(f""" | |
| | | Count | | |
| |---|---| | |
| | Lícito | {s['n_licito_train']:,} | | |
| | Ilícito (fraude) | {s['n_ilicito_train']:,} | | |
| | **Total** | **{s['n_train']:,}** | | |
| """) | |
| with c_b: | |
| st.markdown('#### Split Teste') | |
| st.markdown(f""" | |
| | | Count | | |
| |---|---| | |
| | Lícito | {s['n_licito_test']:,} | | |
| | Ilícito (fraude) | {s['n_ilicito_test']:,} | | |
| | **Total** | **{s['n_test']:,}** | | |
| """) | |
| # Bar chart splits | |
| n_tr = s['n_train']; n_te = s['n_test'] | |
| total = n_tr + n_te | |
| w_tr = int(n_tr/total*400); w_te = 400-w_tr | |
| st.markdown(f""" | |
| <svg viewBox="0 0 420 40" style="width:100%;margin-top:8px"> | |
| <rect x="0" y="8" width="{w_tr}" height="20" fill="#22c55e" rx="4"/> | |
| <rect x="{w_tr+2}" y="8" width="{w_te}" height="20" fill="#3b82f6" rx="4"/> | |
| <text x="{w_tr//2}" y="22" text-anchor="middle" fill="#000" font-size="11" font-weight="bold">Treino t1-34</text> | |
| <text x="{w_tr+2+w_te//2}" y="22" text-anchor="middle" fill="#fff" font-size="11" font-weight="bold">Teste t35-49</text> | |
| </svg> | |
| """, unsafe_allow_html=True) | |
| else: | |
| st.info('Clique em **Carregar Dataset** para começar.') | |
| st.markdown(""" | |
| **O dataset será baixado automaticamente** via PyTorch Geometric | |
| (~50MB, hospedado pela comunidade PyG). | |
| """) | |
| # ── TAB 1: TREINAR ──────────────────────────────────────── | |
| with tabs[1]: | |
| if not st.session_state.loaded: | |
| st.warning('⬅️ Carregue o dataset primeiro.') | |
| else: | |
| data = st.session_state.data | |
| in_dim = data.x.shape[1] | |
| modelos_cfg = [] | |
| if cfg['treinar_sage']: | |
| modelos_cfg.append(('GraphSAGE', '#22c55e', 'model-sage')) | |
| if cfg['treinar_gcn']: | |
| modelos_cfg.append(('GCN', '#3b82f6', 'model-gcn')) | |
| if cfg['treinar_mlp']: | |
| modelos_cfg.append(('MLP', '#f59e0b', 'model-mlp')) | |
| # Cards dos modelos | |
| cols = st.columns(len(modelos_cfg)) | |
| for col, (nome, cor, cls) in zip(cols, modelos_cfg): | |
| col.markdown( | |
| f'<div class="model-card {cls}">' | |
| f'<div style="color:{cor};font-weight:700;font-size:1rem">{nome}</div>' | |
| f'<div style="color:#64748b;font-size:.78rem;margin-top:6px">' | |
| f'{"Inductive · SAGEConv · mini-batch" if nome=="GraphSAGE" else ("Transductive · GCNConv · full graph" if nome=="GCN" else "Sem grafo · MLP puro · baseline")}' | |
| f'</div></div>', | |
| unsafe_allow_html=True) | |
| if st.button('🚀 Treinar Todos', type='primary', use_container_width=True): | |
| batches_train = criar_mini_batches(data, batch_size=cfg['batch'], split='train') | |
| for nome, cor, _ in modelos_cfg: | |
| st.markdown(f'#### Treinando {nome}...') | |
| prog = st.progress(0) | |
| status = st.empty() | |
| if nome == 'GraphSAGE': | |
| model = GraphSAGE(in_dim, cfg['hidden'], 2, | |
| cfg['layers'], cfg['dropout']) | |
| elif nome == 'GCN': | |
| model = GCNBaseline(in_dim, cfg['hidden'], 2, | |
| cfg['layers'], cfg['dropout']) | |
| else: | |
| model = MLPBaseline(in_dim, cfg['hidden'], 2, cfg['dropout']) | |
| trainer = TrainerElliptic(model, data, lr=cfg['lr']) | |
| def make_cb(n, c, p, s): | |
| def cb(ep, total, loss, auc, f1): | |
| p.progress(ep/total) | |
| s.markdown( | |
| f'**{n}** · Época {ep}/{total} · ' | |
| f'Loss `{loss:.4f}` · AUC `{auc:.3f}` · F1 `{f1:.3f}`') | |
| return cb | |
| lotes = batches_train if nome == 'GraphSAGE' else None | |
| trainer.treinar(cfg['epocas'], batches=lotes, | |
| callback=make_cb(nome, cor, prog, status)) | |
| m = trainer.metricas_completas() | |
| st.session_state.trainers[nome] = trainer | |
| st.session_state.resultados[nome] = m | |
| st.success(f'✅ {nome} — AUC: {m["auc"]:.4f} · F1: {m["f1"]:.4f}') | |
| # ── TAB 2: BENCHMARK ────────────────────────────────────── | |
| with tabs[2]: | |
| res = st.session_state.resultados | |
| if not res: | |
| st.warning('⬅️ Treine os modelos primeiro.') | |
| else: | |
| st.markdown('### Comparação de Modelos — Dataset Real Elliptic') | |
| st.components.v1.html( | |
| f'<div style="background:#030711;padding:8px;border-radius:10px">' | |
| f'{benchmark_html(res)}</div>', height=160) | |
| st.markdown('<br>', unsafe_allow_html=True) | |
| st.components.v1.html(roc_svg(res), height=220) | |
| st.markdown('<br>', unsafe_allow_html=True) | |
| cols = st.columns(len(res)) | |
| for col, (nome, m) in zip(cols, res.items()): | |
| col.markdown( | |
| f'<div class="card">{cm_html(m["cm"], nome)}</div>', | |
| unsafe_allow_html=True) | |
| # t-SNE por modelo | |
| st.markdown('### Embeddings t-SNE') | |
| cols2 = st.columns(len(res)) | |
| for col, (nome, m) in zip(cols2, res.items()): | |
| with col: | |
| st.components.v1.html( | |
| tsne_svg(m['embeddings'], m['y_true'], nome), | |
| height=320) | |
| # Insight automático | |
| if len(res) >= 2: | |
| melhor = max(res.items(), key=lambda x: x[1]['auc']) | |
| pior = min(res.items(), key=lambda x: x[1]['auc']) | |
| ganho = melhor[1]['auc'] - pior[1]['auc'] | |
| st.info( | |
| f'**{melhor[0]}** superou **{pior[0]}** em ' | |
| f'`{ganho:.4f}` AUC — ' | |
| f'demonstrando o ganho da informação estrutural do grafo.') | |
| # ── TAB 3: INDUCTIVE DEMO ───────────────────────────────── | |
| with tabs[3]: | |
| st.markdown('### 🔍 GraphSAGE Inductive — Nó Novo') | |
| st.markdown(""" | |
| **O diferencial do GraphSAGE:** classifica transações novas | |
| que não existiam no treino, usando apenas seus vizinhos amostrados. | |
| GCN não consegue isso sem retreinar. | |
| """) | |
| if 'GraphSAGE' not in st.session_state.trainers: | |
| st.warning('⬅️ Treine o GraphSAGE primeiro.') | |
| else: | |
| trainer = st.session_state.trainers['GraphSAGE'] | |
| data = st.session_state.data | |
| st.markdown('#### Simule uma transação nova') | |
| c1, c2 = st.columns(2) | |
| with c1: | |
| # Seleciona nó do teste como "novo" | |
| idx_no = st.number_input( | |
| 'Índice do nó (do split teste — "novo" para o modelo)', | |
| 0, int(data.test_mask_labeled.sum())-1, 0) | |
| test_indices = data.test_mask_labeled.nonzero(as_tuple=True)[0] | |
| no_real = int(test_indices[idx_no]) | |
| label_real = int(data.y[no_real]) | |
| label_str = '🚨 Ilícita (lavagem)' if label_real == 0 else '✅ Lícita' | |
| if st.button('🔮 Predizer', type='primary', use_container_width=True): | |
| # Forward inductive: apenas features + vizinhos | |
| trainer.model.eval() | |
| with torch.no_grad(): | |
| logits, embeds = trainer.model( | |
| data.x, data.edge_index, return_embed=True) | |
| prob_ilicito = float( | |
| F.softmax(logits[no_real], dim=0)[0]) | |
| pred_str = '🚨 Ilícita' if prob_ilicito > 0.5 else '✅ Lícita' | |
| correto = (prob_ilicito > 0.5) == (label_real == 0) | |
| st.markdown(f'**Nó:** `{no_real}`') | |
| st.markdown(f'**Label real:** {label_str}') | |
| st.markdown(f'**Predição:** {pred_str}') | |
| st.markdown(f'**P(ilícita):** `{prob_ilicito:.4f}`') | |
| st.progress(float(prob_ilicito)) | |
| if correto: | |
| st.success('✅ Predição correta!') | |
| else: | |
| st.error('❌ Predição incorreta') | |
| with c2: | |
| st.markdown('#### Por que inductive importa em produção') | |
| st.markdown(""" | |
| ``` | |
| Modelo transductive (GCN): | |
| Nova tx chega | |
| → Adiciona ao grafo | |
| → Retreina tudo | |
| → Tempo: horas | |
| ❌ Inviável em produção | |
| Modelo inductive (GraphSAGE): | |
| Nova tx chega | |
| → Amostra vizinhos | |
| → Aplica agregadores | |
| → Predição: milissegundos | |
| ✅ Pronto para produção | |
| ``` | |
| """) | |
| # ── TAB 4: NEO4J ───────────────────────────────────────── | |
| with tabs[4]: | |
| st.header('🗄️ Neo4j') | |
| if not st.session_state.neo4j_ok: | |
| st.warning('Neo4j offline.') | |
| with st.expander('Como configurar'): | |
| st.markdown(""" | |
| **HF Spaces → Settings → Variables and secrets:** | |
| | Chave | Valor | | |
| |---|---| | |
| | `NEO4J_URI` | `neo4j+s://XXXXXXXX.databases.neo4j.io` | | |
| | `NEO4J_USERNAME` | `neo4j` | | |
| | `NEO4J_PASSWORD` | `sua_senha` | | |
| | `NEO4J_DATABASE` | `neo4j` | | |
| """) | |
| else: | |
| st.success('Conectado!') | |
| res = st.session_state.resultados | |
| if res and st.button('💾 Salvar benchmark no Neo4j'): | |
| driver, db = st.session_state.neo4j | |
| try: | |
| with driver.session(database=db) as s: | |
| ts = datetime.now().isoformat() | |
| for nome, m in res.items(): | |
| s.run(""" | |
| MERGE (r:EllipticRun {nome:$nome, ts:$ts}) | |
| SET r.auc=$auc, r.f1=$f1, r.ap=$ap, | |
| r.precision=$pr, r.recall=$rc | |
| """, nome=nome, ts=ts, | |
| auc=float(m['auc']), f1=float(m['f1']), | |
| ap=float(m['ap']), | |
| pr=float(m['precision']), | |
| rc=float(m['recall'])) | |
| st.success(f'✅ {len(res)} modelos salvos!') | |
| except Exception as e: | |
| st.error(str(e)) | |
| if __name__ == '__main__': | |
| main() |