| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import torch |
|
|
| class EndpointHandler: |
| def __init__(self, model_dir): |
| """ |
| Inicializa el handler con el modelo y tokenizador. |
| """ |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) |
| self.model.eval() |
|
|
| def preprocess(self, data): |
| """ |
| Preprocesa los datos de entrada para el modelo. |
| """ |
| |
| if not isinstance(data, dict) or "inputs" not in data: |
| raise ValueError("Entrada inválida. Debe ser un diccionario con la clave 'inputs'.") |
|
|
| input_text = f"Generate a valid JSON capturing data from this text: {data['inputs']}" |
| |
| tokens = self.tokenizer( |
| input_text, |
| return_tensors="pt", |
| truncation=True, |
| padding="max_length", |
| max_length=512 |
| ) |
| return tokens |
|
|
| def inference(self, inputs): |
| """ |
| Realiza la inferencia con el modelo. |
| """ |
| generate_kwargs = { |
| "max_length": 512, |
| "num_beams": 5, |
| "do_sample": False, |
| "temperature": 0.7, |
| "top_k": 50, |
| "top_p": 0.9, |
| "repetition_penalty": 2.0, |
| "early_stopping": True |
| } |
| with torch.no_grad(): |
| outputs = self.model.generate(**inputs, **generate_kwargs) |
| return outputs |
|
|
| def postprocess(self, model_outputs): |
| """ |
| Procesa las salidas del modelo para devolver resultados. |
| """ |
| |
| decoded_output = self.tokenizer.decode(model_outputs[0], skip_special_tokens=True) |
| return {"response": decoded_output} |
|
|
| def __call__(self, data): |
| """ |
| Ejecuta el pipeline de preprocesamiento, inferencia y postprocesamiento. |
| """ |
| |
| tokens = self.preprocess(data) |
| |
| model_outputs = self.inference(tokens) |
| |
| return self.postprocess(model_outputs) |
|
|