Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| from typing import Optional, Tuple | |
| from threading import Lock | |
| import json | |
| import shutil | |
| import gradio as gr | |
| from query_data import chain_options | |
| from query_data import get_basic_qa_chain | |
| from zipfile import ZipFile | |
| from ingest_data import ingestData | |
| from query_data import (get_basic_qa_chain, | |
| get_qa_with_sources_chain, | |
| get_custom_prompt_qa_chain, | |
| get_condense_prompt_qa_chain, | |
| get_retrievalqa_with_sources_chain) | |
| from metadatainfo import metadata_field_info | |
| from Constants import * | |
| from apiKey import * | |
| def set_openai_api_key(api_key: str): | |
| """Set the api key and return chain. | |
| If no api_key, then None is returned. | |
| """ | |
| if api_key: | |
| os.environ["OPENAI_API_KEY"] = api_key | |
| chain = getChainSelectedByUser(chainType) | |
| return chain | |
| ''' | |
| os.environ["OPENAI_API_KEY"] = api_key | |
| chain=get_basic_qa_chain() | |
| return chain''' | |
| def set_notion_api_key(api_key: str): | |
| """Set the api key and return chain. | |
| If no api_key, then None is returned. | |
| """ | |
| if api_key: | |
| os.environ["NOTION_API_KEY"] = api_key | |
| def getChainSelectedByUser(chainType: gr.Dropdown) : | |
| chain = get_basic_qa_chain() | |
| if (chainType == "with_sources" ): | |
| chain = get_qa_with_sources_chain() | |
| elif (chainType == "custom_prompt"): | |
| chain = get_custom_prompt_qa_chain() | |
| elif (chainType == "condense_prompt"): | |
| chain = get_condense_prompt_qa_chain() | |
| elif (chainType == "retrieval_sources_chain"): | |
| chain = get_retrievalqa_with_sources_chain() | |
| return chain | |
| class Logger: | |
| def __init__(self, filename): | |
| self.terminal = sys.stdout | |
| self.log = open(filename, "w") | |
| def write(self, message): | |
| self.terminal.write(message) | |
| self.log.write(message) | |
| def flush(self): | |
| self.terminal.flush() | |
| self.log.flush() | |
| def isatty(self): | |
| return False | |
| sys.stdout = Logger(LOG_FILE) | |
| def read_logs(): | |
| sys.stdout.flush() | |
| with open(LOG_FILE, "r") as f: | |
| return f.read() | |
| def upload_file(files): | |
| file_paths = [file.name for file in files] | |
| for f in file_paths: | |
| print("moving file :" + f) | |
| shutil.copy(f, DATA_DIRECTORY) | |
| return file_paths | |
| def ingest(): | |
| ingestData() | |
| class ChatWrapper: | |
| def __init__(self): | |
| self.lock = Lock() | |
| def __call__( | |
| self, api_key: str, inp: str, history: Optional[Tuple[str, str]], chain, chainType | |
| ): | |
| """Execute the chat functionality.""" | |
| self.lock.acquire() | |
| try: | |
| history = history or [] | |
| # If chain is None, that is because no API key was provided. | |
| if chain is None: | |
| '''os.environ["OPENAI_API_KEY"] = api_key | |
| chain=get_basic_qa_chain()''' | |
| history.append((inp, "Please paste your OpenAI key to use")) | |
| return history, history | |
| # Set OpenAI key | |
| import openai | |
| openai.api_key = api_key | |
| print("calling chain of type " + str(type(chain))) | |
| # Run chain and append input. | |
| results = chain({"question": inp}) | |
| #metadata=metadata_field_info, | |
| #include_run_info=True) | |
| print("result keys :") | |
| print(*results, sep=" " ) | |
| output = results["answer"] | |
| if (chainType == "with_sources") : | |
| print("document source count :"+str(len(results["source_documents"]))) | |
| for s in results["source_documents"]: | |
| for key in s.metadata: | |
| output = output + "<br>" + key + ":"+ s.metadata[key] + "<br>" | |
| elif (chainType == "retrieval_sources_chain"): | |
| print("results") | |
| #output = output + "<br>" + "SOURCE:" + results["sources"] | |
| history.append((inp, output)) | |
| except Exception as e: | |
| raise e | |
| finally: | |
| self.lock.release() | |
| return history, history | |
| chat = ChatWrapper() | |
| block = gr.Blocks(gr.themes.Soft(), | |
| analytics_enabled=True) | |
| with block : | |
| with gr.Row(): | |
| #api_key=OPENAI_API_KEY | |
| openai_api_key_textbox = gr.Textbox( | |
| #value=api_key, | |
| placeholder="Paste your OpenAI API key (sk-...)", | |
| show_label=False, | |
| lines=1, | |
| type="password", | |
| ) | |
| notion_api_key_textbox = gr.Textbox( | |
| #value=api_key, | |
| placeholder="Paste your Notion API key (secret-...)", | |
| show_label=False, | |
| lines=1, | |
| type="password", | |
| ) | |
| #set_openai_api_key(api_key) | |
| chatbot = gr.Chatbot() | |
| with gr.Row(): | |
| message = gr.Textbox( | |
| value="ask me something about your data", | |
| label="What's your question?", | |
| placeholder="Ask questions about the most recent state of the union", | |
| lines=1, | |
| ) | |
| submit = gr.Button(value="Send", variant="secondary").style( | |
| scale=1) | |
| gr.Examples( | |
| examples=[ | |
| "Who is Tanmay Chopra?", | |
| "Which persons know about the topics LLM?", | |
| "What did Navid say about LLM?", | |
| ], | |
| inputs=message, | |
| ) | |
| with gr.Row(): | |
| chainType = gr.Dropdown(list(chain_options.keys()), | |
| label="Chain Type", value="basic" | |
| ) | |
| with gr.Accordion(label="show_logs"): | |
| logs = gr.Textbox(label="Console") | |
| block.load(read_logs, None, logs, every=1) | |
| file_output = gr.File() | |
| upload_button = gr.UploadButton("Click to Upload a File", file_types=[".docx", ".pdf",".txt",".json"], file_count="multiple") | |
| files = upload_button.upload(upload_file, upload_button, file_output) | |
| # gr.Gallery(files) | |
| btn = gr.Button(value="Ingest") | |
| btn.click(ingest) | |
| gr.HTML("Demo application of a LangChain chain.") | |
| gr.HTML( | |
| "<center>Powered by <a href='https://github.com/hwchase17/langchain'>LangChain 🦜️🔗</a></center>" | |
| ) | |
| state = gr.State() | |
| agent_state = gr.State() | |
| submit.click(chat, inputs=[openai_api_key_textbox,message, | |
| state, agent_state, chainType], outputs=[chatbot, state]) | |
| message.submit(chat, inputs=[ | |
| openai_api_key_textbox, message, state, agent_state, chainType], outputs=[chatbot, state]) | |
| openai_api_key_textbox.change( | |
| set_openai_api_key, | |
| inputs=[openai_api_key_textbox], | |
| outputs=[agent_state], | |
| ) | |
| notion_api_key_textbox.change( | |
| set_notion_api_key, | |
| inputs=[notion_api_key_textbox], | |
| outputs=[agent_state], | |
| ) | |
| chainType.change( | |
| getChainSelectedByUser, | |
| inputs=[chainType], | |
| outputs=[agent_state], | |
| ) | |
| block.queue().launch(debug=True) | |