| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| import torch | |
| class EndpointHandler: | |
| def __init__(self, path="google/flan-t5-large"): | |
| self.tokenizer = AutoTokenizer.from_pretrained(path) | |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(path) | |
| def __call__(self, data): | |
| """ | |
| Args: | |
| data: (dict): A dictionary with a "inputs" key containing the text to process | |
| """ | |
| inputs = data.pop("inputs", data) | |
| # Parameters for text generation | |
| parameters = { | |
| "max_length": 512, | |
| "min_length": 32, | |
| "temperature": 0.9, | |
| "top_p": 0.95, | |
| "top_k": 50, | |
| "do_sample": True, | |
| "num_return_sequences": 1 | |
| } | |
| # Update parameters if provided in the request | |
| parameters.update(data) | |
| # Tokenize the input | |
| input_ids = self.tokenizer(inputs, return_tensors="pt").input_ids | |
| # Generate the response | |
| outputs = self.model.generate(input_ids, **parameters) | |
| # Decode the response | |
| generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return {"generated_text": generated_text} |