Add comments, refactoring.
Browse files- index.html +38 -16
index.html
CHANGED
|
@@ -85,7 +85,6 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
|
| 85 |
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/tiktoken/tiktoken-0.5.1-cp311-cp311-emscripten_3_1_45_wasm32.whl", keep_going=True)
|
| 86 |
|
| 87 |
|
| 88 |
-
import inspect
|
| 89 |
import gradio as gr
|
| 90 |
import base64
|
| 91 |
import json
|
|
@@ -361,6 +360,7 @@ def load_pages(page_numbers):
|
|
| 361 |
return found_pages
|
| 362 |
|
| 363 |
|
|
|
|
| 364 |
CHAT_TOOLS = [
|
| 365 |
# ページ検索
|
| 366 |
{
|
|
@@ -409,10 +409,22 @@ CHAT_TOOLS = [
|
|
| 409 |
}
|
| 410 |
]
|
| 411 |
|
|
|
|
| 412 |
CHAT_TOOLS_TOKENS = 139
|
| 413 |
|
| 414 |
|
| 415 |
def get_openai_messages(prompt, history, context):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 416 |
global SEARCH_ENGINE
|
| 417 |
if SEARCH_ENGINE is not None:
|
| 418 |
context = "".join([page.content for page in SEARCH_ENGINE.pages])
|
|
@@ -430,7 +442,10 @@ def get_openai_messages(prompt, history, context):
|
|
| 430 |
return messages
|
| 431 |
|
| 432 |
|
|
|
|
| 433 |
actual_total_cost_prompt = 0
|
|
|
|
|
|
|
| 434 |
actual_total_cost_completion = 0
|
| 435 |
|
| 436 |
|
|
@@ -440,7 +455,7 @@ async def process_prompt(prompt, history, context, platform, endpoint, azure_dep
|
|
| 440 |
|
| 441 |
Args:
|
| 442 |
prompt (str): ユーザーからの入力プロンプト
|
| 443 |
-
history (list): チャット履歴
|
| 444 |
context (str): チャットコンテキスト
|
| 445 |
platform (str): 使用するAIプラットフォーム
|
| 446 |
endpoint (str): AIサービスのエンドポイント
|
|
@@ -602,6 +617,15 @@ def load_api_key(file_obj):
|
|
| 602 |
|
| 603 |
|
| 604 |
def get_cost_info(prompt_token_count):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
return f"Estimated input cost: {prompt_token_count + CHAT_TOOLS_TOKENS:,} tokens, Actual total input cost: {actual_total_cost_prompt:,} tokens, Actual total output cost: {actual_total_cost_completion:,} tokens"
|
| 606 |
|
| 607 |
|
|
@@ -859,8 +883,8 @@ def main():
|
|
| 859 |
char_counter = gr.Textbox(label="Statistics", value=get_context_info("", []),
|
| 860 |
lines=2, max_lines=2, interactive=False, container=True)
|
| 861 |
|
| 862 |
-
pdf_file.upload(update_context_element, inputs=pdf_file, outputs=[context, char_counter])
|
| 863 |
-
pdf_file.clear(lambda: None, inputs=None, outputs=context, show_progress="hidden")
|
| 864 |
|
| 865 |
with gr.Column(scale=2):
|
| 866 |
|
|
@@ -905,7 +929,7 @@ def main():
|
|
| 905 |
|
| 906 |
return gr.update(value=get_cost_info(token_count))
|
| 907 |
|
| 908 |
-
message_textbox.change(estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, show_progress="hidden")
|
| 909 |
|
| 910 |
example_title_textbox = gr.Textbox(visible=False, interactive=True)
|
| 911 |
gr.Examples([[k] for k, v in examples.items()],
|
|
@@ -931,18 +955,16 @@ def main():
|
|
| 931 |
|
| 932 |
generator = process_prompt(*inputs)
|
| 933 |
|
| 934 |
-
|
| 935 |
-
first_response = await gr.utils.async_iteration(generator)
|
| 936 |
-
update = history + [[message, first_response]]
|
| 937 |
-
yield update, update
|
| 938 |
-
except StopIteration:
|
| 939 |
-
update = history + [[message, None]]
|
| 940 |
-
yield update, update
|
| 941 |
-
|
| 942 |
async for response in generator:
|
|
|
|
| 943 |
update = history + [[message, response]]
|
| 944 |
yield update, update
|
| 945 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 946 |
submit_triggers = [message_textbox.submit, submit_button.click]
|
| 947 |
|
| 948 |
submit_event = gr.events.on(submit_triggers, lambda message: ("", message),
|
|
@@ -990,13 +1012,13 @@ def main():
|
|
| 990 |
|
| 991 |
chatbot.change(None, inputs=[chatbot, save_chat_history_to_url], outputs=None,
|
| 992 |
# チャット履歴をクエリパラメータに保存する。
|
| 993 |
-
js=save_or_delete_chat_history, show_progress="hidden")
|
| 994 |
|
| 995 |
save_chat_history_to_url.change(None, inputs=[chatbot, save_chat_history_to_url], outputs=None,
|
| 996 |
-
js=save_or_delete_chat_history, show_progress="hidden")
|
| 997 |
|
| 998 |
context.change(
|
| 999 |
-
count_characters, inputs=context, outputs=char_counter, show_progress="hidden"
|
| 1000 |
).then(
|
| 1001 |
create_search_engine, inputs=context, outputs=None
|
| 1002 |
).then(
|
|
|
|
| 85 |
await micropip.install("https://raw.githubusercontent.com/sonoisa/pyodide_wheels/main/tiktoken/tiktoken-0.5.1-cp311-cp311-emscripten_3_1_45_wasm32.whl", keep_going=True)
|
| 86 |
|
| 87 |
|
|
|
|
| 88 |
import gradio as gr
|
| 89 |
import base64
|
| 90 |
import json
|
|
|
|
| 360 |
return found_pages
|
| 361 |
|
| 362 |
|
| 363 |
+
# function calling用ツール
|
| 364 |
CHAT_TOOLS = [
|
| 365 |
# ページ検索
|
| 366 |
{
|
|
|
|
| 409 |
}
|
| 410 |
]
|
| 411 |
|
| 412 |
+
# function callingなど、固定で消費するトークン数
|
| 413 |
CHAT_TOOLS_TOKENS = 139
|
| 414 |
|
| 415 |
|
| 416 |
def get_openai_messages(prompt, history, context):
|
| 417 |
+
"""
|
| 418 |
+
与えられた対話用データを、ChatGPT APIの入力に用いられるメッセージデータ形式に変換して返す。
|
| 419 |
+
|
| 420 |
+
Args:
|
| 421 |
+
prompt (str): ユーザーからの入力プロンプト
|
| 422 |
+
history (list[list[str]]): チャット履歴
|
| 423 |
+
context (str): チャットコンテキスト
|
| 424 |
+
|
| 425 |
+
Returns:
|
| 426 |
+
str: ChatGPT APIの入力に用いられるメッセージデータ
|
| 427 |
+
"""
|
| 428 |
global SEARCH_ENGINE
|
| 429 |
if SEARCH_ENGINE is not None:
|
| 430 |
context = "".join([page.content for page in SEARCH_ENGINE.pages])
|
|
|
|
| 442 |
return messages
|
| 443 |
|
| 444 |
|
| 445 |
+
# それまでの全入力トークン数
|
| 446 |
actual_total_cost_prompt = 0
|
| 447 |
+
|
| 448 |
+
# それまでの全出力トークン数
|
| 449 |
actual_total_cost_completion = 0
|
| 450 |
|
| 451 |
|
|
|
|
| 455 |
|
| 456 |
Args:
|
| 457 |
prompt (str): ユーザーからの入力プロンプト
|
| 458 |
+
history (list[list[str]]): チャット履歴
|
| 459 |
context (str): チャットコンテキスト
|
| 460 |
platform (str): 使用するAIプラットフォーム
|
| 461 |
endpoint (str): AIサービスのエンドポイント
|
|
|
|
| 617 |
|
| 618 |
|
| 619 |
def get_cost_info(prompt_token_count):
|
| 620 |
+
"""
|
| 621 |
+
チャットのトークン数情報を表示するための文字列を返す。
|
| 622 |
+
|
| 623 |
+
Args:
|
| 624 |
+
prompt_token_count (int): プロンプト(履歴込み)のトークン数
|
| 625 |
+
|
| 626 |
+
Returns:
|
| 627 |
+
str: チャットのトークン数情報を表示するための文字列
|
| 628 |
+
"""
|
| 629 |
return f"Estimated input cost: {prompt_token_count + CHAT_TOOLS_TOKENS:,} tokens, Actual total input cost: {actual_total_cost_prompt:,} tokens, Actual total output cost: {actual_total_cost_completion:,} tokens"
|
| 630 |
|
| 631 |
|
|
|
|
| 883 |
char_counter = gr.Textbox(label="Statistics", value=get_context_info("", []),
|
| 884 |
lines=2, max_lines=2, interactive=False, container=True)
|
| 885 |
|
| 886 |
+
pdf_file.upload(update_context_element, inputs=pdf_file, outputs=[context, char_counter], queue=False)
|
| 887 |
+
pdf_file.clear(lambda: None, inputs=None, outputs=context, queue=False, show_progress="hidden")
|
| 888 |
|
| 889 |
with gr.Column(scale=2):
|
| 890 |
|
|
|
|
| 929 |
|
| 930 |
return gr.update(value=get_cost_info(token_count))
|
| 931 |
|
| 932 |
+
message_textbox.change(estimate_message_cost, inputs=[message_textbox, chatbot, context], outputs=cost_info, queue=False, show_progress="hidden")
|
| 933 |
|
| 934 |
example_title_textbox = gr.Textbox(visible=False, interactive=True)
|
| 935 |
gr.Examples([[k] for k, v in examples.items()],
|
|
|
|
| 955 |
|
| 956 |
generator = process_prompt(*inputs)
|
| 957 |
|
| 958 |
+
has_response = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 959 |
async for response in generator:
|
| 960 |
+
has_response = True
|
| 961 |
update = history + [[message, response]]
|
| 962 |
yield update, update
|
| 963 |
|
| 964 |
+
if not has_response:
|
| 965 |
+
update = history + [[message, None]]
|
| 966 |
+
yield update, update
|
| 967 |
+
|
| 968 |
submit_triggers = [message_textbox.submit, submit_button.click]
|
| 969 |
|
| 970 |
submit_event = gr.events.on(submit_triggers, lambda message: ("", message),
|
|
|
|
| 1012 |
|
| 1013 |
chatbot.change(None, inputs=[chatbot, save_chat_history_to_url], outputs=None,
|
| 1014 |
# チャット履歴をクエリパラメータに保存する。
|
| 1015 |
+
js=save_or_delete_chat_history, queue=False, show_progress="hidden")
|
| 1016 |
|
| 1017 |
save_chat_history_to_url.change(None, inputs=[chatbot, save_chat_history_to_url], outputs=None,
|
| 1018 |
+
js=save_or_delete_chat_history, queue=False, show_progress="hidden")
|
| 1019 |
|
| 1020 |
context.change(
|
| 1021 |
+
count_characters, inputs=context, outputs=char_counter, queue=False, show_progress="hidden"
|
| 1022 |
).then(
|
| 1023 |
create_search_engine, inputs=context, outputs=None
|
| 1024 |
).then(
|