talicchatbot / app.py
RonsonChau's picture
initial commit
5e918fe verified
import gradio as gr
import random
import time
import openai
import os
openai.api_type = os.environ['OPENAI_API_TYPE']
openai.api_key = os.environ['OPENAI_API_KEY']
openai.api_base = os.environ['OPENAI_API_BASE']
openai.api_version = os.environ['OPENAI_API_VERSION']
######################## Input TASK 1A ########################
from langchain.document_loaders import PyPDFLoader, OnlinePDFLoader
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.text_splitter import CharacterTextSplitter
linkToPDF = os.environ['ONLINE_PDF_URL']
loader = OnlinePDFLoader(linkToPDF)
documents = loader.load()
chuck_size = 1000
chuck_overlap = 200
text_splitter = CharacterTextSplitter(chunk_size=chuck_size, chunk_overlap=chuck_overlap)
docs = text_splitter.split_documents(documents)
###########################END OF TASK 1A ##################################
######################## Input TASK 1B ######3##############################
from langchain.vectorstores import Chroma
# create the open-source embedding function
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
# create simple ids - Index
ids = [str(i) for i in range(1, len(docs) + 1)]
# load it into Chroma
db = Chroma.from_documents(docs, embedding_function, ids=ids, collection_metadata={"hnsw:space": "cosine"} )
###########################END OF TASK 1B ##################################
###########################INPUT OF TASK 3B ##################################
# We try to limit the number of characters for the context
contextCharsLimit = 3072
promptStart = "Answer questions truthfully based on the information in sources provided below. \n If you cannot find the answer to a question based on the sources below, respond by saying “I apologize, but I am unable to provide an answer to your question, which is out of the scope of the document uploaded. Thank you! \n Sources:\n"
# We try to construct a function which can return the system prompt based on user query and fit in context into system prompt
def construct_system_prompt_chromadb (userQuery):
### find the number of relevant documents
docs = db.similarity_search(userQuery)
context = []
### append all relevant documents pagecontent into the context array
for match in docs:
context.append(match.page_content)
### loop for the context, check if the chars of context > limit, if not insert the pagecontent into the prompt with "-" or "\n" separator
for i in range(1, len(context)):
if len("-".join(context[:i])) >= contextCharsLimit:
responsePrompt = promptStart + "-".join(context[:i-1])
elif i == len(context)-1:
responsePrompt = promptStart + "-".join(context)
## return the response rpompt
return responsePrompt
########################### END OF TASK 3B ##################################
systemMessageContent = "" # System prompt we talked before - e.g., "You are a teaching assistant of a programming course CS1117. Try to answer student's question on python only"
systemMessage = {"role": "system", "content": systemMessageContent}
userMessageContent = "" # place holder
chatbotMessageContent = "" # place holder
temperature = 0.8
top_p = 0.95
max_tokens = 800
numOfHistory = 5 # Add in the number history windows here
with gr.Blocks() as simpleChatDemo:
inputMessages = gr.State([systemMessage])
# inputMessages.append(systemMessage)
# Chatbot interface
chatbot = gr.Chatbot()
# Message is a Text Box
msg = gr.Textbox()
# Clear Button on to clear up the msg and chatbot
clear = gr.ClearButton([msg, chatbot])
def respond(userMessageInput, inputMessagesHistory, chatbot_history):
## Construct the system message content based on the prompt function -- i.e., the input messages [0] will change based on system message now
systemMessageContent = construct_system_prompt_chromadb(userMessageInput) # Change the system content to the function
systemMessage = {"role": "system", "content": systemMessageContent}
inputMessagesHistory[0] = systemMessage
userMessageContent = userMessageInput
userMessage = {"role": "user", "content": userMessageContent}
inputMessagesHistory.append(userMessage)
if len(inputMessagesHistory) > numOfHistory + 1:
numOutstandingMessages = len(inputMessagesHistory) - (numOfHistory + 1)
inputMessagesHistory = [inputMessagesHistory[0], *inputMessagesHistory[1+numOutstandingMessages :]]
print(inputMessages)
completion = openai.ChatCompletion.create(engine="chatgpt", messages=inputMessagesHistory, temperature=temperature, top_p = top_p, max_tokens = max_tokens)
chatbotMessageContent = completion.choices[0].message.content
chatbotMessage = {"role": "assistant", "content": chatbotMessageContent }
inputMessagesHistory.append(chatbotMessage)
# chat history is main list of [(user message string, bot message string)]
chatbot_history.append((userMessageContent, chatbotMessageContent))
time.sleep(2)
# return with clear up the message box, and put the new messages into the chat_history
return "", inputMessagesHistory, chatbot_history,
# when the textbox click submit, i.e., enter, the function will be called (function, [input parameters], [output response])
msg.submit(respond, [msg, inputMessages, chatbot], [msg, inputMessages, chatbot])
simpleChatDemo.launch(share=True)