File size: 1,384 Bytes
11f4a01 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | # handler.py
from typing import Any, Dict, List
import os
from unsloth import FastLanguageModel
class EndpointHandler:
def __init__(self, model_id: str):
# Called once at endpoint startup with your model repo ID/path
max_seq = int(os.getenv("MAX_SEQ_LENGTH", 1024))
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_id,
max_seq_length = max_seq,
load_in_4bit = True,
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data: {"inputs": "<str>"} or {"inputs": ["<str>", ...]}
returns: [{"generated_text": "<str>"}, ...]
"""
inputs = data.get("inputs", data)
if isinstance(inputs, str):
prompts = [inputs]
elif isinstance(inputs, list):
prompts = inputs
else:
raise ValueError(f"Unsupported inputs type: {type(inputs)}")
outputs: List[Dict[str, Any]] = []
for prompt in prompts:
# generate one response per prompt
out = self.model.generate(
prompt,
max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", 64)),
pad_token_id = self.tokenizer.eos_token_id,
)
outputs.append({"generated_text": out})
return outputs
|