jla25 commited on
Commit
d161eae
verified
1 Parent(s): 86fc284

Create handler.py

Browse files
Files changed (1) hide show
  1. handler.py +70 -0
handler.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ import torch
3
+ import json
4
+
5
+
6
+ model_name = "jla25/squareV4"
7
+
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
+
11
+
12
+ class EndpointHandler:
13
+ def __init__(self, model_dir):
14
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
15
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
16
+ self.model.eval()
17
+
18
+ def preprocess(self, data):
19
+ if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
20
+ raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")
21
+
22
+ # Prompt personalizado para guiar al modelo
23
+ input_text = f"Generate a valid JSON capturing data from this text:{data['inputs']}"
24
+ print(f"Prompt generado para el modelo: {input_text}")
25
+ input_text = input_text.encode("utf-8").decode("utf-8")
26
+ tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
27
+ return tokens
28
+
29
+ def inference(self, tokens):
30
+ generate_kwargs = {
31
+ "max_length": 512,
32
+ "num_beams": 5,
33
+ "do_sample": False,
34
+ "temperature": 0.3,
35
+ "top_k": 50,
36
+ "top_p": 0.8,
37
+ "early_stopping": True, # A帽adir explicitamente esta configuraci贸n
38
+ "repetition_penalty": 2.5
39
+ }
40
+ with torch.no_grad():
41
+ outputs = self.model.generate(**tokens, **generate_kwargs)
42
+ return outputs
43
+
44
+ def clean_output(self, output):
45
+ try:
46
+ start_index = output.index("{")
47
+ end_index = output.rindex("}") + 1
48
+ return output[start_index:end_index]
49
+ except ValueError:
50
+ return output
51
+
52
+ def postprocess(self, outputs):
53
+ decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
54
+ cleaned_output = self.clean_output(decoded_output)
55
+
56
+ # Imprimir siempre el texto generado para depuraci贸n
57
+ print(f"Texto generado por el modelo: {decoded_output}")
58
+ print(f"JSON limpiado: {cleaned_output}")
59
+
60
+ return {"response": cleaned_output}
61
+
62
+ def __call__(self, data):
63
+ tokens = self.preprocess(data)
64
+ outputs = self.inference(tokens)
65
+ result = self.postprocess(outputs)
66
+ return result
67
+
68
+
69
+ # Crear una instancia del handler
70
+ handler = EndpointHandler(model_name)