File size: 4,643 Bytes
db886e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4e01607
 
db886e4
 
 
 
 
 
 
 
 
cb08ecf
 
 
 
 
 
 
 
 
 
db886e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import json
import os
from typing import Dict, Any, Tuple, Optional

import torch
from huggingface_hub import snapshot_download
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected

from config import AppConfig
from models import GATBaseline, GATv2Enhanced, AdapterWrapper

def _load_threshold(model_dir: str, default_thr: float) -> float:
    for name in ["thresholds.json", "threshold.json", "config.json"]:
        p = os.path.join(model_dir, name)
        if os.path.exists(p):
            try:
                d = json.load(open(p, "r"))
                for k in ["threshold","default_threshold","thr","best_f1","best_j"]:
                    if k in d and isinstance(d[k], (int, float)):
                        return float(d[k])
            except Exception:
                continue
    return default_thr

def _load_scaler(model_dir: str):
    # Optional scaler joblib/pkl
    for name in ["scaler.joblib", "scaler.pkl", "elliptic_scaler.joblib", "elliptic_scaler.pkl"]:
        p = os.path.join(model_dir, name)
        if os.path.exists(p):
            try:
                import joblib
                return joblib.load(p)
            except Exception:
                pass
    return None

def load_models(cfg: AppConfig):
    # Download both repos
    dir_gat = snapshot_download(cfg.HF_GAT_BASELINE_REPO, local_dir_use_symlinks=False)
    dir_gatv2 = snapshot_download(cfg.HF_GATV2_REPO, local_dir_use_symlinks=False)

    # Model files
    ckpt_gat = os.path.join(dir_gat, "gat_baseline_best.pt")
    ckpt_gatv2 = os.path.join(dir_gatv2, "gatv2_enhanced_best.pt")
    if not os.path.exists(ckpt_gat):
        raise FileNotFoundError(f"Missing model.pt in {dir_gat}")
    if not os.path.exists(ckpt_gatv2):
        raise FileNotFoundError(f"Missing model.pt in {dir_gatv2}")

    # Build cores (expected input dim from training)
    core_gat = GATBaseline(cfg.IN_CHANNELS, cfg.HIDDEN_CHANNELS, cfg.HEADS, cfg.NUM_BLOCKS, cfg.DROPOUT)
    core_gatv2 = GATv2Enhanced(cfg.IN_CHANNELS, cfg.HIDDEN_CHANNELS, cfg.HEADS, cfg.NUM_BLOCKS, cfg.DROPOUT)

    try:
        state_gat = torch.load(ckpt_gat, map_location="cpu", weights_only=True)
    except Exception:
        # Fallback ONLY IF checkpoint is trusted
        state_gat = torch.load(ckpt_gat, map_location="cpu", weights_only=False)

    try:
        state_gatv2 = torch.load(ckpt_gatv2, map_location="cpu", weights_only=True)
    except Exception:
        state_gatv2 = torch.load(ckpt_gatv2, map_location="cpu", weights_only=False)

    # strict load for cores
    core_gat.load_state_dict(state_gat, strict=True)
    core_gatv2.load_state_dict(state_gatv2, strict=True)

    thr_gat = _load_threshold(dir_gat, cfg.DEFAULT_THRESHOLD)
    thr_gatv2 = _load_threshold(dir_gatv2, cfg.DEFAULT_THRESHOLD)

    scaler_gat = _load_scaler(dir_gat)
    scaler_gatv2 = _load_scaler(dir_gatv2)

    return {
        "gat": {"core": core_gat.eval(), "threshold": thr_gat, "scaler": scaler_gat, "repo_dir": dir_gat},
        "gatv2": {"core": core_gatv2.eval(), "threshold": thr_gatv2, "scaler": scaler_gatv2, "repo_dir": dir_gatv2},
    }

@torch.no_grad()
def predict(model, data: Data):
    logits = model(data.x, data.edge_index)
    probs = torch.sigmoid(logits).cpu().numpy()
    return probs

def adapt_and_predict(bundle: Dict[str, Any], in_dim_new: int, data: Data, cfg: AppConfig):
    core = bundle["core"]
    if in_dim_new != cfg.IN_CHANNELS and cfg.USE_FEATURE_ADAPTER:
        model = AdapterWrapper(in_dim_new, cfg.IN_CHANNELS, core).eval()
        note = f"FeatureAdapter used (new_dim={in_dim_new} → expected={cfg.IN_CHANNELS})."
    elif in_dim_new != cfg.IN_CHANNELS:
        # attempt to run without adapter (not recommended)
        model = core.eval()
        note = f"Dimension mismatch (new_dim={in_dim_new}, expected={cfg.IN_CHANNELS}). Proceeding without adapter (may fail)."
    else:
        model = core.eval()
        note = "Input dim matches."
    probs = predict(model, data)
    return probs, note

def run_for_both_models(bundles, data: Data, center_idx: int, cfg: AppConfig):
    in_dim_new = data.x.shape[1]
    results = []

    probs_g, note_g = adapt_and_predict(bundles["gat"], in_dim_new, data, cfg)
    thr_g = float(bundles["gat"]["threshold"])
    label_g = int(probs_g[center_idx] >= thr_g)

    probs_v2, note_v2 = adapt_and_predict(bundles["gatv2"], in_dim_new, data, cfg)
    thr_v2 = float(bundles["gatv2"]["threshold"])
    label_v2 = int(probs_v2[center_idx] >= thr_v2)

    return [
        ("GAT", probs_g, thr_g, label_g, note_g),
        ("GATv2", probs_v2, thr_v2, label_v2, note_v2),
    ]