Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import gradio as gr | |
| import openai | |
| from typing import Iterable | |
| from langchain.document_loaders import WebBaseLoader | |
| from langchain.embeddings import OpenAIEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.schema import Document | |
| from langchain.agents import load_tools, initialize_agent | |
| from langchain.agents import AgentType | |
| from langchain.tools import Tool | |
| from langchain.utilities import GoogleSearchAPIWrapper | |
| openai.api_key = os.environ['OPENAI_API_KEY'] | |
| def save_docs_to_jsonl(array:Iterable[Document], file_path:str)->None: | |
| with open(file_path, 'w') as jsonl_file: | |
| for doc in array: | |
| jsonl_file.write(doc.json() + '\n') | |
| def load_docs_from_jsonl(file_path) -> Iterable[Document]: | |
| if not os.path.exists(file_path): | |
| print("Invalid file path.") | |
| return [] | |
| array = [] | |
| with open(file_path, 'r') as jsonl_file: | |
| for line in jsonl_file: | |
| data = json.loads(line) | |
| obj = Document(**data) | |
| array.append(obj) | |
| return array | |
| # Loading all the documents if they are not found locally | |
| documents = load_docs_from_jsonl('striim_docs.jsonl') | |
| # Split the documents into smaller chunks | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=500) | |
| docs = text_splitter.split_documents(documents) | |
| # Convert the document chunks to embedding and save them to the vector store | |
| vectordb = FAISS.from_documents(docs, embedding=OpenAIEmbeddings()) | |
| # create our Q&A chain | |
| pdf_qa = ConversationalRetrievalChain.from_llm( | |
| ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo'), | |
| retriever=vectordb.as_retriever(search_type="similarity", search_kwargs={'k': 4}), | |
| return_generated_question=True, | |
| return_source_documents=True, | |
| verbose=False, | |
| ) | |
| # Function to query Google if user selects allow internet access | |
| def get_query_from_internet(user_query, temperature=0): | |
| delimiter = "```" | |
| # Checking if user query is flagged as inappropriate | |
| response = openai.Moderation.create(input=user_query["question"]) | |
| moderation_output = response["results"][0] | |
| if moderation_output["flagged"]: | |
| return "Your query was flagged as inappropriate. Please try again." | |
| search = GoogleSearchAPIWrapper() | |
| tool = Tool( | |
| name="Google Search", | |
| description="Search Google for recent results.", | |
| func=search.run, | |
| ) | |
| llm = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo') | |
| tools = load_tools(["requests_all"]) | |
| tools += [tool] | |
| agent_chain = initialize_agent( | |
| tools, | |
| llm, | |
| agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
| verbose=True, | |
| handle_parsing_errors="Check your output and make sure it conforms!" | |
| ) | |
| return agent_chain.run({'input': user_query}) | |
| # Front end web application using Gradio | |
| CSS =""" | |
| footer.svelte-1ax1toq.svelte-1ax1toq.svelte-1ax1toq.svelte-1ax1toq { display: none; } | |
| #chatbot { height: 70vh !important;} | |
| #submit-button { background: #00A7E5; color: white; } | |
| #submit-button:hover { background: #00A7E5; color: white; box-shadow: 0 8px 10px 1px #9d9ea124, 0 3px 14px 2px #9d9ea11f, 0 5px 5px -3px #9d9ea133; } | |
| """ | |
| with gr.Blocks(theme='samayg/StriimTheme', css=CSS) as demo: | |
| # image = gr.Image('striim-logo-light.png', height=68, width=200, show_label=False, show_download_button=False, show_share_button=False) | |
| chatbot = gr.Chatbot(show_label=False, elem_id="chatbot") | |
| msg = gr.Textbox(label="Question:") | |
| user = gr.State("gradio") | |
| examples = gr.Examples(examples=[['What\'s new in Striim version 4.2.0?'], ['My Striim application keeps crashing. What should I do?'], ['How can I improve Striim performance?'], ['It says could not connect to source or target. What should I do?']], inputs=msg, label="Examples") | |
| submit = gr.Button("Submit", elem_id="submit-button") | |
| #with gr.Accordion(label="Advanced options", open=False): | |
| #slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0, label="Temperature", info="The temperature of StriimGPT, default at 0. Higher values may allow for better inference but may fabricate false information.") | |
| #internet_access = gr.Checkbox(value=False, label="Allow Internet Access?", info="If the chatbot cannot answer your question, this setting allows for internet access. Warning: this may take longer and produce inaccurate results.") | |
| chat_history = [] | |
| def getResponse(query, history, userId): | |
| global chat_history | |
| #if allow_internet: | |
| # Get response from internet-based query function | |
| # result = get_query_from_internet({"question": query, "chat_history": chat_history}, temperature=slider.value) | |
| # answer = result | |
| #else: | |
| # Get response from QA chain | |
| result = pdf_qa({"question": query, "chat_history": chat_history}) | |
| answer = result["answer"] | |
| # Append user message and response to chat history | |
| chat_history.append((query, answer)) | |
| # Only keeps last 5 messages to not exceed tokens | |
| chat_history = chat_history[-5:] | |
| return gr.update(value=""), chat_history, userId | |
| # The msg.submit() now also depends on the status of the internet_access checkbox | |
| msg.submit(getResponse, [msg, chatbot, user], [msg, chatbot, user], queue=False) | |
| submit.click(getResponse, [msg, chatbot, user], [msg, chatbot, user], queue=False) | |
| if __name__ == "__main__": | |
| # demo.launch(debug=True) | |
| demo.launch(debug=True, share=True) | |