# app.py — GraphSAGE Inductive | Elliptic Bitcoin Dataset Real import streamlit as st import numpy as np import torch import torch.nn.functional as F import os from datetime import datetime st.set_page_config( page_title="GraphSAGE — Elliptic Bitcoin", page_icon="₿", layout="wide", initial_sidebar_state="expanded" ) st.markdown(""" """, unsafe_allow_html=True) # ── SESSION STATE ───────────────────────────────────────────── for k, v in { 'data': None, 'stats': None, 'loaded': False, 'trainers': {}, 'resultados': {}, 'treinando': None, 'neo4j': None, 'neo4j_ok': False, }.items(): if k not in st.session_state: st.session_state[k] = v # ── NEO4J ───────────────────────────────────────────────────── def get_neo4j_config(): cfg = {} try: s = st.secrets if 'NEO4J_URI' in s: cfg = {'uri': s['NEO4J_URI'], 'username': s['NEO4J_USERNAME'], 'password': s['NEO4J_PASSWORD'], 'database': s.get('NEO4J_DATABASE', 'neo4j')} elif 'neo4j' in s: n = s['neo4j'] cfg = {'uri': n.get('uri',''), 'username': n.get('username',''), 'password': n.get('password',''), 'database': n.get('database','neo4j')} except Exception: pass if not cfg.get('uri'): cfg = {'uri': os.getenv('NEO4J_URI',''), 'username': os.getenv('NEO4J_USERNAME',''), 'password': os.getenv('NEO4J_PASSWORD',''), 'database': os.getenv('NEO4J_DATABASE','neo4j')} return cfg @st.cache_resource def conectar_neo4j(): try: from neo4j import GraphDatabase cfg = get_neo4j_config() if not all([cfg['uri'], cfg['username'], cfg['password']]): return None driver = GraphDatabase.driver(cfg['uri'], auth=(cfg['username'], cfg['password'])) with driver.session(database=cfg['database']) as s: s.run('RETURN 1') return driver, cfg['database'] except Exception: return None @st.cache_resource def carregar_libs(): try: from elliptic_data import carregar_elliptic, preparar_splits, criar_mini_batches from elliptic_model import GraphSAGE, GCNBaseline, MLPBaseline, TrainerElliptic return carregar_elliptic, preparar_splits, criar_mini_batches, \ GraphSAGE, GCNBaseline, MLPBaseline, TrainerElliptic except Exception as e: return str(e), None, None, None, None, None, None # ── CHARTS ──────────────────────────────────────────────────── def curves_svg(hist, title='', color='#22c55e'): loss = hist.get('loss_train', []) auc = hist.get('auc_val', []) ep = len(loss) if ep == 0: return '' def pts(vals, H=100): mn,mx = min(vals),max(vals); r=mx-mn or 1 return ' '.join(f'{i*420/max(ep-1,1):.1f},{H-(v-mn)/r*H:.1f}' for i,v in enumerate(vals)) return f"""
{title} — Loss — AUC Test
""" def roc_svg(resultados): """ROC de todos os modelos juntos.""" from sklearn.metrics import roc_curve, auc as sk_auc CORES = {'GraphSAGE':'#22c55e', 'GCN':'#3b82f6', 'MLP':'#f59e0b'} curvas = '' for nome, res in resultados.items(): if 'y_true' not in res: continue fpr,tpr,_ = roc_curve(res['y_true'], res['probs']) ra = sk_auc(fpr,tpr) pts = ' '.join(f'{f*400:.1f},{180-t*180:.1f}' for f,t in zip(fpr,tpr)) cor = CORES.get(nome,'#888') curvas += (f'' f'{nome} {ra:.3f}') return f"""
ROC — TODOS OS MODELOS
{curvas}
""" def benchmark_html(resultados): CORES = {'GraphSAGE':'#22c55e', 'GCN':'#3b82f6', 'MLP':'#f59e0b'} melhor_auc = max((r.get('auc',0) for r in resultados.values()), default=0) html = '' # Ordena por AUC items = sorted(resultados.items(), key=lambda x: x[1].get('auc',0), reverse=True) for nome, res in items: if 'auc' not in res: continue cor = CORES.get(nome, '#888') is_best = res['auc'] == melhor_auc cls = 'best-row' if is_best else '' crown = '👑 ' if is_best else '' bar_auc = int(res['auc']*100) bar_f1 = int(res['f1']*100) html += f"""
{crown}{nome}
AUC {res['auc']:.4f}
F1 {res['f1']:.4f}
P:{res['precision']:.3f} R:{res['recall']:.3f}
""" return html def cm_html(cm, nome): CORES = {'GraphSAGE':'#22c55e', 'GCN':'#3b82f6', 'MLP':'#f59e0b'} cor = CORES.get(nome,'#888') tn,fp,fn,tp = cm.ravel() items = [(cor,tn,'TN','Lícitas\ncorretas'),('#ef4444',fp,'FP','Falsos\nalarmes'), ('#f59e0b',fn,'FN','Ilícitas\nperdidas'),(cor,tp,'TP','Ilícitas\ncapturadas')] html = f'
{nome}
' html += '
' for c,v,a,d in items: html += (f'
' f'
{v}
' f'
{a}
' f'
{d}
') return html + '
' def tsne_svg(embeddings, y_true, titulo=''): try: from sklearn.manifold import TSNE n = min(2000, len(embeddings)) idx = np.random.choice(len(embeddings), n, replace=False) emb = embeddings[idx]; yt = y_true[idx] tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, n//3)) coords = tsne.fit_transform(emb) cx,cy = coords[:,0], coords[:,1] mn_x,mx_x = cx.min(),cx.max(); mn_y,mx_y = cy.min(),cy.max() def sc(v,mn,mx,W): return (v-mn)/(mx-mn+1e-8)*W circles = ''.join( f'' for i,(x,y) in enumerate(zip(cx,cy)) ) return (f'
' f't-SNE {titulo} ● Lícita' f'● Ilícita
' f'' f'{circles}
') except Exception as e: return f'

t-SNE: {e}

' # ── SIDEBAR ─────────────────────────────────────────────────── def sidebar(): st.sidebar.markdown('## ₿ GraphSAGE Config') st.sidebar.markdown('### Dataset') norm = st.sidebar.toggle('Normalizar features', True) st.sidebar.markdown('---') st.sidebar.markdown('### Modelos a treinar') treinar_sage = st.sidebar.checkbox('GraphSAGE', True) treinar_gcn = st.sidebar.checkbox('GCN Baseline', True) treinar_mlp = st.sidebar.checkbox('MLP Baseline', True) st.sidebar.markdown('---') st.sidebar.markdown('### Hiperparâmetros') hidden = st.sidebar.select_slider('Hidden dim', [64,128,256], 128) layers = st.sidebar.select_slider('Camadas GNN', [1,2,3], 2) lr = st.sidebar.select_slider('LR', [0.0005,0.001,0.003], 0.001) epocas = st.sidebar.slider('Épocas', 5, 50, 20, 5) dropout = st.sidebar.slider('Dropout', 0.1, 0.5, 0.3, 0.05) batch = st.sidebar.select_slider('Batch size', [256,512,1024], 512) if st.session_state.neo4j_ok: st.sidebar.success('🗄️ Neo4j Conectado') else: st.sidebar.warning('⚠️ Neo4j Offline') return dict(norm=norm, treinar_sage=treinar_sage, treinar_gcn=treinar_gcn, treinar_mlp=treinar_mlp, hidden=hidden, layers=layers, lr=lr, epocas=epocas, dropout=dropout, batch=batch) # ── MAIN ────────────────────────────────────────────────────── def main(): if st.session_state.neo4j is None: conn = conectar_neo4j() st.session_state.neo4j = conn st.session_state.neo4j_ok = conn is not None cfg = sidebar() res_libs = carregar_libs() if isinstance(res_libs[0], str): st.error(f'Erro de importação: {res_libs[0]}') st.stop() carregar_elliptic, preparar_splits, criar_mini_batches, \ GraphSAGE, GCNBaseline, MLPBaseline, TrainerElliptic = res_libs st.markdown("""

