Update handler.py
Browse files- handler.py +22 -8
handler.py
CHANGED
|
@@ -4,20 +4,34 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
|
| 4 |
|
| 5 |
class EndpointHandler:
|
| 6 |
def __init__(self, model_dir: str):
|
| 7 |
-
#
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
#
|
|
|
|
|
|
|
|
|
|
| 12 |
self.generator = pipeline(
|
| 13 |
"text2text-generation",
|
| 14 |
-
model=model,
|
| 15 |
-
tokenizer=tokenizer,
|
| 16 |
-
device=0
|
|
|
|
|
|
|
| 17 |
)
|
| 18 |
|
| 19 |
def __call__(self, payload: dict) -> list:
|
| 20 |
-
# receive {"inputs": "...", "parameters": {...}}
|
| 21 |
text = payload.get("inputs", "")
|
| 22 |
params = payload.get("parameters", {})
|
|
|
|
| 23 |
return self.generator(text, **params)
|
|
|
|
| 4 |
|
| 5 |
class EndpointHandler:
|
| 6 |
def __init__(self, model_dir: str):
|
| 7 |
+
# some repos upload into a subfolder; detect that:
|
| 8 |
+
# if there's exactly one directory in model_dir that itself has config.json,
|
| 9 |
+
# assume that's the real checkpoint folder.
|
| 10 |
+
candidates = [
|
| 11 |
+
d for d in os.listdir(model_dir)
|
| 12 |
+
if os.path.isdir(os.path.join(model_dir, d))
|
| 13 |
+
and os.path.exists(os.path.join(model_dir, d, "config.json"))
|
| 14 |
+
]
|
| 15 |
+
if len(candidates) == 1:
|
| 16 |
+
real_dir = os.path.join(model_dir, candidates[0])
|
| 17 |
+
else:
|
| 18 |
+
real_dir = model_dir
|
| 19 |
|
| 20 |
+
# now load from the folder that *actually* has your fine‑tuned files
|
| 21 |
+
self.tokenizer = AutoTokenizer.from_pretrained(real_dir)
|
| 22 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained(real_dir)
|
| 23 |
+
# build the pipeline on GPU if available
|
| 24 |
self.generator = pipeline(
|
| 25 |
"text2text-generation",
|
| 26 |
+
model=self.model,
|
| 27 |
+
tokenizer=self.tokenizer,
|
| 28 |
+
device=0, # GPU
|
| 29 |
+
max_new_tokens=500, # defaults you want in every call
|
| 30 |
+
temperature=0.7
|
| 31 |
)
|
| 32 |
|
| 33 |
def __call__(self, payload: dict) -> list:
|
|
|
|
| 34 |
text = payload.get("inputs", "")
|
| 35 |
params = payload.get("parameters", {})
|
| 36 |
+
# Calls the pipeline with whatever per-call overrides you pass
|
| 37 |
return self.generator(text, **params)
|