File size: 2,076 Bytes
648720e
 
 
 
 
 
136190c
648720e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136190c
648720e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import json
import os
import torch
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from config import HF_MODEL_REPO, HF_TOKEN, MODEL_CACHE_DIR, NODE_DIM, EDGE_DIM, HIDDEN_DIM, NUM_TASKS, DROPOUT
from .multitask_gnn import MultiTaskGNN_ResGATv2_JK_VN


class ModelLoader:
    def __init__(self):
        self.model = None
        self.config = None

    def load_config(self) -> dict:
        if self.config is not None:
            return self.config
        try:
            path = hf_hub_download(
                repo_id=HF_MODEL_REPO,
                filename="model_config.json",
                token=HF_TOKEN or None,
                cache_dir=MODEL_CACHE_DIR,
            )
            with open(path) as f:
                self.config = json.load(f)
        except Exception:
            self.config = {
                "node_dim": NODE_DIM,
                "edge_dim": EDGE_DIM,
                "hidden_dim": HIDDEN_DIM,
                "num_tasks": NUM_TASKS,
                "dropout": DROPOUT,
            }
        return self.config

    def load_model(self) -> torch.nn.Module:
        if self.model is not None:
            return self.model

        config = self.load_config()

        model = MultiTaskGNN_ResGATv2_JK_VN(
            in_channels=config.get("node_dim", NODE_DIM),
            edge_dim=config.get("edge_dim", EDGE_DIM),
            hidden_dim=config.get("hidden_dim", HIDDEN_DIM),
            num_tasks=config.get("num_tasks", NUM_TASKS),
            dropout=config.get("dropout", DROPOUT),
        )

        try:
            path = hf_hub_download(
                repo_id=HF_MODEL_REPO,
                filename="model.safetensors",
                token=HF_TOKEN or None,
                cache_dir=MODEL_CACHE_DIR,
            )
            state_dict = load_file(path)
            model.load_state_dict(state_dict)
        except Exception:
            pass

        model.eval()
        self.model = model
        return model

    def is_loaded(self) -> bool:
        return self.model is not None