ChienChung commited on
Commit
9889ccb
·
verified ·
1 Parent(s): 78eeeca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -16
app.py CHANGED
@@ -54,6 +54,7 @@ from geopy.geocoders import Nominatim
54
  from timezonefinder import TimezoneFinder
55
  from langchain_pinecone import Pinecone as PineconeLangchain
56
  import pinecone
 
57
 
58
  pinecone.init(
59
  api_key=os.environ["PINECONE_API_KEY"],
@@ -235,23 +236,31 @@ session_file_hash = None
235
  session_retriever = None
236
  session_qa_chain = None
237
 
238
- def upload_and_chat(file, query):
239
- global session_file_hash, session_retriever, session_qa_chain
 
 
 
 
 
 
 
 
 
 
 
 
240
 
241
  file_path = get_file_path(file)
242
  if file_path is None:
243
  return "Unable to obtain the uploaded file path."
244
 
245
- # 計算目前上傳檔案的 hash 值
246
  import hashlib
247
  with open(file_path, "rb") as f:
248
  file_hash = hashlib.md5(f.read()).hexdigest()
249
 
250
- # 如果是新文件 重建 retriever 和 chain
251
- if file_hash != session_file_hash:
252
- session_file_hash = file_hash
253
-
254
- # Load and chunk the new document
255
  if file_path.lower().endswith(".pdf"):
256
  loader = PyPDFLoader(file_path)
257
  elif file_path.lower().endswith(".docx"):
@@ -261,18 +270,19 @@ def upload_and_chat(file, query):
261
  docs = loader.load()
262
  chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(docs)
263
  pine_db = PineconeLangchain.from_documents(chunks, embeddings, index_name="rag-docs")
264
- session_retriever = pine_db.as_retriever()
265
 
266
- session_qa_chain = RetrievalQA.from_chain_type(
267
  llm=llm_gpt4,
268
- chain_type="stuff",
269
- retriever=session_retriever,
270
- return_source_documents=False,
271
- chain_type_kwargs={"prompt": custom_prompt}
272
  )
273
 
274
- # 用現有的 chain 執行 query
275
- return session_qa_chain.run(query)
 
 
276
 
277
  # tab 4 & 5 summary
278
  initial_prompt = PromptTemplate(
 
54
  from timezonefinder import TimezoneFinder
55
  from langchain_pinecone import Pinecone as PineconeLangchain
56
  import pinecone
57
+ from uuid import uuid4
58
 
59
  pinecone.init(
60
  api_key=os.environ["PINECONE_API_KEY"],
 
236
  session_retriever = None
237
  session_qa_chain = None
238
 
239
+ user_sessions = {} # 用 dict 儲存每個 user_id 對應的 chain、retriever、hash
240
+
241
+ def upload_and_chat(file, query, user_id=None):
242
+ if user_id is None:
243
+ user_id = str(uuid4()) # fallback
244
+
245
+ if user_id not in user_sessions:
246
+ user_sessions[user_id] = {
247
+ "file_hash": None,
248
+ "retriever": None,
249
+ "qa_chain": None
250
+ }
251
+
252
+ session = user_sessions[user_id]
253
 
254
  file_path = get_file_path(file)
255
  if file_path is None:
256
  return "Unable to obtain the uploaded file path."
257
 
 
258
  import hashlib
259
  with open(file_path, "rb") as f:
260
  file_hash = hashlib.md5(f.read()).hexdigest()
261
 
262
+ if file_hash != session["file_hash"]:
263
+ session["file_hash"] = file_hash
 
 
 
264
  if file_path.lower().endswith(".pdf"):
265
  loader = PyPDFLoader(file_path)
266
  elif file_path.lower().endswith(".docx"):
 
270
  docs = loader.load()
271
  chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(docs)
272
  pine_db = PineconeLangchain.from_documents(chunks, embeddings, index_name="rag-docs")
273
+ retriever = pine_db.as_retriever()
274
 
275
+ qa_chain = ConversationalRetrievalChain.from_llm(
276
  llm=llm_gpt4,
277
+ retriever=retriever,
278
+ memory=ConversationBufferMemory(memory_key="chat_history", return_messages=True),
279
+ combine_docs_chain_kwargs={"prompt": custom_prompt}
 
280
  )
281
 
282
+ session["retriever"] = retriever
283
+ session["qa_chain"] = qa_chain
284
+
285
+ return session["qa_chain"].run(query)
286
 
287
  # tab 4 & 5 summary
288
  initial_prompt = PromptTemplate(