EllipticBitcoin / app.py
Danielfonseca1212's picture
Update app.py
2672a04 verified
# app.py — GraphSAGE Inductive | Elliptic Bitcoin Dataset Real
import streamlit as st
import numpy as np
import torch
import torch.nn.functional as F
import os
from datetime import datetime
st.set_page_config(
page_title="GraphSAGE — Elliptic Bitcoin",
page_icon="₿",
layout="wide",
initial_sidebar_state="expanded"
)
st.markdown("""
<style>
@import url('https://fonts.googleapis.com/css2?family=Outfit:wght@300;400;600;700;800&family=JetBrains+Mono:wght@400;600&display=swap');
html, body, [class*="css"] {
font-family: 'Outfit', sans-serif;
background: #030711; color: #e2e8f0;
}
h1,h2,h3 { font-weight: 800; }
code, pre { font-family: 'JetBrains Mono', monospace !important; }
.card { background:#0d1117; border:1px solid #1e2938; border-radius:12px; padding:18px; }
.metric-val { font-size:2rem; font-weight:800; font-family:'JetBrains Mono'; }
.metric-lbl { font-size:.68rem; color:#64748b; text-transform:uppercase; letter-spacing:2px; }
.model-card {
border-radius:12px; padding:16px; margin:6px 0;
border:1px solid transparent; transition:border-color .2s;
}
.model-sage { background:#0f1f14; border-color:#22c55e; }
.model-gcn { background:#0f1528; border-color:#3b82f6; }
.model-mlp { background:#1a0f0f; border-color:#f59e0b; }
.benchmark-row {
display:flex; align-items:center; gap:8px;
padding:10px 14px; border-radius:8px; margin:4px 0;
font-family:'JetBrains Mono',monospace; font-size:.82rem;
}
.best-row { background:#0d1f10; border:1px solid #22c55e33; }
.stProgress > div > div { background:linear-gradient(90deg,#22c55e,#16a34a) !important; }
.bitcoin-badge {
background:linear-gradient(90deg,#f7931a,#fbbf24);
color:#000; padding:2px 10px; border-radius:20px;
font-size:.75rem; font-weight:700;
}
</style>
""", unsafe_allow_html=True)
# ── SESSION STATE ─────────────────────────────────────────────
for k, v in {
'data': None, 'stats': None, 'loaded': False,
'trainers': {}, 'resultados': {},
'treinando': None, 'neo4j': None, 'neo4j_ok': False,
}.items():
if k not in st.session_state:
st.session_state[k] = v
# ── NEO4J ─────────────────────────────────────────────────────
def get_neo4j_config():
cfg = {}
try:
s = st.secrets
if 'NEO4J_URI' in s:
cfg = {'uri': s['NEO4J_URI'], 'username': s['NEO4J_USERNAME'],
'password': s['NEO4J_PASSWORD'],
'database': s.get('NEO4J_DATABASE', 'neo4j')}
elif 'neo4j' in s:
n = s['neo4j']
cfg = {'uri': n.get('uri',''), 'username': n.get('username',''),
'password': n.get('password',''), 'database': n.get('database','neo4j')}
except Exception:
pass
if not cfg.get('uri'):
cfg = {'uri': os.getenv('NEO4J_URI',''), 'username': os.getenv('NEO4J_USERNAME',''),
'password': os.getenv('NEO4J_PASSWORD',''), 'database': os.getenv('NEO4J_DATABASE','neo4j')}
return cfg
@st.cache_resource
def conectar_neo4j():
try:
from neo4j import GraphDatabase
cfg = get_neo4j_config()
if not all([cfg['uri'], cfg['username'], cfg['password']]):
return None
driver = GraphDatabase.driver(cfg['uri'], auth=(cfg['username'], cfg['password']))
with driver.session(database=cfg['database']) as s:
s.run('RETURN 1')
return driver, cfg['database']
except Exception:
return None
@st.cache_resource
def carregar_libs():
try:
from elliptic_data import carregar_elliptic, preparar_splits, criar_mini_batches
from elliptic_model import GraphSAGE, GCNBaseline, MLPBaseline, TrainerElliptic
return carregar_elliptic, preparar_splits, criar_mini_batches, \
GraphSAGE, GCNBaseline, MLPBaseline, TrainerElliptic
except Exception as e:
return str(e), None, None, None, None, None, None
# ── CHARTS ────────────────────────────────────────────────────
def curves_svg(hist, title='', color='#22c55e'):
loss = hist.get('loss_train', [])
auc = hist.get('auc_val', [])
ep = len(loss)
if ep == 0: return ''
def pts(vals, H=100):
mn,mx = min(vals),max(vals); r=mx-mn or 1
return ' '.join(f'{i*420/max(ep-1,1):.1f},{H-(v-mn)/r*H:.1f}'
for i,v in enumerate(vals))
return f"""<div class="card" style="margin-top:8px">
<div style="font-size:11px;color:#64748b;margin-bottom:4px">{title}
<span style="color:#ef4444;margin-left:8px">— Loss</span>
<span style="color:{color};margin-left:8px">— AUC Test</span>
</div>
<svg viewBox="0 0 435 110" style="width:100%">
<polyline points="{pts(loss)}" fill="none" stroke="#ef4444" stroke-width="1.8"/>
<polyline points="{pts(auc)}" fill="none" stroke="{color}" stroke-width="2"/>
<line x1="0" y1="100" x2="420" y2="100" stroke="#1e2938"/>
</svg></div>"""
def roc_svg(resultados):
"""ROC de todos os modelos juntos."""
from sklearn.metrics import roc_curve, auc as sk_auc
CORES = {'GraphSAGE':'#22c55e', 'GCN':'#3b82f6', 'MLP':'#f59e0b'}
curvas = ''
for nome, res in resultados.items():
if 'y_true' not in res: continue
fpr,tpr,_ = roc_curve(res['y_true'], res['probs'])
ra = sk_auc(fpr,tpr)
pts = ' '.join(f'{f*400:.1f},{180-t*180:.1f}' for f,t in zip(fpr,tpr))
cor = CORES.get(nome,'#888')
curvas += (f'<polyline points="{pts}" fill="none" stroke="{cor}" stroke-width="2.5"/>'
f'<text x="410" y="{list(resultados.keys()).index(nome)*16+30}" '
f'fill="{cor}" font-size="10">{nome} {ra:.3f}</text>')
return f"""<div class="card">
<div style="font-size:11px;color:#64748b;margin-bottom:4px">ROC — TODOS OS MODELOS</div>
<svg viewBox="0 0 520 195" style="width:100%">
<line x1="0" y1="0" x2="400" y2="180" stroke="#1e2938" stroke-dasharray="4"/>
{curvas}
<line x1="0" y1="180" x2="400" y2="180" stroke="#1e2938"/>
<line x1="0" y1="0" x2="0" y2="180" stroke="#1e2938"/>
</svg></div>"""
def benchmark_html(resultados):
CORES = {'GraphSAGE':'#22c55e', 'GCN':'#3b82f6', 'MLP':'#f59e0b'}
melhor_auc = max((r.get('auc',0) for r in resultados.values()), default=0)
html = ''
# Ordena por AUC
items = sorted(resultados.items(), key=lambda x: x[1].get('auc',0), reverse=True)
for nome, res in items:
if 'auc' not in res: continue
cor = CORES.get(nome, '#888')
is_best = res['auc'] == melhor_auc
cls = 'best-row' if is_best else ''
crown = '👑 ' if is_best else ''
bar_auc = int(res['auc']*100)
bar_f1 = int(res['f1']*100)
html += f"""<div class="benchmark-row {cls}">
<span style="color:{cor};min-width:110px;font-weight:600">{crown}{nome}</span>
<span style="min-width:80px">
<div style="background:#1e2938;border-radius:3px;height:6px;width:80px">
<div style="width:{bar_auc}%;height:6px;background:{cor};border-radius:3px"></div>
</div>
<span style="font-size:.75rem;color:{cor}">AUC {res['auc']:.4f}</span>
</span>
<span style="min-width:80px">
<div style="background:#1e2938;border-radius:3px;height:6px;width:80px">
<div style="width:{bar_f1}%;height:6px;background:{cor}88;border-radius:3px"></div>
</div>
<span style="font-size:.75rem;color:{cor}88">F1 {res['f1']:.4f}</span>
</span>
<span style="color:#64748b;font-size:.75rem">
P:{res['precision']:.3f} R:{res['recall']:.3f}
</span>
</div>"""
return html
def cm_html(cm, nome):
CORES = {'GraphSAGE':'#22c55e', 'GCN':'#3b82f6', 'MLP':'#f59e0b'}
cor = CORES.get(nome,'#888')
tn,fp,fn,tp = cm.ravel()
items = [(cor,tn,'TN','Lícitas\ncorretas'),('#ef4444',fp,'FP','Falsos\nalarmes'),
('#f59e0b',fn,'FN','Ilícitas\nperdidas'),(cor,tp,'TP','Ilícitas\ncapturadas')]
html = f'<div style="font-size:.8rem;font-weight:700;color:{cor};margin-bottom:8px">{nome}</div>'
html += '<div style="display:grid;grid-template-columns:1fr 1fr;gap:6px">'
for c,v,a,d in items:
html += (f'<div style="background:{c}15;border:1px solid {c}40;border-radius:8px;'
f'padding:10px;text-align:center">'
f'<div style="font-size:1.4rem;font-weight:800;color:{c}">{v}</div>'
f'<div style="color:{c};font-size:.8rem;font-weight:600">{a}</div>'
f'<div style="color:#64748b;font-size:.65rem;white-space:pre-line">{d}</div></div>')
return html + '</div>'
def tsne_svg(embeddings, y_true, titulo=''):
try:
from sklearn.manifold import TSNE
n = min(2000, len(embeddings))
idx = np.random.choice(len(embeddings), n, replace=False)
emb = embeddings[idx]; yt = y_true[idx]
tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, n//3))
coords = tsne.fit_transform(emb)
cx,cy = coords[:,0], coords[:,1]
mn_x,mx_x = cx.min(),cx.max(); mn_y,mx_y = cy.min(),cy.max()
def sc(v,mn,mx,W): return (v-mn)/(mx-mn+1e-8)*W
circles = ''.join(
f'<circle cx="{sc(x,mn_x,mx_x,440):.1f}" cy="{sc(y,mn_y,mx_y,260):.1f}" '
f'r="{5 if yt[i]==1 else 3}" '
f'fill="{"#ef4444" if yt[i]==1 else "#22c55e44"}" opacity=".85"/>'
for i,(x,y) in enumerate(zip(cx,cy))
)
return (f'<div class="card"><div style="font-size:11px;color:#64748b;margin-bottom:4px">'
f't-SNE {titulo} <span style="color:#22c55e">● Lícita</span>'
f'<span style="color:#ef4444;margin-left:8px">● Ilícita</span></div>'
f'<svg viewBox="0 0 460 270" style="width:100%;background:#070d14;border-radius:6px">'
f'{circles}</svg></div>')
except Exception as e:
return f'<p style="color:#64748b">t-SNE: {e}</p>'
# ── SIDEBAR ───────────────────────────────────────────────────
def sidebar():
st.sidebar.markdown('## ₿ GraphSAGE Config')
st.sidebar.markdown('### Dataset')
norm = st.sidebar.toggle('Normalizar features', True)
st.sidebar.markdown('---')
st.sidebar.markdown('### Modelos a treinar')
treinar_sage = st.sidebar.checkbox('GraphSAGE', True)
treinar_gcn = st.sidebar.checkbox('GCN Baseline', True)
treinar_mlp = st.sidebar.checkbox('MLP Baseline', True)
st.sidebar.markdown('---')
st.sidebar.markdown('### Hiperparâmetros')
hidden = st.sidebar.select_slider('Hidden dim', [64,128,256], 128)
layers = st.sidebar.select_slider('Camadas GNN', [1,2,3], 2)
lr = st.sidebar.select_slider('LR', [0.0005,0.001,0.003], 0.001)
epocas = st.sidebar.slider('Épocas', 5, 50, 20, 5)
dropout = st.sidebar.slider('Dropout', 0.1, 0.5, 0.3, 0.05)
batch = st.sidebar.select_slider('Batch size', [256,512,1024], 512)
if st.session_state.neo4j_ok:
st.sidebar.success('🗄️ Neo4j Conectado')
else:
st.sidebar.warning('⚠️ Neo4j Offline')
return dict(norm=norm, treinar_sage=treinar_sage, treinar_gcn=treinar_gcn,
treinar_mlp=treinar_mlp, hidden=hidden, layers=layers,
lr=lr, epocas=epocas, dropout=dropout, batch=batch)
# ── MAIN ──────────────────────────────────────────────────────
def main():
if st.session_state.neo4j is None:
conn = conectar_neo4j()
st.session_state.neo4j = conn
st.session_state.neo4j_ok = conn is not None
cfg = sidebar()
res_libs = carregar_libs()
if isinstance(res_libs[0], str):
st.error(f'Erro de importação: {res_libs[0]}')
st.stop()
carregar_elliptic, preparar_splits, criar_mini_batches, \
GraphSAGE, GCNBaseline, MLPBaseline, TrainerElliptic = res_libs
st.markdown("""
<div style="margin-bottom:28px">
<h1 style="font-size:2.2rem;margin:0;
background:linear-gradient(90deg,#f7931a,#22c55e,#3b82f6);
-webkit-background-clip:text;-webkit-text-fill-color:transparent">
GraphSAGE — Elliptic Bitcoin
</h1>
<p style="color:#64748b;margin:3px 0 0 2px;font-size:.9rem">
Dataset real · 203k transações Bitcoin · GraphSAGE Inductive vs GCN vs MLP
</p>
</div>
""", unsafe_allow_html=True)
tabs = st.tabs(['₿ Dataset', '🧠 Treinar', '📊 Benchmark', '🔍 Inductive Demo', '🗄️ Neo4j'])
# ── TAB 0: DATASET ────────────────────────────────────────
with tabs[0]:
c1, c2 = st.columns([1, 2])
with c1:
st.markdown('### Elliptic Bitcoin Dataset')
st.markdown("""
**Dataset real** coletado pelo MIT Media Lab:
- **203,769** transações Bitcoin reais
- **234,355** arestas de fluxo de Bitcoin
- **166 features** por transação
- **49 timesteps** (jan 2017 — set 2018)
- Labels: `ilícito` (lavagem) / `lícito` / `desconhecido`
**Split temporal** — como no paper:
- Treino: timesteps **1 – 34**
- Teste: timesteps **35 – 49**
*(evita data leakage temporal)*
**Por que GraphSAGE?**
> Transações novas chegam todo segundo.
> GCN precisaria retreinar o grafo inteiro.
> GraphSAGE aprende agregadores — prediz
> para nós novos sem retreinar.
""")
if st.button('📥 Carregar Dataset', type='primary', use_container_width=True):
with st.spinner('Baixando Elliptic Bitcoin Dataset via PyG...'):
data, ok = carregar_elliptic(normalize=cfg['norm'])
if ok is True:
data, stats = preparar_splits(data)
st.session_state.data = data
st.session_state.stats = stats
st.session_state.loaded = True
st.session_state.trainers = {}
st.session_state.resultados = {}
st.success('✅ Dataset carregado!')
else:
st.error(f'Erro ao carregar: {ok}')
with c2:
if st.session_state.loaded and st.session_state.stats:
s = st.session_state.stats
m1,m2,m3,m4 = st.columns(4)
for col, v, l in [
(m1, f"{s['n_nos']:,}", 'Nós'),
(m2, f"{s['n_arestas']:,}", 'Arestas'),
(m3, f"{s['n_features']}", 'Features'),
(m4, f"{s['taxa_fraude_train']:.1%}", 'Taxa ilícito'),
]:
col.markdown(
f'<div class="card" style="text-align:center">'
f'<div class="metric-val" style="color:#f7931a">{v}</div>'
f'<div class="metric-lbl">{l}</div></div>',
unsafe_allow_html=True)
st.markdown('<br>', unsafe_allow_html=True)
c_a, c_b = st.columns(2)
with c_a:
st.markdown('#### Split Treino')
st.markdown(f"""
| | Count |
|---|---|
| Lícito | {s['n_licito_train']:,} |
| Ilícito (fraude) | {s['n_ilicito_train']:,} |
| **Total** | **{s['n_train']:,}** |
""")
with c_b:
st.markdown('#### Split Teste')
st.markdown(f"""
| | Count |
|---|---|
| Lícito | {s['n_licito_test']:,} |
| Ilícito (fraude) | {s['n_ilicito_test']:,} |
| **Total** | **{s['n_test']:,}** |
""")
# Bar chart splits
n_tr = s['n_train']; n_te = s['n_test']
total = n_tr + n_te
w_tr = int(n_tr/total*400); w_te = 400-w_tr
st.markdown(f"""
<svg viewBox="0 0 420 40" style="width:100%;margin-top:8px">
<rect x="0" y="8" width="{w_tr}" height="20" fill="#22c55e" rx="4"/>
<rect x="{w_tr+2}" y="8" width="{w_te}" height="20" fill="#3b82f6" rx="4"/>
<text x="{w_tr//2}" y="22" text-anchor="middle" fill="#000" font-size="11" font-weight="bold">Treino t1-34</text>
<text x="{w_tr+2+w_te//2}" y="22" text-anchor="middle" fill="#fff" font-size="11" font-weight="bold">Teste t35-49</text>
</svg>
""", unsafe_allow_html=True)
else:
st.info('Clique em **Carregar Dataset** para começar.')
st.markdown("""
**O dataset será baixado automaticamente** via PyTorch Geometric
(~50MB, hospedado pela comunidade PyG).
""")
# ── TAB 1: TREINAR ────────────────────────────────────────
with tabs[1]:
if not st.session_state.loaded:
st.warning('⬅️ Carregue o dataset primeiro.')
else:
data = st.session_state.data
in_dim = data.x.shape[1]
modelos_cfg = []
if cfg['treinar_sage']:
modelos_cfg.append(('GraphSAGE', '#22c55e', 'model-sage'))
if cfg['treinar_gcn']:
modelos_cfg.append(('GCN', '#3b82f6', 'model-gcn'))
if cfg['treinar_mlp']:
modelos_cfg.append(('MLP', '#f59e0b', 'model-mlp'))
# Cards dos modelos
cols = st.columns(len(modelos_cfg))
for col, (nome, cor, cls) in zip(cols, modelos_cfg):
col.markdown(
f'<div class="model-card {cls}">'
f'<div style="color:{cor};font-weight:700;font-size:1rem">{nome}</div>'
f'<div style="color:#64748b;font-size:.78rem;margin-top:6px">'
f'{"Inductive · SAGEConv · mini-batch" if nome=="GraphSAGE" else ("Transductive · GCNConv · full graph" if nome=="GCN" else "Sem grafo · MLP puro · baseline")}'
f'</div></div>',
unsafe_allow_html=True)
if st.button('🚀 Treinar Todos', type='primary', use_container_width=True):
batches_train = criar_mini_batches(data, batch_size=cfg['batch'], split='train')
for nome, cor, _ in modelos_cfg:
st.markdown(f'#### Treinando {nome}...')
prog = st.progress(0)
status = st.empty()
if nome == 'GraphSAGE':
model = GraphSAGE(in_dim, cfg['hidden'], 2,
cfg['layers'], cfg['dropout'])
elif nome == 'GCN':
model = GCNBaseline(in_dim, cfg['hidden'], 2,
cfg['layers'], cfg['dropout'])
else:
model = MLPBaseline(in_dim, cfg['hidden'], 2, cfg['dropout'])
trainer = TrainerElliptic(model, data, lr=cfg['lr'])
def make_cb(n, c, p, s):
def cb(ep, total, loss, auc, f1):
p.progress(ep/total)
s.markdown(
f'**{n}** · Época {ep}/{total} · '
f'Loss `{loss:.4f}` · AUC `{auc:.3f}` · F1 `{f1:.3f}`')
return cb
lotes = batches_train if nome == 'GraphSAGE' else None
trainer.treinar(cfg['epocas'], batches=lotes,
callback=make_cb(nome, cor, prog, status))
m = trainer.metricas_completas()
st.session_state.trainers[nome] = trainer
st.session_state.resultados[nome] = m
st.success(f'✅ {nome} — AUC: {m["auc"]:.4f} · F1: {m["f1"]:.4f}')
# ── TAB 2: BENCHMARK ──────────────────────────────────────
with tabs[2]:
res = st.session_state.resultados
if not res:
st.warning('⬅️ Treine os modelos primeiro.')
else:
st.markdown('### Comparação de Modelos — Dataset Real Elliptic')
st.components.v1.html(
f'<div style="background:#030711;padding:8px;border-radius:10px">'
f'{benchmark_html(res)}</div>', height=160)
st.markdown('<br>', unsafe_allow_html=True)
st.components.v1.html(roc_svg(res), height=220)
st.markdown('<br>', unsafe_allow_html=True)
cols = st.columns(len(res))
for col, (nome, m) in zip(cols, res.items()):
col.markdown(
f'<div class="card">{cm_html(m["cm"], nome)}</div>',
unsafe_allow_html=True)
# t-SNE por modelo
st.markdown('### Embeddings t-SNE')
cols2 = st.columns(len(res))
for col, (nome, m) in zip(cols2, res.items()):
with col:
st.components.v1.html(
tsne_svg(m['embeddings'], m['y_true'], nome),
height=320)
# Insight automático
if len(res) >= 2:
melhor = max(res.items(), key=lambda x: x[1]['auc'])
pior = min(res.items(), key=lambda x: x[1]['auc'])
ganho = melhor[1]['auc'] - pior[1]['auc']
st.info(
f'**{melhor[0]}** superou **{pior[0]}** em '
f'`{ganho:.4f}` AUC — '
f'demonstrando o ganho da informação estrutural do grafo.')
# ── TAB 3: INDUCTIVE DEMO ─────────────────────────────────
with tabs[3]:
st.markdown('### 🔍 GraphSAGE Inductive — Nó Novo')
st.markdown("""
**O diferencial do GraphSAGE:** classifica transações novas
que não existiam no treino, usando apenas seus vizinhos amostrados.
GCN não consegue isso sem retreinar.
""")
if 'GraphSAGE' not in st.session_state.trainers:
st.warning('⬅️ Treine o GraphSAGE primeiro.')
else:
trainer = st.session_state.trainers['GraphSAGE']
data = st.session_state.data
st.markdown('#### Simule uma transação nova')
c1, c2 = st.columns(2)
with c1:
# Seleciona nó do teste como "novo"
idx_no = st.number_input(
'Índice do nó (do split teste — "novo" para o modelo)',
0, int(data.test_mask_labeled.sum())-1, 0)
test_indices = data.test_mask_labeled.nonzero(as_tuple=True)[0]
no_real = int(test_indices[idx_no])
label_real = int(data.y[no_real])
label_str = '🚨 Ilícita (lavagem)' if label_real == 0 else '✅ Lícita'
if st.button('🔮 Predizer', type='primary', use_container_width=True):
# Forward inductive: apenas features + vizinhos
trainer.model.eval()
with torch.no_grad():
logits, embeds = trainer.model(
data.x, data.edge_index, return_embed=True)
prob_ilicito = float(
F.softmax(logits[no_real], dim=0)[0])
pred_str = '🚨 Ilícita' if prob_ilicito > 0.5 else '✅ Lícita'
correto = (prob_ilicito > 0.5) == (label_real == 0)
st.markdown(f'**Nó:** `{no_real}`')
st.markdown(f'**Label real:** {label_str}')
st.markdown(f'**Predição:** {pred_str}')
st.markdown(f'**P(ilícita):** `{prob_ilicito:.4f}`')
st.progress(float(prob_ilicito))
if correto:
st.success('✅ Predição correta!')
else:
st.error('❌ Predição incorreta')
with c2:
st.markdown('#### Por que inductive importa em produção')
st.markdown("""
```
Modelo transductive (GCN):
Nova tx chega
→ Adiciona ao grafo
→ Retreina tudo
→ Tempo: horas
❌ Inviável em produção
Modelo inductive (GraphSAGE):
Nova tx chega
→ Amostra vizinhos
→ Aplica agregadores
→ Predição: milissegundos
✅ Pronto para produção
```
""")
# ── TAB 4: NEO4J ─────────────────────────────────────────
with tabs[4]:
st.header('🗄️ Neo4j')
if not st.session_state.neo4j_ok:
st.warning('Neo4j offline.')
with st.expander('Como configurar'):
st.markdown("""
**HF Spaces → Settings → Variables and secrets:**
| Chave | Valor |
|---|---|
| `NEO4J_URI` | `neo4j+s://XXXXXXXX.databases.neo4j.io` |
| `NEO4J_USERNAME` | `neo4j` |
| `NEO4J_PASSWORD` | `sua_senha` |
| `NEO4J_DATABASE` | `neo4j` |
""")
else:
st.success('Conectado!')
res = st.session_state.resultados
if res and st.button('💾 Salvar benchmark no Neo4j'):
driver, db = st.session_state.neo4j
try:
with driver.session(database=db) as s:
ts = datetime.now().isoformat()
for nome, m in res.items():
s.run("""
MERGE (r:EllipticRun {nome:$nome, ts:$ts})
SET r.auc=$auc, r.f1=$f1, r.ap=$ap,
r.precision=$pr, r.recall=$rc
""", nome=nome, ts=ts,
auc=float(m['auc']), f1=float(m['f1']),
ap=float(m['ap']),
pr=float(m['precision']),
rc=float(m['recall']))
st.success(f'✅ {len(res)} modelos salvos!')
except Exception as e:
st.error(str(e))
if __name__ == '__main__':
main()