stop criteria esnetildi
#32
by
ismailhakki37
- opened
- handler.py +16 -3
handler.py
CHANGED
|
@@ -347,13 +347,23 @@ def generate_response(message_text, image_input, temperature=0.05, top_p=1.0, ma
|
|
| 347 |
prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
| 348 |
).unsqueeze(0).to(our_chatbot.model.device)
|
| 349 |
|
| 350 |
-
# Set up stopping criteria
|
| 351 |
stop_str = (
|
| 352 |
our_chatbot.conversation.sep
|
| 353 |
if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
|
| 354 |
else our_chatbot.conversation.sep2
|
| 355 |
)
|
| 356 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
stopping_criteria = KeywordsStoppingCriteria(
|
| 358 |
keywords, our_chatbot.tokenizer, input_ids
|
| 359 |
)
|
|
@@ -370,6 +380,9 @@ def generate_response(message_text, image_input, temperature=0.05, top_p=1.0, ma
|
|
| 370 |
repetition_penalty=repetition_penalty,
|
| 371 |
use_cache=False,
|
| 372 |
stopping_criteria=[stopping_criteria],
|
|
|
|
|
|
|
|
|
|
| 373 |
)
|
| 374 |
|
| 375 |
# Decode response
|
|
@@ -553,7 +566,7 @@ def query(payload):
|
|
| 553 |
top_p = float(payload.get("top_p", 1.0))
|
| 554 |
max_output_tokens = int(payload.get("max_output_tokens",
|
| 555 |
payload.get("max_new_tokens",
|
| 556 |
-
payload.get("max_tokens",
|
| 557 |
repetition_penalty = float(payload.get("repetition_penalty", 1.0))
|
| 558 |
conv_mode_override = payload.get("conv_mode", None)
|
| 559 |
|
|
|
|
| 347 |
prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
| 348 |
).unsqueeze(0).to(our_chatbot.model.device)
|
| 349 |
|
| 350 |
+
# Set up stopping criteria - more flexible to allow longer responses
|
| 351 |
stop_str = (
|
| 352 |
our_chatbot.conversation.sep
|
| 353 |
if our_chatbot.conversation.sep_style != SeparatorStyle.TWO
|
| 354 |
else our_chatbot.conversation.sep2
|
| 355 |
)
|
| 356 |
+
|
| 357 |
+
# Use minimal stopping criteria to allow longer responses
|
| 358 |
+
keywords = []
|
| 359 |
+
if stop_str and stop_str.strip():
|
| 360 |
+
keywords.append(stop_str)
|
| 361 |
+
|
| 362 |
+
# Only add very basic stopping criteria to prevent infinite generation
|
| 363 |
+
if not keywords:
|
| 364 |
+
keywords = ["</s>", "<s>"]
|
| 365 |
+
|
| 366 |
+
print(f"[DEBUG] Using stopping criteria: {keywords}")
|
| 367 |
stopping_criteria = KeywordsStoppingCriteria(
|
| 368 |
keywords, our_chatbot.tokenizer, input_ids
|
| 369 |
)
|
|
|
|
| 380 |
repetition_penalty=repetition_penalty,
|
| 381 |
use_cache=False,
|
| 382 |
stopping_criteria=[stopping_criteria],
|
| 383 |
+
pad_token_id=our_chatbot.tokenizer.eos_token_id,
|
| 384 |
+
eos_token_id=our_chatbot.tokenizer.eos_token_id,
|
| 385 |
+
length_penalty=1.0, # Don't penalize longer sequences
|
| 386 |
)
|
| 387 |
|
| 388 |
# Decode response
|
|
|
|
| 566 |
top_p = float(payload.get("top_p", 1.0))
|
| 567 |
max_output_tokens = int(payload.get("max_output_tokens",
|
| 568 |
payload.get("max_new_tokens",
|
| 569 |
+
payload.get("max_tokens", 8192))))
|
| 570 |
repetition_penalty = float(payload.get("repetition_penalty", 1.0))
|
| 571 |
conv_mode_override = payload.get("conv_mode", None)
|
| 572 |
|