|
|
|
|
|
import os |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, model_dir: str): |
|
|
|
|
|
|
|
|
|
|
|
candidates = [ |
|
|
d for d in os.listdir(model_dir) |
|
|
if os.path.isdir(os.path.join(model_dir, d)) |
|
|
and os.path.exists(os.path.join(model_dir, d, "config.json")) |
|
|
] |
|
|
if len(candidates) == 1: |
|
|
real_dir = os.path.join(model_dir, candidates[0]) |
|
|
else: |
|
|
real_dir = model_dir |
|
|
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(real_dir) |
|
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(real_dir) |
|
|
|
|
|
self.generator = pipeline( |
|
|
"text2text-generation", |
|
|
model=self.model, |
|
|
tokenizer=self.tokenizer, |
|
|
device=0, |
|
|
max_new_tokens=500, |
|
|
temperature=0.7 |
|
|
) |
|
|
|
|
|
def __call__(self, payload: dict) -> list: |
|
|
text = payload.get("inputs", "") |
|
|
params = payload.get("parameters", {}) |
|
|
|
|
|
return self.generator(text, **params) |
|
|
|