thanhphxu commited on
Commit
db886e4
·
verified ·
1 Parent(s): d15e63a

Upload folder using huggingface_hub

Browse files
Files changed (12) hide show
  1. .gitignore +1 -0
  2. README.md +72 -13
  3. app.py +137 -0
  4. config.py +41 -0
  5. explorers.py +189 -0
  6. features.py +90 -0
  7. graph_builder.py +132 -0
  8. inference.py +109 -0
  9. models.py +78 -0
  10. rate_limit.py +34 -0
  11. requirements.txt +28 -0
  12. viz.py +44 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .venv/**
README.md CHANGED
@@ -1,13 +1,72 @@
1
- ---
2
- title: MLGraph Bitcoin GAD
3
- emoji: 🏢
4
- colorFrom: yellow
5
- colorTo: blue
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Bitcoin Abuse Scoring (GAT / GATv2) — Hugging Face Space
2
+
3
+ This Space builds an **ego-subgraph** from a given Bitcoin transaction hash (`k` steps backward & forward), then runs **two pretrained GNN models** (GAT baseline & GATv2 enhanced) trained on **Elliptic** to score whether the center transaction is *abuse*.
4
+
5
+ ## ✅ Features
6
+
7
+ - Data sources (public JSON APIs, no scraping): `mempool.space` / `blockstream.info` (Esplora), fallback to `Blockchair` (optional key).
8
+ - Ego-subgraph expansion **k ∈ {1,2,3}** (both parents & children).
9
+ - Graph safeguards: `MAX_NODES` & `MAX_EDGES` to avoid explosion.
10
+ - Node features: degree stats, value sums/logs, counts, ratio, distance-to-center, block height.
11
+ - Standardized features (on-the-fly). If your model used different features/scaler, set `USE_FEATURE_ADAPTER=true` (default) — it inserts a `Linear` projection to the expected input dimension (165 by default).
12
+ - Two models are loaded from **Hugging Face Hub** with thresholds (via `thresholds.json` or fallback `0.5`).
13
+ - **Rate limit**: 20 requests/min globally (sliding window).
14
+ - Visualizations: **ego-graph (pyvis HTML)** & **histogram of scores** per model.
15
+ - CPU-only deployment on Spaces.
16
+
17
+ ## 🔧 Configuration
18
+
19
+ Set these **Environment Variables** (Space → Settings → Variables):
20
+
21
+ ```
22
+ HF_GAT_BASELINE_REPO=org/name_gat_baseline
23
+ HF_GATV2_REPO=org/name_gatv2
24
+
25
+ # (Optional overrides)
26
+ IN_CHANNELS=165
27
+ HIDDEN_CHANNELS=128
28
+ HEADS=8
29
+ NUM_BLOCKS=2
30
+ DROPOUT=0.5
31
+
32
+ DATA_PROVIDER=mempool # mempool | blockstream | blockchair
33
+ HTTP_TIMEOUT=10
34
+ HTTP_RETRIES=2
35
+ MAX_NODES=5000
36
+ MAX_EDGES=15000
37
+ USE_FEATURE_ADAPTER=true
38
+ DEFAULT_THRESHOLD=0.5
39
+ QUEUE_CONCURRENCY=2
40
+ BLOCKCHAIR_API_KEY=
41
+ ```
42
+
43
+ Each model repo should contain:
44
+ - `model.pt` — PyTorch Geometric weights.
45
+ - (optional) `thresholds.json` with a key like `{"threshold": 0.42}`.
46
+ - (optional) `scaler.joblib` if you want to reuse the training scaler.
47
+
48
+ ## 📦 API Usage in App
49
+
50
+ - `GET /api/tx/{txid}` and `GET /api/tx/{txid}/outspends` (Esplora).
51
+ - `GET /bitcoin/dashboards/transaction/{txid}` (Blockchair).
52
+
53
+ All calls have **timeouts & retries** and use a small **in-memory cache**.
54
+
55
+ ## 🚦 Rate Limiting
56
+
57
+ Global limit `20 req/min` across the app (sliding window). Exceeding returns `Rate limit exceeded (20 req/min)`.
58
+
59
+ ## 🧪 Acceptance Criteria
60
+
61
+ - Enter a valid tx hash & `k=2` → ego-graph is built, both models run, and the app displays:
62
+ - `probability`, `threshold`, `label` for **GAT** and **GATv2**,
63
+ - counts of nodes/edges and notes (e.g., *FeatureAdapter used*).
64
+ - Ego-graph renders with center highlighted; tooltips show txid and score.
65
+ - If the first provider fails, the app falls back.
66
+ - If graph exceeds safeguards, the app stops expansion and warns in logs (but still infers with what it has).
67
+
68
+ ## ⚠️ Notes
69
+
70
+ - **Domain shift**: Features from on-chain crawls can differ from Elliptic; use the adapter and consider fine-tuning for production.
71
+ - Public APIs have their own rate limits — this app is conservative with requests, but heavy usage may still hit external limits.
72
+ - Input is validated to be a 64-hex txid. No arbitrary URLs are accepted.
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import traceback
3
+ from typing import Tuple, List, Dict, Any
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from torch_geometric.data import Data
10
+ from torch_geometric.utils import coalesce
11
+
12
+ from config import AppConfig
13
+ from rate_limit import GlobalRateLimiter, RateLimitExceeded
14
+ from explorers import new_client, fetch_with_fallback
15
+ from graph_builder import expand_ego
16
+ from features import build_features, scale_features
17
+ from inference import load_models, run_for_both_models
18
+ from viz import render_ego_html, histogram_scores
19
+
20
+ CFG = AppConfig()
21
+ RATE = GlobalRateLimiter(CFG.MAX_CALLS_PER_MIN, CFG.WINDOW_SECONDS)
22
+
23
+ _models_cache = None
24
+
25
+ def _is_valid_txid(tx: str) -> bool:
26
+ return bool(re.fullmatch(r"[0-9a-fA-F]{64}", tx or ""))
27
+
28
+ def _load_models_once():
29
+ global _models_cache
30
+ if _models_cache is None:
31
+ _models_cache = load_models(CFG)
32
+ return _models_cache
33
+
34
+ @RATE.enforce()
35
+ def handle_run(tx_hash: str, k: int, provider: str):
36
+ logs = []
37
+ try:
38
+ if not _is_valid_txid(tx_hash):
39
+ return None, None, "❌ Invalid txid. Please enter a 64-hex transaction hash.", None
40
+
41
+ k = int(k)
42
+ if k < 1 or k > 3:
43
+ return None, None, "❌ k must be in {1,2,3}.", None
44
+
45
+ logs.append(f"Fetching ego-subgraph for {tx_hash} with k={k} via {provider}…")
46
+ nodes, edges, center_idx, node_meta, gb_logs = expand_ego(tx_hash, k, provider, CFG)
47
+ logs.extend(gb_logs or [])
48
+ if center_idx < 0 or len(nodes) == 0:
49
+ return None, None, "❌ Failed to build subgraph (see logs).", "\n".join(logs)
50
+
51
+ if len(edges) == 0:
52
+ logs.append("⚠️ No edges in ego-graph; proceeding with single-node graph.")
53
+
54
+ # Build features
55
+ X, feat_names = build_features(nodes, edges, center_idx, node_meta)
56
+ Xs, scaler_used, scale_note = scale_features(X, scaler=None) # can inject scaler from model repo if desired
57
+ logs.append(scale_note)
58
+
59
+ # PyG Data
60
+ if len(edges) > 0:
61
+ edge_index = torch.tensor(np.array(edges).T, dtype=torch.long) # shape [2,E]
62
+ else:
63
+ edge_index = torch.empty((2,0), dtype=torch.long)
64
+ edge_index = coalesce(edge_index)
65
+
66
+ data = Data(
67
+ x=torch.tensor(Xs, dtype=torch.float32),
68
+ edge_index=edge_index
69
+ )
70
+
71
+ bundles = _load_models_once()
72
+ results = run_for_both_models(bundles, data, center_idx, CFG)
73
+
74
+ # Compose output table
75
+ records = []
76
+ for name, probs, thr, label, note in results:
77
+ rec = {
78
+ "tx_hash": tx_hash,
79
+ "model_name": name,
80
+ "probability": float(probs[center_idx]),
81
+ "threshold": float(thr),
82
+ "pred_label": int(label),
83
+ "k_used": int(k),
84
+ "num_nodes": int(len(nodes)),
85
+ "num_edges": int(len(edges)),
86
+ "note": note
87
+ }
88
+ records.append(rec)
89
+
90
+ df = pd.DataFrame(records)
91
+
92
+ # Visuals (two HTML ego-graphs)
93
+ html_gat = render_ego_html(nodes, edges, center_idx, scores=results[0][1])
94
+ html_gatv2 = render_ego_html(nodes, edges, center_idx, scores=results[1][1])
95
+
96
+ # Histogram of scores for the subgraph
97
+ fig_hist_gat = histogram_scores(results[0][1], title="Scores (GAT)")
98
+ fig_hist_v2 = histogram_scores(results[1][1], title="Scores (GATv2)")
99
+
100
+ log_text = "\n".join(logs)
101
+ return df, html_gat, html_gatv2, log_text, fig_hist_gat, fig_hist_v2
102
+
103
+ except RateLimitExceeded as e:
104
+ return None, None, None, f"❌ {e}", None, None
105
+ except Exception as e:
106
+ tb = traceback.format_exc()
107
+ return None, None, None, f"❌ Error: {e}\n\n{tb}", None, None
108
+
109
+ with gr.Blocks(fill_height=True, theme=gr.themes.Soft()) as app:
110
+ gr.Markdown("## 🧭 Bitcoin Abuse Scoring (GAT / GATv2)\nEnter a transaction hash and k (1–3). The app builds an ego-subgraph from on-chain data and returns model scores.")
111
+ with gr.Row():
112
+ tx_in = gr.Textbox(label="Transaction Hash (64-hex)", placeholder="e.g., 4d3c... (64 hex)")
113
+ k_in = gr.Slider(1, 3, value=2, step=1, label="k (steps before/after)")
114
+ provider_in = gr.Dropdown(choices=["mempool", "blockstream", "blockchair"], value="mempool", label="Data Source")
115
+ run_btn = gr.Button("Run", variant="primary")
116
+
117
+ with gr.Row():
118
+ out_table = gr.Dataframe(label="Results (GAT vs GATv2)", interactive=False)
119
+ with gr.Tabs():
120
+ with gr.Tab("Ego-graph (GAT)"):
121
+ out_html_gat = gr.HTML()
122
+ out_hist_gat = gr.Plot(label="Score histogram (GAT)")
123
+ with gr.Tab("Ego-graph (GATv2)"):
124
+ out_html_gatv2 = gr.HTML()
125
+ out_hist_gatv2 = gr.Plot(label="Score histogram (GATv2)")
126
+ out_logs = gr.Textbox(label="Logs", lines=8)
127
+
128
+ run_btn.click(
129
+ handle_run,
130
+ inputs=[tx_in, k_in, provider_in],
131
+ outputs=[out_table, out_html_gat, out_html_gatv2, out_logs, out_hist_gat, out_hist_gatv2]
132
+ )
133
+
134
+ app.queue(concurrency_count=CFG.QUEUE_CONCURRENCY, max_size=32)
135
+
136
+ if __name__ == "__main__":
137
+ app.launch(server_name="0.0.0.0", server_port=7860)
config.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+
4
+ @dataclass
5
+ class AppConfig:
6
+ # --- Hugging Face model repos (set via environment variables on Spaces) ---
7
+ HF_GAT_BASELINE_REPO: str = os.getenv("HF_GAT_BASELINE_REPO", "org/name_gat_baseline")
8
+ HF_GATV2_REPO: str = os.getenv("HF_GATV2_REPO", "org/name_gatv2")
9
+
10
+ # Expected input dim of Elliptic-trained models (given by user)
11
+ IN_CHANNELS: int = int(os.getenv("IN_CHANNELS", "165"))
12
+ HIDDEN_CHANNELS: int = int(os.getenv("HIDDEN_CHANNELS", "128"))
13
+ HEADS: int = int(os.getenv("HEADS", "8"))
14
+ NUM_BLOCKS: int = int(os.getenv("NUM_BLOCKS", "2"))
15
+ DROPOUT: float = float(os.getenv("DROPOUT", "0.5"))
16
+
17
+ # Data providers
18
+ DATA_PROVIDER: str = os.getenv("DATA_PROVIDER", "mempool") # mempool | blockstream | blockchair
19
+ HTTP_TIMEOUT: int = int(os.getenv("HTTP_TIMEOUT", "10"))
20
+ HTTP_RETRIES: int = int(os.getenv("HTTP_RETRIES", "2"))
21
+
22
+ # Graph limits (safeguard)
23
+ MAX_NODES: int = int(os.getenv("MAX_NODES", "5000"))
24
+ MAX_EDGES: int = int(os.getenv("MAX_EDGES", "15000"))
25
+
26
+ # Feature handling
27
+ USE_FEATURE_ADAPTER: bool = os.getenv("USE_FEATURE_ADAPTER", "true").lower() == "true"
28
+ MAKE_UNDIRECTED: bool = os.getenv("MAKE_UNDIRECTED", "false").lower() == "true"
29
+
30
+ # Threshold fallback
31
+ DEFAULT_THRESHOLD: float = float(os.getenv("DEFAULT_THRESHOLD", "0.5"))
32
+
33
+ # Rate limit
34
+ MAX_CALLS_PER_MIN: int = int(os.getenv("MAX_CALLS_PER_MIN", "20"))
35
+ WINDOW_SECONDS: int = int(os.getenv("WINDOW_SECONDS", "60"))
36
+
37
+ # Queue config
38
+ QUEUE_CONCURRENCY: int = int(os.getenv("QUEUE_CONCURRENCY", "2"))
39
+
40
+ # Blockchair API key (optional)
41
+ BLOCKCHAIR_API_KEY: str = os.getenv("BLOCKCHAIR_API_KEY", "").strip()
explorers.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ from typing import Dict, Any, List, Optional, Tuple
5
+ import requests
6
+ from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
7
+ from cachetools import TTLCache
8
+ from config import AppConfig
9
+
10
+ UserAgent = "HF-Space-BTC-Abuse-GNN/1.0 (+https://huggingface.co/spaces)"
11
+
12
+ class ExplorerError(Exception):
13
+ pass
14
+
15
+ def _req_json(url: str, timeout: int, retries: int = 2) -> Any:
16
+ @retry(stop=stop_after_attempt(retries), wait=wait_exponential(min=0.5, max=4),
17
+ retry=retry_if_exception_type((requests.Timeout, requests.ConnectionError)))
18
+ def _do():
19
+ r = requests.get(url, timeout=timeout, headers={"User-Agent": UserAgent})
20
+ if r.status_code != 200:
21
+ raise ExplorerError(f"HTTP {r.status_code} for {url}")
22
+ return r.json()
23
+ return _do()
24
+
25
+ def _satoshis_to_btc(v: Optional[int]) -> float:
26
+ try:
27
+ return float(v) / 1e8 if v is not None else 0.0
28
+ except Exception:
29
+ return 0.0
30
+
31
+ def _normalize_tx_esplora(j: Dict[str, Any]) -> Dict[str, Any]:
32
+ # https://mempool.space/api/tx/{txid}
33
+ vin = j.get("vin", [])
34
+ vout = j.get("vout", [])
35
+ status = j.get("status", {}) or {}
36
+ bh = status.get("block_height")
37
+ bt = status.get("block_time")
38
+ vin_list = []
39
+ for e in vin:
40
+ p = e.get("prevout") or {}
41
+ vin_list.append({
42
+ "txid": e.get("txid"),
43
+ "vout": e.get("vout"),
44
+ "prevout_value": p.get("value"),
45
+ "prevout_address": p.get("scriptpubkey_address") or None
46
+ })
47
+ vout_list = []
48
+ for idx, e in enumerate(vout):
49
+ vout_list.append({
50
+ "n": idx,
51
+ "value": e.get("value"),
52
+ "address": e.get("scriptpubkey_address") or None
53
+ })
54
+ return {
55
+ "txid": j.get("txid") or j.get("hash"),
56
+ "vin": vin_list,
57
+ "vout": vout_list,
58
+ "block_height": bh,
59
+ "block_time": bt,
60
+ }
61
+
62
+ def _normalize_outspends_esplora(j: Any) -> List[Optional[str]]:
63
+ # returns list aligned to outputs: each item has 'spent', 'txid'
64
+ res = []
65
+ if isinstance(j, list):
66
+ for e in j:
67
+ if isinstance(e, dict) and e.get("spent"):
68
+ res.append(e.get("txid"))
69
+ else:
70
+ res.append(None)
71
+ return res
72
+
73
+ class BaseExplorer:
74
+ def __init__(self, cfg: AppConfig):
75
+ self.cfg = cfg
76
+ self.cache_tx = TTLCache(maxsize=10000, ttl=300)
77
+ self.cache_out = TTLCache(maxsize=10000, ttl=300)
78
+
79
+ def get_tx(self, txid: str) -> Dict[str, Any]:
80
+ raise NotImplementedError
81
+
82
+ def get_outspends(self, txid: str) -> List[Optional[str]]:
83
+ raise NotImplementedError
84
+
85
+ class MempoolSpaceClient(BaseExplorer):
86
+ def __init__(self, cfg: AppConfig, base: str = "https://mempool.space"):
87
+ super().__init__(cfg)
88
+ self.base = base.rstrip("/")
89
+
90
+ def get_tx(self, txid: str) -> Dict[str, Any]:
91
+ if txid in self.cache_tx:
92
+ return self.cache_tx[txid]
93
+ url = f"{self.base}/api/tx/{txid}"
94
+ j = _req_json(url, timeout=self.cfg.HTTP_TIMEOUT, retries=self.cfg.HTTP_RETRIES)
95
+ tx = _normalize_tx_esplora(j)
96
+ self.cache_tx[txid] = tx
97
+ return tx
98
+
99
+ def get_outspends(self, txid: str) -> List[Optional[str]]:
100
+ if txid in self.cache_out:
101
+ return self.cache_out[txid]
102
+ url = f"{self.base}/api/tx/{txid}/outspends"
103
+ j = _req_json(url, timeout=self.cfg.HTTP_TIMEOUT, retries=self.cfg.HTTP_RETRIES)
104
+ out = _normalize_outspends_esplora(j)
105
+ self.cache_out[txid] = out
106
+ return out
107
+
108
+ class BlockstreamClient(MempoolSpaceClient):
109
+ def __init__(self, cfg: AppConfig):
110
+ super().__init__(cfg, base="https://blockstream.info")
111
+
112
+ class BlockchairClient(BaseExplorer):
113
+ def __init__(self, cfg: AppConfig):
114
+ super().__init__(cfg)
115
+ self.base = "https://api.blockchair.com/bitcoin"
116
+
117
+ def get_tx(self, txid: str) -> Dict[str, Any]:
118
+ if txid in self.cache_tx:
119
+ return self.cache_tx[txid]
120
+ url = f"{self.base}/dashboards/transaction/{txid}"
121
+ if self.cfg.BLOCKCHAIR_API_KEY:
122
+ url += f"?key={self.cfg.BLOCKCHAIR_API_KEY}"
123
+ j = _req_json(url, timeout=self.cfg.HTTP_TIMEOUT, retries=self.cfg.HTTP_RETRIES)
124
+ data = j.get("data", {}).get(txid, {})
125
+ tx = data.get("transaction", {})
126
+ inputs = data.get("inputs", [])
127
+ outputs = data.get("outputs", [])
128
+ vin_list = [{
129
+ "txid": i.get("spending_transaction_hash") or i.get("recipient_transaction_hash"),
130
+ "vout": i.get("spending_index"),
131
+ "prevout_value": i.get("value"),
132
+ "prevout_address": i.get("recipient"),
133
+ } for i in inputs]
134
+ vout_list = [{
135
+ "n": o.get("index"),
136
+ "value": o.get("value"),
137
+ "address": o.get("recipient"),
138
+ } for o in outputs]
139
+ out = {
140
+ "txid": txid,
141
+ "vin": vin_list,
142
+ "vout": vout_list,
143
+ "block_height": tx.get("block_id"),
144
+ "block_time": tx.get("time"),
145
+ }
146
+ self.cache_tx[txid] = out
147
+ return out
148
+
149
+ def get_outspends(self, txid: str) -> List[Optional[str]]:
150
+ # Blockchair includes outputs with 'spent_by_transaction_hash'
151
+ if txid in self.cache_out:
152
+ return self.cache_out[txid]
153
+ url = f"{self.base}/dashboards/transaction/{txid}"
154
+ if self.cfg.BLOCKCHAIR_API_KEY:
155
+ url += f"?key={self.cfg.BLOCKCHAIR_API_KEY}"
156
+ j = _req_json(url, timeout=self.cfg.HTTP_TIMEOUT, retries=self.cfg.HTTP_RETRIES)
157
+ outputs = j.get("data", {}).get(txid, {}).get("outputs", [])
158
+ res = []
159
+ for o in outputs:
160
+ res.append(o.get("spent_by_transaction_hash"))
161
+ self.cache_out[txid] = res
162
+ return res
163
+
164
+ def new_client(cfg: AppConfig, primary: str) -> List[BaseExplorer]:
165
+ # primary then fallbacks
166
+ primary = (primary or cfg.DATA_PROVIDER).lower()
167
+ chain = []
168
+ if primary in ("mempool", "mempool.space"):
169
+ chain = [MempoolSpaceClient(cfg), BlockstreamClient(cfg), BlockchairClient(cfg)]
170
+ elif primary in ("blockstream", "blockstream.info"):
171
+ chain = [BlockstreamClient(cfg), MempoolSpaceClient(cfg), BlockchairClient(cfg)]
172
+ elif primary == "blockchair":
173
+ chain = [BlockchairClient(cfg), MempoolSpaceClient(cfg), BlockstreamClient(cfg)]
174
+ else:
175
+ chain = [MempoolSpaceClient(cfg), BlockstreamClient(cfg), BlockchairClient(cfg)]
176
+ return chain
177
+
178
+ def fetch_with_fallback(txid: str, cfg: AppConfig, source: str):
179
+ errors = []
180
+ for c in new_client(cfg, source):
181
+ try:
182
+ tx = c.get_tx(txid)
183
+ outspends = c.get_outspends(txid)
184
+ if tx and outspends is not None:
185
+ return c, tx, outspends, None
186
+ except Exception as e:
187
+ errors.append(f"{c.__class__.__name__}: {e}")
188
+ continue
189
+ return None, None, None, errors
features.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List, Tuple
2
+ import numpy as np
3
+ from collections import deque, defaultdict
4
+ from sklearn.preprocessing import StandardScaler
5
+
6
+ def _sum_inputs_btc(tx: Dict[str, Any]) -> float:
7
+ s = 0.0
8
+ for v in tx.get("vin", []):
9
+ s += float(v.get("prevout_value") or 0.0) / 1e8
10
+ return s
11
+
12
+ def _sum_outputs_btc(tx: Dict[str, Any]) -> float:
13
+ s = 0.0
14
+ for o in tx.get("vout", []):
15
+ s += float(o.get("value") or 0.0) / 1e8
16
+ return s
17
+
18
+ def _compute_distances(n: int, edges: List[Tuple[int,int]], center: int) -> np.ndarray:
19
+ # undirected BFS distance
20
+ adj = [[] for _ in range(n)]
21
+ for u,v in edges:
22
+ adj[u].append(v); adj[v].append(u)
23
+ dist = np.full(n, fill_value=-1, dtype=np.int32)
24
+ q = deque([center]); dist[center] = 0
25
+ while q:
26
+ u = q.popleft()
27
+ for nb in adj[u]:
28
+ if dist[nb] == -1:
29
+ dist[nb] = dist[u] + 1
30
+ q.append(nb)
31
+ return dist
32
+
33
+ def build_features(nodes: List[str], edges: List[Tuple[int,int]], center_idx: int, node_meta: Dict[str, Dict[str, Any]]):
34
+ n = len(nodes)
35
+ # degrees
36
+ out_deg = np.zeros(n, dtype=np.float32)
37
+ in_deg = np.zeros(n, dtype=np.float32)
38
+ for u,v in edges:
39
+ out_deg[u] += 1
40
+ in_deg[v] += 1
41
+ deg = in_deg + out_deg
42
+ ratio_in_out = in_deg / (out_deg + 1e-6)
43
+
44
+ # sums & counts from metadata
45
+ sum_in_btc = np.zeros(n, dtype=np.float32)
46
+ sum_out_btc = np.zeros(n, dtype=np.float32)
47
+ n_inputs = np.zeros(n, dtype=np.float32)
48
+ n_outputs = np.zeros(n, dtype=np.float32)
49
+ block_height = np.zeros(n, dtype=np.float32)
50
+
51
+ for idx, txid in enumerate(nodes):
52
+ meta = node_meta.get(txid) or {}
53
+ n_inputs[idx] = float(len(meta.get("vin", []) or []))
54
+ n_outputs[idx] = float(len(meta.get("vout", []) or []))
55
+ sum_in_btc[idx] = _sum_inputs_btc(meta)
56
+ sum_out_btc[idx] = _sum_outputs_btc(meta)
57
+ bh = meta.get("block_height")
58
+ block_height[idx] = float(bh) if bh is not None else 0.0
59
+
60
+ log_sum_in = np.log1p(sum_in_btc)
61
+ log_sum_out = np.log1p(sum_out_btc)
62
+ distance = _compute_distances(n, edges, center_idx)
63
+
64
+ feats = np.stack([
65
+ in_deg, out_deg, deg, ratio_in_out,
66
+ n_inputs, n_outputs,
67
+ sum_in_btc, sum_out_btc,
68
+ log_sum_in, log_sum_out,
69
+ distance.astype(np.float32),
70
+ block_height
71
+ ], axis=1)
72
+
73
+ feature_names = [
74
+ "in_degree","out_degree","degree","ratio_in_out",
75
+ "n_inputs","n_outputs",
76
+ "sum_in_btc","sum_out_btc",
77
+ "log_sum_in","log_sum_out",
78
+ "distance","block_height",
79
+ ]
80
+ return feats, feature_names
81
+
82
+ def scale_features(X: np.ndarray, scaler=None):
83
+ if scaler is None:
84
+ scaler = StandardScaler()
85
+ Xs = scaler.fit_transform(X)
86
+ note = "Fitted new StandardScaler on ego-subgraph (domain shift vs Elliptic)."
87
+ else:
88
+ Xs = scaler.transform(X)
89
+ note = "Used provided scaler from model repo."
90
+ return Xs.astype("float32"), scaler, note
graph_builder.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List, Tuple, Set
2
+ from collections import deque, defaultdict
3
+
4
+ from explorers import fetch_with_fallback
5
+ from config import AppConfig
6
+
7
+ def expand_ego(txid: str, k: int, source: str, cfg: AppConfig):
8
+ """
9
+ Expand ego-subgraph up to k steps backward (parents) and forward (children).
10
+ Returns:
11
+ nodes: List[str] txids
12
+ edges: List[Tuple[int,int]] parent->child indices
13
+ center_idx: int
14
+ node_meta: Dict[txid, dict]
15
+ logs: List[str]
16
+ """
17
+ logs = []
18
+ client, tx0, out0, errs = fetch_with_fallback(txid, cfg, source)
19
+ if client is None:
20
+ return [], [], -1, {}, ["All providers failed", *(errs or [])]
21
+
22
+ nodes: List[str] = []
23
+ idx_map: Dict[str, int] = {}
24
+ edges: List[Tuple[int,int]] = []
25
+ node_meta: Dict[str, Dict[str, Any]] = {}
26
+
27
+ def add_node(tid: str, meta: Dict[str, Any]):
28
+ if tid in idx_map:
29
+ return idx_map[tid]
30
+ if len(nodes) >= cfg.MAX_NODES:
31
+ return None
32
+ idx = len(nodes)
33
+ nodes.append(tid)
34
+ idx_map[tid] = idx
35
+ node_meta[tid] = meta
36
+ return idx
37
+
38
+ def ensure_tx(tid: str):
39
+ c, tj, outsp, _ = fetch_with_fallback(tid, cfg, source)
40
+ if tj is None:
41
+ return None, None
42
+ return tj, outsp
43
+
44
+ # seed
45
+ center_idx = add_node(txid, tx0)
46
+ if center_idx is None:
47
+ return [], [], -1, {}, ["MAX_NODES reached at seed"]
48
+ frontier_par = deque([(txid, 0)])
49
+ frontier_ch = deque([(txid, 0)])
50
+
51
+ # BFS backward (parents)
52
+ while frontier_par:
53
+ cur, depth = frontier_par.popleft()
54
+ if depth >= k:
55
+ continue
56
+ tj = node_meta.get(cur)
57
+ if tj is None:
58
+ t, o = ensure_tx(cur)
59
+ if t is None:
60
+ continue
61
+ node_meta[cur] = t
62
+ tj = t
63
+ for vi in tj.get("vin", []):
64
+ ptx = vi.get("txid")
65
+ if not ptx:
66
+ continue
67
+ if ptx not in idx_map:
68
+ if len(nodes) >= cfg.MAX_NODES:
69
+ logs.append("MAX_NODES reached during backward expansion")
70
+ break
71
+ ptj, pout = ensure_tx(ptx)
72
+ if ptj is None:
73
+ continue
74
+ pidx = add_node(ptx, ptj)
75
+ if pidx is None:
76
+ continue
77
+ frontier_par.append((ptx, depth+1))
78
+ else:
79
+ pidx = idx_map[ptx]
80
+ # edge parent->child
81
+ cidx = idx_map.get(cur)
82
+ if cidx is not None:
83
+ edges.append((pidx, cidx))
84
+ if cfg.MAX_EDGES and len(edges) >= cfg.MAX_EDGES:
85
+ logs.append("MAX_EDGES reached; stopping further edge additions")
86
+ break
87
+
88
+ # BFS forward (children)
89
+ while frontier_ch:
90
+ cur, depth = frontier_ch.popleft()
91
+ if depth >= k:
92
+ continue
93
+ tj = node_meta.get(cur)
94
+ if tj is None:
95
+ t, o = ensure_tx(cur)
96
+ if t is None:
97
+ continue
98
+ node_meta[cur] = t
99
+ tj = t
100
+ outsp = None
101
+ try:
102
+ _, _, outsp, _ = fetch_with_fallback(cur, cfg, source)
103
+ except Exception:
104
+ outsp = None
105
+ if outsp is None:
106
+ continue
107
+ for child_tx in outsp:
108
+ if not child_tx:
109
+ continue
110
+ if child_tx not in idx_map:
111
+ if len(nodes) >= cfg.MAX_NODES:
112
+ logs.append("MAX_NODES reached during forward expansion")
113
+ break
114
+ ctj, cout = ensure_tx(child_tx)
115
+ if ctj is None:
116
+ continue
117
+ cidx = add_node(child_tx, ctj)
118
+ if cidx is None:
119
+ continue
120
+ frontier_ch.append((child_tx, depth+1))
121
+ else:
122
+ cidx = idx_map[child_tx]
123
+ pidx = idx_map.get(cur)
124
+ if pidx is not None:
125
+ edges.append((pidx, cidx))
126
+ if cfg.MAX_EDGES and len(edges) >= cfg.MAX_EDGES:
127
+ logs.append("MAX_EDGES reached; stopping further edge additions")
128
+ break
129
+
130
+ # deduplicate edges
131
+ edges = list(set(edges))
132
+ return nodes, edges, center_idx, node_meta, logs
inference.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from typing import Dict, Any, Tuple, Optional
4
+
5
+ import torch
6
+ from huggingface_hub import snapshot_download
7
+ from torch_geometric.data import Data
8
+ from torch_geometric.utils import to_undirected
9
+
10
+ from config import AppConfig
11
+ from models import GATBaseline, GATv2Enhanced, AdapterWrapper
12
+
13
+ def _load_threshold(model_dir: str, default_thr: float) -> float:
14
+ for name in ["thresholds.json", "threshold.json", "config.json"]:
15
+ p = os.path.join(model_dir, name)
16
+ if os.path.exists(p):
17
+ try:
18
+ d = json.load(open(p, "r"))
19
+ for k in ["threshold","default_threshold","thr","best_f1","best_j"]:
20
+ if k in d and isinstance(d[k], (int, float)):
21
+ return float(d[k])
22
+ except Exception:
23
+ continue
24
+ return default_thr
25
+
26
+ def _load_scaler(model_dir: str):
27
+ # Optional scaler joblib/pkl
28
+ for name in ["scaler.joblib", "scaler.pkl", "elliptic_scaler.joblib", "elliptic_scaler.pkl"]:
29
+ p = os.path.join(model_dir, name)
30
+ if os.path.exists(p):
31
+ try:
32
+ import joblib
33
+ return joblib.load(p)
34
+ except Exception:
35
+ pass
36
+ return None
37
+
38
+ def load_models(cfg: AppConfig):
39
+ # Download both repos
40
+ dir_gat = snapshot_download(cfg.HF_GAT_BASELINE_REPO, local_dir_use_symlinks=False)
41
+ dir_gatv2 = snapshot_download(cfg.HF_GATV2_REPO, local_dir_use_symlinks=False)
42
+
43
+ # Model files
44
+ ckpt_gat = os.path.join(dir_gat, "model.pt")
45
+ ckpt_gatv2 = os.path.join(dir_gatv2, "model.pt")
46
+ if not os.path.exists(ckpt_gat):
47
+ raise FileNotFoundError(f"Missing model.pt in {dir_gat}")
48
+ if not os.path.exists(ckpt_gatv2):
49
+ raise FileNotFoundError(f"Missing model.pt in {dir_gatv2}")
50
+
51
+ # Build cores (expected input dim from training)
52
+ core_gat = GATBaseline(cfg.IN_CHANNELS, cfg.HIDDEN_CHANNELS, cfg.HEADS, cfg.NUM_BLOCKS, cfg.DROPOUT)
53
+ core_gatv2 = GATv2Enhanced(cfg.IN_CHANNELS, cfg.HIDDEN_CHANNELS, cfg.HEADS, cfg.NUM_BLOCKS, cfg.DROPOUT)
54
+
55
+ state_gat = torch.load(ckpt_gat, map_location="cpu")
56
+ state_gatv2 = torch.load(ckpt_gatv2, map_location="cpu")
57
+
58
+ # strict load for cores
59
+ core_gat.load_state_dict(state_gat, strict=True)
60
+ core_gatv2.load_state_dict(state_gatv2, strict=True)
61
+
62
+ thr_gat = _load_threshold(dir_gat, cfg.DEFAULT_THRESHOLD)
63
+ thr_gatv2 = _load_threshold(dir_gatv2, cfg.DEFAULT_THRESHOLD)
64
+
65
+ scaler_gat = _load_scaler(dir_gat)
66
+ scaler_gatv2 = _load_scaler(dir_gatv2)
67
+
68
+ return {
69
+ "gat": {"core": core_gat.eval(), "threshold": thr_gat, "scaler": scaler_gat, "repo_dir": dir_gat},
70
+ "gatv2": {"core": core_gatv2.eval(), "threshold": thr_gatv2, "scaler": scaler_gatv2, "repo_dir": dir_gatv2},
71
+ }
72
+
73
+ @torch.no_grad()
74
+ def predict(model, data: Data):
75
+ logits = model(data.x, data.edge_index)
76
+ probs = torch.sigmoid(logits).cpu().numpy()
77
+ return probs
78
+
79
+ def adapt_and_predict(bundle: Dict[str, Any], in_dim_new: int, data: Data, cfg: AppConfig):
80
+ core = bundle["core"]
81
+ if in_dim_new != cfg.IN_CHANNELS and cfg.USE_FEATURE_ADAPTER:
82
+ model = AdapterWrapper(in_dim_new, cfg.IN_CHANNELS, core).eval()
83
+ note = f"FeatureAdapter used (new_dim={in_dim_new} → expected={cfg.IN_CHANNELS})."
84
+ elif in_dim_new != cfg.IN_CHANNELS:
85
+ # attempt to run without adapter (not recommended)
86
+ model = core.eval()
87
+ note = f"Dimension mismatch (new_dim={in_dim_new}, expected={cfg.IN_CHANNELS}). Proceeding without adapter (may fail)."
88
+ else:
89
+ model = core.eval()
90
+ note = "Input dim matches."
91
+ probs = predict(model, data)
92
+ return probs, note
93
+
94
+ def run_for_both_models(bundles, data: Data, center_idx: int, cfg: AppConfig):
95
+ in_dim_new = data.x.shape[1]
96
+ results = []
97
+
98
+ probs_g, note_g = adapt_and_predict(bundles["gat"], in_dim_new, data, cfg)
99
+ thr_g = float(bundles["gat"]["threshold"])
100
+ label_g = int(probs_g[center_idx] >= thr_g)
101
+
102
+ probs_v2, note_v2 = adapt_and_predict(bundles["gatv2"], in_dim_new, data, cfg)
103
+ thr_v2 = float(bundles["gatv2"]["threshold"])
104
+ label_v2 = int(probs_v2[center_idx] >= thr_v2)
105
+
106
+ return [
107
+ ("GAT", probs_g, thr_g, label_g, note_g),
108
+ ("GATv2", probs_v2, thr_v2, label_v2, note_v2),
109
+ ]
models.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch_geometric.nn import GATConv, GATv2Conv, BatchNorm
4
+
5
+ class ResidualGATBlock(nn.Module):
6
+ def __init__(self, in_channels, hidden_channels, heads=8, dropout=0.5, v2=False):
7
+ super().__init__()
8
+ Conv = GATv2Conv if v2 else GATConv
9
+ self.conv = Conv(in_channels, hidden_channels, heads=heads, dropout=dropout)
10
+ self.bn = BatchNorm(hidden_channels * heads)
11
+ self.act = nn.ReLU()
12
+ self.dropout = nn.Dropout(dropout)
13
+ self.res_proj = None
14
+ out_dim = hidden_channels * heads
15
+ if in_channels != out_dim:
16
+ self.res_proj = nn.Linear(in_channels, out_dim)
17
+
18
+ def forward(self, x, edge_index):
19
+ identity = x
20
+ out = self.conv(x, edge_index)
21
+ out = self.bn(out)
22
+ out = self.act(out)
23
+ out = self.dropout(out)
24
+ if self.res_proj is not None:
25
+ identity = self.res_proj(identity)
26
+ return out + identity
27
+
28
+ class GATBaseline(nn.Module):
29
+ def __init__(self, in_channels, hidden_channels=128, heads=8, num_blocks=2, dropout=0.5):
30
+ super().__init__()
31
+ layers = []
32
+ c_in = in_channels
33
+ for _ in range(num_blocks):
34
+ layers.append(ResidualGATBlock(c_in, hidden_channels, heads=heads, dropout=dropout, v2=False))
35
+ c_in = hidden_channels * heads
36
+ self.blocks = nn.ModuleList(layers)
37
+ self.dropout = nn.Dropout(dropout)
38
+ self.out_conv = GATConv(c_in, 1, heads=1, concat=False, dropout=dropout)
39
+
40
+ def forward(self, x, edge_index):
41
+ for block in self.blocks:
42
+ x = block(x, edge_index)
43
+ x = self.dropout(x)
44
+ out = self.out_conv(x, edge_index)
45
+ return out.view(-1)
46
+
47
+ class GATv2Enhanced(nn.Module):
48
+ def __init__(self, in_channels, hidden_channels=128, heads=8, num_blocks=2, dropout=0.5):
49
+ super().__init__()
50
+ layers = []
51
+ c_in = in_channels
52
+ for _ in range(num_blocks):
53
+ layers.append(ResidualGATBlock(c_in, hidden_channels, heads=heads, dropout=dropout, v2=True))
54
+ c_in = hidden_channels * heads
55
+ self.blocks = nn.ModuleList(layers)
56
+ self.dropout = nn.Dropout(dropout)
57
+ self.out_conv = GATv2Conv(c_in, 1, heads=1, concat=False, dropout=dropout)
58
+
59
+ def forward(self, x, edge_index):
60
+ for block in self.blocks:
61
+ x = block(x, edge_index)
62
+ x = self.dropout(x)
63
+ out = self.out_conv(x, edge_index)
64
+ return out.view(-1)
65
+
66
+ class AdapterWrapper(nn.Module):
67
+ def __init__(self, in_dim_new, expected_in_dim, core_model):
68
+ super().__init__()
69
+ if in_dim_new != expected_in_dim:
70
+ self.adapter = nn.Linear(in_dim_new, expected_in_dim, bias=True)
71
+ else:
72
+ self.adapter = None
73
+ self.core = core_model
74
+
75
+ def forward(self, x, edge_index):
76
+ if self.adapter is not None:
77
+ x = self.adapter(x)
78
+ return self.core(x, edge_index)
rate_limit.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import threading
3
+ from collections import deque
4
+ from functools import wraps
5
+
6
+ class RateLimitExceeded(Exception):
7
+ pass
8
+
9
+ class GlobalRateLimiter:
10
+ def __init__(self, max_calls: int, window_seconds: int):
11
+ self.max_calls = max_calls
12
+ self.window = window_seconds
13
+ self._lock = threading.Lock()
14
+ self._events = deque()
15
+
16
+ def allow(self) -> bool:
17
+ now = time.time()
18
+ with self._lock:
19
+ while self._events and now - self._events[0] > self.window:
20
+ self._events.popleft()
21
+ if len(self._events) < self.max_calls:
22
+ self._events.append(now)
23
+ return True
24
+ return False
25
+
26
+ def enforce(self, func=None):
27
+ def decorator(f):
28
+ @wraps(f)
29
+ def wrapper(*args, **kwargs):
30
+ if not self.allow():
31
+ raise RateLimitExceeded(f"Rate limit exceeded ({self.max_calls} req/{self.window}s)")
32
+ return f(*args, **kwargs)
33
+ return wrapper
34
+ return decorator if func is None else decorator(func)
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch CPU-only build (3.13 compatible)
2
+ torch==2.6.0
3
+ torchvision==0.21.0
4
+ torchaudio==2.6.0
5
+ -f https://download.pytorch.org/whl/cpu
6
+
7
+ # PyTorch Geometric (compatible with torch 2.6.0 CPU)
8
+ pyg_lib==0.4.0
9
+ torch_scatter==2.1.2
10
+ torch_sparse==0.6.18
11
+ torch_cluster==1.6.3
12
+ torch_spline_conv==1.2.2
13
+ torch_geometric==2.6.0
14
+
15
+ # Core app & utilities
16
+ gradio>=5.0.0,<5.2.0
17
+ huggingface_hub>=0.26.2
18
+ requests>=2.32.3
19
+ tenacity>=9.0.0
20
+ diskcache>=5.6.3
21
+ cachetools>=5.5.0
22
+ pyvis>=0.3.3
23
+ networkx>=3.4.2
24
+ pandas>=2.2.3
25
+ numpy>=2.1.3
26
+ scikit-learn>=1.6.0
27
+ matplotlib>=3.9.2
28
+ tqdm>=4.67.1
viz.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Optional
2
+ from pyvis.network import Network
3
+ import numpy as np
4
+ import io, base64
5
+ import matplotlib.pyplot as plt
6
+
7
+ def _color_from_score(p: float) -> str:
8
+ # blue (0) -> gray (0.5) -> red (1)
9
+ p = float(np.clip(p, 0.0, 1.0))
10
+ if p < 0.5:
11
+ # interpolate blue (#377eb8) to gray (#bbbbbb)
12
+ t = p / 0.5
13
+ c0 = (0x37, 0x7e, 0xb8)
14
+ c1 = (0xbb, 0xbb, 0xbb)
15
+ else:
16
+ # interpolate gray to red (#e41a1c)
17
+ t = (p - 0.5) / 0.5
18
+ c0 = (0xbb, 0xbb, 0xbb)
19
+ c1 = (0xe4, 0x1a, 0x1c)
20
+ r = int((1-t)*c0[0] + t*c1[0])
21
+ g = int((1-t)*c0[1] + t*c1[1])
22
+ b = int((1-t)*c0[2] + t*c1[2])
23
+ return f"#{r:02x}{g:02x}{b:02x}"
24
+
25
+ def render_ego_html(nodes: List[str], edges: List[Tuple[int,int]], center_idx: int, scores: Optional[np.ndarray]=None) -> str:
26
+ net = Network(height="600px", width="100%", notebook=False, directed=True)
27
+ for i, txid in enumerate(nodes):
28
+ color = _color_from_score(scores[i]) if scores is not None else "#bbbbbb"
29
+ size = 20 if i == center_idx else 8
30
+ title = f"{txid}"
31
+ if scores is not None:
32
+ title += f"<br/>score={scores[i]:.4f}"
33
+ net.add_node(i, label=txid[:10]+"…", title=title, color=color, size=size)
34
+ for (u,v) in edges:
35
+ net.add_edge(u, v, arrows="to")
36
+ net.toggle_physics(True)
37
+ return net.generate_html()
38
+
39
+ def histogram_scores(scores, title="Score distribution"):
40
+ fig = plt.figure(figsize=(6,4))
41
+ plt.hist(scores, bins=40)
42
+ plt.xlabel("Score"); plt.ylabel("Count"); plt.title(title)
43
+ plt.tight_layout()
44
+ return fig