File size: 608 Bytes
cf02581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from __future__ import annotations

from typing import Any, Dict

try:
    from .runtime import JNUTSBRuntime
except ImportError:  # pragma: no cover
    from runtime import JNUTSBRuntime


class EndpointHandler:
    """Hugging Face Inference Endpoint custom handler."""

    def __init__(self, model_dir: str, **kwargs: Any) -> None:
        self.runtime = JNUTSBRuntime.from_config_dir(model_dir)

    def __call__(self, data: Dict[str, Any]) -> Any:
        inputs = data.get("inputs", data)
        parameters = data.get("parameters", {})
        return self.runtime.predict(inputs=inputs, **parameters)