File size: 2,656 Bytes
14b3daf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# handler.py
# Minimal Hugging Face Inference Endpoint handler for text2text models (e.g., FLAN-T5)
# Loads the model once at startup and serves /__call__ for inference.

from typing import Any, Dict, List, Union
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        HF passes `path` pointing to the repo files mounted in the container.
        We load tokenizer + model from that path once at cold start.
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
        self.model.to(self.device)
        self.model.eval()

        # sensible defaults—override via request "parameters"
        self.gen_defaults = {
            "max_new_tokens": 128,
            "do_sample": False,
            "temperature": 1.0,
            "top_p": 1.0,
            "num_beams": 1,
        }

    def _generate(self, texts: List[str], params: Dict[str, Any]) -> List[str]:
        p = {**self.gen_defaults, **(params or {})}
        enc = self.tokenizer(
            texts, return_tensors="pt", padding=True, truncation=True
        ).to(self.device)

        with torch.no_grad():
            out_ids = self.model.generate(
                **enc,
                max_new_tokens=int(p["max_new_tokens"]),
                do_sample=bool(p["do_sample"]),
                temperature=float(p["temperature"]),
                top_p=float(p["top_p"]),
                num_beams=int(p["num_beams"]),
            )
        return self.tokenizer.batch_decode(out_ids, skip_special_tokens=True)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        """
        Accepts several common payload shapes and returns:
        {"outputs": ["str", ...]}
        """
        if data is None:
            return {"error": "No payload provided."}

        # Accept "inputs", "input", or "texts"
        raw = data.get("inputs") or data.get("input") or data.get("texts")
        if raw is None:
            return {"error": "Provide 'inputs' (str or list of str)."}

        # Normalize to list[str]
        if isinstance(raw, str):
            texts = [raw]
        elif isinstance(raw, list) and all(isinstance(x, str) for x in raw):
            texts = raw
        else:
            return {"error": "inputs must be str or list[str]."}

        params = data.get("parameters", {})  # optional generation params
        outputs = self._generate(texts, params)
        return {"outputs": outputs}