jla25 commited on
Commit
d6ed5c7
·
verified ·
1 Parent(s): 23c6a94

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +3 -3
handler.py CHANGED
@@ -20,7 +20,7 @@ class EndpointHandler:
20
  raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.")
21
 
22
  # Prompt personalizado para guiar al modelo
23
- input_text = f"{data['inputs']}"
24
  print(f"Prompt generado para el modelo: {input_text}")
25
 
26
  tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
@@ -30,10 +30,10 @@ class EndpointHandler:
30
  generate_kwargs = {
31
  "max_length": 1024,
32
  "num_beams": 5,
33
- "do_sample": False,
34
  "temperature": 0.3,
35
  "top_k": 50,
36
- "top_p": 0.7,
37
  "repetition_penalty": 2.5
38
  }
39
  with torch.no_grad():
 
20
  raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor válido.")
21
 
22
  # Prompt personalizado para guiar al modelo
23
+ input_text = f"Generate a valid JSON capturing data from this text:{data['inputs']}"
24
  print(f"Prompt generado para el modelo: {input_text}")
25
 
26
  tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
 
30
  generate_kwargs = {
31
  "max_length": 1024,
32
  "num_beams": 5,
33
+ "do_sample": True,
34
  "temperature": 0.3,
35
  "top_k": 50,
36
+ "top_p": 0.8,
37
  "repetition_penalty": 2.5
38
  }
39
  with torch.no_grad():