File size: 1,431 Bytes
6718e1a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class EndpointHandler:
def __init__(self, path: str):
self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
)
def __call__(self, data):
inputs = data.get("inputs", data)
parameters = data.get("parameters", {})
if isinstance(inputs, list):
prompt = self.tokenizer.apply_chat_template(
inputs,
tokenize=False,
add_generation_prompt=True
)
else:
prompt = inputs
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.device)
max_new_tokens = parameters.get("max_new_tokens", 512)
temperature = parameters.get("temperature", 0.7)
with torch.no_grad():
outputs = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=temperature > 0,
pad_token_id=self.tokenizer.eos_token_id
)
generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
return [{"generated_text": generated}]
|