| from typing import Iterator |
|
|
| import gradio as gr |
|
|
|
|
| from model import run |
|
|
| DEFAULT_SYSTEM_PROMPT = "" |
| MAX_MAX_NEW_TOKENS = 2048 |
| DEFAULT_MAX_NEW_TOKENS = 1024 |
| MAX_INPUT_TOKEN_LENGTH = 4000 |
|
|
| DESCRIPTION = """ |
| # 玉刚六号改/yugangVI-Chat |
| """ |
| LICENSE="基于Baichuan-13B-Chat以及https://github.com/ouwei2013/baichuan13b.cpp" |
|
|
|
|
|
|
| def clear_and_save_textbox(message: str) -> tuple[str, str]: |
| return '', message |
|
|
|
|
| def display_input(message: str, |
| history: list[tuple[str, str]]) -> list[tuple[str, str]]: |
| history.append((message, '')) |
| return history |
|
|
|
|
| def delete_prev_fn( |
| history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]: |
| try: |
| message, _ = history.pop() |
| except IndexError: |
| message = '' |
| return history, message or '' |
|
|
|
|
| def generate( |
| message: str, |
| history_with_input: list[tuple[str, str]], |
| system_prompt: str, |
| max_new_tokens: int, |
| temperature: float, |
| top_p: float, |
| top_k: int, |
| ) -> Iterator[list[tuple[str, str]]]: |
|
|
| history = history_with_input[:-1] |
| generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k) |
| for response in generator: |
| yield history + [(message, response)] |
|
|
|
|
| def process_example(message: str) -> tuple[str, list[tuple[str, str]]]: |
| generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 8192, 1, 0.95, 50) |
| for x in generator: |
| pass |
| return '', x |
|
|
|
|
| def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None: |
| a = 1 |
|
|
|
|
| with gr.Blocks(css='style.css') as demo: |
| gr.Markdown(DESCRIPTION) |
| gr.DuplicateButton(value='Duplicate Space for private use', |
| elem_id='duplicate-button') |
|
|
| with gr.Group(): |
| chatbot = gr.Chatbot(label='Chatbot') |
| with gr.Row(): |
| textbox = gr.Textbox( |
| container=False, |
| show_label=False, |
| placeholder='请输入/Type a message...', |
| scale=10, |
| ) |
| submit_button = gr.Button('提交/Submit', |
| variant='primary', |
| scale=1, |
| min_width=0) |
| with gr.Row(): |
| retry_button = gr.Button('🔄 重来/Retry', variant='secondary') |
| undo_button = gr.Button('↩️ 撤销/Undo', variant='secondary') |
| clear_button = gr.Button('🗑️ 清除/Clear', variant='secondary') |
|
|
| saved_input = gr.State() |
|
|
| with gr.Accordion(label='进阶设置/Advanced options', open=False): |
| system_prompt = gr.Textbox(label='预设引导词/System prompt', |
| value=DEFAULT_SYSTEM_PROMPT, |
| lines=6) |
| max_new_tokens = gr.Slider( |
| label='Max new tokens', |
| minimum=1, |
| maximum=MAX_MAX_NEW_TOKENS, |
| step=1, |
| value=DEFAULT_MAX_NEW_TOKENS, |
| ) |
| temperature = gr.Slider( |
| label='情感温度/Temperature', |
| minimum=0.1, |
| maximum=4.0, |
| step=0.1, |
| value=0.3, |
| ) |
| top_p = gr.Slider( |
| label='Top-p (nucleus sampling)', |
| minimum=0.05, |
| maximum=1.0, |
| step=0.05, |
| value=0.85, |
| ) |
| top_k = gr.Slider( |
| label='Top-k', |
| minimum=1, |
| maximum=1000, |
| step=1, |
| value=5, |
| ) |
|
|
| gr.Examples( |
| examples=[ |
| '中华人民共和国的首都是?', |
|
|
| ], |
| inputs=textbox, |
| outputs=[textbox, chatbot], |
| fn=process_example, |
| cache_examples=True, |
| ) |
|
|
| gr.Markdown(LICENSE) |
|
|
| textbox.submit( |
| fn=clear_and_save_textbox, |
| inputs=textbox, |
| outputs=[textbox, saved_input], |
| api_name=False, |
| queue=False, |
| ).then( |
| fn=display_input, |
| inputs=[saved_input, chatbot], |
| outputs=chatbot, |
| api_name=False, |
| queue=False, |
| ).then( |
| fn=check_input_token_length, |
| inputs=[saved_input, chatbot, system_prompt], |
| api_name=False, |
| queue=False, |
| ).success( |
| fn=generate, |
| inputs=[ |
| saved_input, |
| chatbot, |
| system_prompt, |
| max_new_tokens, |
| temperature, |
| top_p, |
| top_k, |
| ], |
| outputs=chatbot, |
| api_name=False, |
| ) |
|
|
| button_event_preprocess = submit_button.click( |
| fn=clear_and_save_textbox, |
| inputs=textbox, |
| outputs=[textbox, saved_input], |
| api_name=False, |
| queue=False, |
| ).then( |
| fn=display_input, |
| inputs=[saved_input, chatbot], |
| outputs=chatbot, |
| api_name=False, |
| queue=False, |
| ).then( |
| fn=check_input_token_length, |
| inputs=[saved_input, chatbot, system_prompt], |
| api_name=False, |
| queue=False, |
| ).success( |
| fn=generate, |
| inputs=[ |
| saved_input, |
| chatbot, |
| system_prompt, |
| max_new_tokens, |
| temperature, |
| top_p, |
| top_k, |
| ], |
| outputs=chatbot, |
| api_name=False, |
| ) |
|
|
| retry_button.click( |
| fn=delete_prev_fn, |
| inputs=chatbot, |
| outputs=[chatbot, saved_input], |
| api_name=False, |
| queue=False, |
| ).then( |
| fn=display_input, |
| inputs=[saved_input, chatbot], |
| outputs=chatbot, |
| api_name=False, |
| queue=False, |
| ).then( |
| fn=generate, |
| inputs=[ |
| saved_input, |
| chatbot, |
| system_prompt, |
| max_new_tokens, |
| temperature, |
| top_p, |
| top_k, |
| ], |
| outputs=chatbot, |
| api_name=False, |
| ) |
|
|
| undo_button.click( |
|
|
| fn=delete_prev_fn, |
| inputs=chatbot, |
| outputs=[chatbot, saved_input], |
| api_name=False, |
| queue=False, |
| ).then( |
| fn=lambda x: x, |
| inputs=[saved_input], |
| outputs=textbox, |
| api_name=False, |
| queue=False, |
| ) |
|
|
| clear_button.click( |
| fn=lambda: ([], ''), |
| outputs=[chatbot, saved_input], |
| queue=False, |
| api_name=False, |
| ) |
|
|
| demo.queue(max_size=20).launch() |
|
|