joseAndres777 commited on
Commit
c164643
·
verified ·
1 Parent(s): c0bc97c

change handler

Browse files
Files changed (1) hide show
  1. handler.py +7 -6
handler.py CHANGED
@@ -36,9 +36,9 @@ class EndpointHandler:
36
  inputs = data.get("inputs", "")
37
  parameters = data.get("parameters", {})
38
 
39
- # Prepare the conversation
40
  messages = [
41
- {"role": "system", "content": "Eres un asistente conversacional amigable especializado en conversaciones tipo WhatsApp en español."},
42
  {"role": "user", "content": inputs}
43
  ]
44
 
@@ -50,7 +50,7 @@ class EndpointHandler:
50
  add_generation_prompt=True
51
  )
52
  else:
53
- text = f"Usuario: {inputs}\nAsistente:"
54
 
55
  # Tokenize
56
  model_inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
@@ -59,11 +59,12 @@ class EndpointHandler:
59
  with torch.no_grad():
60
  outputs = self.model.generate(
61
  **model_inputs,
62
- max_new_tokens=parameters.get("max_new_tokens", 200),
63
- temperature=parameters.get("temperature", 0.7),
64
  top_p=parameters.get("top_p", 0.9),
65
  do_sample=True,
66
- pad_token_id=self.tokenizer.eos_token_id
 
67
  )
68
 
69
  # Decode response
 
36
  inputs = data.get("inputs", "")
37
  parameters = data.get("parameters", {})
38
 
39
+ # Prepare the conversation for text splitting task
40
  messages = [
41
+ {"role": "system", "content": "Split messages at natural breaks into JSON array. Common patterns: greeting+question, statement+question, topic+followup. Keep original words, only add logical splits."},
42
  {"role": "user", "content": inputs}
43
  ]
44
 
 
50
  add_generation_prompt=True
51
  )
52
  else:
53
+ text = f"Split messages at natural breaks into JSON array. Common patterns: greeting+question, statement+question, topic+followup. Keep original words, only add logical splits.\nUser: {inputs}\nAssistant:"
54
 
55
  # Tokenize
56
  model_inputs = self.tokenizer(text, return_tensors="pt").to(self.model.device)
 
59
  with torch.no_grad():
60
  outputs = self.model.generate(
61
  **model_inputs,
62
+ max_new_tokens=parameters.get("max_new_tokens", 100),
63
+ temperature=parameters.get("temperature", 0.3), # Lower for consistent JSON format
64
  top_p=parameters.get("top_p", 0.9),
65
  do_sample=True,
66
+ pad_token_id=self.tokenizer.eos_token_id,
67
+ repetition_penalty=1.1
68
  )
69
 
70
  # Decode response