tomzhengy's picture
Upload folder using huggingface_hub
a9b2f29 verified
from typing import Dict, Any
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
class EndpointHandler:
def __init__(self, path: str = ""):
# load model and tokenizer from path
self.tokenizer = AutoTokenizer.from_pretrained(path)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16,
device_map="auto",
)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Args:
data: dictionary with 'inputs' key containing the prompt text
optional keys:
- max_new_tokens: max tokens to generate (default 512)
- temperature: sampling temperature (default 0.7)
- top_p: nucleus sampling probability (default 0.9)
- do_sample: whether to sample (default True)
Returns:
dictionary with 'generated_text' key
"""
# extract inputs
inputs = data.pop("inputs", data)
# generation parameters
max_new_tokens = data.pop("max_new_tokens", 512)
temperature = data.pop("temperature", 0.7)
top_p = data.pop("top_p", 0.9)
do_sample = data.pop("do_sample", True)
# tokenize
input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids.to(self.model.device)
# generate
with torch.no_grad():
outputs = self.model.generate(
input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature if do_sample else None,
top_p=top_p if do_sample else None,
do_sample=do_sample,
pad_token_id=self.tokenizer.eos_token_id,
)
# decode only the new tokens
generated_tokens = outputs[0][input_ids.shape[1]:]
generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
return {"generated_text": generated_text}