thanhphxu's picture
Upload folder using huggingface_hub
12a958b verified
import re
import traceback
from typing import Tuple, List, Dict, Any
import gradio as gr
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.utils import coalesce
from config import AppConfig
from rate_limit import GlobalRateLimiter, RateLimitExceeded
from explorers import new_client, fetch_with_fallback
from graph_builder import expand_ego
from features import build_features, scale_features
from inference import load_models, run_for_both_models
from viz import render_ego_html, histogram_scores
CFG = AppConfig()
RATE = GlobalRateLimiter(CFG.MAX_CALLS_PER_MIN, CFG.WINDOW_SECONDS)
_models_cache = None
def _is_valid_txid(tx: str) -> bool:
return bool(re.fullmatch(r"[0-9a-fA-F]{64}", tx or ""))
def _load_models_once():
global _models_cache
if _models_cache is None:
_models_cache = load_models(CFG)
return _models_cache
@RATE.enforce()
def handle_run(tx_hash: str, k: int, provider: str):
logs = []
try:
if not _is_valid_txid(tx_hash):
return None, None, "❌ Invalid txid. Please enter a 64-hex transaction hash.", None
k = int(k)
if k < 1 or k > 3:
return None, None, "❌ k must be in {1,2,3}.", None
logs.append(f"Fetching ego-subgraph for {tx_hash} with k={k} via {provider}…")
nodes, edges, center_idx, node_meta, gb_logs = expand_ego(tx_hash, k, provider, CFG)
logs.extend(gb_logs or [])
if center_idx < 0 or len(nodes) == 0:
return None, None, "❌ Failed to build subgraph (see logs).", "\n".join(logs)
if len(edges) == 0:
logs.append("⚠️ No edges in ego-graph; proceeding with single-node graph.")
# Build features
X, feat_names = build_features(nodes, edges, center_idx, node_meta)
Xs, scaler_used, scale_note = scale_features(X, scaler=None) # can inject scaler from model repo if desired
logs.append(scale_note)
# PyG Data
if len(edges) > 0:
edge_index = torch.tensor(np.array(edges).T, dtype=torch.long) # shape [2,E]
else:
edge_index = torch.empty((2,0), dtype=torch.long)
edge_index = coalesce(edge_index)
data = Data(
x=torch.tensor(Xs, dtype=torch.float32),
edge_index=edge_index
)
bundles = _load_models_once()
results = run_for_both_models(bundles, data, center_idx, CFG)
# Compose output table
records = []
for name, probs, thr, label, note in results:
rec = {
"tx_hash": tx_hash,
"model_name": name,
"probability": float(probs[center_idx]),
"threshold": float(thr),
"pred_label": int(label),
"k_used": int(k),
"num_nodes": int(len(nodes)),
"num_edges": int(len(edges)),
"note": note
}
records.append(rec)
df = pd.DataFrame(records)
# Visuals (two HTML ego-graphs)
html_gat = render_ego_html(nodes, edges, center_idx, scores=results[0][1])
html_gatv2 = render_ego_html(nodes, edges, center_idx, scores=results[1][1])
# Histogram of scores for the subgraph
fig_hist_gat = histogram_scores(results[0][1], title="Scores (GAT)")
fig_hist_v2 = histogram_scores(results[1][1], title="Scores (GATv2)")
log_text = "\n".join(logs)
return df, html_gat, html_gatv2, log_text, fig_hist_gat, fig_hist_v2
except RateLimitExceeded as e:
return None, None, None, f"❌ {e}", None, None
except Exception as e:
tb = traceback.format_exc()
return None, None, None, f"❌ Error: {e}\n\n{tb}", None, None
with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as app:
gr.Markdown("## 🧭 Bitcoin Abuse Scoring (GAT / GATv2)\nEnter a transaction hash and k (1–3). The app builds an ego-subgraph from on-chain data and returns model scores.")
with gr.Row():
tx_in = gr.Textbox(label="Transaction Hash (64-hex)", placeholder="e.g., 4d3c... (64 hex)")
k_in = gr.Slider(1, 3, value=2, step=1, label="k (steps before/after)")
provider_in = gr.Dropdown(choices=["mempool", "blockstream", "blockchair"], value="mempool", label="Data Source")
run_btn = gr.Button("Run", variant="primary")
with gr.Row():
out_table = gr.Dataframe(label="Results (GAT vs GATv2)", interactive=False)
with gr.Tabs():
with gr.Tab("Ego-graph (GAT)"):
out_html_gat = gr.HTML()
out_hist_gat = gr.Plot(label="Score histogram (GAT)")
with gr.Tab("Ego-graph (GATv2)"):
out_html_gatv2 = gr.HTML()
out_hist_gatv2 = gr.Plot(label="Score histogram (GATv2)")
out_logs = gr.Textbox(label="Logs", lines=8)
run_btn.click(
handle_run,
inputs=[tx_in, k_in, provider_in],
outputs=[out_table, out_html_gat, out_html_gatv2, out_logs, out_hist_gat, out_hist_gatv2],
concurrency_limit=CFG.QUEUE_CONCURRENCY, # NEW
)
app.queue(max_size=32) # removed deprecated concurrency_count
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860, max_threads=CFG.QUEUE_CONCURRENCY)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0", server_port=7860)