toxipredict-api / models /loader.py
Arko006's picture
fix: update model architecture, features, config to match trained model
136190c verified
Raw
History Blame Contribute Delete
2.08 kB
import json
import os
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from config import HF_MODEL_REPO, HF_TOKEN, MODEL_CACHE_DIR, NODE_DIM, EDGE_DIM, HIDDEN_DIM, NUM_TASKS, DROPOUT
from .multitask_gnn import MultiTaskGNN_ResGATv2_JK_VN
class ModelLoader:
def __init__(self):
self.model = None
self.config = None
def load_config(self) -> dict:
if self.config is not None:
return self.config
try:
path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename="model_config.json",
token=HF_TOKEN or None,
cache_dir=MODEL_CACHE_DIR,
)
with open(path) as f:
self.config = json.load(f)
except Exception:
self.config = {
"node_dim": NODE_DIM,
"edge_dim": EDGE_DIM,
"hidden_dim": HIDDEN_DIM,
"num_tasks": NUM_TASKS,
"dropout": DROPOUT,
}
return self.config
def load_model(self) -> torch.nn.Module:
if self.model is not None:
return self.model
config = self.load_config()
model = MultiTaskGNN_ResGATv2_JK_VN(
in_channels=config.get("node_dim", NODE_DIM),
edge_dim=config.get("edge_dim", EDGE_DIM),
hidden_dim=config.get("hidden_dim", HIDDEN_DIM),
num_tasks=config.get("num_tasks", NUM_TASKS),
dropout=config.get("dropout", DROPOUT),
)
try:
path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename="model.safetensors",
token=HF_TOKEN or None,
cache_dir=MODEL_CACHE_DIR,
)
state_dict = load_file(path)
model.load_state_dict(state_dict)
except Exception:
pass
model.eval()
self.model = model
return model
def is_loaded(self) -> bool:
return self.model is not None