Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| from langchain.chains import ConversationalRetrievalChain | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain_community.document_loaders import PyPDFLoader, Docx2txtLoader, TextLoader | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") | |
| # Validate OpenAI API Key | |
| api_key = os.getenv('OPENAI_API_KEY') | |
| if not api_key: | |
| raise ValueError("Please set the 'OPENAI_API_KEY' environment variable") | |
| # OpenAI API key | |
| openai_api_key = api_key | |
| # Transform chat history for LangChain format | |
| def transform_history_for_langchain(history): | |
| return [(chat[0], chat[1]) for chat in history if chat[0]] | |
| # Transform chat history for OpenAI format | |
| def transform_history_for_openai(history): | |
| new_history = [] | |
| for chat in history: | |
| if chat[0]: | |
| new_history.append({"role": "user", "content": chat[0]}) | |
| if chat[1]: | |
| new_history.append({"role": "assistant", "content": chat[1]}) | |
| return new_history | |
| # Load and process documents function | |
| def load_and_process_documents(folder_path): | |
| documents = [] | |
| for file in os.listdir(folder_path): | |
| file_path = os.path.join(folder_path, file) | |
| if file.endswith(".pdf"): | |
| loader = PyPDFLoader(file_path) | |
| documents.extend(loader.load()) | |
| elif file.endswith('.docx') or file.endswith('.doc'): | |
| loader = Docx2txtLoader(file_path) | |
| documents.extend(loader.load()) | |
| elif file.endswith('.txt'): | |
| loader = TextLoader(file_path) | |
| documents.extend(loader.load()) | |
| text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=10) | |
| documents = text_splitter.split_documents(documents) | |
| vectordb = Chroma.from_documents( | |
| documents, | |
| embedding=OpenAIEmbeddings(), | |
| persist_directory="./tmp" | |
| ) | |
| return vectordb | |
| # Initialize vector database as a global variable | |
| if 'vectordb' not in globals(): | |
| vectordb = load_and_process_documents("./") | |
| # Define query handling function for RAG | |
| def handle_query(user_message, temperature, chat_history): | |
| try: | |
| if not user_message: | |
| return chat_history # Return unchanged chat history | |
| # Use LangChain's ConversationalRetrievalChain to handle the query | |
| preface = """ | |
| Instruction: Answer in Traditional Chinese, within 200 characters.這是AI論壇,只回答AI相關問題 | |
| If the question is unrelated to the documents, respond with: 此事無可奉告,話說這件事須請教海虔王... | |
| """ | |
| query = f"{preface} Query content: {user_message}" | |
| # Extract previous answers as context, converting them to LangChain format | |
| previous_answers = transform_history_for_langchain(chat_history) | |
| pdf_qa = ConversationalRetrievalChain.from_llm( | |
| ChatOpenAI(temperature=temperature, model_name='gpt-4'), | |
| retriever=vectordb.as_retriever(search_kwargs={'k': 6}), | |
| return_source_documents=True, | |
| verbose=False | |
| ) | |
| # Invoke the model to handle the query | |
| result = pdf_qa.invoke({"question": query, "chat_history": previous_answers}) | |
| # Ensure 'answer' is present in the result | |
| if "answer" not in result: | |
| return chat_history + [("System", "Sorry, an error occurred.")] | |
| # Update the AI response in chat history | |
| chat_history[-1] = (user_message, result["answer"]) # Update the last record, pairing user input with AI response | |
| return chat_history | |
| except Exception as e: | |
| return chat_history + [("System", f"An error occurred: {str(e)}")] | |
| # Create a custom chat interface using Gradio Blocks API | |
| with gr.Blocks() as demo: | |
| gr.Markdown("<h1 style='text-align: center;'>AI Assistant for AI Forum</h1>") | |
| chatbot = gr.Chatbot() | |
| state = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=0.85): | |
| txt = gr.Textbox(show_label=False, placeholder="Please enter your question...") | |
| with gr.Column(scale=0.15, min_width=0): | |
| submit_btn = gr.Button("Ask") | |
| # Immediately show user input without response part, and clear input box | |
| def user_input(user_message, history): | |
| history.append((user_message, "")) # Show user message, response part as empty string | |
| return history, "", history # Return cleared input box and updated chat history | |
| # Handle AI response, update response part | |
| def bot_response(history): | |
| user_message = history[-1][0] # Get the latest user input | |
| history = handle_query(user_message, 0.7, history) # Call the query handler | |
| return history, history # Return updated chat history | |
| # First show user message, then handle AI response, clear input box | |
| submit_btn.click(user_input, [txt, state], [chatbot, txt, state], queue=False).then( | |
| bot_response, state, [chatbot, state] | |
| ) | |
| # Support pressing "Enter" to submit question, immediately show user input, clear input box | |
| txt.submit(user_input, [txt, state], [chatbot, txt, state], queue=False).then( | |
| bot_response, state, [chatbot, state] | |
| ) | |
| # Launch Gradio app | |
| demo.launch() | |