Spaces:
Running on CPU Upgrade
Running on CPU Upgrade
Remove trailing punctuation 80% of the time
Browse files- chat_application/main.py +14 -8
chat_application/main.py
CHANGED
|
@@ -62,6 +62,7 @@ class datasetHandler():
|
|
| 62 |
CHAT_CONTEXT = 20 #how many messages from chat history to append to inference prompt
|
| 63 |
#minimum number of chars where we start checking for duplicate messages
|
| 64 |
DUP_LEN = 25 #since short messages may reasonably be the same
|
|
|
|
| 65 |
|
| 66 |
# Directory alignment
|
| 67 |
BASE_DIR = Path(__file__).resolve().parent
|
|
@@ -368,8 +369,10 @@ def ask_bot(room_id, bot, bot_display_name, initial_prompt, instruct_prompt):
|
|
| 368 |
print("PASSED")
|
| 369 |
return True # a pass is still recorded in the database, but not sent to the client
|
| 370 |
|
|
|
|
|
|
|
| 371 |
#remove encapsulating quotes
|
| 372 |
-
no_quotes = remove_quotes(
|
| 373 |
#humanize the response (remove obvious AI formatting styles)
|
| 374 |
humanized_response = humanize(no_quotes)
|
| 375 |
#replace most semicolons
|
|
@@ -378,11 +381,14 @@ def ask_bot(room_id, bot, bot_display_name, initial_prompt, instruct_prompt):
|
|
| 378 |
corrupted_response = corrupt(less_semicolons_response, misspell_aug_p=0.01, typo_aug_p=0.005)
|
| 379 |
#remove weird chars
|
| 380 |
no_weird_chars = remove_weird_characters(corrupted_response)
|
| 381 |
-
#
|
| 382 |
-
|
|
|
|
|
|
|
|
|
|
| 383 |
|
| 384 |
#check that there are no reccent duplicate messages
|
| 385 |
-
if len(
|
| 386 |
print("****DUPLICATE MESSAGE DETECTED")
|
| 387 |
print("Treating this bot's response as a pass.")
|
| 388 |
# Do not store/send messages if the chat has ended
|
|
@@ -392,7 +398,7 @@ def ask_bot(room_id, bot, bot_display_name, initial_prompt, instruct_prompt):
|
|
| 392 |
# Store the error response in the database
|
| 393 |
bot_message = {
|
| 394 |
"sender": bot_display_name,
|
| 395 |
-
"message": f"DUPLICATE message detected - treated as a (pass) : {
|
| 396 |
"timestamp": datetime.utcnow()
|
| 397 |
}
|
| 398 |
rooms_collection.update_one(
|
|
@@ -407,7 +413,7 @@ def ask_bot(room_id, bot, bot_display_name, initial_prompt, instruct_prompt):
|
|
| 407 |
print(corrupted_response)
|
| 408 |
|
| 409 |
# Add latency/wait time for bot responses
|
| 410 |
-
delay = get_response_delay(
|
| 411 |
print(delay)
|
| 412 |
time.sleep(delay)
|
| 413 |
|
|
@@ -419,7 +425,7 @@ def ask_bot(room_id, bot, bot_display_name, initial_prompt, instruct_prompt):
|
|
| 419 |
# Store the response in the database
|
| 420 |
bot_message = {
|
| 421 |
"sender": bot_display_name,
|
| 422 |
-
"message":
|
| 423 |
"timestamp": datetime.utcnow()
|
| 424 |
}
|
| 425 |
rooms_collection.update_one(
|
|
@@ -428,7 +434,7 @@ def ask_bot(room_id, bot, bot_display_name, initial_prompt, instruct_prompt):
|
|
| 428 |
)
|
| 429 |
|
| 430 |
# Send the bot's response to the client
|
| 431 |
-
socketio.emit("message", {"sender": bot_display_name, "message":
|
| 432 |
return False
|
| 433 |
|
| 434 |
def ask_bot_round(room_id):
|
|
|
|
| 62 |
CHAT_CONTEXT = 20 #how many messages from chat history to append to inference prompt
|
| 63 |
#minimum number of chars where we start checking for duplicate messages
|
| 64 |
DUP_LEN = 25 #since short messages may reasonably be the same
|
| 65 |
+
REMOVE_PUNC_RATE = .8 #how often to remove final punctuation
|
| 66 |
|
| 67 |
# Directory alignment
|
| 68 |
BASE_DIR = Path(__file__).resolve().parent
|
|
|
|
| 369 |
print("PASSED")
|
| 370 |
return True # a pass is still recorded in the database, but not sent to the client
|
| 371 |
|
| 372 |
+
#sub letters for names, so if the bot addressed A -> Apple
|
| 373 |
+
named_response = let_to_name(room_id, parsed_response)
|
| 374 |
#remove encapsulating quotes
|
| 375 |
+
no_quotes = remove_quotes(named_response)
|
| 376 |
#humanize the response (remove obvious AI formatting styles)
|
| 377 |
humanized_response = humanize(no_quotes)
|
| 378 |
#replace most semicolons
|
|
|
|
| 381 |
corrupted_response = corrupt(less_semicolons_response, misspell_aug_p=0.01, typo_aug_p=0.005)
|
| 382 |
#remove weird chars
|
| 383 |
no_weird_chars = remove_weird_characters(corrupted_response)
|
| 384 |
+
#remove trailing punctuation % of the time
|
| 385 |
+
if random.random() < REMOVE_PUNC_RATE:
|
| 386 |
+
no_weird_chars = re.sub(r'[^\w\s]+$', '', no_weird_chars)
|
| 387 |
+
|
| 388 |
+
final_response = no_weird_chars
|
| 389 |
|
| 390 |
#check that there are no reccent duplicate messages
|
| 391 |
+
if len(final_response) > DUP_LEN and duplicate_check(final_response, context):
|
| 392 |
print("****DUPLICATE MESSAGE DETECTED")
|
| 393 |
print("Treating this bot's response as a pass.")
|
| 394 |
# Do not store/send messages if the chat has ended
|
|
|
|
| 398 |
# Store the error response in the database
|
| 399 |
bot_message = {
|
| 400 |
"sender": bot_display_name,
|
| 401 |
+
"message": f"DUPLICATE message detected - treated as a (pass) : {final_response}",
|
| 402 |
"timestamp": datetime.utcnow()
|
| 403 |
}
|
| 404 |
rooms_collection.update_one(
|
|
|
|
| 413 |
print(corrupted_response)
|
| 414 |
|
| 415 |
# Add latency/wait time for bot responses
|
| 416 |
+
delay = get_response_delay(final_response);
|
| 417 |
print(delay)
|
| 418 |
time.sleep(delay)
|
| 419 |
|
|
|
|
| 425 |
# Store the response in the database
|
| 426 |
bot_message = {
|
| 427 |
"sender": bot_display_name,
|
| 428 |
+
"message": final_response, #save fruits in db so page reload shows proper names
|
| 429 |
"timestamp": datetime.utcnow()
|
| 430 |
}
|
| 431 |
rooms_collection.update_one(
|
|
|
|
| 434 |
)
|
| 435 |
|
| 436 |
# Send the bot's response to the client
|
| 437 |
+
socketio.emit("message", {"sender": bot_display_name, "message": final_response}, to=room_id)
|
| 438 |
return False
|
| 439 |
|
| 440 |
def ask_bot_round(room_id):
|