Spaces:
Build error
Build error
Upload folder using huggingface_hub
Browse files- .gitignore +1 -0
- README.md +72 -13
- app.py +137 -0
- config.py +41 -0
- explorers.py +189 -0
- features.py +90 -0
- graph_builder.py +132 -0
- inference.py +109 -0
- models.py +78 -0
- rate_limit.py +34 -0
- requirements.txt +28 -0
- viz.py +44 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.venv/**
|
README.md
CHANGED
|
@@ -1,13 +1,72 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
---
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Bitcoin Abuse Scoring (GAT / GATv2) — Hugging Face Space
|
| 2 |
+
|
| 3 |
+
This Space builds an **ego-subgraph** from a given Bitcoin transaction hash (`k` steps backward & forward), then runs **two pretrained GNN models** (GAT baseline & GATv2 enhanced) trained on **Elliptic** to score whether the center transaction is *abuse*.
|
| 4 |
+
|
| 5 |
+
## ✅ Features
|
| 6 |
+
|
| 7 |
+
- Data sources (public JSON APIs, no scraping): `mempool.space` / `blockstream.info` (Esplora), fallback to `Blockchair` (optional key).
|
| 8 |
+
- Ego-subgraph expansion **k ∈ {1,2,3}** (both parents & children).
|
| 9 |
+
- Graph safeguards: `MAX_NODES` & `MAX_EDGES` to avoid explosion.
|
| 10 |
+
- Node features: degree stats, value sums/logs, counts, ratio, distance-to-center, block height.
|
| 11 |
+
- Standardized features (on-the-fly). If your model used different features/scaler, set `USE_FEATURE_ADAPTER=true` (default) — it inserts a `Linear` projection to the expected input dimension (165 by default).
|
| 12 |
+
- Two models are loaded from **Hugging Face Hub** with thresholds (via `thresholds.json` or fallback `0.5`).
|
| 13 |
+
- **Rate limit**: 20 requests/min globally (sliding window).
|
| 14 |
+
- Visualizations: **ego-graph (pyvis HTML)** & **histogram of scores** per model.
|
| 15 |
+
- CPU-only deployment on Spaces.
|
| 16 |
+
|
| 17 |
+
## 🔧 Configuration
|
| 18 |
+
|
| 19 |
+
Set these **Environment Variables** (Space → Settings → Variables):
|
| 20 |
+
|
| 21 |
+
```
|
| 22 |
+
HF_GAT_BASELINE_REPO=org/name_gat_baseline
|
| 23 |
+
HF_GATV2_REPO=org/name_gatv2
|
| 24 |
+
|
| 25 |
+
# (Optional overrides)
|
| 26 |
+
IN_CHANNELS=165
|
| 27 |
+
HIDDEN_CHANNELS=128
|
| 28 |
+
HEADS=8
|
| 29 |
+
NUM_BLOCKS=2
|
| 30 |
+
DROPOUT=0.5
|
| 31 |
+
|
| 32 |
+
DATA_PROVIDER=mempool # mempool | blockstream | blockchair
|
| 33 |
+
HTTP_TIMEOUT=10
|
| 34 |
+
HTTP_RETRIES=2
|
| 35 |
+
MAX_NODES=5000
|
| 36 |
+
MAX_EDGES=15000
|
| 37 |
+
USE_FEATURE_ADAPTER=true
|
| 38 |
+
DEFAULT_THRESHOLD=0.5
|
| 39 |
+
QUEUE_CONCURRENCY=2
|
| 40 |
+
BLOCKCHAIR_API_KEY=
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
Each model repo should contain:
|
| 44 |
+
- `model.pt` — PyTorch Geometric weights.
|
| 45 |
+
- (optional) `thresholds.json` with a key like `{"threshold": 0.42}`.
|
| 46 |
+
- (optional) `scaler.joblib` if you want to reuse the training scaler.
|
| 47 |
+
|
| 48 |
+
## 📦 API Usage in App
|
| 49 |
+
|
| 50 |
+
- `GET /api/tx/{txid}` and `GET /api/tx/{txid}/outspends` (Esplora).
|
| 51 |
+
- `GET /bitcoin/dashboards/transaction/{txid}` (Blockchair).
|
| 52 |
+
|
| 53 |
+
All calls have **timeouts & retries** and use a small **in-memory cache**.
|
| 54 |
+
|
| 55 |
+
## 🚦 Rate Limiting
|
| 56 |
+
|
| 57 |
+
Global limit `20 req/min` across the app (sliding window). Exceeding returns `Rate limit exceeded (20 req/min)`.
|
| 58 |
+
|
| 59 |
+
## 🧪 Acceptance Criteria
|
| 60 |
+
|
| 61 |
+
- Enter a valid tx hash & `k=2` → ego-graph is built, both models run, and the app displays:
|
| 62 |
+
- `probability`, `threshold`, `label` for **GAT** and **GATv2**,
|
| 63 |
+
- counts of nodes/edges and notes (e.g., *FeatureAdapter used*).
|
| 64 |
+
- Ego-graph renders with center highlighted; tooltips show txid and score.
|
| 65 |
+
- If the first provider fails, the app falls back.
|
| 66 |
+
- If graph exceeds safeguards, the app stops expansion and warns in logs (but still infers with what it has).
|
| 67 |
+
|
| 68 |
+
## ⚠️ Notes
|
| 69 |
+
|
| 70 |
+
- **Domain shift**: Features from on-chain crawls can differ from Elliptic; use the adapter and consider fine-tuning for production.
|
| 71 |
+
- Public APIs have their own rate limits — this app is conservative with requests, but heavy usage may still hit external limits.
|
| 72 |
+
- Input is validated to be a 64-hex txid. No arbitrary URLs are accepted.
|
app.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import traceback
|
| 3 |
+
from typing import Tuple, List, Dict, Any
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import torch
|
| 9 |
+
from torch_geometric.data import Data
|
| 10 |
+
from torch_geometric.utils import coalesce
|
| 11 |
+
|
| 12 |
+
from config import AppConfig
|
| 13 |
+
from rate_limit import GlobalRateLimiter, RateLimitExceeded
|
| 14 |
+
from explorers import new_client, fetch_with_fallback
|
| 15 |
+
from graph_builder import expand_ego
|
| 16 |
+
from features import build_features, scale_features
|
| 17 |
+
from inference import load_models, run_for_both_models
|
| 18 |
+
from viz import render_ego_html, histogram_scores
|
| 19 |
+
|
| 20 |
+
CFG = AppConfig()
|
| 21 |
+
RATE = GlobalRateLimiter(CFG.MAX_CALLS_PER_MIN, CFG.WINDOW_SECONDS)
|
| 22 |
+
|
| 23 |
+
_models_cache = None
|
| 24 |
+
|
| 25 |
+
def _is_valid_txid(tx: str) -> bool:
|
| 26 |
+
return bool(re.fullmatch(r"[0-9a-fA-F]{64}", tx or ""))
|
| 27 |
+
|
| 28 |
+
def _load_models_once():
|
| 29 |
+
global _models_cache
|
| 30 |
+
if _models_cache is None:
|
| 31 |
+
_models_cache = load_models(CFG)
|
| 32 |
+
return _models_cache
|
| 33 |
+
|
| 34 |
+
@RATE.enforce()
|
| 35 |
+
def handle_run(tx_hash: str, k: int, provider: str):
|
| 36 |
+
logs = []
|
| 37 |
+
try:
|
| 38 |
+
if not _is_valid_txid(tx_hash):
|
| 39 |
+
return None, None, "❌ Invalid txid. Please enter a 64-hex transaction hash.", None
|
| 40 |
+
|
| 41 |
+
k = int(k)
|
| 42 |
+
if k < 1 or k > 3:
|
| 43 |
+
return None, None, "❌ k must be in {1,2,3}.", None
|
| 44 |
+
|
| 45 |
+
logs.append(f"Fetching ego-subgraph for {tx_hash} with k={k} via {provider}…")
|
| 46 |
+
nodes, edges, center_idx, node_meta, gb_logs = expand_ego(tx_hash, k, provider, CFG)
|
| 47 |
+
logs.extend(gb_logs or [])
|
| 48 |
+
if center_idx < 0 or len(nodes) == 0:
|
| 49 |
+
return None, None, "❌ Failed to build subgraph (see logs).", "\n".join(logs)
|
| 50 |
+
|
| 51 |
+
if len(edges) == 0:
|
| 52 |
+
logs.append("⚠️ No edges in ego-graph; proceeding with single-node graph.")
|
| 53 |
+
|
| 54 |
+
# Build features
|
| 55 |
+
X, feat_names = build_features(nodes, edges, center_idx, node_meta)
|
| 56 |
+
Xs, scaler_used, scale_note = scale_features(X, scaler=None) # can inject scaler from model repo if desired
|
| 57 |
+
logs.append(scale_note)
|
| 58 |
+
|
| 59 |
+
# PyG Data
|
| 60 |
+
if len(edges) > 0:
|
| 61 |
+
edge_index = torch.tensor(np.array(edges).T, dtype=torch.long) # shape [2,E]
|
| 62 |
+
else:
|
| 63 |
+
edge_index = torch.empty((2,0), dtype=torch.long)
|
| 64 |
+
edge_index = coalesce(edge_index)
|
| 65 |
+
|
| 66 |
+
data = Data(
|
| 67 |
+
x=torch.tensor(Xs, dtype=torch.float32),
|
| 68 |
+
edge_index=edge_index
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
bundles = _load_models_once()
|
| 72 |
+
results = run_for_both_models(bundles, data, center_idx, CFG)
|
| 73 |
+
|
| 74 |
+
# Compose output table
|
| 75 |
+
records = []
|
| 76 |
+
for name, probs, thr, label, note in results:
|
| 77 |
+
rec = {
|
| 78 |
+
"tx_hash": tx_hash,
|
| 79 |
+
"model_name": name,
|
| 80 |
+
"probability": float(probs[center_idx]),
|
| 81 |
+
"threshold": float(thr),
|
| 82 |
+
"pred_label": int(label),
|
| 83 |
+
"k_used": int(k),
|
| 84 |
+
"num_nodes": int(len(nodes)),
|
| 85 |
+
"num_edges": int(len(edges)),
|
| 86 |
+
"note": note
|
| 87 |
+
}
|
| 88 |
+
records.append(rec)
|
| 89 |
+
|
| 90 |
+
df = pd.DataFrame(records)
|
| 91 |
+
|
| 92 |
+
# Visuals (two HTML ego-graphs)
|
| 93 |
+
html_gat = render_ego_html(nodes, edges, center_idx, scores=results[0][1])
|
| 94 |
+
html_gatv2 = render_ego_html(nodes, edges, center_idx, scores=results[1][1])
|
| 95 |
+
|
| 96 |
+
# Histogram of scores for the subgraph
|
| 97 |
+
fig_hist_gat = histogram_scores(results[0][1], title="Scores (GAT)")
|
| 98 |
+
fig_hist_v2 = histogram_scores(results[1][1], title="Scores (GATv2)")
|
| 99 |
+
|
| 100 |
+
log_text = "\n".join(logs)
|
| 101 |
+
return df, html_gat, html_gatv2, log_text, fig_hist_gat, fig_hist_v2
|
| 102 |
+
|
| 103 |
+
except RateLimitExceeded as e:
|
| 104 |
+
return None, None, None, f"❌ {e}", None, None
|
| 105 |
+
except Exception as e:
|
| 106 |
+
tb = traceback.format_exc()
|
| 107 |
+
return None, None, None, f"❌ Error: {e}\n\n{tb}", None, None
|
| 108 |
+
|
| 109 |
+
with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as app:
|
| 110 |
+
gr.Markdown("## 🧭 Bitcoin Abuse Scoring (GAT / GATv2)\nEnter a transaction hash and k (1–3). The app builds an ego-subgraph from on-chain data and returns model scores.")
|
| 111 |
+
with gr.Row():
|
| 112 |
+
tx_in = gr.Textbox(label="Transaction Hash (64-hex)", placeholder="e.g., 4d3c... (64 hex)")
|
| 113 |
+
k_in = gr.Slider(1, 3, value=2, step=1, label="k (steps before/after)")
|
| 114 |
+
provider_in = gr.Dropdown(choices=["mempool", "blockstream", "blockchair"], value="mempool", label="Data Source")
|
| 115 |
+
run_btn = gr.Button("Run", variant="primary")
|
| 116 |
+
|
| 117 |
+
with gr.Row():
|
| 118 |
+
out_table = gr.Dataframe(label="Results (GAT vs GATv2)", interactive=False)
|
| 119 |
+
with gr.Tabs():
|
| 120 |
+
with gr.Tab("Ego-graph (GAT)"):
|
| 121 |
+
out_html_gat = gr.HTML()
|
| 122 |
+
out_hist_gat = gr.Plot(label="Score histogram (GAT)")
|
| 123 |
+
with gr.Tab("Ego-graph (GATv2)"):
|
| 124 |
+
out_html_gatv2 = gr.HTML()
|
| 125 |
+
out_hist_gatv2 = gr.Plot(label="Score histogram (GATv2)")
|
| 126 |
+
out_logs = gr.Textbox(label="Logs", lines=8)
|
| 127 |
+
|
| 128 |
+
run_btn.click(
|
| 129 |
+
handle_run,
|
| 130 |
+
inputs=[tx_in, k_in, provider_in],
|
| 131 |
+
outputs=[out_table, out_html_gat, out_html_gatv2, out_logs, out_hist_gat, out_hist_gatv2]
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
app.queue(concurrency_count=CFG.QUEUE_CONCURRENCY, max_size=32)
|
| 135 |
+
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
app.launch(server_name="0.0.0.0", server_port=7860)
|
config.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
@dataclass
|
| 5 |
+
class AppConfig:
|
| 6 |
+
# --- Hugging Face model repos (set via environment variables on Spaces) ---
|
| 7 |
+
HF_GAT_BASELINE_REPO: str = os.getenv("HF_GAT_BASELINE_REPO", "org/name_gat_baseline")
|
| 8 |
+
HF_GATV2_REPO: str = os.getenv("HF_GATV2_REPO", "org/name_gatv2")
|
| 9 |
+
|
| 10 |
+
# Expected input dim of Elliptic-trained models (given by user)
|
| 11 |
+
IN_CHANNELS: int = int(os.getenv("IN_CHANNELS", "165"))
|
| 12 |
+
HIDDEN_CHANNELS: int = int(os.getenv("HIDDEN_CHANNELS", "128"))
|
| 13 |
+
HEADS: int = int(os.getenv("HEADS", "8"))
|
| 14 |
+
NUM_BLOCKS: int = int(os.getenv("NUM_BLOCKS", "2"))
|
| 15 |
+
DROPOUT: float = float(os.getenv("DROPOUT", "0.5"))
|
| 16 |
+
|
| 17 |
+
# Data providers
|
| 18 |
+
DATA_PROVIDER: str = os.getenv("DATA_PROVIDER", "mempool") # mempool | blockstream | blockchair
|
| 19 |
+
HTTP_TIMEOUT: int = int(os.getenv("HTTP_TIMEOUT", "10"))
|
| 20 |
+
HTTP_RETRIES: int = int(os.getenv("HTTP_RETRIES", "2"))
|
| 21 |
+
|
| 22 |
+
# Graph limits (safeguard)
|
| 23 |
+
MAX_NODES: int = int(os.getenv("MAX_NODES", "5000"))
|
| 24 |
+
MAX_EDGES: int = int(os.getenv("MAX_EDGES", "15000"))
|
| 25 |
+
|
| 26 |
+
# Feature handling
|
| 27 |
+
USE_FEATURE_ADAPTER: bool = os.getenv("USE_FEATURE_ADAPTER", "true").lower() == "true"
|
| 28 |
+
MAKE_UNDIRECTED: bool = os.getenv("MAKE_UNDIRECTED", "false").lower() == "true"
|
| 29 |
+
|
| 30 |
+
# Threshold fallback
|
| 31 |
+
DEFAULT_THRESHOLD: float = float(os.getenv("DEFAULT_THRESHOLD", "0.5"))
|
| 32 |
+
|
| 33 |
+
# Rate limit
|
| 34 |
+
MAX_CALLS_PER_MIN: int = int(os.getenv("MAX_CALLS_PER_MIN", "20"))
|
| 35 |
+
WINDOW_SECONDS: int = int(os.getenv("WINDOW_SECONDS", "60"))
|
| 36 |
+
|
| 37 |
+
# Queue config
|
| 38 |
+
QUEUE_CONCURRENCY: int = int(os.getenv("QUEUE_CONCURRENCY", "2"))
|
| 39 |
+
|
| 40 |
+
# Blockchair API key (optional)
|
| 41 |
+
BLOCKCHAIR_API_KEY: str = os.getenv("BLOCKCHAIR_API_KEY", "").strip()
|
explorers.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
from typing import Dict, Any, List, Optional, Tuple
|
| 5 |
+
import requests
|
| 6 |
+
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
| 7 |
+
from cachetools import TTLCache
|
| 8 |
+
from config import AppConfig
|
| 9 |
+
|
| 10 |
+
UserAgent = "HF-Space-BTC-Abuse-GNN/1.0 (+https://huggingface.co/spaces)"
|
| 11 |
+
|
| 12 |
+
class ExplorerError(Exception):
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
def _req_json(url: str, timeout: int, retries: int = 2) -> Any:
|
| 16 |
+
@retry(stop=stop_after_attempt(retries), wait=wait_exponential(min=0.5, max=4),
|
| 17 |
+
retry=retry_if_exception_type((requests.Timeout, requests.ConnectionError)))
|
| 18 |
+
def _do():
|
| 19 |
+
r = requests.get(url, timeout=timeout, headers={"User-Agent": UserAgent})
|
| 20 |
+
if r.status_code != 200:
|
| 21 |
+
raise ExplorerError(f"HTTP {r.status_code} for {url}")
|
| 22 |
+
return r.json()
|
| 23 |
+
return _do()
|
| 24 |
+
|
| 25 |
+
def _satoshis_to_btc(v: Optional[int]) -> float:
|
| 26 |
+
try:
|
| 27 |
+
return float(v) / 1e8 if v is not None else 0.0
|
| 28 |
+
except Exception:
|
| 29 |
+
return 0.0
|
| 30 |
+
|
| 31 |
+
def _normalize_tx_esplora(j: Dict[str, Any]) -> Dict[str, Any]:
|
| 32 |
+
# https://mempool.space/api/tx/{txid}
|
| 33 |
+
vin = j.get("vin", [])
|
| 34 |
+
vout = j.get("vout", [])
|
| 35 |
+
status = j.get("status", {}) or {}
|
| 36 |
+
bh = status.get("block_height")
|
| 37 |
+
bt = status.get("block_time")
|
| 38 |
+
vin_list = []
|
| 39 |
+
for e in vin:
|
| 40 |
+
p = e.get("prevout") or {}
|
| 41 |
+
vin_list.append({
|
| 42 |
+
"txid": e.get("txid"),
|
| 43 |
+
"vout": e.get("vout"),
|
| 44 |
+
"prevout_value": p.get("value"),
|
| 45 |
+
"prevout_address": p.get("scriptpubkey_address") or None
|
| 46 |
+
})
|
| 47 |
+
vout_list = []
|
| 48 |
+
for idx, e in enumerate(vout):
|
| 49 |
+
vout_list.append({
|
| 50 |
+
"n": idx,
|
| 51 |
+
"value": e.get("value"),
|
| 52 |
+
"address": e.get("scriptpubkey_address") or None
|
| 53 |
+
})
|
| 54 |
+
return {
|
| 55 |
+
"txid": j.get("txid") or j.get("hash"),
|
| 56 |
+
"vin": vin_list,
|
| 57 |
+
"vout": vout_list,
|
| 58 |
+
"block_height": bh,
|
| 59 |
+
"block_time": bt,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
def _normalize_outspends_esplora(j: Any) -> List[Optional[str]]:
|
| 63 |
+
# returns list aligned to outputs: each item has 'spent', 'txid'
|
| 64 |
+
res = []
|
| 65 |
+
if isinstance(j, list):
|
| 66 |
+
for e in j:
|
| 67 |
+
if isinstance(e, dict) and e.get("spent"):
|
| 68 |
+
res.append(e.get("txid"))
|
| 69 |
+
else:
|
| 70 |
+
res.append(None)
|
| 71 |
+
return res
|
| 72 |
+
|
| 73 |
+
class BaseExplorer:
|
| 74 |
+
def __init__(self, cfg: AppConfig):
|
| 75 |
+
self.cfg = cfg
|
| 76 |
+
self.cache_tx = TTLCache(maxsize=10000, ttl=300)
|
| 77 |
+
self.cache_out = TTLCache(maxsize=10000, ttl=300)
|
| 78 |
+
|
| 79 |
+
def get_tx(self, txid: str) -> Dict[str, Any]:
|
| 80 |
+
raise NotImplementedError
|
| 81 |
+
|
| 82 |
+
def get_outspends(self, txid: str) -> List[Optional[str]]:
|
| 83 |
+
raise NotImplementedError
|
| 84 |
+
|
| 85 |
+
class MempoolSpaceClient(BaseExplorer):
|
| 86 |
+
def __init__(self, cfg: AppConfig, base: str = "https://mempool.space"):
|
| 87 |
+
super().__init__(cfg)
|
| 88 |
+
self.base = base.rstrip("/")
|
| 89 |
+
|
| 90 |
+
def get_tx(self, txid: str) -> Dict[str, Any]:
|
| 91 |
+
if txid in self.cache_tx:
|
| 92 |
+
return self.cache_tx[txid]
|
| 93 |
+
url = f"{self.base}/api/tx/{txid}"
|
| 94 |
+
j = _req_json(url, timeout=self.cfg.HTTP_TIMEOUT, retries=self.cfg.HTTP_RETRIES)
|
| 95 |
+
tx = _normalize_tx_esplora(j)
|
| 96 |
+
self.cache_tx[txid] = tx
|
| 97 |
+
return tx
|
| 98 |
+
|
| 99 |
+
def get_outspends(self, txid: str) -> List[Optional[str]]:
|
| 100 |
+
if txid in self.cache_out:
|
| 101 |
+
return self.cache_out[txid]
|
| 102 |
+
url = f"{self.base}/api/tx/{txid}/outspends"
|
| 103 |
+
j = _req_json(url, timeout=self.cfg.HTTP_TIMEOUT, retries=self.cfg.HTTP_RETRIES)
|
| 104 |
+
out = _normalize_outspends_esplora(j)
|
| 105 |
+
self.cache_out[txid] = out
|
| 106 |
+
return out
|
| 107 |
+
|
| 108 |
+
class BlockstreamClient(MempoolSpaceClient):
|
| 109 |
+
def __init__(self, cfg: AppConfig):
|
| 110 |
+
super().__init__(cfg, base="https://blockstream.info")
|
| 111 |
+
|
| 112 |
+
class BlockchairClient(BaseExplorer):
|
| 113 |
+
def __init__(self, cfg: AppConfig):
|
| 114 |
+
super().__init__(cfg)
|
| 115 |
+
self.base = "https://api.blockchair.com/bitcoin"
|
| 116 |
+
|
| 117 |
+
def get_tx(self, txid: str) -> Dict[str, Any]:
|
| 118 |
+
if txid in self.cache_tx:
|
| 119 |
+
return self.cache_tx[txid]
|
| 120 |
+
url = f"{self.base}/dashboards/transaction/{txid}"
|
| 121 |
+
if self.cfg.BLOCKCHAIR_API_KEY:
|
| 122 |
+
url += f"?key={self.cfg.BLOCKCHAIR_API_KEY}"
|
| 123 |
+
j = _req_json(url, timeout=self.cfg.HTTP_TIMEOUT, retries=self.cfg.HTTP_RETRIES)
|
| 124 |
+
data = j.get("data", {}).get(txid, {})
|
| 125 |
+
tx = data.get("transaction", {})
|
| 126 |
+
inputs = data.get("inputs", [])
|
| 127 |
+
outputs = data.get("outputs", [])
|
| 128 |
+
vin_list = [{
|
| 129 |
+
"txid": i.get("spending_transaction_hash") or i.get("recipient_transaction_hash"),
|
| 130 |
+
"vout": i.get("spending_index"),
|
| 131 |
+
"prevout_value": i.get("value"),
|
| 132 |
+
"prevout_address": i.get("recipient"),
|
| 133 |
+
} for i in inputs]
|
| 134 |
+
vout_list = [{
|
| 135 |
+
"n": o.get("index"),
|
| 136 |
+
"value": o.get("value"),
|
| 137 |
+
"address": o.get("recipient"),
|
| 138 |
+
} for o in outputs]
|
| 139 |
+
out = {
|
| 140 |
+
"txid": txid,
|
| 141 |
+
"vin": vin_list,
|
| 142 |
+
"vout": vout_list,
|
| 143 |
+
"block_height": tx.get("block_id"),
|
| 144 |
+
"block_time": tx.get("time"),
|
| 145 |
+
}
|
| 146 |
+
self.cache_tx[txid] = out
|
| 147 |
+
return out
|
| 148 |
+
|
| 149 |
+
def get_outspends(self, txid: str) -> List[Optional[str]]:
|
| 150 |
+
# Blockchair includes outputs with 'spent_by_transaction_hash'
|
| 151 |
+
if txid in self.cache_out:
|
| 152 |
+
return self.cache_out[txid]
|
| 153 |
+
url = f"{self.base}/dashboards/transaction/{txid}"
|
| 154 |
+
if self.cfg.BLOCKCHAIR_API_KEY:
|
| 155 |
+
url += f"?key={self.cfg.BLOCKCHAIR_API_KEY}"
|
| 156 |
+
j = _req_json(url, timeout=self.cfg.HTTP_TIMEOUT, retries=self.cfg.HTTP_RETRIES)
|
| 157 |
+
outputs = j.get("data", {}).get(txid, {}).get("outputs", [])
|
| 158 |
+
res = []
|
| 159 |
+
for o in outputs:
|
| 160 |
+
res.append(o.get("spent_by_transaction_hash"))
|
| 161 |
+
self.cache_out[txid] = res
|
| 162 |
+
return res
|
| 163 |
+
|
| 164 |
+
def new_client(cfg: AppConfig, primary: str) -> List[BaseExplorer]:
|
| 165 |
+
# primary then fallbacks
|
| 166 |
+
primary = (primary or cfg.DATA_PROVIDER).lower()
|
| 167 |
+
chain = []
|
| 168 |
+
if primary in ("mempool", "mempool.space"):
|
| 169 |
+
chain = [MempoolSpaceClient(cfg), BlockstreamClient(cfg), BlockchairClient(cfg)]
|
| 170 |
+
elif primary in ("blockstream", "blockstream.info"):
|
| 171 |
+
chain = [BlockstreamClient(cfg), MempoolSpaceClient(cfg), BlockchairClient(cfg)]
|
| 172 |
+
elif primary == "blockchair":
|
| 173 |
+
chain = [BlockchairClient(cfg), MempoolSpaceClient(cfg), BlockstreamClient(cfg)]
|
| 174 |
+
else:
|
| 175 |
+
chain = [MempoolSpaceClient(cfg), BlockstreamClient(cfg), BlockchairClient(cfg)]
|
| 176 |
+
return chain
|
| 177 |
+
|
| 178 |
+
def fetch_with_fallback(txid: str, cfg: AppConfig, source: str):
|
| 179 |
+
errors = []
|
| 180 |
+
for c in new_client(cfg, source):
|
| 181 |
+
try:
|
| 182 |
+
tx = c.get_tx(txid)
|
| 183 |
+
outspends = c.get_outspends(txid)
|
| 184 |
+
if tx and outspends is not None:
|
| 185 |
+
return c, tx, outspends, None
|
| 186 |
+
except Exception as e:
|
| 187 |
+
errors.append(f"{c.__class__.__name__}: {e}")
|
| 188 |
+
continue
|
| 189 |
+
return None, None, None, errors
|
features.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, List, Tuple
|
| 2 |
+
import numpy as np
|
| 3 |
+
from collections import deque, defaultdict
|
| 4 |
+
from sklearn.preprocessing import StandardScaler
|
| 5 |
+
|
| 6 |
+
def _sum_inputs_btc(tx: Dict[str, Any]) -> float:
|
| 7 |
+
s = 0.0
|
| 8 |
+
for v in tx.get("vin", []):
|
| 9 |
+
s += float(v.get("prevout_value") or 0.0) / 1e8
|
| 10 |
+
return s
|
| 11 |
+
|
| 12 |
+
def _sum_outputs_btc(tx: Dict[str, Any]) -> float:
|
| 13 |
+
s = 0.0
|
| 14 |
+
for o in tx.get("vout", []):
|
| 15 |
+
s += float(o.get("value") or 0.0) / 1e8
|
| 16 |
+
return s
|
| 17 |
+
|
| 18 |
+
def _compute_distances(n: int, edges: List[Tuple[int,int]], center: int) -> np.ndarray:
|
| 19 |
+
# undirected BFS distance
|
| 20 |
+
adj = [[] for _ in range(n)]
|
| 21 |
+
for u,v in edges:
|
| 22 |
+
adj[u].append(v); adj[v].append(u)
|
| 23 |
+
dist = np.full(n, fill_value=-1, dtype=np.int32)
|
| 24 |
+
q = deque([center]); dist[center] = 0
|
| 25 |
+
while q:
|
| 26 |
+
u = q.popleft()
|
| 27 |
+
for nb in adj[u]:
|
| 28 |
+
if dist[nb] == -1:
|
| 29 |
+
dist[nb] = dist[u] + 1
|
| 30 |
+
q.append(nb)
|
| 31 |
+
return dist
|
| 32 |
+
|
| 33 |
+
def build_features(nodes: List[str], edges: List[Tuple[int,int]], center_idx: int, node_meta: Dict[str, Dict[str, Any]]):
|
| 34 |
+
n = len(nodes)
|
| 35 |
+
# degrees
|
| 36 |
+
out_deg = np.zeros(n, dtype=np.float32)
|
| 37 |
+
in_deg = np.zeros(n, dtype=np.float32)
|
| 38 |
+
for u,v in edges:
|
| 39 |
+
out_deg[u] += 1
|
| 40 |
+
in_deg[v] += 1
|
| 41 |
+
deg = in_deg + out_deg
|
| 42 |
+
ratio_in_out = in_deg / (out_deg + 1e-6)
|
| 43 |
+
|
| 44 |
+
# sums & counts from metadata
|
| 45 |
+
sum_in_btc = np.zeros(n, dtype=np.float32)
|
| 46 |
+
sum_out_btc = np.zeros(n, dtype=np.float32)
|
| 47 |
+
n_inputs = np.zeros(n, dtype=np.float32)
|
| 48 |
+
n_outputs = np.zeros(n, dtype=np.float32)
|
| 49 |
+
block_height = np.zeros(n, dtype=np.float32)
|
| 50 |
+
|
| 51 |
+
for idx, txid in enumerate(nodes):
|
| 52 |
+
meta = node_meta.get(txid) or {}
|
| 53 |
+
n_inputs[idx] = float(len(meta.get("vin", []) or []))
|
| 54 |
+
n_outputs[idx] = float(len(meta.get("vout", []) or []))
|
| 55 |
+
sum_in_btc[idx] = _sum_inputs_btc(meta)
|
| 56 |
+
sum_out_btc[idx] = _sum_outputs_btc(meta)
|
| 57 |
+
bh = meta.get("block_height")
|
| 58 |
+
block_height[idx] = float(bh) if bh is not None else 0.0
|
| 59 |
+
|
| 60 |
+
log_sum_in = np.log1p(sum_in_btc)
|
| 61 |
+
log_sum_out = np.log1p(sum_out_btc)
|
| 62 |
+
distance = _compute_distances(n, edges, center_idx)
|
| 63 |
+
|
| 64 |
+
feats = np.stack([
|
| 65 |
+
in_deg, out_deg, deg, ratio_in_out,
|
| 66 |
+
n_inputs, n_outputs,
|
| 67 |
+
sum_in_btc, sum_out_btc,
|
| 68 |
+
log_sum_in, log_sum_out,
|
| 69 |
+
distance.astype(np.float32),
|
| 70 |
+
block_height
|
| 71 |
+
], axis=1)
|
| 72 |
+
|
| 73 |
+
feature_names = [
|
| 74 |
+
"in_degree","out_degree","degree","ratio_in_out",
|
| 75 |
+
"n_inputs","n_outputs",
|
| 76 |
+
"sum_in_btc","sum_out_btc",
|
| 77 |
+
"log_sum_in","log_sum_out",
|
| 78 |
+
"distance","block_height",
|
| 79 |
+
]
|
| 80 |
+
return feats, feature_names
|
| 81 |
+
|
| 82 |
+
def scale_features(X: np.ndarray, scaler=None):
|
| 83 |
+
if scaler is None:
|
| 84 |
+
scaler = StandardScaler()
|
| 85 |
+
Xs = scaler.fit_transform(X)
|
| 86 |
+
note = "Fitted new StandardScaler on ego-subgraph (domain shift vs Elliptic)."
|
| 87 |
+
else:
|
| 88 |
+
Xs = scaler.transform(X)
|
| 89 |
+
note = "Used provided scaler from model repo."
|
| 90 |
+
return Xs.astype("float32"), scaler, note
|
graph_builder.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, List, Tuple, Set
|
| 2 |
+
from collections import deque, defaultdict
|
| 3 |
+
|
| 4 |
+
from explorers import fetch_with_fallback
|
| 5 |
+
from config import AppConfig
|
| 6 |
+
|
| 7 |
+
def expand_ego(txid: str, k: int, source: str, cfg: AppConfig):
|
| 8 |
+
"""
|
| 9 |
+
Expand ego-subgraph up to k steps backward (parents) and forward (children).
|
| 10 |
+
Returns:
|
| 11 |
+
nodes: List[str] txids
|
| 12 |
+
edges: List[Tuple[int,int]] parent->child indices
|
| 13 |
+
center_idx: int
|
| 14 |
+
node_meta: Dict[txid, dict]
|
| 15 |
+
logs: List[str]
|
| 16 |
+
"""
|
| 17 |
+
logs = []
|
| 18 |
+
client, tx0, out0, errs = fetch_with_fallback(txid, cfg, source)
|
| 19 |
+
if client is None:
|
| 20 |
+
return [], [], -1, {}, ["All providers failed", *(errs or [])]
|
| 21 |
+
|
| 22 |
+
nodes: List[str] = []
|
| 23 |
+
idx_map: Dict[str, int] = {}
|
| 24 |
+
edges: List[Tuple[int,int]] = []
|
| 25 |
+
node_meta: Dict[str, Dict[str, Any]] = {}
|
| 26 |
+
|
| 27 |
+
def add_node(tid: str, meta: Dict[str, Any]):
|
| 28 |
+
if tid in idx_map:
|
| 29 |
+
return idx_map[tid]
|
| 30 |
+
if len(nodes) >= cfg.MAX_NODES:
|
| 31 |
+
return None
|
| 32 |
+
idx = len(nodes)
|
| 33 |
+
nodes.append(tid)
|
| 34 |
+
idx_map[tid] = idx
|
| 35 |
+
node_meta[tid] = meta
|
| 36 |
+
return idx
|
| 37 |
+
|
| 38 |
+
def ensure_tx(tid: str):
|
| 39 |
+
c, tj, outsp, _ = fetch_with_fallback(tid, cfg, source)
|
| 40 |
+
if tj is None:
|
| 41 |
+
return None, None
|
| 42 |
+
return tj, outsp
|
| 43 |
+
|
| 44 |
+
# seed
|
| 45 |
+
center_idx = add_node(txid, tx0)
|
| 46 |
+
if center_idx is None:
|
| 47 |
+
return [], [], -1, {}, ["MAX_NODES reached at seed"]
|
| 48 |
+
frontier_par = deque([(txid, 0)])
|
| 49 |
+
frontier_ch = deque([(txid, 0)])
|
| 50 |
+
|
| 51 |
+
# BFS backward (parents)
|
| 52 |
+
while frontier_par:
|
| 53 |
+
cur, depth = frontier_par.popleft()
|
| 54 |
+
if depth >= k:
|
| 55 |
+
continue
|
| 56 |
+
tj = node_meta.get(cur)
|
| 57 |
+
if tj is None:
|
| 58 |
+
t, o = ensure_tx(cur)
|
| 59 |
+
if t is None:
|
| 60 |
+
continue
|
| 61 |
+
node_meta[cur] = t
|
| 62 |
+
tj = t
|
| 63 |
+
for vi in tj.get("vin", []):
|
| 64 |
+
ptx = vi.get("txid")
|
| 65 |
+
if not ptx:
|
| 66 |
+
continue
|
| 67 |
+
if ptx not in idx_map:
|
| 68 |
+
if len(nodes) >= cfg.MAX_NODES:
|
| 69 |
+
logs.append("MAX_NODES reached during backward expansion")
|
| 70 |
+
break
|
| 71 |
+
ptj, pout = ensure_tx(ptx)
|
| 72 |
+
if ptj is None:
|
| 73 |
+
continue
|
| 74 |
+
pidx = add_node(ptx, ptj)
|
| 75 |
+
if pidx is None:
|
| 76 |
+
continue
|
| 77 |
+
frontier_par.append((ptx, depth+1))
|
| 78 |
+
else:
|
| 79 |
+
pidx = idx_map[ptx]
|
| 80 |
+
# edge parent->child
|
| 81 |
+
cidx = idx_map.get(cur)
|
| 82 |
+
if cidx is not None:
|
| 83 |
+
edges.append((pidx, cidx))
|
| 84 |
+
if cfg.MAX_EDGES and len(edges) >= cfg.MAX_EDGES:
|
| 85 |
+
logs.append("MAX_EDGES reached; stopping further edge additions")
|
| 86 |
+
break
|
| 87 |
+
|
| 88 |
+
# BFS forward (children)
|
| 89 |
+
while frontier_ch:
|
| 90 |
+
cur, depth = frontier_ch.popleft()
|
| 91 |
+
if depth >= k:
|
| 92 |
+
continue
|
| 93 |
+
tj = node_meta.get(cur)
|
| 94 |
+
if tj is None:
|
| 95 |
+
t, o = ensure_tx(cur)
|
| 96 |
+
if t is None:
|
| 97 |
+
continue
|
| 98 |
+
node_meta[cur] = t
|
| 99 |
+
tj = t
|
| 100 |
+
outsp = None
|
| 101 |
+
try:
|
| 102 |
+
_, _, outsp, _ = fetch_with_fallback(cur, cfg, source)
|
| 103 |
+
except Exception:
|
| 104 |
+
outsp = None
|
| 105 |
+
if outsp is None:
|
| 106 |
+
continue
|
| 107 |
+
for child_tx in outsp:
|
| 108 |
+
if not child_tx:
|
| 109 |
+
continue
|
| 110 |
+
if child_tx not in idx_map:
|
| 111 |
+
if len(nodes) >= cfg.MAX_NODES:
|
| 112 |
+
logs.append("MAX_NODES reached during forward expansion")
|
| 113 |
+
break
|
| 114 |
+
ctj, cout = ensure_tx(child_tx)
|
| 115 |
+
if ctj is None:
|
| 116 |
+
continue
|
| 117 |
+
cidx = add_node(child_tx, ctj)
|
| 118 |
+
if cidx is None:
|
| 119 |
+
continue
|
| 120 |
+
frontier_ch.append((child_tx, depth+1))
|
| 121 |
+
else:
|
| 122 |
+
cidx = idx_map[child_tx]
|
| 123 |
+
pidx = idx_map.get(cur)
|
| 124 |
+
if pidx is not None:
|
| 125 |
+
edges.append((pidx, cidx))
|
| 126 |
+
if cfg.MAX_EDGES and len(edges) >= cfg.MAX_EDGES:
|
| 127 |
+
logs.append("MAX_EDGES reached; stopping further edge additions")
|
| 128 |
+
break
|
| 129 |
+
|
| 130 |
+
# deduplicate edges
|
| 131 |
+
edges = list(set(edges))
|
| 132 |
+
return nodes, edges, center_idx, node_meta, logs
|
inference.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from typing import Dict, Any, Tuple, Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from huggingface_hub import snapshot_download
|
| 7 |
+
from torch_geometric.data import Data
|
| 8 |
+
from torch_geometric.utils import to_undirected
|
| 9 |
+
|
| 10 |
+
from config import AppConfig
|
| 11 |
+
from models import GATBaseline, GATv2Enhanced, AdapterWrapper
|
| 12 |
+
|
| 13 |
+
def _load_threshold(model_dir: str, default_thr: float) -> float:
|
| 14 |
+
for name in ["thresholds.json", "threshold.json", "config.json"]:
|
| 15 |
+
p = os.path.join(model_dir, name)
|
| 16 |
+
if os.path.exists(p):
|
| 17 |
+
try:
|
| 18 |
+
d = json.load(open(p, "r"))
|
| 19 |
+
for k in ["threshold","default_threshold","thr","best_f1","best_j"]:
|
| 20 |
+
if k in d and isinstance(d[k], (int, float)):
|
| 21 |
+
return float(d[k])
|
| 22 |
+
except Exception:
|
| 23 |
+
continue
|
| 24 |
+
return default_thr
|
| 25 |
+
|
| 26 |
+
def _load_scaler(model_dir: str):
|
| 27 |
+
# Optional scaler joblib/pkl
|
| 28 |
+
for name in ["scaler.joblib", "scaler.pkl", "elliptic_scaler.joblib", "elliptic_scaler.pkl"]:
|
| 29 |
+
p = os.path.join(model_dir, name)
|
| 30 |
+
if os.path.exists(p):
|
| 31 |
+
try:
|
| 32 |
+
import joblib
|
| 33 |
+
return joblib.load(p)
|
| 34 |
+
except Exception:
|
| 35 |
+
pass
|
| 36 |
+
return None
|
| 37 |
+
|
| 38 |
+
def load_models(cfg: AppConfig):
|
| 39 |
+
# Download both repos
|
| 40 |
+
dir_gat = snapshot_download(cfg.HF_GAT_BASELINE_REPO, local_dir_use_symlinks=False)
|
| 41 |
+
dir_gatv2 = snapshot_download(cfg.HF_GATV2_REPO, local_dir_use_symlinks=False)
|
| 42 |
+
|
| 43 |
+
# Model files
|
| 44 |
+
ckpt_gat = os.path.join(dir_gat, "model.pt")
|
| 45 |
+
ckpt_gatv2 = os.path.join(dir_gatv2, "model.pt")
|
| 46 |
+
if not os.path.exists(ckpt_gat):
|
| 47 |
+
raise FileNotFoundError(f"Missing model.pt in {dir_gat}")
|
| 48 |
+
if not os.path.exists(ckpt_gatv2):
|
| 49 |
+
raise FileNotFoundError(f"Missing model.pt in {dir_gatv2}")
|
| 50 |
+
|
| 51 |
+
# Build cores (expected input dim from training)
|
| 52 |
+
core_gat = GATBaseline(cfg.IN_CHANNELS, cfg.HIDDEN_CHANNELS, cfg.HEADS, cfg.NUM_BLOCKS, cfg.DROPOUT)
|
| 53 |
+
core_gatv2 = GATv2Enhanced(cfg.IN_CHANNELS, cfg.HIDDEN_CHANNELS, cfg.HEADS, cfg.NUM_BLOCKS, cfg.DROPOUT)
|
| 54 |
+
|
| 55 |
+
state_gat = torch.load(ckpt_gat, map_location="cpu")
|
| 56 |
+
state_gatv2 = torch.load(ckpt_gatv2, map_location="cpu")
|
| 57 |
+
|
| 58 |
+
# strict load for cores
|
| 59 |
+
core_gat.load_state_dict(state_gat, strict=True)
|
| 60 |
+
core_gatv2.load_state_dict(state_gatv2, strict=True)
|
| 61 |
+
|
| 62 |
+
thr_gat = _load_threshold(dir_gat, cfg.DEFAULT_THRESHOLD)
|
| 63 |
+
thr_gatv2 = _load_threshold(dir_gatv2, cfg.DEFAULT_THRESHOLD)
|
| 64 |
+
|
| 65 |
+
scaler_gat = _load_scaler(dir_gat)
|
| 66 |
+
scaler_gatv2 = _load_scaler(dir_gatv2)
|
| 67 |
+
|
| 68 |
+
return {
|
| 69 |
+
"gat": {"core": core_gat.eval(), "threshold": thr_gat, "scaler": scaler_gat, "repo_dir": dir_gat},
|
| 70 |
+
"gatv2": {"core": core_gatv2.eval(), "threshold": thr_gatv2, "scaler": scaler_gatv2, "repo_dir": dir_gatv2},
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
@torch.no_grad()
|
| 74 |
+
def predict(model, data: Data):
|
| 75 |
+
logits = model(data.x, data.edge_index)
|
| 76 |
+
probs = torch.sigmoid(logits).cpu().numpy()
|
| 77 |
+
return probs
|
| 78 |
+
|
| 79 |
+
def adapt_and_predict(bundle: Dict[str, Any], in_dim_new: int, data: Data, cfg: AppConfig):
|
| 80 |
+
core = bundle["core"]
|
| 81 |
+
if in_dim_new != cfg.IN_CHANNELS and cfg.USE_FEATURE_ADAPTER:
|
| 82 |
+
model = AdapterWrapper(in_dim_new, cfg.IN_CHANNELS, core).eval()
|
| 83 |
+
note = f"FeatureAdapter used (new_dim={in_dim_new} → expected={cfg.IN_CHANNELS})."
|
| 84 |
+
elif in_dim_new != cfg.IN_CHANNELS:
|
| 85 |
+
# attempt to run without adapter (not recommended)
|
| 86 |
+
model = core.eval()
|
| 87 |
+
note = f"Dimension mismatch (new_dim={in_dim_new}, expected={cfg.IN_CHANNELS}). Proceeding without adapter (may fail)."
|
| 88 |
+
else:
|
| 89 |
+
model = core.eval()
|
| 90 |
+
note = "Input dim matches."
|
| 91 |
+
probs = predict(model, data)
|
| 92 |
+
return probs, note
|
| 93 |
+
|
| 94 |
+
def run_for_both_models(bundles, data: Data, center_idx: int, cfg: AppConfig):
|
| 95 |
+
in_dim_new = data.x.shape[1]
|
| 96 |
+
results = []
|
| 97 |
+
|
| 98 |
+
probs_g, note_g = adapt_and_predict(bundles["gat"], in_dim_new, data, cfg)
|
| 99 |
+
thr_g = float(bundles["gat"]["threshold"])
|
| 100 |
+
label_g = int(probs_g[center_idx] >= thr_g)
|
| 101 |
+
|
| 102 |
+
probs_v2, note_v2 = adapt_and_predict(bundles["gatv2"], in_dim_new, data, cfg)
|
| 103 |
+
thr_v2 = float(bundles["gatv2"]["threshold"])
|
| 104 |
+
label_v2 = int(probs_v2[center_idx] >= thr_v2)
|
| 105 |
+
|
| 106 |
+
return [
|
| 107 |
+
("GAT", probs_g, thr_g, label_g, note_g),
|
| 108 |
+
("GATv2", probs_v2, thr_v2, label_v2, note_v2),
|
| 109 |
+
]
|
models.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch_geometric.nn import GATConv, GATv2Conv, BatchNorm
|
| 4 |
+
|
| 5 |
+
class ResidualGATBlock(nn.Module):
|
| 6 |
+
def __init__(self, in_channels, hidden_channels, heads=8, dropout=0.5, v2=False):
|
| 7 |
+
super().__init__()
|
| 8 |
+
Conv = GATv2Conv if v2 else GATConv
|
| 9 |
+
self.conv = Conv(in_channels, hidden_channels, heads=heads, dropout=dropout)
|
| 10 |
+
self.bn = BatchNorm(hidden_channels * heads)
|
| 11 |
+
self.act = nn.ReLU()
|
| 12 |
+
self.dropout = nn.Dropout(dropout)
|
| 13 |
+
self.res_proj = None
|
| 14 |
+
out_dim = hidden_channels * heads
|
| 15 |
+
if in_channels != out_dim:
|
| 16 |
+
self.res_proj = nn.Linear(in_channels, out_dim)
|
| 17 |
+
|
| 18 |
+
def forward(self, x, edge_index):
|
| 19 |
+
identity = x
|
| 20 |
+
out = self.conv(x, edge_index)
|
| 21 |
+
out = self.bn(out)
|
| 22 |
+
out = self.act(out)
|
| 23 |
+
out = self.dropout(out)
|
| 24 |
+
if self.res_proj is not None:
|
| 25 |
+
identity = self.res_proj(identity)
|
| 26 |
+
return out + identity
|
| 27 |
+
|
| 28 |
+
class GATBaseline(nn.Module):
|
| 29 |
+
def __init__(self, in_channels, hidden_channels=128, heads=8, num_blocks=2, dropout=0.5):
|
| 30 |
+
super().__init__()
|
| 31 |
+
layers = []
|
| 32 |
+
c_in = in_channels
|
| 33 |
+
for _ in range(num_blocks):
|
| 34 |
+
layers.append(ResidualGATBlock(c_in, hidden_channels, heads=heads, dropout=dropout, v2=False))
|
| 35 |
+
c_in = hidden_channels * heads
|
| 36 |
+
self.blocks = nn.ModuleList(layers)
|
| 37 |
+
self.dropout = nn.Dropout(dropout)
|
| 38 |
+
self.out_conv = GATConv(c_in, 1, heads=1, concat=False, dropout=dropout)
|
| 39 |
+
|
| 40 |
+
def forward(self, x, edge_index):
|
| 41 |
+
for block in self.blocks:
|
| 42 |
+
x = block(x, edge_index)
|
| 43 |
+
x = self.dropout(x)
|
| 44 |
+
out = self.out_conv(x, edge_index)
|
| 45 |
+
return out.view(-1)
|
| 46 |
+
|
| 47 |
+
class GATv2Enhanced(nn.Module):
|
| 48 |
+
def __init__(self, in_channels, hidden_channels=128, heads=8, num_blocks=2, dropout=0.5):
|
| 49 |
+
super().__init__()
|
| 50 |
+
layers = []
|
| 51 |
+
c_in = in_channels
|
| 52 |
+
for _ in range(num_blocks):
|
| 53 |
+
layers.append(ResidualGATBlock(c_in, hidden_channels, heads=heads, dropout=dropout, v2=True))
|
| 54 |
+
c_in = hidden_channels * heads
|
| 55 |
+
self.blocks = nn.ModuleList(layers)
|
| 56 |
+
self.dropout = nn.Dropout(dropout)
|
| 57 |
+
self.out_conv = GATv2Conv(c_in, 1, heads=1, concat=False, dropout=dropout)
|
| 58 |
+
|
| 59 |
+
def forward(self, x, edge_index):
|
| 60 |
+
for block in self.blocks:
|
| 61 |
+
x = block(x, edge_index)
|
| 62 |
+
x = self.dropout(x)
|
| 63 |
+
out = self.out_conv(x, edge_index)
|
| 64 |
+
return out.view(-1)
|
| 65 |
+
|
| 66 |
+
class AdapterWrapper(nn.Module):
|
| 67 |
+
def __init__(self, in_dim_new, expected_in_dim, core_model):
|
| 68 |
+
super().__init__()
|
| 69 |
+
if in_dim_new != expected_in_dim:
|
| 70 |
+
self.adapter = nn.Linear(in_dim_new, expected_in_dim, bias=True)
|
| 71 |
+
else:
|
| 72 |
+
self.adapter = None
|
| 73 |
+
self.core = core_model
|
| 74 |
+
|
| 75 |
+
def forward(self, x, edge_index):
|
| 76 |
+
if self.adapter is not None:
|
| 77 |
+
x = self.adapter(x)
|
| 78 |
+
return self.core(x, edge_index)
|
rate_limit.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import threading
|
| 3 |
+
from collections import deque
|
| 4 |
+
from functools import wraps
|
| 5 |
+
|
| 6 |
+
class RateLimitExceeded(Exception):
|
| 7 |
+
pass
|
| 8 |
+
|
| 9 |
+
class GlobalRateLimiter:
|
| 10 |
+
def __init__(self, max_calls: int, window_seconds: int):
|
| 11 |
+
self.max_calls = max_calls
|
| 12 |
+
self.window = window_seconds
|
| 13 |
+
self._lock = threading.Lock()
|
| 14 |
+
self._events = deque()
|
| 15 |
+
|
| 16 |
+
def allow(self) -> bool:
|
| 17 |
+
now = time.time()
|
| 18 |
+
with self._lock:
|
| 19 |
+
while self._events and now - self._events[0] > self.window:
|
| 20 |
+
self._events.popleft()
|
| 21 |
+
if len(self._events) < self.max_calls:
|
| 22 |
+
self._events.append(now)
|
| 23 |
+
return True
|
| 24 |
+
return False
|
| 25 |
+
|
| 26 |
+
def enforce(self, func=None):
|
| 27 |
+
def decorator(f):
|
| 28 |
+
@wraps(f)
|
| 29 |
+
def wrapper(*args, **kwargs):
|
| 30 |
+
if not self.allow():
|
| 31 |
+
raise RateLimitExceeded(f"Rate limit exceeded ({self.max_calls} req/{self.window}s)")
|
| 32 |
+
return f(*args, **kwargs)
|
| 33 |
+
return wrapper
|
| 34 |
+
return decorator if func is None else decorator(func)
|
requirements.txt
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PyTorch CPU-only build (3.13 compatible)
|
| 2 |
+
torch==2.6.0
|
| 3 |
+
torchvision==0.21.0
|
| 4 |
+
torchaudio==2.6.0
|
| 5 |
+
-f https://download.pytorch.org/whl/cpu
|
| 6 |
+
|
| 7 |
+
# PyTorch Geometric (compatible with torch 2.6.0 CPU)
|
| 8 |
+
pyg_lib==0.4.0
|
| 9 |
+
torch_scatter==2.1.2
|
| 10 |
+
torch_sparse==0.6.18
|
| 11 |
+
torch_cluster==1.6.3
|
| 12 |
+
torch_spline_conv==1.2.2
|
| 13 |
+
torch_geometric==2.6.0
|
| 14 |
+
|
| 15 |
+
# Core app & utilities
|
| 16 |
+
gradio>=5.0.0,<5.2.0
|
| 17 |
+
huggingface_hub>=0.26.2
|
| 18 |
+
requests>=2.32.3
|
| 19 |
+
tenacity>=9.0.0
|
| 20 |
+
diskcache>=5.6.3
|
| 21 |
+
cachetools>=5.5.0
|
| 22 |
+
pyvis>=0.3.3
|
| 23 |
+
networkx>=3.4.2
|
| 24 |
+
pandas>=2.2.3
|
| 25 |
+
numpy>=2.1.3
|
| 26 |
+
scikit-learn>=1.6.0
|
| 27 |
+
matplotlib>=3.9.2
|
| 28 |
+
tqdm>=4.67.1
|
viz.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple, Optional
|
| 2 |
+
from pyvis.network import Network
|
| 3 |
+
import numpy as np
|
| 4 |
+
import io, base64
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
|
| 7 |
+
def _color_from_score(p: float) -> str:
|
| 8 |
+
# blue (0) -> gray (0.5) -> red (1)
|
| 9 |
+
p = float(np.clip(p, 0.0, 1.0))
|
| 10 |
+
if p < 0.5:
|
| 11 |
+
# interpolate blue (#377eb8) to gray (#bbbbbb)
|
| 12 |
+
t = p / 0.5
|
| 13 |
+
c0 = (0x37, 0x7e, 0xb8)
|
| 14 |
+
c1 = (0xbb, 0xbb, 0xbb)
|
| 15 |
+
else:
|
| 16 |
+
# interpolate gray to red (#e41a1c)
|
| 17 |
+
t = (p - 0.5) / 0.5
|
| 18 |
+
c0 = (0xbb, 0xbb, 0xbb)
|
| 19 |
+
c1 = (0xe4, 0x1a, 0x1c)
|
| 20 |
+
r = int((1-t)*c0[0] + t*c1[0])
|
| 21 |
+
g = int((1-t)*c0[1] + t*c1[1])
|
| 22 |
+
b = int((1-t)*c0[2] + t*c1[2])
|
| 23 |
+
return f"#{r:02x}{g:02x}{b:02x}"
|
| 24 |
+
|
| 25 |
+
def render_ego_html(nodes: List[str], edges: List[Tuple[int,int]], center_idx: int, scores: Optional[np.ndarray]=None) -> str:
|
| 26 |
+
net = Network(height="600px", width="100%", notebook=False, directed=True)
|
| 27 |
+
for i, txid in enumerate(nodes):
|
| 28 |
+
color = _color_from_score(scores[i]) if scores is not None else "#bbbbbb"
|
| 29 |
+
size = 20 if i == center_idx else 8
|
| 30 |
+
title = f"{txid}"
|
| 31 |
+
if scores is not None:
|
| 32 |
+
title += f"<br/>score={scores[i]:.4f}"
|
| 33 |
+
net.add_node(i, label=txid[:10]+"…", title=title, color=color, size=size)
|
| 34 |
+
for (u,v) in edges:
|
| 35 |
+
net.add_edge(u, v, arrows="to")
|
| 36 |
+
net.toggle_physics(True)
|
| 37 |
+
return net.generate_html()
|
| 38 |
+
|
| 39 |
+
def histogram_scores(scores, title="Score distribution"):
|
| 40 |
+
fig = plt.figure(figsize=(6,4))
|
| 41 |
+
plt.hist(scores, bins=40)
|
| 42 |
+
plt.xlabel("Score"); plt.ylabel("Count"); plt.title(title)
|
| 43 |
+
plt.tight_layout()
|
| 44 |
+
return fig
|