Yong Liu commited on
Commit ·
ead8711
1
Parent(s): 0b6ae9b
handler.py updated
Browse files- handler.py +26 -1
handler.py
CHANGED
|
@@ -237,11 +237,16 @@ class EndpointHandler:
|
|
| 237 |
logger.info(f"Input prompt length: {len(input_text)} characters")
|
| 238 |
|
| 239 |
# Generate one token at a time to avoid index errors
|
| 240 |
-
|
|
|
|
| 241 |
current_ids = input_ids.clone()
|
| 242 |
|
| 243 |
logger.info(f"Generating up to {max_steps} tokens")
|
| 244 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
for i in range(max_steps):
|
| 246 |
if i % 50 == 0:
|
| 247 |
logger.info(f"Generated {i} tokens so far")
|
|
@@ -284,6 +289,19 @@ class EndpointHandler:
|
|
| 284 |
current_ids = torch.cat([current_ids, next_token], dim=-1)
|
| 285 |
attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
|
| 286 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
# Check if we've generated an EOS token
|
| 288 |
if next_token[0, 0].item() == self.tokenizer.eos_token_id:
|
| 289 |
logger.info(f"EOS token generated after {i+1} tokens")
|
|
@@ -298,6 +316,13 @@ class EndpointHandler:
|
|
| 298 |
if len(split_text) > 1:
|
| 299 |
response_text = split_text[1].strip()
|
| 300 |
logger.info(f"Extracted assistant response: {len(response_text)} characters")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
else:
|
| 302 |
# Fallback if the expected format is not found
|
| 303 |
logger.warning("Could not find assistant tag in generated text")
|
|
|
|
| 237 |
logger.info(f"Input prompt length: {len(input_text)} characters")
|
| 238 |
|
| 239 |
# Generate one token at a time to avoid index errors
|
| 240 |
+
# Increase from 250 to 500 to allow for longer completions
|
| 241 |
+
max_steps = min(max_new_tokens, 500)
|
| 242 |
current_ids = input_ids.clone()
|
| 243 |
|
| 244 |
logger.info(f"Generating up to {max_steps} tokens")
|
| 245 |
|
| 246 |
+
# Keep track of last 5 tokens to detect repetition
|
| 247 |
+
last_tokens = []
|
| 248 |
+
repetition_detected = False
|
| 249 |
+
|
| 250 |
for i in range(max_steps):
|
| 251 |
if i % 50 == 0:
|
| 252 |
logger.info(f"Generated {i} tokens so far")
|
|
|
|
| 289 |
current_ids = torch.cat([current_ids, next_token], dim=-1)
|
| 290 |
attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
|
| 291 |
|
| 292 |
+
# Add to last tokens list for repetition detection
|
| 293 |
+
last_tokens.append(next_token.item())
|
| 294 |
+
if len(last_tokens) > 5:
|
| 295 |
+
last_tokens.pop(0)
|
| 296 |
+
|
| 297 |
+
# Check for repetition (if we have at least 5 tokens)
|
| 298 |
+
if len(last_tokens) >= 5:
|
| 299 |
+
# Check if all last 5 tokens are the same
|
| 300 |
+
if len(set(last_tokens)) == 1:
|
| 301 |
+
logger.warning(f"Repetition detected after {i+1} tokens, stopping generation")
|
| 302 |
+
repetition_detected = True
|
| 303 |
+
break
|
| 304 |
+
|
| 305 |
# Check if we've generated an EOS token
|
| 306 |
if next_token[0, 0].item() == self.tokenizer.eos_token_id:
|
| 307 |
logger.info(f"EOS token generated after {i+1} tokens")
|
|
|
|
| 316 |
if len(split_text) > 1:
|
| 317 |
response_text = split_text[1].strip()
|
| 318 |
logger.info(f"Extracted assistant response: {len(response_text)} characters")
|
| 319 |
+
|
| 320 |
+
# Check if the response text ends with a complete sentence
|
| 321 |
+
if not repetition_detected and not response_text.endswith(('.', '!', '?', ':', ';', '"', "'", ')', ']', '}')):
|
| 322 |
+
# Add an ellipsis to indicate truncation
|
| 323 |
+
response_text += "..."
|
| 324 |
+
logger.info("Added ellipsis to incomplete sentence")
|
| 325 |
+
|
| 326 |
else:
|
| 327 |
# Fallback if the expected format is not found
|
| 328 |
logger.warning("Could not find assistant tag in generated text")
|