Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| import retrieval | |
| # UNCOMMENT ONLY WHEN RUNNING LOCALLY (not on Spaces) | |
| # from dotenv import load_dotenv | |
| from text_generation import Client, InferenceAPIClient | |
| from typing import List, Tuple | |
| # load API keys from globally-availabe .env file | |
| # SECRETS_FILEPATH = "/mnt/project/chatbotai/huggingface_cache/internal_api_keys.env" | |
| # load_dotenv(dotenv_path=SECRETS_FILEPATH, override=True) | |
| openchat_preprompt = ( | |
| "\n<human>: Hi!\n<bot>: My name is Bot, model version is 0.15, part of an open-source kit for " | |
| "fine-tuning new bots! I was created by Together, LAION, and Ontocord.ai and the open-source " | |
| "community. I am not human, not evil and not alive, and thus have no thoughts and feelings, " | |
| "but I am programmed to be helpful, polite, honest, and friendly. I'm really smart at answering electrical engineering questions.\n") | |
| # LOAD MODELS | |
| ta = retrieval.Retrieval() | |
| NUM_ANSWERS_GENERATED = 3 | |
| def clip_img_search(img): | |
| if img is None: | |
| return [] | |
| else: | |
| return ta.reverse_img_search(img) | |
| def get_client(model: str): | |
| if model == "Rallio67/joi2_20Be_instruct_alpha": | |
| return Client(os.getenv("JOI_API_URL")) | |
| if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B": | |
| return Client(os.getenv("OPENCHAT_API_URL")) | |
| return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None)) | |
| def get_usernames(model: str): | |
| """ | |
| Returns: | |
| (str, str, str, str): pre-prompt, username, bot name, separator | |
| """ | |
| if model == "OpenAssistant/oasst-sft-1-pythia-12b": | |
| return "", "<|prompter|>", "<|assistant|>", "<|endoftext|>" | |
| if model == "Rallio67/joi2_20Be_instruct_alpha": | |
| return "", "User: ", "Joi: ", "\n\n" | |
| if model == "togethercomputer/GPT-NeoXT-Chat-Base-20B": | |
| return openchat_preprompt, "<human>: ", "<bot>: ", "\n" | |
| return "", "User: ", "Assistant: ", "\n" | |
| def predict( | |
| model: str, | |
| inputs: str, | |
| typical_p: float, | |
| top_p: float, | |
| temperature: float, | |
| top_k: int, | |
| repetition_penalty: float, | |
| watermark: bool, | |
| chatbot, | |
| history, | |
| ): | |
| client = get_client(model) | |
| preprompt, user_name, assistant_name, sep = get_usernames(model) | |
| history.append(inputs) | |
| past = [] | |
| for data in chatbot: | |
| user_data, model_data = data | |
| if not user_data.startswith(user_name): | |
| user_data = user_name + user_data | |
| if not model_data.startswith(sep + assistant_name): | |
| model_data = sep + assistant_name + model_data | |
| past.append(user_data + model_data.rstrip() + sep) | |
| if not inputs.startswith(user_name): | |
| inputs = user_name + inputs | |
| total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() | |
| partial_words = "" | |
| if model == "OpenAssistant/oasst-sft-1-pythia-12b": | |
| iterator = client.generate_stream( | |
| total_inputs, | |
| typical_p=typical_p, | |
| truncate=1000, | |
| watermark=watermark, | |
| max_new_tokens=500, | |
| ) | |
| else: | |
| iterator = client.generate_stream( | |
| total_inputs, | |
| top_p=top_p if top_p < 1.0 else None, | |
| top_k=top_k, | |
| truncate=1000, | |
| repetition_penalty=repetition_penalty, | |
| watermark=watermark, | |
| temperature=temperature, | |
| max_new_tokens=500, | |
| stop_sequences=[user_name.rstrip(), assistant_name.rstrip()], | |
| ) | |
| chat_response = None | |
| for i, response in enumerate(iterator): | |
| if response.token.special: | |
| continue | |
| partial_words = partial_words + response.token.text | |
| if partial_words.endswith(user_name.rstrip()): | |
| partial_words = partial_words.rstrip(user_name.rstrip()) | |
| if partial_words.endswith(assistant_name.rstrip()): | |
| partial_words = partial_words.rstrip(assistant_name.rstrip()) | |
| if i == 0: | |
| history.append(" " + partial_words) | |
| elif response.token.text not in user_name: | |
| history[-1] = partial_words | |
| chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)] | |
| chat_response = chat | |
| yield chat, history, None, None, None, [] | |
| cleaned_final_chat_response = clean_chat_response(chat_response) | |
| # Pinecone context retrieval | |
| top_context_list = ta.retrieve_contexts_from_pinecone(user_question=inputs, topk=NUM_ANSWERS_GENERATED) | |
| # yield chat, history, top_context_list[0], top_context_list[1], top_context_list[2], [] | |
| yield cleaned_final_chat_response, history, top_context_list[0], top_context_list[1], top_context_list[2], [] | |
| cleaned_final_chat_response = clean_chat_response(chat_response) | |
| # run CLIP | |
| images_list = ta.clip_text_to_image(inputs) | |
| # yield chat, history, top_context_list[0], top_context_list[1], top_context_list[2], images_list | |
| yield cleaned_final_chat_response, history, top_context_list[0], top_context_list[1], top_context_list[2], images_list | |
| def clean_chat_response(chat: List[Tuple[str, str]]) -> List[Tuple[str, str]]: | |
| ''' Not perfect, but much better at removing all the crazy newlines. ''' | |
| cleaned_chat = [] | |
| for human_chat, bot_chat in chat: | |
| human_chat = human_chat.replace("<br>", "") | |
| human_chat = human_chat.replace("\n\n", "\n") | |
| bot_chat = bot_chat.replace("<br>", "") | |
| bot_chat = bot_chat.replace("\n\n", "\n") | |
| cleaned_chat.append( (human_chat, bot_chat) ) | |
| return cleaned_chat | |
| def reset_textbox(): | |
| return gr.update(value="") | |
| def radio_on_change( | |
| value: str, | |
| disclaimer, | |
| typical_p, | |
| top_p, | |
| top_k, | |
| temperature, | |
| repetition_penalty, | |
| watermark, | |
| ): | |
| if value == "OpenAssistant/oasst-sft-1-pythia-12b": | |
| typical_p = typical_p.update(value=0.2, visible=True) | |
| top_p = top_p.update(visible=False) | |
| top_k = top_k.update(visible=False) | |
| temperature = temperature.update(visible=False) | |
| disclaimer = disclaimer.update(visible=False) | |
| repetition_penalty = repetition_penalty.update(visible=False) | |
| watermark = watermark.update(False) | |
| elif value == "togethercomputer/GPT-NeoXT-Chat-Base-20B": | |
| typical_p = typical_p.update(visible=False) | |
| top_p = top_p.update(value=0.25, visible=True) | |
| top_k = top_k.update(value=50, visible=True) | |
| temperature = temperature.update(value=0.6, visible=True) | |
| repetition_penalty = repetition_penalty.update(value=1.01, visible=True) | |
| watermark = watermark.update(False) | |
| disclaimer = disclaimer.update(visible=True) | |
| else: | |
| typical_p = typical_p.update(visible=False) | |
| top_p = top_p.update(value=0.95, visible=True) | |
| top_k = top_k.update(value=4, visible=True) | |
| temperature = temperature.update(value=0.5, visible=True) | |
| repetition_penalty = repetition_penalty.update(value=1.03, visible=True) | |
| watermark = watermark.update(True) | |
| disclaimer = disclaimer.update(visible=False) | |
| return ( | |
| disclaimer, | |
| typical_p, | |
| top_p, | |
| top_k, | |
| temperature, | |
| repetition_penalty, | |
| watermark, | |
| ) | |
| title = """<h1 align="center">π₯Teaching Assistant Chatbot""" | |
| description = """ | |
| """ | |
| openchat_disclaimer = """ | |
| <div align="center">Checkout the official <a href=https://huggingface.co/spaces/togethercomputer/OpenChatKit>OpenChatKit feedback app</a> for the full experience.</div> | |
| """ | |
| with gr.Blocks(css="""#col_container {margin-left: auto; margin-right: auto;} | |
| #chatbot {height: 520px; overflow: auto;}""") as demo: | |
| gr.HTML(title) | |
| with gr.Row(): | |
| with gr.Accordion("Model choices", open=False, visible=True): | |
| model = gr.Radio( | |
| value="OpenAssistant/oasst-sft-1-pythia-12b", | |
| choices=[ | |
| "OpenAssistant/oasst-sft-1-pythia-12b", | |
| # "togethercomputer/GPT-NeoXT-Chat-Base-20B", | |
| "Rallio67/joi2_20Be_instruct_alpha", | |
| "google/flan-t5-xxl", | |
| "google/flan-ul2", | |
| "bigscience/bloom", | |
| "bigscience/bloomz", | |
| "EleutherAI/gpt-neox-20b", | |
| ], | |
| label="", | |
| interactive=True, | |
| ) | |
| # with gr.Row(): | |
| # with gr.Column(): | |
| # use_gpt3_checkbox = gr.Checkbox(label="Include GPT-3 (paid)?") | |
| # with gr.Column(): | |
| # use_equation_checkbox = gr.Checkbox(label="Prioritize equations?") | |
| state = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(): | |
| chatbot = gr.Chatbot(elem_id="chatbot") | |
| inputs = gr.Textbox(placeholder="Ask an Electrical Engineering question!", label="Send a message...") | |
| examples = gr.Examples( | |
| examples=[ | |
| "What is a Finite State Machine?", | |
| "How do you design a functional a Two-Bit Gray Code Counter?", | |
| "How can we compare an 8-bit 2's complement number to the value -1 using AND, OR, and NOT?", | |
| "What does the uninterrupted counting cycle label mean?", | |
| ], | |
| inputs=[inputs], | |
| outputs=[], | |
| ) | |
| gr.Markdown("## Relevant Textbook Passages & Lecture Transcripts") | |
| with gr.Row(): | |
| with gr.Column(): | |
| context1 = gr.Textbox(label="Context 1") | |
| with gr.Column(): | |
| context2 = gr.Textbox(label="Context 2") | |
| with gr.Column(): | |
| context3 = gr.Textbox(label="Context 3") | |
| gr.Markdown("## Relevant Lecture Slides") | |
| with gr.Row(): | |
| with gr.Column(scale=2.6): | |
| lec_gallery = gr.Gallery(label="Lecture images", show_label=False, elem_id="gallery").style(grid=[2], height="auto") | |
| with gr.Column(scale=1): | |
| inp_image = gr.Image(type="pil", label="Reverse Image Search (optional)", shape=(224, 398)) | |
| inp_image.change(fn=clip_img_search, inputs=inp_image, outputs=lec_gallery, scroll_to_output=True) | |
| disclaimer = gr.Markdown(openchat_disclaimer, visible=False) | |
| # state = gr.State([]) | |
| with gr.Row(): | |
| with gr.Accordion("Parameters", open=False, visible=True): | |
| typical_p = gr.Slider( | |
| minimum=-0, | |
| maximum=1.0, | |
| value=0.2, | |
| step=0.05, | |
| interactive=True, | |
| label="Typical P mass", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=-0, | |
| maximum=1.0, | |
| value=0.25, | |
| step=0.05, | |
| interactive=True, | |
| label="Top-p (nucleus sampling)", | |
| visible=False, | |
| ) | |
| temperature = gr.Slider( | |
| minimum=-0, | |
| maximum=5.0, | |
| value=0.6, | |
| step=0.1, | |
| interactive=True, | |
| label="Temperature", | |
| visible=False, | |
| ) | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=50, | |
| value=50, | |
| step=1, | |
| interactive=True, | |
| label="Top-k", | |
| visible=False, | |
| ) | |
| repetition_penalty = gr.Slider( | |
| minimum=0.1, | |
| maximum=3.0, | |
| value=1.03, | |
| step=0.01, | |
| interactive=True, | |
| label="Repetition Penalty", | |
| visible=False, | |
| ) | |
| watermark = gr.Checkbox(value=False, label="Text watermarking") | |
| model.change( | |
| lambda value: radio_on_change( | |
| value, | |
| disclaimer, | |
| typical_p, | |
| top_p, | |
| top_k, | |
| temperature, | |
| repetition_penalty, | |
| watermark, | |
| ), | |
| inputs=model, | |
| outputs=[ | |
| disclaimer, | |
| typical_p, | |
| top_p, | |
| top_k, | |
| temperature, | |
| repetition_penalty, | |
| watermark, | |
| ], | |
| ) | |
| inputs.submit( | |
| predict, | |
| [ | |
| model, | |
| inputs, | |
| typical_p, | |
| top_p, | |
| temperature, | |
| top_k, | |
| repetition_penalty, | |
| watermark, | |
| chatbot, | |
| state, | |
| ], | |
| [chatbot, state, context1, context2, context3, lec_gallery], | |
| ) | |
| inputs.submit(reset_textbox, [], [inputs]) | |
| gr.Markdown(description) | |
| demo.queue(concurrency_count=16).launch(debug=True) | |