|
|
| from typing import Any, Dict, List
|
| import os
|
| from unsloth import FastLanguageModel
|
|
|
| class EndpointHandler:
|
| def __init__(self, model_id: str):
|
|
|
| max_seq = int(os.getenv("MAX_SEQ_LENGTH", 1024))
|
| self.model, self.tokenizer = FastLanguageModel.from_pretrained(
|
| model_id,
|
| max_seq_length = max_seq,
|
| load_in_4bit = True,
|
| )
|
|
|
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
| """
|
| data: {"inputs": "<str>"} or {"inputs": ["<str>", ...]}
|
| returns: [{"generated_text": "<str>"}, ...]
|
| """
|
| inputs = data.get("inputs", data)
|
| if isinstance(inputs, str):
|
| prompts = [inputs]
|
| elif isinstance(inputs, list):
|
| prompts = inputs
|
| else:
|
| raise ValueError(f"Unsupported inputs type: {type(inputs)}")
|
|
|
| outputs: List[Dict[str, Any]] = []
|
| for prompt in prompts:
|
|
|
| out = self.model.generate(
|
| prompt,
|
| max_new_tokens = int(os.getenv("MAX_NEW_TOKENS", 64)),
|
| pad_token_id = self.tokenizer.eos_token_id,
|
| )
|
| outputs.append({"generated_text": out})
|
| return outputs
|
|
|