|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, List, Union |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
""" |
|
|
HF passes `path` pointing to the repo files mounted in the container. |
|
|
We load tokenizer + model from that path once at cold start. |
|
|
""" |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True) |
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(path) |
|
|
self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
self.gen_defaults = { |
|
|
"max_new_tokens": 128, |
|
|
"do_sample": False, |
|
|
"temperature": 1.0, |
|
|
"top_p": 1.0, |
|
|
"num_beams": 1, |
|
|
} |
|
|
|
|
|
def _generate(self, texts: List[str], params: Dict[str, Any]) -> List[str]: |
|
|
p = {**self.gen_defaults, **(params or {})} |
|
|
enc = self.tokenizer( |
|
|
texts, return_tensors="pt", padding=True, truncation=True |
|
|
).to(self.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
out_ids = self.model.generate( |
|
|
**enc, |
|
|
max_new_tokens=int(p["max_new_tokens"]), |
|
|
do_sample=bool(p["do_sample"]), |
|
|
temperature=float(p["temperature"]), |
|
|
top_p=float(p["top_p"]), |
|
|
num_beams=int(p["num_beams"]), |
|
|
) |
|
|
return self.tokenizer.batch_decode(out_ids, skip_special_tokens=True) |
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
""" |
|
|
Accepts several common payload shapes and returns: |
|
|
{"outputs": ["str", ...]} |
|
|
""" |
|
|
if data is None: |
|
|
return {"error": "No payload provided."} |
|
|
|
|
|
|
|
|
raw = data.get("inputs") or data.get("input") or data.get("texts") |
|
|
if raw is None: |
|
|
return {"error": "Provide 'inputs' (str or list of str)."} |
|
|
|
|
|
|
|
|
if isinstance(raw, str): |
|
|
texts = [raw] |
|
|
elif isinstance(raw, list) and all(isinstance(x, str) for x in raw): |
|
|
texts = raw |
|
|
else: |
|
|
return {"error": "inputs must be str or list[str]."} |
|
|
|
|
|
params = data.get("parameters", {}) |
|
|
outputs = self._generate(texts, params) |
|
|
return {"outputs": outputs} |
|
|
|