| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import torch |
| import json |
|
|
| class EndpointHandler: |
| def __init__(self, model_dir): |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) |
| self.model.eval() |
|
|
| def preprocess(self, data): |
| |
| if isinstance(data, dict) and "inputs" in data: |
| input_text = "Generate a valid JSON capturing data from this text: " + data["inputs"] |
| else: |
| raise ValueError("Esperando un diccionario con la clave 'inputs'") |
|
|
| |
| tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True) |
| return tokens |
|
|
| def inference(self, tokens): |
| |
| with torch.no_grad(): |
| outputs = self.model.generate(**tokens) |
| return outputs |
|
|
| def postprocess(self, outputs): |
| |
| decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| return {"generated_text": decoded_output} |
|
|
| def __call__(self, data): |
| |
| tokens = self.preprocess(data) |
| outputs = self.inference(tokens) |
| result = self.postprocess(outputs) |
| return result |
|
|