# handler.py # Minimal Hugging Face Inference Endpoint handler for text2text models (e.g., FLAN-T5) # Loads the model once at startup and serves /__call__ for inference. from typing import Any, Dict, List, Union import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM class EndpointHandler: def __init__(self, path: str = ""): """ HF passes `path` pointing to the repo files mounted in the container. We load tokenizer + model from that path once at cold start. """ self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True) self.model = AutoModelForSeq2SeqLM.from_pretrained(path) self.model.to(self.device) self.model.eval() # sensible defaults—override via request "parameters" self.gen_defaults = { "max_new_tokens": 128, "do_sample": False, "temperature": 1.0, "top_p": 1.0, "num_beams": 1, } def _generate(self, texts: List[str], params: Dict[str, Any]) -> List[str]: p = {**self.gen_defaults, **(params or {})} enc = self.tokenizer( texts, return_tensors="pt", padding=True, truncation=True ).to(self.device) with torch.no_grad(): out_ids = self.model.generate( **enc, max_new_tokens=int(p["max_new_tokens"]), do_sample=bool(p["do_sample"]), temperature=float(p["temperature"]), top_p=float(p["top_p"]), num_beams=int(p["num_beams"]), ) return self.tokenizer.batch_decode(out_ids, skip_special_tokens=True) def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Accepts several common payload shapes and returns: {"outputs": ["str", ...]} """ if data is None: return {"error": "No payload provided."} # Accept "inputs", "input", or "texts" raw = data.get("inputs") or data.get("input") or data.get("texts") if raw is None: return {"error": "Provide 'inputs' (str or list of str)."} # Normalize to list[str] if isinstance(raw, str): texts = [raw] elif isinstance(raw, list) and all(isinstance(x, str) for x in raw): texts = raw else: return {"error": "inputs must be str or list[str]."} params = data.get("parameters", {}) # optional generation params outputs = self._generate(texts, params) return {"outputs": outputs}