| """Custom inference handler for Hugging Face Inference Endpoints. |
| |
| NLLB needs a source-language code on the tokenizer and a forced BOS token |
| id for the target language at generation time, so the default translation |
| pipeline is not flexible enough. This handler accepts `src_lang` and |
| `tgt_lang` (NLLB Flores-200 codes, e.g. "eng_Latn", "spa_Latn") per |
| request. |
| |
| Request format: |
| { |
| "inputs": "Hello, world!", # str or List[str] |
| "parameters": { |
| "src_lang": "eng_Latn", # optional, default eng_Latn |
| "tgt_lang": "spa_Latn", # optional, default spa_Latn |
| "max_length": 256, # optional |
| "num_beams": 4, # optional |
| "temperature": 1.0, # optional |
| "do_sample": false # optional |
| } |
| } |
| |
| Response: List[{"translation_text": str}] |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Any, Dict, List, Union |
|
|
| import torch |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
|
| DEFAULT_SRC_LANG = "eng_Latn" |
| DEFAULT_TGT_LANG = "spa_Latn" |
| DEFAULT_MAX_LENGTH = 256 |
| DEFAULT_NUM_BEAMS = 4 |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = "") -> None: |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| dtype = torch.float16 if self.device == "cuda" else torch.float32 |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(path) |
| self.model = AutoModelForSeq2SeqLM.from_pretrained( |
| path, torch_dtype=dtype |
| ).to(self.device) |
| self.model.eval() |
|
|
| def __call__( |
| self, data: Dict[str, Any] |
| ) -> List[Dict[str, str]]: |
| inputs: Union[str, List[str], None] = data.get("inputs") |
| if inputs is None: |
| return [{"error": "Missing 'inputs' field."}] |
| if isinstance(inputs, str): |
| inputs = [inputs] |
| if not all(isinstance(x, str) for x in inputs): |
| return [{"error": "'inputs' must be a string or a list of strings."}] |
|
|
| params: Dict[str, Any] = data.get("parameters") or {} |
| src_lang = params.get("src_lang", DEFAULT_SRC_LANG) |
| tgt_lang = params.get("tgt_lang", DEFAULT_TGT_LANG) |
| max_length = int(params.get("max_length", DEFAULT_MAX_LENGTH)) |
| num_beams = int(params.get("num_beams", DEFAULT_NUM_BEAMS)) |
| do_sample = bool(params.get("do_sample", False)) |
| temperature = float(params.get("temperature", 1.0)) |
|
|
| try: |
| forced_bos_token_id = self.tokenizer.convert_tokens_to_ids(tgt_lang) |
| except Exception: |
| return [{"error": f"Unknown target language code: {tgt_lang!r}"}] |
| if forced_bos_token_id == self.tokenizer.unk_token_id: |
| return [{"error": f"Unknown target language code: {tgt_lang!r}"}] |
|
|
| self.tokenizer.src_lang = src_lang |
| encoded = self.tokenizer( |
| inputs, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| ).to(self.device) |
|
|
| with torch.inference_mode(): |
| generated = self.model.generate( |
| **encoded, |
| forced_bos_token_id=forced_bos_token_id, |
| max_length=max_length, |
| num_beams=num_beams, |
| do_sample=do_sample, |
| temperature=temperature, |
| ) |
|
|
| decoded = self.tokenizer.batch_decode(generated, skip_special_tokens=True) |
| return [{"translation_text": t} for t in decoded] |
|
|