handbook-qa / app.py
varun96's picture
Update app.py
002f1e7
import os
from typing import Iterator
# from retreiver import md_header_splits, create_embeddings
from retreiver import vectordb
from retreiver import examples
import gradio as gr
from model import run
import requests
def retreive(user_query: str, ret_k: int, mmr: bool) -> list[str]:
retreived_passages = []
if mmr:
retreivals = vectordb.max_marginal_relevance_search(query=user_query, k=ret_k, fetch_k=10)
else:
retreivals = vectordb.similarity_search(query=user_query, k=ret_k)
for item in retreivals:
retreived_passages.append(item.page_content)
return retreived_passages
HF_PUBLIC = os.environ.get("HF_PUBLIC", False)
DEFAULT_SYSTEM_PROMPT = "You are Mistral. You are a Question-Answer bot. Answer the question as truthfully as possible.'"
MAX_MAX_NEW_TOKENS = 4096
DEFAULT_MAX_NEW_TOKENS = 512
MAX_INPUT_TOKEN_LENGTH = 4000
DESCRIPTION = """
# Handbook QA Bot
"""
def clear_and_save_textbox(message: str) -> tuple[str, str]:
return '', message
def display_input(message: str,
history: list[tuple[str, str]]) -> list[tuple[str, str]]:
history.append((message, ''))
return history
def delete_prev_fn(
history: list[tuple[str, str]]) -> tuple[list[tuple[str, str]], str]:
try:
message, _ = history.pop()
except IndexError:
message = ''
return history, message or ''
def generate(
message: str,
history_with_input: list[tuple[str, str]],
system_prompt: str,
max_new_tokens: int,
temperature: float,
top_p: float,
top_k: int,
ret_k: int,
mmr: bool,
) -> Iterator[list[tuple[str, str]]]:
if max_new_tokens > MAX_MAX_NEW_TOKENS:
raise ValueError
retrieved_passages = retreive(message, ret_k, mmr)
context = ' \n '.join(retrieved_passages)
full_message = f'''
Context information is below.
{context}
Given the context information and not prior knowledge, answer the question
Question: {message}
Answer:
'''
history = history_with_input[:-1]
generator = run(full_message, history, system_prompt, max_new_tokens, temperature, top_p, top_k)
try:
first_response = next(generator)
yield history + [(message, first_response)]
except StopIteration:
yield history + [(message, '')]
for response in generator:
yield history + [(message, response)]
if message not in examples:
url = 'https://api.jsonbin.io/v3/b'
headers = {
'Content-Type': 'application/json',
'X-Access-Key': os.environ.get('JSON_BIN', False),
'X-Collection-Id': '6572cc4c0574da7622d1f738'
}
data = {"query": message, "context": context, "response": response}
requests.post(url, json=data, headers=headers)
def process_example(message: str) -> tuple[str, list[tuple[str, str]]]:
generator = generate(message, [], DEFAULT_SYSTEM_PROMPT, 1024, 1, 0.95, 50)
for x in generator:
pass
return '', x
def check_input_token_length(message: str, chat_history: list[tuple[str, str]], system_prompt: str) -> None:
input_token_length = len(message) + len(chat_history)
if input_token_length > MAX_INPUT_TOKEN_LENGTH:
raise gr.Error(f'The accumulated input is too long ({input_token_length} > {MAX_INPUT_TOKEN_LENGTH}). Clear your chat history and try again.')
with gr.Blocks(css='style.css') as demo:
gr.Markdown(DESCRIPTION)
gr.HTML(
'''
<center>
<a href='https://docs.google.com/document/d/1nF4mNJV0eTa2XjhzmhrfRBmh3zs8rmMNYn9K8zwtj6I/edit?usp=sharing'> Please add your feedback here. Thanks!</a>
</center>
'''
)
with gr.Group():
chatbot = gr.Chatbot(label='QA bot', show_copy_button=True)
with gr.Row():
textbox = gr.Textbox(
container=False,
show_label=False,
placeholder='Hi!',
scale=10,
)
submit_button = gr.Button('Submit',
variant='primary',
scale=1,
min_width=0)
gr.Examples(
examples=examples,
inputs=textbox,
)
with gr.Row():
retry_button = gr.Button('🔄 Retry', variant='secondary')
undo_button = gr.Button('↩️ Undo', variant='secondary')
clear_button = gr.Button('🗑️ Clear', variant='secondary')
saved_input = gr.State()
with gr.Accordion(label='⚙️ Advanced options', open=False):
system_prompt = gr.Textbox(label='System prompt',
value=DEFAULT_SYSTEM_PROMPT,
lines=5,
interactive=False,
visible=False,)
max_new_tokens = gr.Slider(
label='Max new tokens',
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
info="The maximum numbers of new tokens",
)
temperature = gr.Slider(
label='Temperature',
minimum=0.1,
maximum=4.0,
step=0.1,
value=0.1,
info="Higher values produce more diverse outputs",
)
top_p = gr.Slider(
label='Top-p (nucleus sampling)',
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
info="Higher values sample more low-probability tokens",
)
top_k = gr.Slider(
label='Top-k',
minimum=1,
maximum=1000,
step=1,
value=10,
info="Limits the model’s predictions to the top k most probable tokens at each step of generation",
)
ret_k = gr.Slider(
label='Retreiver-k Results',
minimum=1,
maximum=20,
step=1,
value=5,
info="Top k results retreived by the retreiver",
)
mmr = gr.Checkbox(
info=" By Default Cosine Similarity is is used. MMR finds the top-10 using cosine similarity and ret_k using MMR",
label="Use Maximum Marginal Relevance (MMR)",
)
textbox.submit(
fn=clear_and_save_textbox,
inputs=textbox,
outputs=[textbox, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).then(
fn=check_input_token_length,
inputs=[saved_input, chatbot, system_prompt],
api_name=False,
queue=False,
).success(
fn=generate,
inputs=[
saved_input,
chatbot,
system_prompt,
max_new_tokens,
temperature,
top_p,
top_k,
ret_k,
mmr,
],
outputs=chatbot,
api_name=False,
)
button_event_preprocess = submit_button.click(
fn=clear_and_save_textbox,
inputs=textbox,
outputs=[textbox, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).then(
fn=check_input_token_length,
inputs=[saved_input, chatbot, system_prompt],
api_name=False,
queue=False,
).success(
fn=generate,
inputs=[
saved_input,
chatbot,
system_prompt,
max_new_tokens,
temperature,
top_p,
top_k,
ret_k,
mmr,
],
outputs=chatbot,
api_name=False,
)
retry_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
api_name=False,
queue=False,
).then(
fn=display_input,
inputs=[saved_input, chatbot],
outputs=chatbot,
api_name=False,
queue=False,
).then(
fn=generate,
inputs=[
saved_input,
chatbot,
system_prompt,
max_new_tokens,
temperature,
top_p,
top_k,
ret_k,
mmr,
],
outputs=chatbot,
api_name=False,
)
undo_button.click(
fn=delete_prev_fn,
inputs=chatbot,
outputs=[chatbot, saved_input],
api_name=False,
queue=False,
).then(
fn=lambda x: x,
inputs=[saved_input],
outputs=textbox,
api_name=False,
queue=False,
)
clear_button.click(
fn=lambda: ([], ''),
outputs=[chatbot, saved_input],
queue=False,
api_name=False,
)
demo.queue(concurrency_count=75,max_size=32).launch(share=HF_PUBLIC, show_api=False)