Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,23 +7,23 @@ from langchain_community.vectorstores import Chroma
|
|
| 7 |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
|
| 10 |
-
#
|
| 11 |
load_dotenv()
|
| 12 |
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
| 13 |
|
| 14 |
-
#
|
| 15 |
api_key = os.getenv('OPENAI_API_KEY')
|
| 16 |
if not api_key:
|
| 17 |
-
raise ValueError("
|
| 18 |
|
| 19 |
# OpenAI API key
|
| 20 |
openai_api_key = api_key
|
| 21 |
|
| 22 |
-
#
|
| 23 |
def transform_history_for_langchain(history):
|
| 24 |
-
return [(chat[0], chat[1]) for chat in history if chat[0]]
|
| 25 |
|
| 26 |
-
#
|
| 27 |
def transform_history_for_openai(history):
|
| 28 |
new_history = []
|
| 29 |
for chat in history:
|
|
@@ -33,7 +33,7 @@ def transform_history_for_openai(history):
|
|
| 33 |
new_history.append({"role": "assistant", "content": chat[1]})
|
| 34 |
return new_history
|
| 35 |
|
| 36 |
-
#
|
| 37 |
def load_and_process_documents(folder_path):
|
| 38 |
documents = []
|
| 39 |
for file in os.listdir(folder_path):
|
|
@@ -58,24 +58,24 @@ def load_and_process_documents(folder_path):
|
|
| 58 |
)
|
| 59 |
return vectordb
|
| 60 |
|
| 61 |
-
#
|
| 62 |
if 'vectordb' not in globals():
|
| 63 |
vectordb = load_and_process_documents("./")
|
| 64 |
|
| 65 |
-
#
|
| 66 |
def handle_query(user_message, temperature, chat_history):
|
| 67 |
try:
|
| 68 |
if not user_message:
|
| 69 |
-
return chat_history #
|
| 70 |
|
| 71 |
-
#
|
| 72 |
preface = """
|
| 73 |
-
|
| 74 |
-
|
| 75 |
"""
|
| 76 |
-
query = f"{preface}
|
| 77 |
|
| 78 |
-
#
|
| 79 |
previous_answers = transform_history_for_langchain(chat_history)
|
| 80 |
|
| 81 |
pdf_qa = ConversationalRetrievalChain.from_llm(
|
|
@@ -85,54 +85,54 @@ def handle_query(user_message, temperature, chat_history):
|
|
| 85 |
verbose=False
|
| 86 |
)
|
| 87 |
|
| 88 |
-
#
|
| 89 |
result = pdf_qa.invoke({"question": query, "chat_history": previous_answers})
|
| 90 |
|
| 91 |
-
#
|
| 92 |
if "answer" not in result:
|
| 93 |
-
return chat_history + [("
|
| 94 |
|
| 95 |
-
#
|
| 96 |
-
chat_history[-1] = (user_message, result["answer"]) #
|
| 97 |
|
| 98 |
return chat_history
|
| 99 |
|
| 100 |
except Exception as e:
|
| 101 |
-
return chat_history + [("
|
| 102 |
|
| 103 |
-
#
|
| 104 |
with gr.Blocks() as demo:
|
| 105 |
-
gr.Markdown("<h1 style='text-align: center;'>AI
|
| 106 |
|
| 107 |
chatbot = gr.Chatbot()
|
| 108 |
state = gr.State([])
|
| 109 |
|
| 110 |
with gr.Row():
|
| 111 |
with gr.Column(scale=0.85):
|
| 112 |
-
txt = gr.Textbox(show_label=False, placeholder="
|
| 113 |
with gr.Column(scale=0.15, min_width=0):
|
| 114 |
-
submit_btn = gr.Button("
|
| 115 |
|
| 116 |
-
#
|
| 117 |
def user_input(user_message, history):
|
| 118 |
-
history.append((user_message, "")) #
|
| 119 |
-
return history, "", history #
|
| 120 |
|
| 121 |
-
#
|
| 122 |
def bot_response(history):
|
| 123 |
-
user_message = history[-1][0] #
|
| 124 |
-
history = handle_query(user_message, 0.7, history) #
|
| 125 |
-
return history, history #
|
| 126 |
|
| 127 |
-
#
|
| 128 |
submit_btn.click(user_input, [txt, state], [chatbot, txt, state], queue=False).then(
|
| 129 |
bot_response, state, [chatbot, state]
|
| 130 |
)
|
| 131 |
|
| 132 |
-
#
|
| 133 |
txt.submit(user_input, [txt, state], [chatbot, txt, state], queue=False).then(
|
| 134 |
bot_response, state, [chatbot, state]
|
| 135 |
)
|
| 136 |
|
| 137 |
-
#
|
| 138 |
-
demo.launch()
|
|
|
|
| 7 |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
| 8 |
from dotenv import load_dotenv
|
| 9 |
|
| 10 |
+
# Load environment variables
|
| 11 |
load_dotenv()
|
| 12 |
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
|
| 13 |
|
| 14 |
+
# Validate OpenAI API Key
|
| 15 |
api_key = os.getenv('OPENAI_API_KEY')
|
| 16 |
if not api_key:
|
| 17 |
+
raise ValueError("Please set the 'OPENAI_API_KEY' environment variable")
|
| 18 |
|
| 19 |
# OpenAI API key
|
| 20 |
openai_api_key = api_key
|
| 21 |
|
| 22 |
+
# Transform chat history for LangChain format
|
| 23 |
def transform_history_for_langchain(history):
|
| 24 |
+
return [(chat[0], chat[1]) for chat in history if chat[0]]
|
| 25 |
|
| 26 |
+
# Transform chat history for OpenAI format
|
| 27 |
def transform_history_for_openai(history):
|
| 28 |
new_history = []
|
| 29 |
for chat in history:
|
|
|
|
| 33 |
new_history.append({"role": "assistant", "content": chat[1]})
|
| 34 |
return new_history
|
| 35 |
|
| 36 |
+
# Load and process documents function
|
| 37 |
def load_and_process_documents(folder_path):
|
| 38 |
documents = []
|
| 39 |
for file in os.listdir(folder_path):
|
|
|
|
| 58 |
)
|
| 59 |
return vectordb
|
| 60 |
|
| 61 |
+
# Initialize vector database as a global variable
|
| 62 |
if 'vectordb' not in globals():
|
| 63 |
vectordb = load_and_process_documents("./")
|
| 64 |
|
| 65 |
+
# Define query handling function for RAG
|
| 66 |
def handle_query(user_message, temperature, chat_history):
|
| 67 |
try:
|
| 68 |
if not user_message:
|
| 69 |
+
return chat_history # Return unchanged chat history
|
| 70 |
|
| 71 |
+
# Use LangChain's ConversationalRetrievalChain to handle the query
|
| 72 |
preface = """
|
| 73 |
+
Instruction: Answer in Traditional Chinese, within 200 characters.
|
| 74 |
+
If the question is unrelated to the documents, respond with: 此事無可奉告,話說這件事須請教海虔王...
|
| 75 |
"""
|
| 76 |
+
query = f"{preface} Query content: {user_message}"
|
| 77 |
|
| 78 |
+
# Extract previous answers as context, converting them to LangChain format
|
| 79 |
previous_answers = transform_history_for_langchain(chat_history)
|
| 80 |
|
| 81 |
pdf_qa = ConversationalRetrievalChain.from_llm(
|
|
|
|
| 85 |
verbose=False
|
| 86 |
)
|
| 87 |
|
| 88 |
+
# Invoke the model to handle the query
|
| 89 |
result = pdf_qa.invoke({"question": query, "chat_history": previous_answers})
|
| 90 |
|
| 91 |
+
# Ensure 'answer' is present in the result
|
| 92 |
if "answer" not in result:
|
| 93 |
+
return chat_history + [("System", "Sorry, an error occurred.")]
|
| 94 |
|
| 95 |
+
# Update the AI response in chat history
|
| 96 |
+
chat_history[-1] = (user_message, result["answer"]) # Update the last record, pairing user input with AI response
|
| 97 |
|
| 98 |
return chat_history
|
| 99 |
|
| 100 |
except Exception as e:
|
| 101 |
+
return chat_history + [("System", f"An error occurred: {str(e)}")]
|
| 102 |
|
| 103 |
+
# Create a custom chat interface using Gradio Blocks API
|
| 104 |
with gr.Blocks() as demo:
|
| 105 |
+
gr.Markdown("<h1 style='text-align: center;'>AI Assistant for AI Forum</h1>")
|
| 106 |
|
| 107 |
chatbot = gr.Chatbot()
|
| 108 |
state = gr.State([])
|
| 109 |
|
| 110 |
with gr.Row():
|
| 111 |
with gr.Column(scale=0.85):
|
| 112 |
+
txt = gr.Textbox(show_label=False, placeholder="Please enter your question...")
|
| 113 |
with gr.Column(scale=0.15, min_width=0):
|
| 114 |
+
submit_btn = gr.Button("Ask")
|
| 115 |
|
| 116 |
+
# Immediately show user input without response part, and clear input box
|
| 117 |
def user_input(user_message, history):
|
| 118 |
+
history.append((user_message, "")) # Show user message, response part as empty string
|
| 119 |
+
return history, "", history # Return cleared input box and updated chat history
|
| 120 |
|
| 121 |
+
# Handle AI response, update response part
|
| 122 |
def bot_response(history):
|
| 123 |
+
user_message = history[-1][0] # Get the latest user input
|
| 124 |
+
history = handle_query(user_message, 0.7, history) # Call the query handler
|
| 125 |
+
return history, history # Return updated chat history
|
| 126 |
|
| 127 |
+
# First show user message, then handle AI response, clear input box
|
| 128 |
submit_btn.click(user_input, [txt, state], [chatbot, txt, state], queue=False).then(
|
| 129 |
bot_response, state, [chatbot, state]
|
| 130 |
)
|
| 131 |
|
| 132 |
+
# Support pressing "Enter" to submit question, immediately show user input, clear input box
|
| 133 |
txt.submit(user_input, [txt, state], [chatbot, txt, state], queue=False).then(
|
| 134 |
bot_response, state, [chatbot, state]
|
| 135 |
)
|
| 136 |
|
| 137 |
+
# Launch Gradio app
|
| 138 |
+
demo.launch()
|