Spaces:
Sleeping
Sleeping
File size: 5,369 Bytes
db886e4 12a958b db886e4 12a958b db886e4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | import re
import traceback
from typing import Tuple, List, Dict, Any
import gradio as gr
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.utils import coalesce
from config import AppConfig
from rate_limit import GlobalRateLimiter, RateLimitExceeded
from explorers import new_client, fetch_with_fallback
from graph_builder import expand_ego
from features import build_features, scale_features
from inference import load_models, run_for_both_models
from viz import render_ego_html, histogram_scores
CFG = AppConfig()
RATE = GlobalRateLimiter(CFG.MAX_CALLS_PER_MIN, CFG.WINDOW_SECONDS)
_models_cache = None
def _is_valid_txid(tx: str) -> bool:
return bool(re.fullmatch(r"[0-9a-fA-F]{64}", tx or ""))
def _load_models_once():
global _models_cache
if _models_cache is None:
_models_cache = load_models(CFG)
return _models_cache
@RATE.enforce()
def handle_run(tx_hash: str, k: int, provider: str):
logs = []
try:
if not _is_valid_txid(tx_hash):
return None, None, "❌ Invalid txid. Please enter a 64-hex transaction hash.", None
k = int(k)
if k < 1 or k > 3:
return None, None, "❌ k must be in {1,2,3}.", None
logs.append(f"Fetching ego-subgraph for {tx_hash} with k={k} via {provider}…")
nodes, edges, center_idx, node_meta, gb_logs = expand_ego(tx_hash, k, provider, CFG)
logs.extend(gb_logs or [])
if center_idx < 0 or len(nodes) == 0:
return None, None, "❌ Failed to build subgraph (see logs).", "\n".join(logs)
if len(edges) == 0:
logs.append("⚠️ No edges in ego-graph; proceeding with single-node graph.")
# Build features
X, feat_names = build_features(nodes, edges, center_idx, node_meta)
Xs, scaler_used, scale_note = scale_features(X, scaler=None) # can inject scaler from model repo if desired
logs.append(scale_note)
# PyG Data
if len(edges) > 0:
edge_index = torch.tensor(np.array(edges).T, dtype=torch.long) # shape [2,E]
else:
edge_index = torch.empty((2,0), dtype=torch.long)
edge_index = coalesce(edge_index)
data = Data(
x=torch.tensor(Xs, dtype=torch.float32),
edge_index=edge_index
)
bundles = _load_models_once()
results = run_for_both_models(bundles, data, center_idx, CFG)
# Compose output table
records = []
for name, probs, thr, label, note in results:
rec = {
"tx_hash": tx_hash,
"model_name": name,
"probability": float(probs[center_idx]),
"threshold": float(thr),
"pred_label": int(label),
"k_used": int(k),
"num_nodes": int(len(nodes)),
"num_edges": int(len(edges)),
"note": note
}
records.append(rec)
df = pd.DataFrame(records)
# Visuals (two HTML ego-graphs)
html_gat = render_ego_html(nodes, edges, center_idx, scores=results[0][1])
html_gatv2 = render_ego_html(nodes, edges, center_idx, scores=results[1][1])
# Histogram of scores for the subgraph
fig_hist_gat = histogram_scores(results[0][1], title="Scores (GAT)")
fig_hist_v2 = histogram_scores(results[1][1], title="Scores (GATv2)")
log_text = "\n".join(logs)
return df, html_gat, html_gatv2, log_text, fig_hist_gat, fig_hist_v2
except RateLimitExceeded as e:
return None, None, None, f"❌ {e}", None, None
except Exception as e:
tb = traceback.format_exc()
return None, None, None, f"❌ Error: {e}\n\n{tb}", None, None
with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as app:
gr.Markdown("## 🧭 Bitcoin Abuse Scoring (GAT / GATv2)\nEnter a transaction hash and k (1–3). The app builds an ego-subgraph from on-chain data and returns model scores.")
with gr.Row():
tx_in = gr.Textbox(label="Transaction Hash (64-hex)", placeholder="e.g., 4d3c... (64 hex)")
k_in = gr.Slider(1, 3, value=2, step=1, label="k (steps before/after)")
provider_in = gr.Dropdown(choices=["mempool", "blockstream", "blockchair"], value="mempool", label="Data Source")
run_btn = gr.Button("Run", variant="primary")
with gr.Row():
out_table = gr.Dataframe(label="Results (GAT vs GATv2)", interactive=False)
with gr.Tabs():
with gr.Tab("Ego-graph (GAT)"):
out_html_gat = gr.HTML()
out_hist_gat = gr.Plot(label="Score histogram (GAT)")
with gr.Tab("Ego-graph (GATv2)"):
out_html_gatv2 = gr.HTML()
out_hist_gatv2 = gr.Plot(label="Score histogram (GATv2)")
out_logs = gr.Textbox(label="Logs", lines=8)
run_btn.click(
handle_run,
inputs=[tx_in, k_in, provider_in],
outputs=[out_table, out_html_gat, out_html_gatv2, out_logs, out_hist_gat, out_hist_gatv2],
concurrency_limit=CFG.QUEUE_CONCURRENCY, # NEW
)
app.queue(max_size=32) # removed deprecated concurrency_count
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860, max_threads=CFG.QUEUE_CONCURRENCY)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)
|