yol146 commited on
Commit
290cf25
·
1 Parent(s): 72ed73b

modify the handler

Browse files
Files changed (1) hide show
  1. handler.py +8 -7
handler.py CHANGED
@@ -106,7 +106,7 @@ class EndpointHandler:
106
 
107
  # Tokenize the input safely
108
  inputs = self.tokenizer(prompt, return_tensors="pt")
109
- logger.info(f"Input tokens shape: {inputs.input_ids.shape}")
110
 
111
  # Move to device
112
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
@@ -135,9 +135,10 @@ class EndpointHandler:
135
 
136
  logger.info(f"Generating with config: {generation_config}")
137
 
 
138
  outputs = self.model.generate(
139
- inputs.input_ids,
140
- attention_mask=inputs.attention_mask if hasattr(inputs, 'attention_mask') else None,
141
  **generation_config
142
  )
143
 
@@ -145,7 +146,7 @@ class EndpointHandler:
145
  generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
146
 
147
  # Return only the newly generated text (without the prompt)
148
- input_text = self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)
149
 
150
  if generated_text.startswith(input_text):
151
  response_text = generated_text[len(input_text):]
@@ -168,8 +169,8 @@ class EndpointHandler:
168
 
169
  # Set up generation in a separate thread
170
  generation_kwargs = {
171
- "input_ids": inputs.input_ids,
172
- "attention_mask": inputs.attention_mask if hasattr(inputs, 'attention_mask') else None,
173
  "streamer": streamer,
174
  "max_new_tokens": max_new_tokens,
175
  "temperature": temperature,
@@ -182,7 +183,7 @@ class EndpointHandler:
182
  thread.start()
183
 
184
  # Determine input text length to strip it from outputs
185
- input_text = self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)
186
 
187
  # Stream the output
188
  def generate_stream():
 
106
 
107
  # Tokenize the input safely
108
  inputs = self.tokenizer(prompt, return_tensors="pt")
109
+ logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
110
 
111
  # Move to device
112
  inputs = {k: v.to(self.device) for k, v in inputs.items()}
 
135
 
136
  logger.info(f"Generating with config: {generation_config}")
137
 
138
+ # Fix: inputs is a dictionary, not an object with attributes
139
  outputs = self.model.generate(
140
+ inputs["input_ids"],
141
+ attention_mask=inputs.get("attention_mask", None),
142
  **generation_config
143
  )
144
 
 
146
  generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
147
 
148
  # Return only the newly generated text (without the prompt)
149
+ input_text = self.tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
150
 
151
  if generated_text.startswith(input_text):
152
  response_text = generated_text[len(input_text):]
 
169
 
170
  # Set up generation in a separate thread
171
  generation_kwargs = {
172
+ "input_ids": inputs["input_ids"],
173
+ "attention_mask": inputs.get("attention_mask", None),
174
  "streamer": streamer,
175
  "max_new_tokens": max_new_tokens,
176
  "temperature": temperature,
 
183
  thread.start()
184
 
185
  # Determine input text length to strip it from outputs
186
+ input_text = self.tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
187
 
188
  # Stream the output
189
  def generate_stream():