from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import torch class EndpointHandler: def __init__(self, model_dir): """ Inicializa el handler con el modelo y tokenizador. """ # Cargar el tokenizador y el modelo desde el directorio proporcionado self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) self.model.eval() # Poner el modelo en modo evaluación def preprocess(self, data): """ Preprocesa los datos de entrada para el modelo. """ # Validar entrada if not isinstance(data, dict) or "inputs" not in data: raise ValueError("Entrada inválida. Debe ser un diccionario con la clave 'inputs'.") input_text = f"Generate a valid JSON capturing data from this text: {data['inputs']}" # Tokenizar entrada tokens = self.tokenizer( input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=512 ) return tokens def inference(self, inputs): """ Realiza la inferencia con el modelo. """ generate_kwargs = { "max_length": 512, "num_beams": 5, "do_sample": False, "temperature": 0.7, "top_k": 50, "top_p": 0.9, "repetition_penalty": 2.0, "early_stopping": True # Asegurar que no sea None } with torch.no_grad(): outputs = self.model.generate(**inputs, **generate_kwargs) return outputs def postprocess(self, model_outputs): """ Procesa las salidas del modelo para devolver resultados. """ # Decodificar la salida generada por el modelo decoded_output = self.tokenizer.decode(model_outputs[0], skip_special_tokens=True) return {"response": decoded_output} def __call__(self, data): """ Ejecuta el pipeline de preprocesamiento, inferencia y postprocesamiento. """ # Preprocesar entrada tokens = self.preprocess(data) # Realizar inferencia model_outputs = self.inference(tokens) # Postprocesar y devolver resultados return self.postprocess(model_outputs)