| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import torch |
| import json |
|
|
|
|
| model_name = "jla25/squareV3" |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, model_dir): |
| self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) |
| self.model.eval() |
|
|
| def preprocess(self, data): |
| if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None: |
| raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.") |
|
|
| |
| input_text = f"Generate a valid JSON capturing data from this text:{data['inputs']}" |
| print(f"Prompt generado para el modelo: {input_text}") |
| input_text = input_text.encode("utf-8").decode("utf-8") |
| tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024) |
| return tokens |
|
|
| def inference(self, tokens): |
| generate_kwargs = { |
| "max_length": 512, |
| "num_beams": 5, |
| "do_sample": False, |
| "temperature": 0.3, |
| "top_k": 50, |
| "top_p": 0.8, |
| "early_stopping": True, |
| "repetition_penalty": 2.5 |
| } |
| with torch.no_grad(): |
| outputs = self.model.generate(**tokens, **generate_kwargs) |
| return outputs |
|
|
| def clean_output(self, output): |
| try: |
| start_index = output.index("{") |
| end_index = output.rindex("}") + 1 |
| return output[start_index:end_index] |
| except ValueError: |
| return output |
|
|
| def postprocess(self, outputs): |
| decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| cleaned_output = self.clean_output(decoded_output) |
|
|
| |
| print(f"Texto generado por el modelo: {decoded_output}") |
| print(f"JSON limpiado: {cleaned_output}") |
|
|
| return {"response": cleaned_output} |
|
|
| def __call__(self, data): |
| tokens = self.preprocess(data) |
| outputs = self.inference(tokens) |
| result = self.postprocess(outputs) |
| return result |
|
|
|
|
| |
| handler = EndpointHandler(model_name) |
|
|