import re import traceback from typing import Tuple, List, Dict, Any import gradio as gr import numpy as np import pandas as pd import torch from torch_geometric.data import Data from torch_geometric.utils import coalesce from config import AppConfig from rate_limit import GlobalRateLimiter, RateLimitExceeded from explorers import new_client, fetch_with_fallback from graph_builder import expand_ego from features import build_features, scale_features from inference import load_models, run_for_both_models from viz import render_ego_html, histogram_scores CFG = AppConfig() RATE = GlobalRateLimiter(CFG.MAX_CALLS_PER_MIN, CFG.WINDOW_SECONDS) _models_cache = None def _is_valid_txid(tx: str) -> bool: return bool(re.fullmatch(r"[0-9a-fA-F]{64}", tx or "")) def _load_models_once(): global _models_cache if _models_cache is None: _models_cache = load_models(CFG) return _models_cache @RATE.enforce() def handle_run(tx_hash: str, k: int, provider: str): logs = [] try: if not _is_valid_txid(tx_hash): return None, None, "❌ Invalid txid. Please enter a 64-hex transaction hash.", None k = int(k) if k < 1 or k > 3: return None, None, "❌ k must be in {1,2,3}.", None logs.append(f"Fetching ego-subgraph for {tx_hash} with k={k} via {provider}…") nodes, edges, center_idx, node_meta, gb_logs = expand_ego(tx_hash, k, provider, CFG) logs.extend(gb_logs or []) if center_idx < 0 or len(nodes) == 0: return None, None, "❌ Failed to build subgraph (see logs).", "\n".join(logs) if len(edges) == 0: logs.append("⚠️ No edges in ego-graph; proceeding with single-node graph.") # Build features X, feat_names = build_features(nodes, edges, center_idx, node_meta) Xs, scaler_used, scale_note = scale_features(X, scaler=None) # can inject scaler from model repo if desired logs.append(scale_note) # PyG Data if len(edges) > 0: edge_index = torch.tensor(np.array(edges).T, dtype=torch.long) # shape [2,E] else: edge_index = torch.empty((2,0), dtype=torch.long) edge_index = coalesce(edge_index) data = Data( x=torch.tensor(Xs, dtype=torch.float32), edge_index=edge_index ) bundles = _load_models_once() results = run_for_both_models(bundles, data, center_idx, CFG) # Compose output table records = [] for name, probs, thr, label, note in results: rec = { "tx_hash": tx_hash, "model_name": name, "probability": float(probs[center_idx]), "threshold": float(thr), "pred_label": int(label), "k_used": int(k), "num_nodes": int(len(nodes)), "num_edges": int(len(edges)), "note": note } records.append(rec) df = pd.DataFrame(records) # Visuals (two HTML ego-graphs) html_gat = render_ego_html(nodes, edges, center_idx, scores=results[0][1]) html_gatv2 = render_ego_html(nodes, edges, center_idx, scores=results[1][1]) # Histogram of scores for the subgraph fig_hist_gat = histogram_scores(results[0][1], title="Scores (GAT)") fig_hist_v2 = histogram_scores(results[1][1], title="Scores (GATv2)") log_text = "\n".join(logs) return df, html_gat, html_gatv2, log_text, fig_hist_gat, fig_hist_v2 except RateLimitExceeded as e: return None, None, None, f"❌ {e}", None, None except Exception as e: tb = traceback.format_exc() return None, None, None, f"❌ Error: {e}\n\n{tb}", None, None with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as app: gr.Markdown("## 🧭 Bitcoin Abuse Scoring (GAT / GATv2)\nEnter a transaction hash and k (1–3). The app builds an ego-subgraph from on-chain data and returns model scores.") with gr.Row(): tx_in = gr.Textbox(label="Transaction Hash (64-hex)", placeholder="e.g., 4d3c... (64 hex)") k_in = gr.Slider(1, 3, value=2, step=1, label="k (steps before/after)") provider_in = gr.Dropdown(choices=["mempool", "blockstream", "blockchair"], value="mempool", label="Data Source") run_btn = gr.Button("Run", variant="primary") with gr.Row(): out_table = gr.Dataframe(label="Results (GAT vs GATv2)", interactive=False) with gr.Tabs(): with gr.Tab("Ego-graph (GAT)"): out_html_gat = gr.HTML() out_hist_gat = gr.Plot(label="Score histogram (GAT)") with gr.Tab("Ego-graph (GATv2)"): out_html_gatv2 = gr.HTML() out_hist_gatv2 = gr.Plot(label="Score histogram (GATv2)") out_logs = gr.Textbox(label="Logs", lines=8) run_btn.click( handle_run, inputs=[tx_in, k_in, provider_in], outputs=[out_table, out_html_gat, out_html_gatv2, out_logs, out_hist_gat, out_hist_gatv2], concurrency_limit=CFG.QUEUE_CONCURRENCY, # NEW ) app.queue(max_size=32) # removed deprecated concurrency_count if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860, max_threads=CFG.QUEUE_CONCURRENCY) if __name__ == "__main__": app.launch(server_name="0.0.0.0", server_port=7860)