File size: 909 Bytes
7d6a683
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from typing import Any, Dict

from inference_api import predict


class EndpointHandler:
    """
    Hugging Face Inference Endpoint handler.
    Expects payload:
      {
        "inputs": "dharmo rakṣati rakṣitaḥ",
        "parameters": {"temperature": 0.7, ...}
      }
    """

    def __init__(self, path: str = ""):
        self.path = path

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        text = data.get("inputs", "")
        params = data.get("parameters", {}) or {}
        return predict(
            text=text,
            temperature=params.get("temperature", 0.7),
            top_k=params.get("top_k", 40),
            repetition_penalty=params.get("repetition_penalty", 1.2),
            diversity_penalty=params.get("diversity_penalty", 0.0),
            num_steps=params.get("num_steps", 64),
            clean_output=params.get("clean_output", True),
        )