Files changed (1) hide show
  1. handler.py +16 -3
handler.py CHANGED
@@ -347,13 +347,23 @@ def generate_response(message_text, image_input, temperature=0.05, top_p=1.0, ma
347
  prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
348
  ).unsqueeze(0).to(our_chatbot.model.device)
349
 
350
- # Set up stopping criteria
351
  stop_str = (
352
  our_chatbot.conversation.sep
353
  if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
354
  else our_chatbot.conversation.sep2
355
  )
356
- keywords = [stop_str]
 
 
 
 
 
 
 
 
 
 
357
  stopping_criteria = KeywordsStoppingCriteria(
358
  keywords, our_chatbot.tokenizer, input_ids
359
  )
@@ -370,6 +380,9 @@ def generate_response(message_text, image_input, temperature=0.05, top_p=1.0, ma
370
  repetition_penalty=repetition_penalty,
371
  use_cache=False,
372
  stopping_criteria=[stopping_criteria],
 
 
 
373
  )
374
 
375
  # Decode response
@@ -553,7 +566,7 @@ def query(payload):
553
  top_p = float(payload.get("top_p", 1.0))
554
  max_output_tokens = int(payload.get("max_output_tokens",
555
  payload.get("max_new_tokens",
556
- payload.get("max_tokens", 4096))))
557
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
558
  conv_mode_override = payload.get("conv_mode", None)
559
 
 
347
  prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
348
  ).unsqueeze(0).to(our_chatbot.model.device)
349
 
350
+ # Set up stopping criteria - more flexible to allow longer responses
351
  stop_str = (
352
  our_chatbot.conversation.sep
353
  if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
354
  else our_chatbot.conversation.sep2
355
  )
356
+
357
+ # Use minimal stopping criteria to allow longer responses
358
+ keywords = []
359
+ if stop_str and stop_str.strip():
360
+ keywords.append(stop_str)
361
+
362
+ # Only add very basic stopping criteria to prevent infinite generation
363
+ if not keywords:
364
+ keywords = ["</s>", "<s>"]
365
+
366
+ print(f"[DEBUG] Using stopping criteria: {keywords}")
367
  stopping_criteria = KeywordsStoppingCriteria(
368
  keywords, our_chatbot.tokenizer, input_ids
369
  )
 
380
  repetition_penalty=repetition_penalty,
381
  use_cache=False,
382
  stopping_criteria=[stopping_criteria],
383
+ pad_token_id=our_chatbot.tokenizer.eos_token_id,
384
+ eos_token_id=our_chatbot.tokenizer.eos_token_id,
385
+ length_penalty=1.0, # Don't penalize longer sequences
386
  )
387
 
388
  # Decode response
 
566
  top_p = float(payload.get("top_p", 1.0))
567
  max_output_tokens = int(payload.get("max_output_tokens",
568
  payload.get("max_new_tokens",
569
+ payload.get("max_tokens", 8192))))
570
  repetition_penalty = float(payload.get("repetition_penalty", 1.0))
571
  conv_mode_override = payload.get("conv_mode", None)
572