| from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
|
| import torch
|
|
|
| class EndpointHandler:
|
| def __init__(self, path=""):
|
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(path)
|
|
|
|
|
| self.model = AutoModelForCausalLM.from_pretrained(
|
| path,
|
| device_map="auto",
|
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
|
| )
|
|
|
|
|
| self.pipeline = TextGenerationPipeline(
|
| model=self.model,
|
| tokenizer=self.tokenizer
|
| )
|
|
|
| def __call__(self, data):
|
| prompt = data.get("inputs", "")
|
| parameters = data.get("parameters", {}) or {}
|
|
|
| generation_args = {
|
| "max_new_tokens": parameters.get("max_new_tokens", 128),
|
| "temperature": parameters.get("temperature", 0.7),
|
| "top_p": parameters.get("top_p", 0.9),
|
| "do_sample": parameters.get("do_sample", True),
|
| "eos_token_id": self.tokenizer.eos_token_id,
|
| }
|
|
|
| outputs = self.pipeline(prompt, **generation_args)
|
| return {"generated_text": outputs[0]["generated_text"]}
|
|
|