File size: 5,369 Bytes
db886e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12a958b
 
 
 
 
db886e4
12a958b
 
 
 
db886e4
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import re
import traceback
from typing import Tuple, List, Dict, Any

import gradio as gr
import numpy as np
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.utils import coalesce

from config import AppConfig
from rate_limit import GlobalRateLimiter, RateLimitExceeded
from explorers import new_client, fetch_with_fallback
from graph_builder import expand_ego
from features import build_features, scale_features
from inference import load_models, run_for_both_models
from viz import render_ego_html, histogram_scores

CFG = AppConfig()
RATE = GlobalRateLimiter(CFG.MAX_CALLS_PER_MIN, CFG.WINDOW_SECONDS)

_models_cache = None

def _is_valid_txid(tx: str) -> bool:
    return bool(re.fullmatch(r"[0-9a-fA-F]{64}", tx or ""))

def _load_models_once():
    global _models_cache
    if _models_cache is None:
        _models_cache = load_models(CFG)
    return _models_cache

@RATE.enforce()
def handle_run(tx_hash: str, k: int, provider: str):
    logs = []
    try:
        if not _is_valid_txid(tx_hash):
            return None, None, "❌ Invalid txid. Please enter a 64-hex transaction hash.", None

        k = int(k)
        if k < 1 or k > 3:
            return None, None, "❌ k must be in {1,2,3}.", None

        logs.append(f"Fetching ego-subgraph for {tx_hash} with k={k} via {provider}…")
        nodes, edges, center_idx, node_meta, gb_logs = expand_ego(tx_hash, k, provider, CFG)
        logs.extend(gb_logs or [])
        if center_idx < 0 or len(nodes) == 0:
            return None, None, "❌ Failed to build subgraph (see logs).", "\n".join(logs)

        if len(edges) == 0:
            logs.append("⚠️ No edges in ego-graph; proceeding with single-node graph.")

        # Build features
        X, feat_names = build_features(nodes, edges, center_idx, node_meta)
        Xs, scaler_used, scale_note = scale_features(X, scaler=None)  # can inject scaler from model repo if desired
        logs.append(scale_note)

        # PyG Data
        if len(edges) > 0:
            edge_index = torch.tensor(np.array(edges).T, dtype=torch.long)  # shape [2,E]
        else:
            edge_index = torch.empty((2,0), dtype=torch.long)
        edge_index = coalesce(edge_index)

        data = Data(
            x=torch.tensor(Xs, dtype=torch.float32),
            edge_index=edge_index
        )

        bundles = _load_models_once()
        results = run_for_both_models(bundles, data, center_idx, CFG)

        # Compose output table
        records = []
        for name, probs, thr, label, note in results:
            rec = {
                "tx_hash": tx_hash,
                "model_name": name,
                "probability": float(probs[center_idx]),
                "threshold": float(thr),
                "pred_label": int(label),
                "k_used": int(k),
                "num_nodes": int(len(nodes)),
                "num_edges": int(len(edges)),
                "note": note
            }
            records.append(rec)

        df = pd.DataFrame(records)

        # Visuals (two HTML ego-graphs)
        html_gat = render_ego_html(nodes, edges, center_idx, scores=results[0][1])
        html_gatv2 = render_ego_html(nodes, edges, center_idx, scores=results[1][1])

        # Histogram of scores for the subgraph
        fig_hist_gat = histogram_scores(results[0][1], title="Scores (GAT)")
        fig_hist_v2 = histogram_scores(results[1][1], title="Scores (GATv2)")

        log_text = "\n".join(logs)
        return df, html_gat, html_gatv2, log_text, fig_hist_gat, fig_hist_v2

    except RateLimitExceeded as e:
        return None, None, None, f"❌ {e}", None, None
    except Exception as e:
        tb = traceback.format_exc()
        return None, None, None, f"❌ Error: {e}\n\n{tb}", None, None

with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as app:
    gr.Markdown("## 🧭 Bitcoin Abuse Scoring (GAT / GATv2)\nEnter a transaction hash and k (1–3). The app builds an ego-subgraph from on-chain data and returns model scores.")
    with gr.Row():
        tx_in = gr.Textbox(label="Transaction Hash (64-hex)", placeholder="e.g., 4d3c... (64 hex)")
        k_in = gr.Slider(1, 3, value=2, step=1, label="k (steps before/after)")
        provider_in = gr.Dropdown(choices=["mempool", "blockstream", "blockchair"], value="mempool", label="Data Source")
    run_btn = gr.Button("Run", variant="primary")

    with gr.Row():
        out_table = gr.Dataframe(label="Results (GAT vs GATv2)", interactive=False)
    with gr.Tabs():
        with gr.Tab("Ego-graph (GAT)"):
            out_html_gat = gr.HTML()
            out_hist_gat = gr.Plot(label="Score histogram (GAT)")
        with gr.Tab("Ego-graph (GATv2)"):
            out_html_gatv2 = gr.HTML()
            out_hist_gatv2 = gr.Plot(label="Score histogram (GATv2)")
    out_logs = gr.Textbox(label="Logs", lines=8)

    run_btn.click(
    handle_run,
    inputs=[tx_in, k_in, provider_in],
    outputs=[out_table, out_html_gat, out_html_gatv2, out_logs, out_hist_gat, out_hist_gatv2],
    concurrency_limit=CFG.QUEUE_CONCURRENCY,  # NEW
)

app.queue(max_size=32)  # removed deprecated concurrency_count

if __name__ == "__main__":
    app.launch(server_name="0.0.0.0", server_port=7860, max_threads=CFG.QUEUE_CONCURRENCY)

if __name__ == "__main__":
    app.launch(server_name="0.0.0.0", server_port=7860)