ziphai commited on
Commit
facc243
·
verified ·
1 Parent(s): a53a4f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -94
app.py CHANGED
@@ -4,40 +4,29 @@ import gradio as gr
4
  from langchain.chains import ConversationalRetrievalChain
5
  from langchain.text_splitter import CharacterTextSplitter
6
  from langchain_community.document_loaders import PyMuPDFLoader, PyPDFLoader
7
- from langchain_community.vectorstores import Chroma
8
  from langchain_community.embeddings import OpenAIEmbeddings
9
  from langchain_community.chat_models import ChatOpenAI
10
- from dotenv import load_dotenv
11
-
12
- # Load environment variables
13
- load_dotenv()
14
- os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
15
-
16
- # Validate OpenAI API Key
17
- api_key = os.getenv('OPENAI_API_KEY')
18
- if not api_key:
19
- raise ValueError("Please set the 'OPENAI_API_KEY' environment variable")
20
-
21
- # OpenAI API key
22
- openai_api_key = api_key
23
-
24
- # Transform chat history for LangChain format
25
- def transform_history_for_langchain(history):
26
- return [(chat[0], chat[1]) for chat in history if chat[0]]
27
-
28
- # Transform chat history for OpenAI format
29
- def transform_history_for_openai(history):
30
- new_history = []
31
- for chat in history:
32
- if chat[0]:
33
- new_history.append({"role": "user", "content": chat[0]})
34
- if chat[1]:
35
- new_history.append({"role": "assistant", "content": chat[1]})
36
- return new_history
37
-
38
- # Load and process documents function
39
- def load_and_process_documents(file_paths, loader_type='PyMuPDFLoader'):
40
  documents = []
 
41
  for file_path in file_paths:
42
  if not os.path.exists(file_path):
43
  continue
@@ -55,99 +44,165 @@ def load_and_process_documents(file_paths, loader_type='PyMuPDFLoader'):
55
  continue
56
 
57
  if not documents:
58
- raise ValueError("No documents found or could not load any documents.")
59
 
60
- # Split long texts
61
  text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=50)
62
  documents = text_splitter.split_documents(documents)
63
 
64
  if not documents:
65
- raise ValueError("Document list is empty after splitting. Please check the content of the files.")
66
 
67
- # Initialize vector database
68
- vectordb = Chroma.from_documents(
69
- documents,
70
- embedding=OpenAIEmbeddings(),
71
- persist_directory="./tmp"
72
- )
73
- return vectordb
74
 
75
- # Initialize vector database as a global variable
76
- if 'vectordb' not in globals():
77
- vectordb = load_and_process_documents(["./sample_docs/sample.pdf"])
 
 
 
 
 
 
 
78
 
79
- # Define query handling function for RAG
80
- def handle_query(user_message, temperature, chat_history):
81
  try:
82
  if not user_message:
83
- return chat_history # Return unchanged chat history
84
 
85
- # Use LangChain's ConversationalRetrievalChain to handle the query
86
  preface = """
87
  Instruction: Answer in Traditional Chinese, within 200 characters.
88
  If the question is unrelated to the documents, respond with: 此事無可奉告,話說這件事須請教海虔王...
89
  """
90
- query = f"{preface} Query content: {user_message}"
91
-
92
- # Extract previous answers as context, converting them to LangChain format
93
- previous_answers = transform_history_for_langchain(chat_history)
94
 
 
95
  pdf_qa = ConversationalRetrievalChain.from_llm(
96
- ChatOpenAI(temperature=temperature, model_name='gpt-4'),
97
  retriever=vectordb.as_retriever(search_kwargs={'k': 6}),
98
- return_source_documents=True,
99
- verbose=False
100
  )
101
 
102
- # Invoke the model to handle the query
103
- result = pdf_qa.invoke({"question": query, "chat_history": previous_answers})
104
-
105
- # Ensure 'answer' is present in the result
106
- if "answer" not in result:
107
- return chat_history + [("System", "Sorry, an error occurred.")]
108
-
109
- # Update the AI response in chat history
110
- chat_history[-1] = (user_message, result["answer"]) # Update the last record, pairing user input with AI response
111
 
 
 
 
 
112
  return chat_history
113
 
114
  except Exception as e:
115
- return chat_history + [("System", f"An error occurred: {str(e)}")]
116
-
117
- # Create a custom chat interface using Gradio Blocks API
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  with gr.Blocks(css="body { background-color: #EBD6D6; }") as demo:
119
- gr.Markdown("<h1 style='text-align: center;'>AI Assistant for AI Forum</h1>")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
 
 
121
  chatbot = gr.Chatbot()
