spur-chatbot / app.py
snehasquasher's picture
Upload folder using huggingface_hub
becc174
import os
import sys
from typing import Optional, Tuple
from threading import Lock
import json
import shutil
import gradio as gr
from query_data import chain_options
from query_data import get_basic_qa_chain
from zipfile import ZipFile
from ingest_data import ingestData
from query_data import (get_basic_qa_chain,
get_qa_with_sources_chain,
get_custom_prompt_qa_chain,
get_condense_prompt_qa_chain,
get_retrievalqa_with_sources_chain)
from metadatainfo import metadata_field_info
from Constants import *
from apiKey import *
def set_openai_api_key(api_key: str):
"""Set the api key and return chain.
If no api_key, then None is returned.
"""
if api_key:
os.environ["OPENAI_API_KEY"] = api_key
chain = getChainSelectedByUser(chainType)
return chain
'''
os.environ["OPENAI_API_KEY"] = api_key
chain=get_basic_qa_chain()
return chain'''
def set_notion_api_key(api_key: str):
"""Set the api key and return chain.
If no api_key, then None is returned.
"""
if api_key:
os.environ["NOTION_API_KEY"] = api_key
def getChainSelectedByUser(chainType: gr.Dropdown) :
chain = get_basic_qa_chain()
if (chainType == "with_sources" ):
chain = get_qa_with_sources_chain()
elif (chainType == "custom_prompt"):
chain = get_custom_prompt_qa_chain()
elif (chainType == "condense_prompt"):
chain = get_condense_prompt_qa_chain()
elif (chainType == "retrieval_sources_chain"):
chain = get_retrievalqa_with_sources_chain()
return chain
class Logger:
def __init__(self, filename):
self.terminal = sys.stdout
self.log = open(filename, "w")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
def flush(self):
self.terminal.flush()
self.log.flush()
def isatty(self):
return False
sys.stdout = Logger(LOG_FILE)
def read_logs():
sys.stdout.flush()
with open(LOG_FILE, "r") as f:
return f.read()
def upload_file(files):
file_paths = [file.name for file in files]
for f in file_paths:
print("moving file :" + f)
shutil.copy(f, DATA_DIRECTORY)
return file_paths
def ingest():
ingestData()
class ChatWrapper:
def __init__(self):
self.lock = Lock()
def __call__(
self, api_key: str, inp: str, history: Optional[Tuple[str, str]], chain, chainType
):
"""Execute the chat functionality."""
self.lock.acquire()
try:
history = history or []
# If chain is None, that is because no API key was provided.
if chain is None:
'''os.environ["OPENAI_API_KEY"] = api_key
chain=get_basic_qa_chain()'''
history.append((inp, "Please paste your OpenAI key to use"))
return history, history
# Set OpenAI key
import openai
openai.api_key = api_key
print("calling chain of type " + str(type(chain)))
# Run chain and append input.
results = chain({"question": inp})
#metadata=metadata_field_info,
#include_run_info=True)
print("result keys :")
print(*results, sep=" " )
output = results["answer"]
if (chainType == "with_sources") :
print("document source count :"+str(len(results["source_documents"])))
for s in results["source_documents"]:
for key in s.metadata:
output = output + "<br>" + key + ":"+ s.metadata[key] + "<br>"
elif (chainType == "retrieval_sources_chain"):
print("results")
#output = output + "<br>" + "SOURCE:" + results["sources"]
history.append((inp, output))
except Exception as e:
raise e
finally:
self.lock.release()
return history, history
chat = ChatWrapper()
block = gr.Blocks(gr.themes.Soft(),
analytics_enabled=True)
with block :
with gr.Row():
#api_key=OPENAI_API_KEY
openai_api_key_textbox = gr.Textbox(
#value=api_key,
placeholder="Paste your OpenAI API key (sk-...)",
show_label=False,
lines=1,
type="password",
)
notion_api_key_textbox = gr.Textbox(
#value=api_key,
placeholder="Paste your Notion API key (secret-...)",
show_label=False,
lines=1,
type="password",
)
#set_openai_api_key(api_key)
chatbot = gr.Chatbot()
with gr.Row():
message = gr.Textbox(
value="ask me something about your data",
label="What's your question?",
placeholder="Ask questions about the most recent state of the union",
lines=1,
)
submit = gr.Button(value="Send", variant="secondary").style(
scale=1)
gr.Examples(
examples=[
"Who is Tanmay Chopra?",
"Which persons know about the topics LLM?",
"What did Navid say about LLM?",
],
inputs=message,
)
with gr.Row():
chainType = gr.Dropdown(list(chain_options.keys()),
label="Chain Type", value="basic"
)
with gr.Accordion(label="show_logs"):
logs = gr.Textbox(label="Console")
block.load(read_logs, None, logs, every=1)
file_output = gr.File()
upload_button = gr.UploadButton("Click to Upload a File", file_types=[".docx", ".pdf",".txt",".json"], file_count="multiple")
files = upload_button.upload(upload_file, upload_button, file_output)
# gr.Gallery(files)
btn = gr.Button(value="Ingest")
btn.click(ingest)
gr.HTML("Demo application of a LangChain chain.")
gr.HTML(
"<center>Powered by <a href='https://github.com/hwchase17/langchain'>LangChain 🦜️🔗</a></center>"
)
state = gr.State()
agent_state = gr.State()
submit.click(chat, inputs=[openai_api_key_textbox,message,
state, agent_state, chainType], outputs=[chatbot, state])
message.submit(chat, inputs=[
openai_api_key_textbox, message, state, agent_state, chainType], outputs=[chatbot, state])
openai_api_key_textbox.change(
set_openai_api_key,
inputs=[openai_api_key_textbox],
outputs=[agent_state],
)
notion_api_key_textbox.change(
set_notion_api_key,
inputs=[notion_api_key_textbox],
outputs=[agent_state],
)
chainType.change(
getChainSelectedByUser,
inputs=[chainType],
outputs=[agent_state],
)
block.queue().launch(debug=True)