| from typing import Iterator |
|
|
| import gradio as gr |
| import torch |
|
|
| from model import get_input_token_length, run |
|
|
| DEFAULT_SYSTEM_PROMPT = """\ |
| You are a healthcare bot. You will give friendly advice for the Prevention and Treatment of various illnesses.\ |
| """ |
| MAX_MAX_NEW_TOKENS = 1024 |
| DEFAULT_MAX_NEW_TOKENS = 512 |
| MAX_INPUT_TOKEN_LENGTH = 4000 |
|
|
| DESCRIPTION = """ |
| # Healthcare Bot |
| |
| This interface demonstrates a healthcare bot using the [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) model by Meta. This Llama 2 model, with 7B parameters, has been fine-tuned for providing advice on prevention and treatment of various illnesses. |
| |
| Feel free to ask any health-related questions. You can also [deploy this model on Inference Endpoints](https://huggingface.co/inference-endpoints) to run your own healthcare service. |
| |
| ⚠️ Please remember that while this bot provides health advice based on its training data, it is not a substitute for professional medical advice. |
| """ |
|
|
| LICENSE = """ |
| <p/> |
| |
| --- |
| As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta, |
| this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md). |
| """ |
|
|
| if not torch.cuda.is_available(): |
| DESCRIPTION += '\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>' |
|
|
|
|
| def clear_and_save_textbox(message: str) -> tuple[str, str]: |
| return '', message |
|
|
|
|
| def display_input(message: str, |
| history: list[tuple[str, str]]) -> list[tuple[str, str]]: |
| history.append((message, '')) |
| return history |
|
|
|
|
| def delete_prev_fn( |
| history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]: |
| try: |
| message, _ = history.pop() |
| except IndexError: |
| message = '' |
| return history, message or '' |
|
|
|
|
| def generate( |
| message: str, |
| history_with_input: list[tuple[str, str]], |
| system_prompt: str, |
| max_new_tokens: int, |
| temperature: float, |
| top_p: float, |
| top_k: int, |
| ) -> Iterator[list[tuple[str, str]]]: |
| if max_new_tokens > MAX_MAX_NEW_TOKENS: |
| raise ValueError |
|
|
| history = history_with_input[:-1] |
| generator = run(message, history, system_prompt, max_new_tokens, temperature, top_p, top_k) |
| try: |
| first_response = next(generator) |
| yield history + [(message, first_response)] |
| except StopIteration: |
| yield history + [(message, '')] |
| for response in generator: |
| yield history + [(message, response)] |
|
|
|
|
| def process_example(message: str) -> tuple[str, list[tuple[str, str]]]: |
| generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50) |
| for x in generator: |
| pass |
| return '', x |
|
|
|
|
| def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None: |
| input_token_length = get_input_token_length(message, chat_history, system_prompt) |
| if input_token_length > MAX_INPUT_TOKEN_LENGTH: |
| raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.') |
|
|
|
|
| with gr.Blocks(css='style.css') as demo: |
| gr.Markdown(DESCRIPTION) |
| gr.DuplicateButton(value='Duplicate Space for private use', |
| elem_id='duplicate-button') |
|
|
| with gr.Group(): |
| chatbot = gr.Chatbot(label='Chatbot') |
| with gr.Row(): |
| textbox = gr.Textbox( |
| container=False, |
| show_label=False, |
| placeholder='Type a message...', |
| scale=10, |
| ) |
| submit_button = gr.Button('Submit', |
| variant='primary', |
| scale=1, |
| min_width=0) |
| with gr.Row(): |
| retry_button = gr.Button('🔄 Retry', variant='secondary') |
| undo_button = gr.Button('↩️ Undo', variant='secondary') |
| clear_button = gr.Button('🗑️ Clear', variant='secondary') |
|
|
| saved_input = gr.State() |
|
|
| with gr.Accordion(label='Advanced options', open=False): |
| system_prompt = gr.Textbox(label='System prompt', |
| value=DEFAULT_SYSTEM_PROMPT, |
| lines=6) |
| max_new_tokens = gr.Slider( |
| label='Max new tokens', |
| minimum=1, |
| maximum=MAX_MAX_NEW_TOKENS, |
| step=1, |
| value=DEFAULT_MAX_NEW_TOKENS, |
| ) |
| temperature = gr.Slider( |
| label='Temperature', |
| minimum=0.1, |
| maximum=4.0, |
| step=0.1, |
| value=1.0, |
| ) |
| top_p = gr.Slider( |
| label='Top-p (nucleus sampling)', |
| minimum=0.05, |
| maximum=1.0, |
| step=0.05, |
| value=0.95, |
| ) |
| top_k = gr.Slider( |
| label='Top-k', |
| minimum=1, |
| maximum=1000, |
| step=1, |
| value=50, |
| ) |
|
|
| gr.Examples( |
| examples=[ |
| 'What are some common symptoms of diabetes?', |
| 'Can you tell me about the benefits of regular exercise?', |
| 'Explain the importance of a balanced diet.', |
| 'How many hours of sleep should an adult get each night?', |
| "What are some ways to reduce stress?", |
| ], |
| inputs=textbox, |
| outputs=[textbox, chatbot], |
| fn=process_example, |
| cache_examples=True, |
| ) |
|
|
| gr.Markdown(LICENSE) |
|
|
| textbox.submit( |
| fn=clear_and_save_textbox, |
| inputs=textbox, |
| outputs=[textbox, saved_input], |
| api_name=False, |
| queue=False, |
| ).then( |
| fn=display_input, |
| inputs=[saved_input, chatbot], |
| outputs=chatbot, |
| api_name=False, |
| queue=False, |
| ).then( |
| fn=check_input_token_length, |
| inputs=[saved_input, chatbot, system_prompt], |
| api_name=False, |
| queue=False, |
| ).success( |
| fn=generate, |
| inputs=[ |
| saved_input, |
| chatbot, |
| system_prompt, |
| max_new_tokens, |
| temperature, |
| top_p, |
| top_k, |
| ], |
| outputs=chatbot, |
| api_name=False, |
| ) |
|
|
| button_event_preprocess = submit_button.click( |
| fn=clear_and_save_textbox, |
| inputs=textbox, |
| outputs=[textbox, saved_input], |
| api_name=False, |
| queue=False, |
| ).then( |
| fn=display_input, |
| inputs=[saved_input, chatbot], |
| outputs=chatbot, |
| api_name=False, |
| queue=False, |
| ).then( |
| fn=check_input_token_length, |
| inputs=[saved_input, chatbot, system_prompt], |
| api_name=False, |
| queue=False, |
| ).success( |
| fn=generate, |
| inputs=[ |
| saved_input, |
| chatbot, |
| system_prompt, |
| max_new_tokens, |
| temperature, |
| top_p, |
| top_k, |
| ], |
| outputs=chatbot, |
| api_name=False, |
| ) |
|
|
| retry_button.click( |
| fn=delete_prev_fn, |
| inputs=chatbot, |
| outputs=[chatbot, saved_input], |
| api_name=False, |
| queue=False, |
| ).then( |
| fn=display_input, |
| inputs=[saved_input, chatbot], |
| outputs=chatbot, |
| api_name=False, |
| queue=False, |
| ).then( |
| fn=generate, |
| inputs=[ |
| saved_input, |
| chatbot, |
| system_prompt, |
| max_new_tokens, |
| temperature, |
| top_p, |
| top_k, |
| ], |
| outputs=chatbot, |
| api_name=False, |
| ) |
|
|
| undo_button.click( |
| fn=delete_prev_fn, |
| inputs=chatbot, |
| outputs=[chatbot, saved_input], |
| api_name=False, |
| queue=False, |
| ).then( |
| fn=lambda x: x, |
| inputs=[saved_input], |
| outputs=textbox, |
| api_name=False, |
| queue=False, |
| ) |
|
|
| clear_button.click( |
| fn=lambda: ([], ''), |
| outputs=[chatbot, saved_input], |
| queue=False, |
| api_name=False, |
| ) |
|
|
| demo.queue(max_size=20).launch() |
|
|