# app.py — GraphSAGE Inductive | Elliptic Bitcoin Dataset Real
import streamlit as st
import numpy as np
import torch
import torch.nn.functional as F
import os
from datetime import datetime
st.set_page_config(
page_title="GraphSAGE — Elliptic Bitcoin",
page_icon="₿",
layout="wide",
initial_sidebar_state="expanded"
)
st.markdown("""
""", unsafe_allow_html=True)
# ── SESSION STATE ─────────────────────────────────────────────
for k, v in {
'data': None, 'stats': None, 'loaded': False,
'trainers': {}, 'resultados': {},
'treinando': None, 'neo4j': None, 'neo4j_ok': False,
}.items():
if k not in st.session_state:
st.session_state[k] = v
# ── NEO4J ─────────────────────────────────────────────────────
def get_neo4j_config():
cfg = {}
try:
s = st.secrets
if 'NEO4J_URI' in s:
cfg = {'uri': s['NEO4J_URI'], 'username': s['NEO4J_USERNAME'],
'password': s['NEO4J_PASSWORD'],
'database': s.get('NEO4J_DATABASE', 'neo4j')}
elif 'neo4j' in s:
n = s['neo4j']
cfg = {'uri': n.get('uri',''), 'username': n.get('username',''),
'password': n.get('password',''), 'database': n.get('database','neo4j')}
except Exception:
pass
if not cfg.get('uri'):
cfg = {'uri': os.getenv('NEO4J_URI',''), 'username': os.getenv('NEO4J_USERNAME',''),
'password': os.getenv('NEO4J_PASSWORD',''), 'database': os.getenv('NEO4J_DATABASE','neo4j')}
return cfg
@st.cache_resource
def conectar_neo4j():
try:
from neo4j import GraphDatabase
cfg = get_neo4j_config()
if not all([cfg['uri'], cfg['username'], cfg['password']]):
return None
driver = GraphDatabase.driver(cfg['uri'], auth=(cfg['username'], cfg['password']))
with driver.session(database=cfg['database']) as s:
s.run('RETURN 1')
return driver, cfg['database']
except Exception:
return None
@st.cache_resource
def carregar_libs():
try:
from elliptic_data import carregar_elliptic, preparar_splits, criar_mini_batches
from elliptic_model import GraphSAGE, GCNBaseline, MLPBaseline, TrainerElliptic
return carregar_elliptic, preparar_splits, criar_mini_batches, \
GraphSAGE, GCNBaseline, MLPBaseline, TrainerElliptic
except Exception as e:
return str(e), None, None, None, None, None, None
# ── CHARTS ────────────────────────────────────────────────────
def curves_svg(hist, title='', color='#22c55e'):
loss = hist.get('loss_train', [])
auc = hist.get('auc_val', [])
ep = len(loss)
if ep == 0: return ''
def pts(vals, H=100):
mn,mx = min(vals),max(vals); r=mx-mn or 1
return ' '.join(f'{i*420/max(ep-1,1):.1f},{H-(v-mn)/r*H:.1f}'
for i,v in enumerate(vals))
return f"""
{title}
— Loss— AUC Test
"""
def roc_svg(resultados):
"""ROC de todos os modelos juntos."""
from sklearn.metrics import roc_curve, auc as sk_auc
CORES = {'GraphSAGE':'#22c55e', 'GCN':'#3b82f6', 'MLP':'#f59e0b'}
curvas = ''
for nome, res in resultados.items():
if 'y_true' not in res: continue
fpr,tpr,_ = roc_curve(res['y_true'], res['probs'])
ra = sk_auc(fpr,tpr)
pts = ' '.join(f'{f*400:.1f},{180-t*180:.1f}' for f,t in zip(fpr,tpr))
cor = CORES.get(nome,'#888')
curvas += (f''
f'{nome} {ra:.3f}')
return f"""
ROC — TODOS OS MODELOS
"""
def benchmark_html(resultados):
CORES = {'GraphSAGE':'#22c55e', 'GCN':'#3b82f6', 'MLP':'#f59e0b'}
melhor_auc = max((r.get('auc',0) for r in resultados.values()), default=0)
html = ''
# Ordena por AUC
items = sorted(resultados.items(), key=lambda x: x[1].get('auc',0), reverse=True)
for nome, res in items:
if 'auc' not in res: continue
cor = CORES.get(nome, '#888')
is_best = res['auc'] == melhor_auc
cls = 'best-row' if is_best else ''
crown = '👑 ' if is_best else ''
bar_auc = int(res['auc']*100)
bar_f1 = int(res['f1']*100)
html += f"""
{crown}{nome}
AUC {res['auc']:.4f}
F1 {res['f1']:.4f}
P:{res['precision']:.3f} R:{res['recall']:.3f}
"""
return html
def cm_html(cm, nome):
CORES = {'GraphSAGE':'#22c55e', 'GCN':'#3b82f6', 'MLP':'#f59e0b'}
cor = CORES.get(nome,'#888')
tn,fp,fn,tp = cm.ravel()
items = [(cor,tn,'TN','Lícitas\ncorretas'),('#ef4444',fp,'FP','Falsos\nalarmes'),
('#f59e0b',fn,'FN','Ilícitas\nperdidas'),(cor,tp,'TP','Ilícitas\ncapturadas')]
html = f'