striim-gpt / app.py
samayg's picture
test: remove footer
da77c33
raw
history blame
5.26 kB
import os
import json
import gradio as gr
import openai
from typing import Iterable
from langchain.document_loaders import WebBaseLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import FAISS
from langchain.chat_models import ChatOpenAI
from langchain.chains import ConversationalRetrievalChain
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.agents import load_tools, initialize_agent
from langchain.agents import AgentType
from langchain.tools import Tool
from langchain.utilities import GoogleSearchAPIWrapper
openai.api_key = os.environ['OPENAI_API_KEY']
def save_docs_to_jsonl(array:Iterable[Document], file_path:str)->None:
with open(file_path, 'w') as jsonl_file:
for doc in array:
jsonl_file.write(doc.json() + '\n')
def load_docs_from_jsonl(file_path) -> Iterable[Document]:
if not os.path.exists(file_path):
print("Invalid file path.")
return []
array = []
with open(file_path, 'r') as jsonl_file:
for line in jsonl_file:
data = json.loads(line)
obj = Document(**data)
array.append(obj)
return array
# Loading all the documents if they are not found locally
documents = load_docs_from_jsonl('striim_docs.jsonl')
# Split the documents into smaller chunks
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=500)
docs = text_splitter.split_documents(documents)
# Convert the document chunks to embedding and save them to the vector store
vectordb = FAISS.from_documents(docs, embedding=OpenAIEmbeddings())
# create our Q&A chain
pdf_qa = ConversationalRetrievalChain.from_llm(
ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo'),
retriever=vectordb.as_retriever(search_type="similarity", search_kwargs={'k': 4}),
return_generated_question=True,
return_source_documents=True,
verbose=False,
)
# Function to query Google if user selects allow internet access
def get_query_from_internet(user_query, temperature=0):
delimiter = "```"
# Checking if user query is flagged as inappropriate
response = openai.Moderation.create(input=user_query["question"])
moderation_output = response["results"][0]
if moderation_output["flagged"]:
return "Your query was flagged as inappropriate. Please try again."
search = GoogleSearchAPIWrapper()
tool = Tool(
name="Google Search",
description="Search Google for recent results.",
func=search.run,
)
llm = ChatOpenAI(temperature=0, model_name='gpt-3.5-turbo')
tools = load_tools(["requests_all"])
tools += [tool]
agent_chain = initialize_agent(
tools,
llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
handle_parsing_errors="Check your output and make sure it conforms!"
)
return agent_chain.run({'input': user_query})
# Front end web application using Gradio
chat_history = []
CSS ="""
.contain { display: flex; flex-direction: column; }
.svelte-1ax1toq { display: none; }
#component-0 { height: 100%; }
#chatbot { flex-grow: 1; overflow: auto;}
"""
with gr.Blocks(theme='samayg/StriimTheme', css=CSS) as demo:
image = gr.Image('striim-logo-light.png', height=47, width=200, show_label=False, show_download_button=False, show_share_button=False)
chatbot = gr.Chatbot(show_label=False, height=300)
msg = gr.Textbox(label="Question:")
examples = gr.Examples(examples=[['What\'s new in Striim version 4.2.0?'], ['My Striim application keeps crashing. What should I do?'], ['How can I improve Striim performance?'], ['It says could not connect to source or target. What should I do?']], inputs=msg, label="Examples")
submit = gr.Button("Submit")
#with gr.Accordion(label="Advanced options", open=False):
#slider = gr.Slider(minimum=0, maximum=1, step=0.01, value=0, label="Temperature", info="The temperature of StriimGPT, default at 0. Higher values may allow for better inference but may fabricate false information.")
#internet_access = gr.Checkbox(value=False, label="Allow Internet Access?", info="If the chatbot cannot answer your question, this setting allows for internet access. Warning: this may take longer and produce inaccurate results.")
def user(query, history):
#if allow_internet:
# Get response from internet-based query function
# result = get_query_from_internet({"question": query, "chat_history": chat_history}, temperature=slider.value)
# answer = result
#else:
# Get response from QA chain
result = pdf_qa({"question": query, "chat_history": chat_history})
answer = result["answer"]
# Append user message and response to chat history
chat_history.append((query, answer))
return gr.update(value=""), chat_history
# The msg.submit() now also depends on the status of the internet_access checkbox
msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False)
submit.click(user, [msg, chatbot], [msg, chatbot], queue=False)
if __name__ == "__main__":
# demo.launch(debug=True)
demo.launch(debug=True, share=True)