ClergeF's picture
Create handler.py
14b3daf verified
# handler.py
# Minimal Hugging Face Inference Endpoint handler for text2text models (e.g., FLAN-T5)
# Loads the model once at startup and serves /__call__ for inference.
from typing import Any, Dict, List, Union
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
class EndpointHandler:
def __init__(self, path: str = ""):
"""
HF passes `path` pointing to the repo files mounted in the container.
We load tokenizer + model from that path once at cold start.
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True)
self.model = AutoModelForSeq2SeqLM.from_pretrained(path)
self.model.to(self.device)
self.model.eval()
# sensible defaults—override via request "parameters"
self.gen_defaults = {
"max_new_tokens": 128,
"do_sample": False,
"temperature": 1.0,
"top_p": 1.0,
"num_beams": 1,
}
def _generate(self, texts: List[str], params: Dict[str, Any]) -> List[str]:
p = {**self.gen_defaults, **(params or {})}
enc = self.tokenizer(
texts, return_tensors="pt", padding=True, truncation=True
).to(self.device)
with torch.no_grad():
out_ids = self.model.generate(
**enc,
max_new_tokens=int(p["max_new_tokens"]),
do_sample=bool(p["do_sample"]),
temperature=float(p["temperature"]),
top_p=float(p["top_p"]),
num_beams=int(p["num_beams"]),
)
return self.tokenizer.batch_decode(out_ids, skip_special_tokens=True)
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Accepts several common payload shapes and returns:
{"outputs": ["str", ...]}
"""
if data is None:
return {"error": "No payload provided."}
# Accept "inputs", "input", or "texts"
raw = data.get("inputs") or data.get("input") or data.get("texts")
if raw is None:
return {"error": "Provide 'inputs' (str or list of str)."}
# Normalize to list[str]
if isinstance(raw, str):
texts = [raw]
elif isinstance(raw, list) and all(isinstance(x, str) for x in raw):
texts = raw
else:
return {"error": "inputs must be str or list[str]."}
params = data.get("parameters", {}) # optional generation params
outputs = self._generate(texts, params)
return {"outputs": outputs}