MLGraph-Bitcoin-GAD / inference.py
thanhphxu's picture
Upload folder using huggingface_hub
cb08ecf verified
import json
import os
from typing import Dict, Any, Tuple, Optional
import torch
from huggingface_hub import snapshot_download
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected
from config import AppConfig
from models import GATBaseline, GATv2Enhanced, AdapterWrapper
def _load_threshold(model_dir: str, default_thr: float) -> float:
for name in ["thresholds.json", "threshold.json", "config.json"]:
p = os.path.join(model_dir, name)
if os.path.exists(p):
try:
d = json.load(open(p, "r"))
for k in ["threshold","default_threshold","thr","best_f1","best_j"]:
if k in d and isinstance(d[k], (int, float)):
return float(d[k])
except Exception:
continue
return default_thr
def _load_scaler(model_dir: str):
# Optional scaler joblib/pkl
for name in ["scaler.joblib", "scaler.pkl", "elliptic_scaler.joblib", "elliptic_scaler.pkl"]:
p = os.path.join(model_dir, name)
if os.path.exists(p):
try:
import joblib
return joblib.load(p)
except Exception:
pass
return None
def load_models(cfg: AppConfig):
# Download both repos
dir_gat = snapshot_download(cfg.HF_GAT_BASELINE_REPO, local_dir_use_symlinks=False)
dir_gatv2 = snapshot_download(cfg.HF_GATV2_REPO, local_dir_use_symlinks=False)
# Model files
ckpt_gat = os.path.join(dir_gat, "gat_baseline_best.pt")
ckpt_gatv2 = os.path.join(dir_gatv2, "gatv2_enhanced_best.pt")
if not os.path.exists(ckpt_gat):
raise FileNotFoundError(f"Missing model.pt in {dir_gat}")
if not os.path.exists(ckpt_gatv2):
raise FileNotFoundError(f"Missing model.pt in {dir_gatv2}")
# Build cores (expected input dim from training)
core_gat = GATBaseline(cfg.IN_CHANNELS, cfg.HIDDEN_CHANNELS, cfg.HEADS, cfg.NUM_BLOCKS, cfg.DROPOUT)
core_gatv2 = GATv2Enhanced(cfg.IN_CHANNELS, cfg.HIDDEN_CHANNELS, cfg.HEADS, cfg.NUM_BLOCKS, cfg.DROPOUT)
try:
state_gat = torch.load(ckpt_gat, map_location="cpu", weights_only=True)
except Exception:
# Fallback ONLY IF checkpoint is trusted
state_gat = torch.load(ckpt_gat, map_location="cpu", weights_only=False)
try:
state_gatv2 = torch.load(ckpt_gatv2, map_location="cpu", weights_only=True)
except Exception:
state_gatv2 = torch.load(ckpt_gatv2, map_location="cpu", weights_only=False)
# strict load for cores
core_gat.load_state_dict(state_gat, strict=True)
core_gatv2.load_state_dict(state_gatv2, strict=True)
thr_gat = _load_threshold(dir_gat, cfg.DEFAULT_THRESHOLD)
thr_gatv2 = _load_threshold(dir_gatv2, cfg.DEFAULT_THRESHOLD)
scaler_gat = _load_scaler(dir_gat)
scaler_gatv2 = _load_scaler(dir_gatv2)
return {
"gat": {"core": core_gat.eval(), "threshold": thr_gat, "scaler": scaler_gat, "repo_dir": dir_gat},
"gatv2": {"core": core_gatv2.eval(), "threshold": thr_gatv2, "scaler": scaler_gatv2, "repo_dir": dir_gatv2},
}
@torch.no_grad()
def predict(model, data: Data):
logits = model(data.x, data.edge_index)
probs = torch.sigmoid(logits).cpu().numpy()
return probs
def adapt_and_predict(bundle: Dict[str, Any], in_dim_new: int, data: Data, cfg: AppConfig):
core = bundle["core"]
if in_dim_new != cfg.IN_CHANNELS and cfg.USE_FEATURE_ADAPTER:
model = AdapterWrapper(in_dim_new, cfg.IN_CHANNELS, core).eval()
note = f"FeatureAdapter used (new_dim={in_dim_new} → expected={cfg.IN_CHANNELS})."
elif in_dim_new != cfg.IN_CHANNELS:
# attempt to run without adapter (not recommended)
model = core.eval()
note = f"Dimension mismatch (new_dim={in_dim_new}, expected={cfg.IN_CHANNELS}). Proceeding without adapter (may fail)."
else:
model = core.eval()
note = "Input dim matches."
probs = predict(model, data)
return probs, note
def run_for_both_models(bundles, data: Data, center_idx: int, cfg: AppConfig):
in_dim_new = data.x.shape[1]
results = []
probs_g, note_g = adapt_and_predict(bundles["gat"], in_dim_new, data, cfg)
thr_g = float(bundles["gat"]["threshold"])
label_g = int(probs_g[center_idx] >= thr_g)
probs_v2, note_v2 = adapt_and_predict(bundles["gatv2"], in_dim_new, data, cfg)
thr_v2 = float(bundles["gatv2"]["threshold"])
label_v2 = int(probs_v2[center_idx] >= thr_v2)
return [
("GAT", probs_g, thr_g, label_g, note_g),
("GATv2", probs_v2, thr_v2, label_v2, note_v2),
]