import json import os import torch from huggingface_hub import hf_hub_download from safetensors.torch import load_file from config import HF_MODEL_REPO, HF_TOKEN, MODEL_CACHE_DIR, NODE_DIM, EDGE_DIM, HIDDEN_DIM, NUM_TASKS, DROPOUT from .multitask_gnn import MultiTaskGNN_ResGATv2_JK_VN class ModelLoader: def __init__(self): self.model = None self.config = None def load_config(self) -> dict: if self.config is not None: return self.config try: path = hf_hub_download( repo_id=HF_MODEL_REPO, filename="model_config.json", token=HF_TOKEN or None, cache_dir=MODEL_CACHE_DIR, ) with open(path) as f: self.config = json.load(f) except Exception: self.config = { "node_dim": NODE_DIM, "edge_dim": EDGE_DIM, "hidden_dim": HIDDEN_DIM, "num_tasks": NUM_TASKS, "dropout": DROPOUT, } return self.config def load_model(self) -> torch.nn.Module: if self.model is not None: return self.model config = self.load_config() model = MultiTaskGNN_ResGATv2_JK_VN( in_channels=config.get("node_dim", NODE_DIM), edge_dim=config.get("edge_dim", EDGE_DIM), hidden_dim=config.get("hidden_dim", HIDDEN_DIM), num_tasks=config.get("num_tasks", NUM_TASKS), dropout=config.get("dropout", DROPOUT), ) try: path = hf_hub_download( repo_id=HF_MODEL_REPO, filename="model.safetensors", token=HF_TOKEN or None, cache_dir=MODEL_CACHE_DIR, ) state_dict = load_file(path) model.load_state_dict(state_dict) except Exception: pass model.eval() self.model = model return model def is_loaded(self) -> bool: return self.model is not None