import json import os from typing import Dict, Any, Tuple, Optional import torch from huggingface_hub import snapshot_download from torch_geometric.data import Data from torch_geometric.utils import to_undirected from config import AppConfig from models import GATBaseline, GATv2Enhanced, AdapterWrapper def _load_threshold(model_dir: str, default_thr: float) -> float: for name in ["thresholds.json", "threshold.json", "config.json"]: p = os.path.join(model_dir, name) if os.path.exists(p): try: d = json.load(open(p, "r")) for k in ["threshold","default_threshold","thr","best_f1","best_j"]: if k in d and isinstance(d[k], (int, float)): return float(d[k]) except Exception: continue return default_thr def _load_scaler(model_dir: str): # Optional scaler joblib/pkl for name in ["scaler.joblib", "scaler.pkl", "elliptic_scaler.joblib", "elliptic_scaler.pkl"]: p = os.path.join(model_dir, name) if os.path.exists(p): try: import joblib return joblib.load(p) except Exception: pass return None def load_models(cfg: AppConfig): # Download both repos dir_gat = snapshot_download(cfg.HF_GAT_BASELINE_REPO, local_dir_use_symlinks=False) dir_gatv2 = snapshot_download(cfg.HF_GATV2_REPO, local_dir_use_symlinks=False) # Model files ckpt_gat = os.path.join(dir_gat, "gat_baseline_best.pt") ckpt_gatv2 = os.path.join(dir_gatv2, "gatv2_enhanced_best.pt") if not os.path.exists(ckpt_gat): raise FileNotFoundError(f"Missing model.pt in {dir_gat}") if not os.path.exists(ckpt_gatv2): raise FileNotFoundError(f"Missing model.pt in {dir_gatv2}") # Build cores (expected input dim from training) core_gat = GATBaseline(cfg.IN_CHANNELS, cfg.HIDDEN_CHANNELS, cfg.HEADS, cfg.NUM_BLOCKS, cfg.DROPOUT) core_gatv2 = GATv2Enhanced(cfg.IN_CHANNELS, cfg.HIDDEN_CHANNELS, cfg.HEADS, cfg.NUM_BLOCKS, cfg.DROPOUT) try: state_gat = torch.load(ckpt_gat, map_location="cpu", weights_only=True) except Exception: # Fallback ONLY IF checkpoint is trusted state_gat = torch.load(ckpt_gat, map_location="cpu", weights_only=False) try: state_gatv2 = torch.load(ckpt_gatv2, map_location="cpu", weights_only=True) except Exception: state_gatv2 = torch.load(ckpt_gatv2, map_location="cpu", weights_only=False) # strict load for cores core_gat.load_state_dict(state_gat, strict=True) core_gatv2.load_state_dict(state_gatv2, strict=True) thr_gat = _load_threshold(dir_gat, cfg.DEFAULT_THRESHOLD) thr_gatv2 = _load_threshold(dir_gatv2, cfg.DEFAULT_THRESHOLD) scaler_gat = _load_scaler(dir_gat) scaler_gatv2 = _load_scaler(dir_gatv2) return { "gat": {"core": core_gat.eval(), "threshold": thr_gat, "scaler": scaler_gat, "repo_dir": dir_gat}, "gatv2": {"core": core_gatv2.eval(), "threshold": thr_gatv2, "scaler": scaler_gatv2, "repo_dir": dir_gatv2}, } @torch.no_grad() def predict(model, data: Data): logits = model(data.x, data.edge_index) probs = torch.sigmoid(logits).cpu().numpy() return probs def adapt_and_predict(bundle: Dict[str, Any], in_dim_new: int, data: Data, cfg: AppConfig): core = bundle["core"] if in_dim_new != cfg.IN_CHANNELS and cfg.USE_FEATURE_ADAPTER: model = AdapterWrapper(in_dim_new, cfg.IN_CHANNELS, core).eval() note = f"FeatureAdapter used (new_dim={in_dim_new} → expected={cfg.IN_CHANNELS})." elif in_dim_new != cfg.IN_CHANNELS: # attempt to run without adapter (not recommended) model = core.eval() note = f"Dimension mismatch (new_dim={in_dim_new}, expected={cfg.IN_CHANNELS}). Proceeding without adapter (may fail)." else: model = core.eval() note = "Input dim matches." probs = predict(model, data) return probs, note def run_for_both_models(bundles, data: Data, center_idx: int, cfg: AppConfig): in_dim_new = data.x.shape[1] results = [] probs_g, note_g = adapt_and_predict(bundles["gat"], in_dim_new, data, cfg) thr_g = float(bundles["gat"]["threshold"]) label_g = int(probs_g[center_idx] >= thr_g) probs_v2, note_v2 = adapt_and_predict(bundles["gatv2"], in_dim_new, data, cfg) thr_v2 = float(bundles["gatv2"]["threshold"]) label_v2 = int(probs_v2[center_idx] >= thr_v2) return [ ("GAT", probs_g, thr_g, label_g, note_g), ("GATv2", probs_v2, thr_v2, label_v2, note_v2), ]