from typing import Dict, Any import torch from transformers import AutoModelForCausalLM, AutoTokenizer class EndpointHandler: def __init__(self, path: str = ""): # load model and tokenizer from path self.tokenizer = AutoTokenizer.from_pretrained(path) self.model = AutoModelForCausalLM.from_pretrained( path, torch_dtype=torch.bfloat16, device_map="auto", ) self.model.eval() def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Args: data: dictionary with 'inputs' key containing the prompt text optional keys: - max_new_tokens: max tokens to generate (default 512) - temperature: sampling temperature (default 0.7) - top_p: nucleus sampling probability (default 0.9) - do_sample: whether to sample (default True) Returns: dictionary with 'generated_text' key """ # extract inputs inputs = data.pop("inputs", data) # generation parameters max_new_tokens = data.pop("max_new_tokens", 512) temperature = data.pop("temperature", 0.7) top_p = data.pop("top_p", 0.9) do_sample = data.pop("do_sample", True) # tokenize input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids.to(self.model.device) # generate with torch.no_grad(): outputs = self.model.generate( input_ids, max_new_tokens=max_new_tokens, temperature=temperature if do_sample else None, top_p=top_p if do_sample else None, do_sample=do_sample, pad_token_id=self.tokenizer.eos_token_id, ) # decode only the new tokens generated_tokens = outputs[0][input_ids.shape[1]:] generated_text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True) return {"generated_text": generated_text}