Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import gradio as gr | |
| import openai | |
| from typing import Iterable | |
| from langchain.document_loaders import WebBaseLoader | |
| from langchain.embeddings import OpenAIEmbeddings | |
| from langchain.vectorstores import FAISS | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.schema import Document | |
| from langchain.agents import load_tools, initialize_agent | |
| from langchain.agents import AgentType | |
| from langchain.tools import Tool | |
| from langchain.utilities import GoogleSearchAPIWrapper | |
| openai.api_key = os.environ['OPENAI_API_KEY'] | |
| def save_docs_to_jsonl(array:Iterable[Document], file_path:str)->None: | |
| with open(file_path, 'w') as jsonl_file: | |
| for doc in array: | |
| jsonl_file.write(doc.json() + '\n') | |
| def load_docs_from_jsonl(file_path) -> Iterable[Document]: | |
| if not os.path.exists(file_path): | |
| print("Invalid file path.") | |
| return [] | |
| array = [] | |
| with open(file_path, 'r') as jsonl_file: | |
| for line in jsonl_file: | |
| data = json.loads(line) | |
| obj = Document(**data) | |
| array.append(obj) | |
| return array | |
| # Loading all the documents if they are not found locally | |
| documents = load_docs_from_jsonl('striim_docs.jsonl') | |
| # Split the documents into smaller chunks | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=500) | |
| docs = text_splitter.split_documents(documents) | |
| # Convert the document chunks to embedding and save them to the vector store | |
| vectordb = FAISS.from_documents(docs, embedding=OpenAIEmbeddings()) | |
| # create our Q&A chain | |
| pdf_qa = ConversationalRetrievalChain.from_llm( | |
| ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo'), | |
| retriever=vectordb.as_retriever(search_type="similarity", search_kwargs={'k': 4}), | |
| return_generated_question=True, | |
| return_source_documents=True, | |
| verbose=False, | |
| ) | |
| # Function to query Google if user selects allow internet access | |
| def get_query_from_internet(user_query, temperature=0): | |
| delimiter = "```" | |
| # Checking if user query is flagged as inappropriate | |
| response = openai.Moderation.create(input=user_query["question"]) | |
| moderation_output = response["results"][0] | |
| if moderation_output["flagged"]: | |
| return "Your query was flagged as inappropriate. Please try again." | |
| search = GoogleSearchAPIWrapper() | |
| tool = Tool( | |
| name="Google Search", | |
| description="Search Google for recent results.", | |
| func=search.run, | |
| ) | |
| llm = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo') | |
| tools = load_tools(["requests_all"]) | |
| tools += [tool] | |
| agent_chain = initialize_agent( | |
| tools, | |
| llm, | |
| agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
| verbose=True, | |
| handle_parsing_errors="Check your output and make sure it conforms!" | |
| ) | |
| return agent_chain.run({'input': user_query}) | |
| # Front end web application using Gradio | |
| chat_history = [] | |
| CSS =""" | |
| .contain { display: flex; flex-direction: column; } | |
| .svelte-1ax1toq { display: none; } | |
| #component-0 { height: 100%; } | |
| #chatbot { flex-grow: 1; overflow: auto;} | |
| """ | |
| with gr.Blocks(theme='samayg/StriimTheme', css=CSS) as demo: | |
| image = gr.Image('striim-logo-light.png', height=47, width=200, show_label=False, show_download_button=False, show_share_button=False) | |
| chatbot = gr.Chatbot(show_label=False, height=300) | |
| msg = gr.Textbox(label="Question:") | |
| examples = gr.Examples(examples=[['What\'s new in Striim version 4.2.0?'], ['My Striim application keeps crashing. What should I do?'], ['How can I improve Striim performance?'], ['It says could not connect to source or target. What should I do?']], inputs=msg, label="Examples") | |
| submit = gr.Button("Submit") | |
| #with gr.Accordion(label="Advanced options", open=False): | |
| #slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0, label="Temperature", info="The temperature of StriimGPT, default at 0. Higher values may allow for better inference but may fabricate false information.") | |
| #internet_access = gr.Checkbox(value=False, label="Allow Internet Access?", info="If the chatbot cannot answer your question, this setting allows for internet access. Warning: this may take longer and produce inaccurate results.") | |
| def user(query, history): | |
| #if allow_internet: | |
| # Get response from internet-based query function | |
| # result = get_query_from_internet({"question": query, "chat_history": chat_history}, temperature=slider.value) | |
| # answer = result | |
| #else: | |
| # Get response from QA chain | |
| result = pdf_qa({"question": query, "chat_history": chat_history}) | |
| answer = result["answer"] | |
| # Append user message and response to chat history | |
| chat_history.append((query, answer)) | |
| return gr.update(value=""), chat_history | |
| # The msg.submit() now also depends on the status of the internet_access checkbox | |
| msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False) | |
| submit.click(user, [msg, chatbot], [msg, chatbot], queue=False) | |
| if __name__ == "__main__": | |
| # demo.launch(debug=True) | |
| demo.launch(debug=True, share=True) | |