Yong Liu commited on
Commit
ead8711
·
1 Parent(s): 0b6ae9b

handler.py updated

Browse files
Files changed (1) hide show
  1. handler.py +26 -1
handler.py CHANGED
@@ -237,11 +237,16 @@ class EndpointHandler:
237
  logger.info(f"Input prompt length: {len(input_text)} characters")
238
 
239
  # Generate one token at a time to avoid index errors
240
- max_steps = min(max_new_tokens, 250) # Limit to 250 tokens for reliability
 
241
  current_ids = input_ids.clone()
242
 
243
  logger.info(f"Generating up to {max_steps} tokens")
244
 
 
 
 
 
245
  for i in range(max_steps):
246
  if i % 50 == 0:
247
  logger.info(f"Generated {i} tokens so far")
@@ -284,6 +289,19 @@ class EndpointHandler:
284
  current_ids = torch.cat([current_ids, next_token], dim=-1)
285
  attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  # Check if we've generated an EOS token
288
  if next_token[0, 0].item() == self.tokenizer.eos_token_id:
289
  logger.info(f"EOS token generated after {i+1} tokens")
@@ -298,6 +316,13 @@ class EndpointHandler:
298
  if len(split_text) > 1:
299
  response_text = split_text[1].strip()
300
  logger.info(f"Extracted assistant response: {len(response_text)} characters")
 
 
 
 
 
 
 
301
  else:
302
  # Fallback if the expected format is not found
303
  logger.warning("Could not find assistant tag in generated text")
 
237
  logger.info(f"Input prompt length: {len(input_text)} characters")
238
 
239
  # Generate one token at a time to avoid index errors
240
+ # Increase from 250 to 500 to allow for longer completions
241
+ max_steps = min(max_new_tokens, 500)
242
  current_ids = input_ids.clone()
243
 
244
  logger.info(f"Generating up to {max_steps} tokens")
245
 
246
+ # Keep track of last 5 tokens to detect repetition
247
+ last_tokens = []
248
+ repetition_detected = False
249
+
250
  for i in range(max_steps):
251
  if i % 50 == 0:
252
  logger.info(f"Generated {i} tokens so far")
 
289
  current_ids = torch.cat([current_ids, next_token], dim=-1)
290
  attention_mask = torch.cat([attention_mask, torch.ones_like(next_token)], dim=-1)
291
 
292
+ # Add to last tokens list for repetition detection
293
+ last_tokens.append(next_token.item())
294
+ if len(last_tokens) > 5:
295
+ last_tokens.pop(0)
296
+
297
+ # Check for repetition (if we have at least 5 tokens)
298
+ if len(last_tokens) >= 5:
299
+ # Check if all last 5 tokens are the same
300
+ if len(set(last_tokens)) == 1:
301
+ logger.warning(f"Repetition detected after {i+1} tokens, stopping generation")
302
+ repetition_detected = True
303
+ break
304
+
305
  # Check if we've generated an EOS token
306
  if next_token[0, 0].item() == self.tokenizer.eos_token_id:
307
  logger.info(f"EOS token generated after {i+1} tokens")
 
316
  if len(split_text) > 1:
317
  response_text = split_text[1].strip()
318
  logger.info(f"Extracted assistant response: {len(response_text)} characters")
319
+
320
+ # Check if the response text ends with a complete sentence
321
+ if not repetition_detected and not response_text.endswith(('.', '!', '?', ':', ';', '"', "'", ')', ']', '}')):
322
+ # Add an ellipsis to indicate truncation
323
+ response_text += "..."
324
+ logger.info("Added ellipsis to incomplete sentence")
325
+
326
  else:
327
  # Fallback if the expected format is not found
328
  logger.warning("Could not find assistant tag in generated text")