122
- state = gr.State([])
123
-
124
- with gr.Row():
125
- with gr.Column(scale=0.85):
126
- txt = gr.Textbox(show_label=False, placeholder="Please enter your question...")
127
- with gr.Column(scale=0.15, min_width=0):
128
- submit_btn = gr.Button("Ask")
129
-
130
- # Immediately show user input without response part, and clear input box
131
- def user_input(user_message, history):
132
- history.append((user_message, "")) # Show user message, response part as empty string
133
- return history, "", history # Return cleared input box and updated chat history
134
-
135
- # Handle AI response, update response part
136
- def bot_response(history):
137
- user_message = history[-1][0] # Get the latest user input
138
- history = handle_query(user_message, 0.7, history) # Call the query handler
139
- return history, history # Return updated chat history
140
-
141
- # First show user message, then handle AI response, clear input box
142
- submit_btn.click(user_input, [txt, state], [chatbot, txt, state], queue=False).then(
143
- bot_response, state, [chatbot, state]
144
  )
145
 
146
- # Support pressing "Enter" to submit question, immediately show user input, clear input box
147
- txt.submit(user_input, [txt, state], [chatbot, txt, state], queue=False).then(
148
- bot_response, state, [chatbot, state]
 
149
  )
150
 
151
- # Launch Gradio app
152
  demo.launch()
153
 
 
4
  from langchain.chains import ConversationalRetrievalChain
5
  from langchain.text_splitter import CharacterTextSplitter
6
  from langchain_community.document_loaders import PyMuPDFLoader, PyPDFLoader
7
+ from langchain.vectorstores import Chroma
8
  from langchain_community.embeddings import OpenAIEmbeddings
9
  from langchain_community.chat_models import ChatOpenAI
10
+ import shutil # 用於文件複製
11
+
12
+ # 獲取 OpenAI API 密鑰(初始不使用固定密鑰)
13
+ api_key_env = os.getenv("OPENAI_API_KEY")
14
+ if api_key_env:
15
+ openai.api_key = api_key_env
16
+ else:
17
+ print("未設置固定的 OpenAI API 密鑰。將使用使用者提供的密鑰。")
18
+
19
+ # 確保向量資料庫目錄存在且有寫入權限
20
+ VECTORDB_DIR = os.path.abspath("./data")
21
+ os.makedirs(VECTORDB_DIR, exist_ok=True)
22
+ os.chmod(VECTORDB_DIR, 0o755)
23
+
24
+ # 定義載入和處理 PDF 文件的函數
25
+ def load_and_process_documents(file_paths, loader_type='PyMuPDFLoader', api_key=None):
26
+ if not api_key:
27
+ raise ValueError("未提供 OpenAI API 密鑰。")
 
 
 
 
 
 
 
 
 
 
 
 
28
  documents = []
29
+
30
  for file_path in file_paths:
31
  if not os.path.exists(file_path):
32
  continue
 
44
  continue
45
 
46
  if not documents:
47
+ raise ValueError("沒有找到任何 PDF 文件或 PDF 文件無法載入。")
48
 
49
+ # 分割長文本
50
  text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=50)
51
  documents = text_splitter.split_documents(documents)
52
 
53
  if not documents:
54
+ raise ValueError("分割後的文檔列表為空。請檢查 PDF 文件內容。")
55
 
56
+ # 初始化向量資料庫
57
+ try:
58
+ embeddings = OpenAIEmbeddings(openai_api_key=api_key) # 使用使用者的 API 密鑰
59
+ except Exception as e:
60
+ raise ValueError(f"初始化 OpenAIEmbeddings 時出現錯誤: {e}")
 
 
61
 
62
+ try:
63
+ vectordb = Chroma.from_documents(
64
+ documents,
65
+ embedding=embeddings,
66
+ persist_directory=VECTORDB_DIR
67
+ )
68
+ except Exception as e:
69
+ raise ValueError(f"初始化 Chroma 向量資料庫時出現錯誤: {e}")
70
+
71
+ return vectordb
72
 
73
+ # 定義聊天處理函數
74
+ def handle_query(user_message, chat_history, vectordb, api_key):
75
  try:
76
  if not user_message:
77
+ return chat_history
78
 
79
+ # 添加角色指令前綴
80
  preface = """
81
  Instruction: Answer in Traditional Chinese, within 200 characters.
82
  If the question is unrelated to the documents, respond with: 此事無可奉告,話說這件事須請教海虔王...
83
  """
84
+ query = f"{preface} 查詢內容:{user_message}"
 
 
 
85
 
86
+ # 初始化 ConversationalRetrievalChain,並傳遞 openai_api_key
87
  pdf_qa = ConversationalRetrievalChain.from_llm(
88
+ ChatOpenAI(temperature=0.7, model="gpt-4", openai_api_key=api_key),
89
  retriever=vectordb.as_retriever(search_kwargs={'k': 6}),
90
+ return_source_documents=True
 
91
  )
92
 
93
+ # 呼叫模型並處理查詢
94
+ result = pdf_qa.invoke({"question": query, "chat_history": chat_history})
 
 
 
 
 
 
 
