Spaces:
Sleeping
Sleeping
File size: 2,076 Bytes
648720e 136190c 648720e 136190c 648720e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | import json
import os
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from config import HF_MODEL_REPO, HF_TOKEN, MODEL_CACHE_DIR, NODE_DIM, EDGE_DIM, HIDDEN_DIM, NUM_TASKS, DROPOUT
from .multitask_gnn import MultiTaskGNN_ResGATv2_JK_VN
class ModelLoader:
def __init__(self):
self.model = None
self.config = None
def load_config(self) -> dict:
if self.config is not None:
return self.config
try:
path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename="model_config.json",
token=HF_TOKEN or None,
cache_dir=MODEL_CACHE_DIR,
)
with open(path) as f:
self.config = json.load(f)
except Exception:
self.config = {
"node_dim": NODE_DIM,
"edge_dim": EDGE_DIM,
"hidden_dim": HIDDEN_DIM,
"num_tasks": NUM_TASKS,
"dropout": DROPOUT,
}
return self.config
def load_model(self) -> torch.nn.Module:
if self.model is not None:
return self.model
config = self.load_config()
model = MultiTaskGNN_ResGATv2_JK_VN(
in_channels=config.get("node_dim", NODE_DIM),
edge_dim=config.get("edge_dim", EDGE_DIM),
hidden_dim=config.get("hidden_dim", HIDDEN_DIM),
num_tasks=config.get("num_tasks", NUM_TASKS),
dropout=config.get("dropout", DROPOUT),
)
try:
path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename="model.safetensors",
token=HF_TOKEN or None,
cache_dir=MODEL_CACHE_DIR,
)
state_dict = load_file(path)
model.load_state_dict(state_dict)
except Exception:
pass
model.eval()
self.model = model
return model
def is_loaded(self) -> bool:
return self.model is not None
|