flan-t5-pilgrim-full / handler.py
chaima01's picture
Update handler.py
9d37bf0 verified
# handler.py
import os
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
class EndpointHandler:
def __init__(self, model_dir: str):
# some repos upload into a subfolder; detect that:
# if there's exactly one directory in model_dir that itself has config.json,
# assume that's the real checkpoint folder.
candidates = [
d for d in os.listdir(model_dir)
if os.path.isdir(os.path.join(model_dir, d))
and os.path.exists(os.path.join(model_dir, d, "config.json"))
]
if len(candidates) == 1:
real_dir = os.path.join(model_dir, candidates[0])
else:
real_dir = model_dir
# now load from the folder that *actually* has your fine‑tuned files
self.tokenizer = AutoTokenizer.from_pretrained(real_dir)
self.model = AutoModelForSeq2SeqLM.from_pretrained(real_dir)
# build the pipeline on GPU if available
self.generator = pipeline(
"text2text-generation",
model=self.model,
tokenizer=self.tokenizer,
device=0, # GPU
max_new_tokens=500, # defaults you want in every call
temperature=0.7
)
def __call__(self, payload: dict) -> list:
text = payload.get("inputs", "")
params = payload.get("parameters", {})
# Calls the pipeline with whatever per-call overrides you pass
return self.generator(text, **params)