File size: 1,809 Bytes
cca395f 3fcc17e cca395f 3fcc17e cca395f 3fcc17e cca395f 3fcc17e cca395f 3fcc17e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class EndpointHandler:
def __init__(self, path=""):
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16,
device_map="auto"
)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
# Handle chat format
if isinstance(inputs, list) and len(inputs) > 0 and isinstance(inputs[0], dict):
text = self.tokenizer.apply_chat_template(
inputs,
tokenize=False,
add_generation_prompt=True
)
else:
text = inputs
encoded = self.tokenizer(text, return_tensors="pt").to(self.model.device)
# Default generation parameters
gen_kwargs = {
"max_new_tokens": parameters.get("max_new_tokens", 512),
"temperature": parameters.get("temperature", 0.7),
"top_p": parameters.get("top_p", 0.9),
"do_sample": parameters.get("do_sample", True),
"pad_token_id": self.tokenizer.eos_token_id,
}
with torch.no_grad():
outputs = self.model.generate(**encoded, **gen_kwargs)
decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Return only the generated part (remove input)
if isinstance(inputs, str):
generated = decoded[len(inputs):].strip()
else:
generated = decoded
return [{"generated_text": generated}]
|