Spaces:
Sleeping
Sleeping
| import re | |
| import traceback | |
| from typing import Tuple, List, Dict, Any | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from torch_geometric.data import Data | |
| from torch_geometric.utils import coalesce | |
| from config import AppConfig | |
| from rate_limit import GlobalRateLimiter, RateLimitExceeded | |
| from explorers import new_client, fetch_with_fallback | |
| from graph_builder import expand_ego | |
| from features import build_features, scale_features | |
| from inference import load_models, run_for_both_models | |
| from viz import render_ego_html, histogram_scores | |
| CFG = AppConfig() | |
| RATE = GlobalRateLimiter(CFG.MAX_CALLS_PER_MIN, CFG.WINDOW_SECONDS) | |
| _models_cache = None | |
| def _is_valid_txid(tx: str) -> bool: | |
| return bool(re.fullmatch(r"[0-9a-fA-F]{64}", tx or "")) | |
| def _load_models_once(): | |
| global _models_cache | |
| if _models_cache is None: | |
| _models_cache = load_models(CFG) | |
| return _models_cache | |
| def handle_run(tx_hash: str, k: int, provider: str): | |
| logs = [] | |
| try: | |
| if not _is_valid_txid(tx_hash): | |
| return None, None, "❌ Invalid txid. Please enter a 64-hex transaction hash.", None | |
| k = int(k) | |
| if k < 1 or k > 3: | |
| return None, None, "❌ k must be in {1,2,3}.", None | |
| logs.append(f"Fetching ego-subgraph for {tx_hash} with k={k} via {provider}…") | |
| nodes, edges, center_idx, node_meta, gb_logs = expand_ego(tx_hash, k, provider, CFG) | |
| logs.extend(gb_logs or []) | |
| if center_idx < 0 or len(nodes) == 0: | |
| return None, None, "❌ Failed to build subgraph (see logs).", "\n".join(logs) | |
| if len(edges) == 0: | |
| logs.append("⚠️ No edges in ego-graph; proceeding with single-node graph.") | |
| # Build features | |
| X, feat_names = build_features(nodes, edges, center_idx, node_meta) | |
| Xs, scaler_used, scale_note = scale_features(X, scaler=None) # can inject scaler from model repo if desired | |
| logs.append(scale_note) | |
| # PyG Data | |
| if len(edges) > 0: | |
| edge_index = torch.tensor(np.array(edges).T, dtype=torch.long) # shape [2,E] | |
| else: | |
| edge_index = torch.empty((2,0), dtype=torch.long) | |
| edge_index = coalesce(edge_index) | |
| data = Data( | |
| x=torch.tensor(Xs, dtype=torch.float32), | |
| edge_index=edge_index | |
| ) | |
| bundles = _load_models_once() | |
| results = run_for_both_models(bundles, data, center_idx, CFG) | |
| # Compose output table | |
| records = [] | |
| for name, probs, thr, label, note in results: | |
| rec = { | |
| "tx_hash": tx_hash, | |
| "model_name": name, | |
| "probability": float(probs[center_idx]), | |
| "threshold": float(thr), | |
| "pred_label": int(label), | |
| "k_used": int(k), | |
| "num_nodes": int(len(nodes)), | |
| "num_edges": int(len(edges)), | |
| "note": note | |
| } | |
| records.append(rec) | |
| df = pd.DataFrame(records) | |
| # Visuals (two HTML ego-graphs) | |
| html_gat = render_ego_html(nodes, edges, center_idx, scores=results[0][1]) | |
| html_gatv2 = render_ego_html(nodes, edges, center_idx, scores=results[1][1]) | |
| # Histogram of scores for the subgraph | |
| fig_hist_gat = histogram_scores(results[0][1], title="Scores (GAT)") | |
| fig_hist_v2 = histogram_scores(results[1][1], title="Scores (GATv2)") | |
| log_text = "\n".join(logs) | |
| return df, html_gat, html_gatv2, log_text, fig_hist_gat, fig_hist_v2 | |
| except RateLimitExceeded as e: | |
| return None, None, None, f"❌ {e}", None, None | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| return None, None, None, f"❌ Error: {e}\n\n{tb}", None, None | |
| with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as app: | |
| gr.Markdown("## 🧭 Bitcoin Abuse Scoring (GAT / GATv2)\nEnter a transaction hash and k (1–3). The app builds an ego-subgraph from on-chain data and returns model scores.") | |
| with gr.Row(): | |
| tx_in = gr.Textbox(label="Transaction Hash (64-hex)", placeholder="e.g., 4d3c... (64 hex)") | |
| k_in = gr.Slider(1, 3, value=2, step=1, label="k (steps before/after)") | |
| provider_in = gr.Dropdown(choices=["mempool", "blockstream", "blockchair"], value="mempool", label="Data Source") | |
| run_btn = gr.Button("Run", variant="primary") | |
| with gr.Row(): | |
| out_table = gr.Dataframe(label="Results (GAT vs GATv2)", interactive=False) | |
| with gr.Tabs(): | |
| with gr.Tab("Ego-graph (GAT)"): | |
| out_html_gat = gr.HTML() | |
| out_hist_gat = gr.Plot(label="Score histogram (GAT)") | |
| with gr.Tab("Ego-graph (GATv2)"): | |
| out_html_gatv2 = gr.HTML() | |
| out_hist_gatv2 = gr.Plot(label="Score histogram (GATv2)") | |
| out_logs = gr.Textbox(label="Logs", lines=8) | |
| run_btn.click( | |
| handle_run, | |
| inputs=[tx_in, k_in, provider_in], | |
| outputs=[out_table, out_html_gat, out_html_gatv2, out_logs, out_hist_gat, out_hist_gatv2], | |
| concurrency_limit=CFG.QUEUE_CONCURRENCY, # NEW | |
| ) | |
| app.queue(max_size=32) # removed deprecated concurrency_count | |
| if __name__ == "__main__": | |
| app.launch(server_name="0.0.0.0", server_port=7860, max_threads=CFG.QUEUE_CONCURRENCY) | |
| if __name__ == "__main__": | |
| app.launch(server_name="0.0.0.0", server_port=7860) | |