Update handler.py
Browse files- handler.py +10 -24
handler.py
CHANGED
|
@@ -4,8 +4,8 @@ import json
|
|
| 4 |
|
| 5 |
class EndpointHandler:
|
| 6 |
def __init__(self, model_dir):
|
| 7 |
-
self.tokenizer = AutoTokenizer.from_pretrained("
|
| 8 |
-
self.model = AutoModelForSeq2SeqLM.from_pretrained("
|
| 9 |
self.model.eval()
|
| 10 |
|
| 11 |
def preprocess(self, data):
|
|
@@ -15,37 +15,23 @@ class EndpointHandler:
|
|
| 15 |
# Prompt personalizado para guiar al modelo
|
| 16 |
input_text = (
|
| 17 |
f"""
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
- Opciones para 'id': firstName, lastName, jobTitle, address, email, phone, notes, roleFunction.
|
| 21 |
-
- Si 'id' es address, email o phone, debe incluir subclaves: MOBILE, WORK, PERSONAL, MAIN, OTHER.
|
| 22 |
-
- 'roleFunction' debe ser una de estas: BUYER, SELLER, SUPPLIER, PARTNER, COLLABORATOR, PROVIDER, CUSTOMER.
|
| 23 |
-
Ejemplo:
|
| 24 |
-
Entrada: "Contacté a Juan Pérez, Gerente de Finanzas."
|
| 25 |
-
Salida esperada:
|
| 26 |
-
{{
|
| 27 |
-
"values": [
|
| 28 |
-
{{"id": "firstName", "value": "Juan"}},
|
| 29 |
-
{{"id": "lastName", "value": "Pérez"}},
|
| 30 |
-
{{"id": "jobTitle", "value": "Gerente de Finanzas"}}
|
| 31 |
-
]
|
| 32 |
-
}}
|
| 33 |
-
Procesa este texto: "{data['inputs']}"
|
| 34 |
""")
|
| 35 |
# Imprimir el texto generado para el prompt
|
| 36 |
print(f"Prompt generado para el modelo: {input_text}")
|
| 37 |
-
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=
|
| 38 |
return tokens
|
| 39 |
|
| 40 |
def inference(self, tokens):
|
| 41 |
generate_kwargs = {
|
| 42 |
-
"max_length":
|
| 43 |
-
"num_beams":
|
| 44 |
"do_sample": False,
|
| 45 |
-
"temperature": 0.
|
| 46 |
-
"top_k":
|
| 47 |
"top_p": 0.7,
|
| 48 |
-
"repetition_penalty": 2.
|
| 49 |
}
|
| 50 |
with torch.no_grad():
|
| 51 |
outputs = self.model.generate(**tokens, **generate_kwargs)
|
|
|
|
| 4 |
|
| 5 |
class EndpointHandler:
|
| 6 |
def __init__(self, model_dir):
|
| 7 |
+
self.tokenizer = AutoTokenizer.from_pretrained("jla25/squareV3")
|
| 8 |
+
self.model = AutoModelForSeq2SeqLM.from_pretrained("jla25/squareV3")
|
| 9 |
self.model.eval()
|
| 10 |
|
| 11 |
def preprocess(self, data):
|
|
|
|
| 15 |
# Prompt personalizado para guiar al modelo
|
| 16 |
input_text = (
|
| 17 |
f"""
|
| 18 |
+
### Procesa el siguiente texto y genera un JSON válido:
|
| 19 |
+
"{data['inputs']}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
""")
|
| 21 |
# Imprimir el texto generado para el prompt
|
| 22 |
print(f"Prompt generado para el modelo: {input_text}")
|
| 23 |
+
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
|
| 24 |
return tokens
|
| 25 |
|
| 26 |
def inference(self, tokens):
|
| 27 |
generate_kwargs = {
|
| 28 |
+
"max_length": 1024,
|
| 29 |
+
"num_beams": 5,
|
| 30 |
"do_sample": False,
|
| 31 |
+
"temperature": 0.3,
|
| 32 |
+
"top_k": 50,
|
| 33 |
"top_p": 0.7,
|
| 34 |
+
"repetition_penalty": 2.5
|
| 35 |
}
|
| 36 |
with torch.no_grad():
|
| 37 |
outputs = self.model.generate(**tokens, **generate_kwargs)
|