SciGuru-zero / handler.py
golyuval's picture
Upload 2 files
11f4a01 verified
# handler.py
from typing import Any, Dict, List
import os
from unsloth import FastLanguageModel
class EndpointHandler:
def __init__(self, model_id: str):
# Called once at endpoint startup with your model repo ID/path
max_seq = int(os.getenv("MAX_SEQ_LENGTH", 1024))
self.model, self.tokenizer = FastLanguageModel.from_pretrained(
model_id,
max_seq_length = max_seq,
load_in_4bit = True,
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data: {"inputs": "<str>"} or {"inputs": ["<str>", ...]}
returns: [{"generated_text": "<str>"}, ...]
"""
inputs = data.get("inputs", data)
if isinstance(inputs, str):
prompts = [inputs]
elif isinstance(inputs, list):
prompts = inputs
else:
raise ValueError(f"Unsupported inputs type: {type(inputs)}")
outputs: List[Dict[str, Any]] = []
for prompt in prompts:
# generate one response per prompt
out = self.model.generate(
prompt,
max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", 64)),
pad_token_id = self.tokenizer.eos_token_id,
)
outputs.append({"generated_text": out})
return outputs