95
 
96
+ if "answer" in result:
97
+ chat_history = chat_history + [(user_message, result["answer"])]
98
+ else:
99
+ chat_history = chat_history + [(user_message, "抱歉,未能獲得有效回應。")]
100
  return chat_history
101
 
102
  except Exception as e:
103
+ return chat_history + [("系統", f"出現錯誤: {str(e)}")]
104
+
105
+ # 定義保存 API 密鑰的函數
106
+ def save_api_key(api_key, state):
107
+ if not api_key.startswith("sk-"):
108
+ return "請輸入有效的 OpenAI API 密鑰。", state
109
+ state['api_key'] = api_key
110
+ return "API 密鑰已成功保存。您現在可以上傳 PDF 文件並開始提問。", state
111
+
112
+ # 定義 Gradio 的處理函數
113
+ def process_files(files, state):
114
+ if files:
115
+ try:
116
+ api_key = state.get('api_key', None)
117
+ if not api_key:
118
+ return "請先輸入並保存您的 OpenAI API 密鑰。", state
119
+
120
+ saved_file_paths = []
121
+ for idx, file_data in enumerate(files):
122
+ filename = f"uploaded_{idx}.pdf"
123
+ save_path = os.path.join(VECTORDB_DIR, filename)
124
+ with open(save_path, "wb") as f:
125
+ f.write(file_data)
126
+ saved_file_paths.append(save_path)
127
+ vectordb = load_and_process_documents(saved_file_paths, loader_type='PyMuPDFLoader', api_key=api_key)
128
+ state['vectordb'] = vectordb
129
+ return "PDF 文件已成功上傳並處理。您現在可以開始提問。", state
130
+ except Exception as e:
131
+ return f"處理文件時出現錯誤: {e}", state
132
+ else:
133
+ return "請上傳至少一個 PDF 文件。", state
134
+
135
+ def chat_interface(user_message, chat_history, state):
136
+ vectordb = state.get('vectordb', None)
137
+ api_key = state.get('api_key', None)
138
+ if not vectordb:
139
+ return chat_history, state, "請先上傳 PDF 文件以進行處理。"
140
+ if not api_key:
141
+ return chat_history, state, "請先輸入並保存您的 OpenAI API 密鑰。"
142
+
143
+ updated_history = handle_query(user_message, chat_history, vectordb, api_key)
144
+ return updated_history, state, ""
145
+
146
+ # 設計 Gradio 介面
147
  with gr.Blocks(css="body { background-color: #EBD6D6; }") as demo:
148
+ gr.Markdown("<h1 style='text-align: center;'>AI論壇助理</h1>")
149
+
150
+ state = gr.State({"vectordb": None, "api_key": None})
151
+
152
+ # API 密鑰輸入框
153
+ api_key_input = gr.Textbox(
154
+ label="輸入您的 OpenAI API 密鑰",
155
+ placeholder="sk-...",
156
+ type="password",
157
+ interactive=True
158
+ )
159
+ save_api_key_btn = gr.Button("保存 API 密鑰")
160
+ api_key_status = gr.Textbox(label="狀態", interactive=False)
161
+
162
+ # 上傳 PDF 文件
163
+ gr.Markdown("<span style='font-size: 1.5em; font-weight: bold;'>請上傳AI論壇相關文檔,提供AI相關問題解答</span>")
164
+ upload = gr.File(
165
+ file_count="multiple",
166
+ file_types=[".pdf"],
167
+ label="上傳AI論壇 PDF 文件",
168
+ interactive=True,
169
+ type="binary"
170
+ )
171
+ upload_btn = gr.Button("上傳並處理")
172
+ upload_status = gr.Textbox(label="上傳狀態", interactive=False)
173
 
174
+ # 智能諮詢
175
+ gr.Markdown("### AI論壇助理")
176
  chatbot = gr.Chatbot()
177
+
178
+ txt = gr.Textbox(show_label=False, placeholder="請輸入您的AI問題...")
179
+ submit_btn = gr.Button("提問")
180
+
181
+ # 綁定事件
182
+ save_api_key_btn.click(
183
+ save_api_key,
184
+ inputs=[api_key_input, state],
185
+ outputs=[api_key_status, state]
186
+ )
187
+
188
+ upload_btn.click(
189
+ process_files,
190
+ inputs=[upload, state],
191
+ outputs=[upload_status, state]
192
+ )
193
+
194
+ submit_btn.click(
195
+ chat_interface,
196
+ inputs=[txt, chatbot, state],
197
+ outputs=[chatbot, state, txt]
 
198
  )
199
 
200
+ txt.submit(
201
+ chat_interface,
202
+ inputs=[txt, chatbot, state],
203
+ outputs=[chatbot, state, txt]
204
  )
205
 
206
+ # 啟動 Gradio 應用
207
  demo.launch()
208