OzTianlu commited on
Commit
cd370de
·
verified ·
1 Parent(s): 6198884

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -7
handler.py CHANGED
@@ -86,19 +86,22 @@ class EndpointHandler:
86
  item = item["messages"]
87
 
88
  if _is_messages(item):
89
- # Chat template path exists in repo; tokenizer.apply_chat_template will use it if configured
90
- input_ids = self.tokenizer.apply_chat_template(
91
  item,
92
- return_tensors="pt",
93
  add_generation_prompt=True,
94
  )
 
 
 
95
  else:
96
- if not isinstance(item, str):
97
- item = str(item)
98
- enc = self.tokenizer(item, return_tensors="pt")
99
  input_ids = enc["input_ids"]
100
-
101
  input_ids = input_ids.to(self.model.device)
 
 
 
102
  input_len = input_ids.shape[-1]
103
 
104
  gen_ids = self.model.generate(
 
86
  item = item["messages"]
87
 
88
  if _is_messages(item):
89
+ rendered = self.tokenizer.apply_chat_template(
 
90
  item,
91
+ tokenize=False,
92
  add_generation_prompt=True,
93
  )
94
+ enc = self.tokenizer(rendered, return_tensors="pt")
95
+ input_ids = enc["input_ids"]
96
+ attention_mask = enc.get("attention_mask", None)
97
  else:
98
+ enc = self.tokenizer(str(item), return_tensors="pt")
 
 
99
  input_ids = enc["input_ids"]
100
+ attention_mask = enc.get("attention_mask", None)
101
  input_ids = input_ids.to(self.model.device)
102
+ if attention_mask is not None:
103
+ attention_mask = attention_mask.to(self.model.device)
104
+
105
  input_len = input_ids.shape[-1]
106
 
107
  gen_ids = self.model.generate(