Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| from typing import Dict, Any, Tuple, Optional | |
| import torch | |
| from huggingface_hub import snapshot_download | |
| from torch_geometric.data import Data | |
| from torch_geometric.utils import to_undirected | |
| from config import AppConfig | |
| from models import GATBaseline, GATv2Enhanced, AdapterWrapper | |
| def _load_threshold(model_dir: str, default_thr: float) -> float: | |
| for name in ["thresholds.json", "threshold.json", "config.json"]: | |
| p = os.path.join(model_dir, name) | |
| if os.path.exists(p): | |
| try: | |
| d = json.load(open(p, "r")) | |
| for k in ["threshold","default_threshold","thr","best_f1","best_j"]: | |
| if k in d and isinstance(d[k], (int, float)): | |
| return float(d[k]) | |
| except Exception: | |
| continue | |
| return default_thr | |
| def _load_scaler(model_dir: str): | |
| # Optional scaler joblib/pkl | |
| for name in ["scaler.joblib", "scaler.pkl", "elliptic_scaler.joblib", "elliptic_scaler.pkl"]: | |
| p = os.path.join(model_dir, name) | |
| if os.path.exists(p): | |
| try: | |
| import joblib | |
| return joblib.load(p) | |
| except Exception: | |
| pass | |
| return None | |
| def load_models(cfg: AppConfig): | |
| # Download both repos | |
| dir_gat = snapshot_download(cfg.HF_GAT_BASELINE_REPO, local_dir_use_symlinks=False) | |
| dir_gatv2 = snapshot_download(cfg.HF_GATV2_REPO, local_dir_use_symlinks=False) | |
| # Model files | |
| ckpt_gat = os.path.join(dir_gat, "gat_baseline_best.pt") | |
| ckpt_gatv2 = os.path.join(dir_gatv2, "gatv2_enhanced_best.pt") | |
| if not os.path.exists(ckpt_gat): | |
| raise FileNotFoundError(f"Missing model.pt in {dir_gat}") | |
| if not os.path.exists(ckpt_gatv2): | |
| raise FileNotFoundError(f"Missing model.pt in {dir_gatv2}") | |
| # Build cores (expected input dim from training) | |
| core_gat = GATBaseline(cfg.IN_CHANNELS, cfg.HIDDEN_CHANNELS, cfg.HEADS, cfg.NUM_BLOCKS, cfg.DROPOUT) | |
| core_gatv2 = GATv2Enhanced(cfg.IN_CHANNELS, cfg.HIDDEN_CHANNELS, cfg.HEADS, cfg.NUM_BLOCKS, cfg.DROPOUT) | |
| try: | |
| state_gat = torch.load(ckpt_gat, map_location="cpu", weights_only=True) | |
| except Exception: | |
| # Fallback ONLY IF checkpoint is trusted | |
| state_gat = torch.load(ckpt_gat, map_location="cpu", weights_only=False) | |
| try: | |
| state_gatv2 = torch.load(ckpt_gatv2, map_location="cpu", weights_only=True) | |
| except Exception: | |
| state_gatv2 = torch.load(ckpt_gatv2, map_location="cpu", weights_only=False) | |
| # strict load for cores | |
| core_gat.load_state_dict(state_gat, strict=True) | |
| core_gatv2.load_state_dict(state_gatv2, strict=True) | |
| thr_gat = _load_threshold(dir_gat, cfg.DEFAULT_THRESHOLD) | |
| thr_gatv2 = _load_threshold(dir_gatv2, cfg.DEFAULT_THRESHOLD) | |
| scaler_gat = _load_scaler(dir_gat) | |
| scaler_gatv2 = _load_scaler(dir_gatv2) | |
| return { | |
| "gat": {"core": core_gat.eval(), "threshold": thr_gat, "scaler": scaler_gat, "repo_dir": dir_gat}, | |
| "gatv2": {"core": core_gatv2.eval(), "threshold": thr_gatv2, "scaler": scaler_gatv2, "repo_dir": dir_gatv2}, | |
| } | |
| def predict(model, data: Data): | |
| logits = model(data.x, data.edge_index) | |
| probs = torch.sigmoid(logits).cpu().numpy() | |
| return probs | |
| def adapt_and_predict(bundle: Dict[str, Any], in_dim_new: int, data: Data, cfg: AppConfig): | |
| core = bundle["core"] | |
| if in_dim_new != cfg.IN_CHANNELS and cfg.USE_FEATURE_ADAPTER: | |
| model = AdapterWrapper(in_dim_new, cfg.IN_CHANNELS, core).eval() | |
| note = f"FeatureAdapter used (new_dim={in_dim_new} → expected={cfg.IN_CHANNELS})." | |
| elif in_dim_new != cfg.IN_CHANNELS: | |
| # attempt to run without adapter (not recommended) | |
| model = core.eval() | |
| note = f"Dimension mismatch (new_dim={in_dim_new}, expected={cfg.IN_CHANNELS}). Proceeding without adapter (may fail)." | |
| else: | |
| model = core.eval() | |
| note = "Input dim matches." | |
| probs = predict(model, data) | |
| return probs, note | |
| def run_for_both_models(bundles, data: Data, center_idx: int, cfg: AppConfig): | |
| in_dim_new = data.x.shape[1] | |
| results = [] | |
| probs_g, note_g = adapt_and_predict(bundles["gat"], in_dim_new, data, cfg) | |
| thr_g = float(bundles["gat"]["threshold"]) | |
| label_g = int(probs_g[center_idx] >= thr_g) | |
| probs_v2, note_v2 = adapt_and_predict(bundles["gatv2"], in_dim_new, data, cfg) | |
| thr_v2 = float(bundles["gatv2"]["threshold"]) | |
| label_v2 = int(probs_v2[center_idx] >= thr_v2) | |
| return [ | |
| ("GAT", probs_g, thr_g, label_g, note_g), | |
| ("GATv2", probs_v2, thr_v2, label_v2, note_v2), | |
| ] | |