File size: 3,615 Bytes
b79dedc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""Custom inference handler for Hugging Face Inference Endpoints.

NLLB needs a source-language code on the tokenizer and a forced BOS token
id for the target language at generation time, so the default translation
pipeline is not flexible enough. This handler accepts `src_lang` and
`tgt_lang` (NLLB Flores-200 codes, e.g. "eng_Latn", "spa_Latn") per
request.

Request format:
    {
      "inputs": "Hello, world!",            # str or List[str]
      "parameters": {
        "src_lang": "eng_Latn",             # optional, default eng_Latn
        "tgt_lang": "spa_Latn",             # optional, default spa_Latn
        "max_length": 256,                  # optional
        "num_beams": 4,                     # optional
        "temperature": 1.0,                 # optional
        "do_sample": false                  # optional
      }
    }

Response: List[{"translation_text": str}]
"""

from __future__ import annotations

from typing import Any, Dict, List, Union

import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

DEFAULT_SRC_LANG = "eng_Latn"
DEFAULT_TGT_LANG = "spa_Latn"
DEFAULT_MAX_LENGTH = 256
DEFAULT_NUM_BEAMS = 4


class EndpointHandler:
    def __init__(self, path: str = "") -> None:
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        # fp16 on GPU keeps latency and memory down; stay in fp32 on CPU for stability.
        dtype = torch.float16 if self.device == "cuda" else torch.float32

        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(
            path, torch_dtype=dtype
        ).to(self.device)
        self.model.eval()

    def __call__(
        self, data: Dict[str, Any]
    ) -> List[Dict[str, str]]:
        inputs: Union[str, List[str], None] = data.get("inputs")
        if inputs is None:
            return [{"error": "Missing 'inputs' field."}]
        if isinstance(inputs, str):
            inputs = [inputs]
        if not all(isinstance(x, str) for x in inputs):
            return [{"error": "'inputs' must be a string or a list of strings."}]

        params: Dict[str, Any] = data.get("parameters") or {}
        src_lang = params.get("src_lang", DEFAULT_SRC_LANG)
        tgt_lang = params.get("tgt_lang", DEFAULT_TGT_LANG)
        max_length = int(params.get("max_length", DEFAULT_MAX_LENGTH))
        num_beams = int(params.get("num_beams", DEFAULT_NUM_BEAMS))
        do_sample = bool(params.get("do_sample", False))
        temperature = float(params.get("temperature", 1.0))

        try:
            forced_bos_token_id = self.tokenizer.convert_tokens_to_ids(tgt_lang)
        except Exception:
            return [{"error": f"Unknown target language code: {tgt_lang!r}"}]
        if forced_bos_token_id == self.tokenizer.unk_token_id:
            return [{"error": f"Unknown target language code: {tgt_lang!r}"}]

        self.tokenizer.src_lang = src_lang
        encoded = self.tokenizer(
            inputs,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=max_length,
        ).to(self.device)

        with torch.inference_mode():
            generated = self.model.generate(
                **encoded,
                forced_bos_token_id=forced_bos_token_id,
                max_length=max_length,
                num_beams=num_beams,
                do_sample=do_sample,
                temperature=temperature,
            )

        decoded = self.tokenizer.batch_decode(generated, skip_special_tokens=True)
        return [{"translation_text": t} for t in decoded]