yol146 commited on
Commit ·
290cf25
1
Parent(s): 72ed73b
modify the handler
Browse files- handler.py +8 -7
handler.py
CHANGED
|
@@ -106,7 +106,7 @@ class EndpointHandler:
|
|
| 106 |
|
| 107 |
# Tokenize the input safely
|
| 108 |
inputs = self.tokenizer(prompt, return_tensors="pt")
|
| 109 |
-
logger.info(f"Input tokens shape: {inputs
|
| 110 |
|
| 111 |
# Move to device
|
| 112 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
@@ -135,9 +135,10 @@ class EndpointHandler:
|
|
| 135 |
|
| 136 |
logger.info(f"Generating with config: {generation_config}")
|
| 137 |
|
|
|
|
| 138 |
outputs = self.model.generate(
|
| 139 |
-
inputs
|
| 140 |
-
attention_mask=inputs.
|
| 141 |
**generation_config
|
| 142 |
)
|
| 143 |
|
|
@@ -145,7 +146,7 @@ class EndpointHandler:
|
|
| 145 |
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 146 |
|
| 147 |
# Return only the newly generated text (without the prompt)
|
| 148 |
-
input_text = self.tokenizer.decode(inputs
|
| 149 |
|
| 150 |
if generated_text.startswith(input_text):
|
| 151 |
response_text = generated_text[len(input_text):]
|
|
@@ -168,8 +169,8 @@ class EndpointHandler:
|
|
| 168 |
|
| 169 |
# Set up generation in a separate thread
|
| 170 |
generation_kwargs = {
|
| 171 |
-
"input_ids": inputs
|
| 172 |
-
"attention_mask": inputs.
|
| 173 |
"streamer": streamer,
|
| 174 |
"max_new_tokens": max_new_tokens,
|
| 175 |
"temperature": temperature,
|
|
@@ -182,7 +183,7 @@ class EndpointHandler:
|
|
| 182 |
thread.start()
|
| 183 |
|
| 184 |
# Determine input text length to strip it from outputs
|
| 185 |
-
input_text = self.tokenizer.decode(inputs
|
| 186 |
|
| 187 |
# Stream the output
|
| 188 |
def generate_stream():
|
|
|
|
| 106 |
|
| 107 |
# Tokenize the input safely
|
| 108 |
inputs = self.tokenizer(prompt, return_tensors="pt")
|
| 109 |
+
logger.info(f"Input tokens shape: {inputs['input_ids'].shape}")
|
| 110 |
|
| 111 |
# Move to device
|
| 112 |
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
|
|
|
| 135 |
|
| 136 |
logger.info(f"Generating with config: {generation_config}")
|
| 137 |
|
| 138 |
+
# Fix: inputs is a dictionary, not an object with attributes
|
| 139 |
outputs = self.model.generate(
|
| 140 |
+
inputs["input_ids"],
|
| 141 |
+
attention_mask=inputs.get("attention_mask", None),
|
| 142 |
**generation_config
|
| 143 |
)
|
| 144 |
|
|
|
|
| 146 |
generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 147 |
|
| 148 |
# Return only the newly generated text (without the prompt)
|
| 149 |
+
input_text = self.tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
| 150 |
|
| 151 |
if generated_text.startswith(input_text):
|
| 152 |
response_text = generated_text[len(input_text):]
|
|
|
|
| 169 |
|
| 170 |
# Set up generation in a separate thread
|
| 171 |
generation_kwargs = {
|
| 172 |
+
"input_ids": inputs["input_ids"],
|
| 173 |
+
"attention_mask": inputs.get("attention_mask", None),
|
| 174 |
"streamer": streamer,
|
| 175 |
"max_new_tokens": max_new_tokens,
|
| 176 |
"temperature": temperature,
|
|
|
|
| 183 |
thread.start()
|
| 184 |
|
| 185 |
# Determine input text length to strip it from outputs
|
| 186 |
+
input_text = self.tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
|
| 187 |
|
| 188 |
# Stream the output
|
| 189 |
def generate_stream():
|