LLM-2 / handler.py
Magjot's picture
Upload 2 files
b798577 verified
from transformers import AutoTokenizer, AutoModelForCausalLM, TextGenerationPipeline
import torch
class EndpointHandler:
def __init__(self, path=""):
# Load tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(path)
# Load model with device_map="auto" for Accelerate (used for LoRA or big models)
self.model = AutoModelForCausalLM.from_pretrained(
path,
device_map="auto", # this loads with accelerate
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
# DO NOT pass device=... — this causes the crash you're seeing
self.pipeline = TextGenerationPipeline(
model=self.model,
tokenizer=self.tokenizer
)
def __call__(self, data):
prompt = data.get("inputs", "")
parameters = data.get("parameters", {}) or {}
generation_args = {
"max_new_tokens": parameters.get("max_new_tokens", 128),
"temperature": parameters.get("temperature", 0.7),
"top_p": parameters.get("top_p", 0.9),
"do_sample": parameters.get("do_sample", True),
"eos_token_id": self.tokenizer.eos_token_id,
}
outputs = self.pipeline(prompt, **generation_args)
return {"generated_text": outputs[0]["generated_text"]}