GraphSAGE — Elliptic Bitcoin

Dataset real · 203k transações Bitcoin · GraphSAGE Inductive vs GCN vs MLP

""", unsafe_allow_html=True) tabs = st.tabs(['₿ Dataset', '🧠 Treinar', '📊 Benchmark', '🔍 Inductive Demo', '🗄️ Neo4j']) # ── TAB 0: DATASET ──────────────────────────────────────── with tabs[0]: c1, c2 = st.columns([1, 2]) with c1: st.markdown('### Elliptic Bitcoin Dataset') st.markdown(""" **Dataset real** coletado pelo MIT Media Lab: - **203,769** transações Bitcoin reais - **234,355** arestas de fluxo de Bitcoin - **166 features** por transação - **49 timesteps** (jan 2017 — set 2018) - Labels: `ilícito` (lavagem) / `lícito` / `desconhecido` **Split temporal** — como no paper: - Treino: timesteps **1 – 34** - Teste: timesteps **35 – 49** *(evita data leakage temporal)* **Por que GraphSAGE?** > Transações novas chegam todo segundo. > GCN precisaria retreinar o grafo inteiro. > GraphSAGE aprende agregadores — prediz > para nós novos sem retreinar. """) if st.button('📥 Carregar Dataset', type='primary', use_container_width=True): with st.spinner('Baixando Elliptic Bitcoin Dataset via PyG...'): data, ok = carregar_elliptic(normalize=cfg['norm']) if ok is True: data, stats = preparar_splits(data) st.session_state.data = data st.session_state.stats = stats st.session_state.loaded = True st.session_state.trainers = {} st.session_state.resultados = {} st.success('✅ Dataset carregado!') else: st.error(f'Erro ao carregar: {ok}') with c2: if st.session_state.loaded and st.session_state.stats: s = st.session_state.stats m1,m2,m3,m4 = st.columns(4) for col, v, l in [ (m1, f"{s['n_nos']:,}", 'Nós'), (m2, f"{s['n_arestas']:,}", 'Arestas'), (m3, f"{s['n_features']}", 'Features'), (m4, f"{s['taxa_fraude_train']:.1%}", 'Taxa ilícito'), ]: col.markdown( f'
' f'
{v}
' f'
{l}
', unsafe_allow_html=True) st.markdown('
', unsafe_allow_html=True) c_a, c_b = st.columns(2) with c_a: st.markdown('#### Split Treino') st.markdown(f""" | | Count | |---|---| | Lícito | {s['n_licito_train']:,} | | Ilícito (fraude) | {s['n_ilicito_train']:,} | | **Total** | **{s['n_train']:,}** | """) with c_b: st.markdown('#### Split Teste') st.markdown(f""" | | Count | |---|---| | Lícito | {s['n_licito_test']:,} | | Ilícito (fraude) | {s['n_ilicito_test']:,} | | **Total** | **{s['n_test']:,}** | """) # Bar chart splits n_tr = s['n_train']; n_te = s['n_test'] total = n_tr + n_te w_tr = int(n_tr/total*400); w_te = 400-w_tr st.markdown(f""" Treino t1-34 Teste t35-49 """, unsafe_allow_html=True) else: st.info('Clique em **Carregar Dataset** para começar.') st.markdown(""" **O dataset será baixado automaticamente** via PyTorch Geometric (~50MB, hospedado pela comunidade PyG). """) # ── TAB 1: TREINAR ──────────────────────────────────────── with tabs[1]: if not st.session_state.loaded: st.warning('⬅️ Carregue o dataset primeiro.') else: data = st.session_state.data in_dim = data.x.shape[1] modelos_cfg = [] if cfg['treinar_sage']: modelos_cfg.append(('GraphSAGE', '#22c55e', 'model-sage')) if cfg['treinar_gcn']: modelos_cfg.append(('GCN', '#3b82f6', 'model-gcn')) if cfg['treinar_mlp']: modelos_cfg.append(('MLP', '#f59e0b', 'model-mlp')) # Cards dos modelos cols = st.columns(len(modelos_cfg)) for col, (nome, cor, cls) in zip(cols, modelos_cfg): col.markdown( f'
' f'
{nome}
' f'
' f'{"Inductive · SAGEConv · mini-batch" if nome=="GraphSAGE" else ("Transductive · GCNConv · full graph" if nome=="GCN" else "Sem grafo · MLP puro · baseline")}' f'
', unsafe_allow_html=True) if st.button('🚀 Treinar Todos', type='primary', use_container_width=True): batches_train = criar_mini_batches(data, batch_size=cfg['batch'], split='train') for nome, cor, _ in modelos_cfg: st.markdown(f'#### Treinando {nome}...') prog = st.progress(0) status = st.empty() if nome == 'GraphSAGE': model = GraphSAGE(in_dim, cfg['hidden'], 2, cfg['layers'], cfg['dropout']) elif nome == 'GCN': model = GCNBaseline(in_dim, cfg['hidden'], 2, cfg['layers'], cfg['dropout']) else: model = MLPBaseline(in_dim, cfg['hidden'], 2, cfg['dropout']) trainer = TrainerElliptic(model, data, lr=cfg['lr']) def make_cb(n, c, p, s): def cb(ep, total, loss, auc, f1): p.progress(ep/total) s.markdown( f'**{n}** · Época {ep}/{total} · ' f'Loss `{loss:.4f}` · AUC `{auc:.3f}` · F1 `{f1:.3f}`') return cb lotes = batches_train if nome == 'GraphSAGE' else None trainer.treinar(cfg['epocas'], batches=lotes, callback=make_cb(nome, cor, prog, status)) m = trainer.metricas_completas() st.session_state.trainers[nome] = trainer st.session_state.resultados[nome] = m st.success(f'✅ {nome} — AUC: {m["auc"]:.4f} · F1: {m["f1"]:.4f}') # ── TAB 2: BENCHMARK ────────────────────────────────────── with tabs[2]: res = st.session_state.resultados if not res: st.warning('⬅️ Treine os modelos primeiro.') else: st.markdown('### Comparação de Modelos — Dataset Real Elliptic') st.components.v1.html( f'
' f'{benchmark_html(res)}
', height=160) st.markdown('
', unsafe_allow_html=True) st.components.v1.html(roc_svg(res), height=220) st.markdown('
', unsafe_allow_html=True) cols = st.columns(len(res)) for col, (nome, m) in zip(cols, res.items()): col.markdown( f'
{cm_html(m["cm"], nome)}
', unsafe_allow_html=True) # t-SNE por modelo st.markdown('### Embeddings t-SNE') cols2 = st.columns(len(res)) for col, (nome, m) in zip(cols2, res.items()): with col: st.components.v1.html( tsne_svg(m['embeddings'], m['y_true'], nome), height=320) # Insight automático if len(res) >= 2: melhor = max(res.items(), key=lambda x: x[1]['auc']) pior = min(res.items(), key=lambda x: x[1]['auc']) ganho = melhor[1]['auc'] - pior[1]['auc'] st.info( f'**{melhor[0]}** superou **{pior[0]}** em ' f'`{ganho:.4f}` AUC — ' f'demonstrando o ganho da informação estrutural do grafo.') # ── TAB 3: INDUCTIVE DEMO ───────────────────────────────── with tabs[3]: st.markdown('### 🔍 GraphSAGE Inductive — Nó Novo') st.markdown(""" **O diferencial do GraphSAGE:** classifica transações novas que não existiam no treino, usando apenas seus vizinhos amostrados. GCN não consegue isso sem retreinar. """) if 'GraphSAGE' not in st.session_state.trainers: st.warning('⬅️ Treine o GraphSAGE primeiro.') else: trainer = st.session_state.trainers['GraphSAGE'] data = st.session_state.data st.markdown('#### Simule uma transação nova') c1, c2 = st.columns(2) with c1: # Seleciona nó do teste como "novo" idx_no = st.number_input( 'Índice do nó (do split teste — "novo" para o modelo)', 0, int(data.test_mask_labeled.sum())-1, 0) test_indices = data.test_mask_labeled.nonzero(as_tuple=True)[0] no_real = int(test_indices[idx_no]) label_real = int(data.y[no_real]) label_str = '🚨 Ilícita (lavagem)' if label_real == 0 else '✅ Lícita' if st.button('🔮 Predizer', type='primary', use_container_width=True): # Forward inductive: apenas features + vizinhos trainer.model.eval() with torch.no_grad(): logits, embeds = trainer.model( data.x, data.edge_index, return_embed=True) prob_ilicito = float( F.softmax(logits[no_real], dim=0)[0]) pred_str = '🚨 Ilícita' if prob_ilicito > 0.5 else '✅ Lícita' correto = (prob_ilicito > 0.5) == (label_real == 0) st.markdown(f'**Nó:** `{no_real}`') st.markdown(f'**Label real:** {label_str}') st.markdown(f'**Predição:** {pred_str}') st.markdown(f'**P(ilícita):** `{prob_ilicito:.4f}`') st.progress(float(prob_ilicito)) if correto: st.success('✅ Predição correta!') else: st.error('❌ Predição incorreta') with c2: st.markdown('#### Por que inductive importa em produção') st.markdown(""" ``` Modelo transductive (GCN): Nova tx chega → Adiciona ao grafo → Retreina tudo → Tempo: horas ❌ Inviável em produção Modelo inductive (GraphSAGE): Nova tx chega → Amostra vizinhos → Aplica agregadores → Predição: milissegundos ✅ Pronto para produção ``` """) # ── TAB 4: NEO4J ───────────────────────────────────────── with tabs[4]: st.header('🗄️ Neo4j') if not st.session_state.neo4j_ok: st.warning('Neo4j offline.') with st.expander('Como configurar'): st.markdown(""" **HF Spaces → Settings → Variables and secrets:** | Chave | Valor | |---|---| | `NEO4J_URI` | `neo4j+s://XXXXXXXX.databases.neo4j.io` | | `NEO4J_USERNAME` | `neo4j` | | `NEO4J_PASSWORD` | `sua_senha` | | `NEO4J_DATABASE` | `neo4j` | """) else: st.success('Conectado!') res = st.session_state.resultados if res and st.button('💾 Salvar benchmark no Neo4j'): driver, db = st.session_state.neo4j try: with driver.session(database=db) as s: ts = datetime.now().isoformat() for nome, m in res.items(): s.run(""" MERGE (r:EllipticRun {nome:$nome, ts:$ts}) SET r.auc=$auc, r.f1=$f1, r.ap=$ap, r.precision=$pr, r.recall=$rc """, nome=nome, ts=ts, auc=float(m['auc']), f1=float(m['f1']), ap=float(m['ap']), pr=float(m['precision']), rc=float(m['recall'])) st.success(f'✅ {len(res)} modelos salvos!') except Exception as e: st.error(str(e)) if __name__ == '__main__': main()