File size: 1,499 Bytes
474b76e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | """HuggingFace Inference API handler for text normalization.
This enables the model to work with the HuggingFace Inference API
and the `text2text-generation` pipeline.
"""
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
class EndpointHandler:
def __init__(self, path: str = ""):
self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
self.tokenizer = AutoTokenizer.from_pretrained("google/byt5-base")
self.model.eval()
def __call__(self, data):
"""Handle inference request.
Expected input format:
{"inputs": "<de> Das kostet 12,50 €."}
or:
{"inputs": "Das kostet 12,50 €.", "parameters": {"language": "de"}}
"""
inputs = data.get("inputs", "")
params = data.get("parameters", {})
# If language is passed separately, add the prefix
if not inputs.startswith("<") and "language" in params:
inputs = f"<{params['language']}> {inputs}"
tokenized = self.tokenizer(
inputs, return_tensors="pt", max_length=512, truncation=True
)
import torch
with torch.no_grad():
output = self.model.generate(
**tokenized,
max_new_tokens=params.get("max_new_tokens", 512),
num_beams=params.get("num_beams", 1),
)
result = self.tokenizer.decode(output[0], skip_special_tokens=True)
return [{"generated_text": result}]
|