File size: 891 Bytes
ec42e29
b56dbfa
761b08f
 
 
 
ec42e29
 
 
 
 
70583d3
 
 
 
 
 
 
 
761b08f
 
 
 
 
ec42e29
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os
from huggingface_hub import hf_hub_download
from nail_classification.inference import Inference


class Model:
    def __init__(self, DEBUG):
        if DEBUG:
            base = r"C:\Users\follels\Documents\hand-ki-model-weights\DeepNAPSIModel\inference_checkpoints_v1"
            file_paths = [os.path.join(base, f"version_{v}") for v in range(10, 15)]
        else:
            file_paths = [
                hf_hub_download(
                    "lfolle/DeepNAPSIModel",
                    f"version_{v}.ckpt",
                    use_auth_token=os.environ["DeepNAPSIModel"],
                )
                for v in [10, 11, 12, 13, 14]
            ]
        self.inference = Inference(file_paths)

    def predict(self, x):
        y_hat, uncertainty = self.inference.predict(x)
        return y_hat, uncertainty

    def __call__(self, x):
        return self.predict(x)