import os import sys from typing import Optional, Tuple from threading import Lock import json import shutil import gradio as gr from query_data import chain_options from query_data import get_basic_qa_chain from zipfile import ZipFile from ingest_data import ingestData from query_data import (get_basic_qa_chain, get_qa_with_sources_chain, get_custom_prompt_qa_chain, get_condense_prompt_qa_chain, get_retrievalqa_with_sources_chain) from metadatainfo import metadata_field_info from Constants import * from apiKey import * def set_openai_api_key(api_key: str): """Set the api key and return chain. If no api_key, then None is returned. """ if api_key: os.environ["OPENAI_API_KEY"] = api_key chain = getChainSelectedByUser(chainType) return chain ''' os.environ["OPENAI_API_KEY"] = api_key chain=get_basic_qa_chain() return chain''' def set_notion_api_key(api_key: str): """Set the api key and return chain. If no api_key, then None is returned. """ if api_key: os.environ["NOTION_API_KEY"] = api_key def getChainSelectedByUser(chainType: gr.Dropdown) : chain = get_basic_qa_chain() if (chainType == "with_sources" ): chain = get_qa_with_sources_chain() elif (chainType == "custom_prompt"): chain = get_custom_prompt_qa_chain() elif (chainType == "condense_prompt"): chain = get_condense_prompt_qa_chain() elif (chainType == "retrieval_sources_chain"): chain = get_retrievalqa_with_sources_chain() return chain class Logger: def __init__(self, filename): self.terminal = sys.stdout self.log = open(filename, "w") def write(self, message): self.terminal.write(message) self.log.write(message) def flush(self): self.terminal.flush() self.log.flush() def isatty(self): return False sys.stdout = Logger(LOG_FILE) def read_logs(): sys.stdout.flush() with open(LOG_FILE, "r") as f: return f.read() def upload_file(files): file_paths = [file.name for file in files] for f in file_paths: print("moving file :" + f) shutil.copy(f, DATA_DIRECTORY) return file_paths def ingest(): ingestData() class ChatWrapper: def __init__(self): self.lock = Lock() def __call__( self, api_key: str, inp: str, history: Optional[Tuple[str, str]], chain, chainType ): """Execute the chat functionality.""" self.lock.acquire() try: history = history or [] # If chain is None, that is because no API key was provided. if chain is None: '''os.environ["OPENAI_API_KEY"] = api_key chain=get_basic_qa_chain()''' history.append((inp, "Please paste your OpenAI key to use")) return history, history # Set OpenAI key import openai openai.api_key = api_key print("calling chain of type " + str(type(chain))) # Run chain and append input. results = chain({"question": inp}) #metadata=metadata_field_info, #include_run_info=True) print("result keys :") print(*results, sep=" " ) output = results["answer"] if (chainType == "with_sources") : print("document source count :"+str(len(results["source_documents"]))) for s in results["source_documents"]: for key in s.metadata: output = output + "
" + key + ":"+ s.metadata[key] + "
" elif (chainType == "retrieval_sources_chain"): print("results") #output = output + "
" + "SOURCE:" + results["sources"] history.append((inp, output)) except Exception as e: raise e finally: self.lock.release() return history, history chat = ChatWrapper() block = gr.Blocks(gr.themes.Soft(), analytics_enabled=True) with block : with gr.Row(): #api_key=OPENAI_API_KEY openai_api_key_textbox = gr.Textbox( #value=api_key, placeholder="Paste your OpenAI API key (sk-...)", show_label=False, lines=1, type="password", ) notion_api_key_textbox = gr.Textbox( #value=api_key, placeholder="Paste your Notion API key (secret-...)", show_label=False, lines=1, type="password", ) #set_openai_api_key(api_key) chatbot = gr.Chatbot() with gr.Row(): message = gr.Textbox( value="ask me something about your data", label="What's your question?", placeholder="Ask questions about the most recent state of the union", lines=1, ) submit = gr.Button(value="Send", variant="secondary").style( scale=1) gr.Examples( examples=[ "Who is Tanmay Chopra?", "Which persons know about the topics LLM?", "What did Navid say about LLM?", ], inputs=message, ) with gr.Row(): chainType = gr.Dropdown(list(chain_options.keys()), label="Chain Type", value="basic" ) with gr.Accordion(label="show_logs"): logs = gr.Textbox(label="Console") block.load(read_logs, None, logs, every=1) file_output = gr.File() upload_button = gr.UploadButton("Click to Upload a File", file_types=[".docx", ".pdf",".txt",".json"], file_count="multiple") files = upload_button.upload(upload_file, upload_button, file_output) # gr.Gallery(files) btn = gr.Button(value="Ingest") btn.click(ingest) gr.HTML("Demo application of a LangChain chain.") gr.HTML( "
Powered by LangChain 🦜️🔗
" ) state = gr.State() agent_state = gr.State() submit.click(chat, inputs=[openai_api_key_textbox,message, state, agent_state, chainType], outputs=[chatbot, state]) message.submit(chat, inputs=[ openai_api_key_textbox, message, state, agent_state, chainType], outputs=[chatbot, state]) openai_api_key_textbox.change( set_openai_api_key, inputs=[openai_api_key_textbox], outputs=[agent_state], ) notion_api_key_textbox.change( set_notion_api_key, inputs=[notion_api_key_textbox], outputs=[agent_state], ) chainType.change( getChainSelectedByUser, inputs=[chainType], outputs=[agent_state], ) block.queue().launch(debug=True)