ChienChung commited on
Commit
4fd8be5
·
verified ·
1 Parent(s): 5aaf550

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -102
app.py CHANGED
@@ -1452,122 +1452,76 @@ def build_langgraph_pipeline():
1452
  from tempfile import mkdtemp
1453
 
1454
 
1455
-
1456
- from pathlib import Path
1457
- import logging
1458
-
1459
- logger = logging.getLogger(__name__)
1460
-
1461
  def get_file_path_tab6(file):
1462
- """改進的文件路徑處理函數"""
1463
  try:
1464
- logger.debug(f"Processing file input: {file}")
1465
-
1466
- # 如果輸入為 None
1467
  if file is None:
1468
- logger.warning("File input is None")
1469
  return None
1470
 
1471
- # 處理字串輸入
 
 
 
 
 
 
 
 
 
 
1472
  if isinstance(file, str):
1473
- # 移除路徑中的特殊字符
1474
- safe_path = file.replace('\n', '').replace('\r', '').strip()
1475
-
1476
- # 檢查多個可能的路徑
1477
- potential_paths = [
1478
- safe_path,
1479
- os.path.join("/tmp/gradio/", safe_path),
1480
- os.path.join(os.getcwd(), safe_path)
1481
  ]
1482
 
1483
- for path in potential_paths:
1484
  if os.path.exists(path):
1485
- logger.info(f"Found file at: {path}")
1486
  return path
1487
 
1488
- logger.warning(f"Could not find file in any of these locations: {potential_paths}")
1489
- return None
1490
-
1491
- # 處理字典輸入(Gradio 上傳文件的常見格式)
1492
- elif isinstance(file, dict):
1493
- logger.debug("Processing dictionary input")
1494
-
1495
- # 獲取文件數據和名稱
1496
- data = file.get("data")
1497
- name = file.get("name", "uploaded_file")
1498
-
1499
- if data:
1500
- # 創建臨時目錄
1501
- temp_dir = mkdtemp()
1502
- file_path = os.path.join(temp_dir, name)
1503
-
1504
- # 寫入文件
1505
- try:
1506
- with open(file_path, "wb") as f:
1507
- if isinstance(data, str):
1508
- f.write(data.encode("utf-8"))
1509
- else:
1510
- f.write(data)
1511
-
1512
- if os.path.exists(file_path):
1513
- logger.info(f"Successfully created file at: {file_path}")
1514
- return file_path
1515
- except Exception as e:
1516
- logger.error(f"Error writing file: {e}")
1517
- return None
1518
-
1519
- logger.warning("No data found in file dictionary")
1520
- return None
1521
-
1522
- # 處理具有 save 方法的對象(如 UploadedFile)
1523
- elif hasattr(file, "save"):
1524
- logger.debug("Processing file object with save method")
1525
- try:
1526
- temp_dir = mkdtemp()
1527
- file_name = getattr(file, "name", "uploaded_file")
1528
- file_path = os.path.join(temp_dir, file_name)
1529
- file.save(file_path)
1530
-
1531
- if os.path.exists(file_path):
1532
- logger.info(f"Successfully saved file to: {file_path}")
1533
- return file_path
1534
- except Exception as e:
1535
- logger.error(f"Error saving file: {e}")
1536
- return None
1537
-
1538
- # 處理其他情況
1539
- else:
1540
- logger.warning(f"Unsupported file type: {type(file)}")
1541
- return None
1542
 
 
1543
  except Exception as e:
1544
- logger.error(f"Error in get_file_path_tab6: {e}")
1545
  return None
1546
 
1547
-
1548
-
1549
  def langgraph_tab6_main(query: str, file=None):
1550
  try:
 
1551
  files = file if isinstance(file, list) else [file] if file else []
1552
  all_docs = []
1553
  file_names = []
1554
  docs_by_file = []
1555
 
 
1556
  for f in files:
1557
- path = get_file_path_tab6(f)
1558
- if not path:
1559
- logger.warning(f"Could not process file: {f}")
1560
- continue
1561
-
1562
- logger.info(f"Processing file: {path}")
1563
  try:
 
 
 
 
 
 
 
 
 
1564
  if path.lower().endswith(".pdf"):
 
1565
  loader = PyPDFLoader(path)
1566
  elif path.lower().endswith(".docx"):
 
1567
  loader = UnstructuredWordDocumentLoader(path)
1568
  else:
 
1569
  loader = TextLoader(path)
1570
-
 
1571
  docs = loader.load()
1572
  if docs:
1573
  file_names.append(os.path.basename(path))
@@ -1577,26 +1531,28 @@ def langgraph_tab6_main(query: str, file=None):
1577
  text = "\n".join(docs)
1578
  docs_by_file.append(text)
1579
  all_docs.extend(docs)
1580
-
1581
  except Exception as e:
1582
- logger.error(f"Error loading file {path}: {e}")
1583
  continue
1584
 
1585
- # 如果沒有成功處理任何文件
1586
  if not all_docs:
1587
  return "No valid documents could be processed. Please check your file and try again."
1588
- else:
1589
- chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(all_docs)
1590
- db = FAISS.from_documents(chunks, embeddings)
1591
- retriever = db.as_retriever()
1592
- global session_retriever
1593
- session_retriever = retriever
1594
- global session_qa_chain
1595
- session_qa_chain = ConversationalRetrievalChain.from_llm(
1596
- llm=llm_gpt4,
1597
- retriever=retriever,
1598
- memory=ConversationBufferMemory(memory_key="chat_history", return_messages=True),
1599
- )
 
 
1600
 
1601
  parsed = parse_query(query)
1602
  if (parsed.get("summarise") or parsed.get("compare")) and len(docs_by_file) > 0:
@@ -1607,14 +1563,16 @@ def langgraph_tab6_main(query: str, file=None):
1607
  state = {"query": query, "file_names": file_names}
1608
  if retriever is not None:
1609
  state["retriever"] = retriever
 
1610
  result = graph.invoke(state)
1611
  if "answer" in result:
1612
  return result["answer"]
1613
  if "summary" in result:
1614
  return result["summary"]
1615
  return "No answer."
 
1616
  except Exception as e:
1617
- logger.error(f"Error in main function: {e}")
1618
  return f"[Tab6 Error] {str(e)}"
1619
 
1620
  # Gradio Interface Settings
 
1452
  from tempfile import mkdtemp
1453
 
1454
 
 
 
 
 
 
 
1455
  def get_file_path_tab6(file):
1456
+ """改進的文件路徑處理函數,專門處理 Gradio 上傳的文件"""
1457
  try:
1458
+ # 如果是 None
 
 
1459
  if file is None:
 
1460
  return None
1461
 
1462
+ # 處理 Gradio 文件對象
1463
+ if hasattr(file, 'name'):
1464
+ return file.name
1465
+
1466
+ # 如果是字典(Gradio 文件上傳的另一種格式)
1467
+ if isinstance(file, dict):
1468
+ if 'name' in file:
1469
+ return file['name']
1470
+ return None
1471
+
1472
+ # 如果是字符串路徑
1473
  if isinstance(file, str):
1474
+ # 檢查常見的上傳路徑
1475
+ possible_paths = [
1476
+ file,
1477
+ os.path.join('/tmp/gradio/', file),
1478
+ os.path.join(os.getcwd(), file),
1479
+ os.path.abspath(file)
 
 
1480
  ]
1481
 
1482
+ for path in possible_paths:
1483
  if os.path.exists(path):
 
1484
  return path
1485
 
1486
+ # 如果找不到文件,返回原始路徑
1487
+ return file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1488
 
1489
+ return None
1490
  except Exception as e:
1491
+ print(f"Error in get_file_path_tab6: {e}")
1492
  return None
1493
 
 
 
1494
  def langgraph_tab6_main(query: str, file=None):
1495
  try:
1496
+ # 初始化文件處理
1497
  files = file if isinstance(file, list) else [file] if file else []
1498
  all_docs = []
1499
  file_names = []
1500
  docs_by_file = []
1501
 
1502
+ # 處理每個文件
1503
  for f in files:
 
 
 
 
 
 
1504
  try:
1505
+ # 獲取文件路徑
1506
+ path = get_file_path_tab6(f)
1507
+ if not path:
1508
+ print(f"Could not get valid path for file: {f}")
1509
+ continue
1510
+
1511
+ print(f"Attempting to process file: {path}")
1512
+
1513
+ # 根據文件類型選擇加載器
1514
  if path.lower().endswith(".pdf"):
1515
+ from langchain.document_loaders import PyPDFLoader
1516
  loader = PyPDFLoader(path)
1517
  elif path.lower().endswith(".docx"):
1518
+ from langchain.document_loaders import UnstructuredWordDocumentLoader
1519
  loader = UnstructuredWordDocumentLoader(path)
1520
  else:
1521
+ from langchain.document_loaders import TextLoader
1522
  loader = TextLoader(path)
1523
+
1524
+ # 加載文件
1525
  docs = loader.load()
1526
  if docs:
1527
  file_names.append(os.path.basename(path))
 
1531
  text = "\n".join(docs)
1532
  docs_by_file.append(text)
1533
  all_docs.extend(docs)
1534
+ print(f"Successfully processed file: {path}")
1535
  except Exception as e:
1536
+ print(f"Error processing file {f}: {e}")
1537
  continue
1538
 
1539
+ # 檢查是否有成功處理文件
1540
  if not all_docs:
1541
  return "No valid documents could be processed. Please check your file and try again."
1542
+
1543
+ # 其餘代碼保持不變...
1544
+ chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(all_docs)
1545
+ db = FAISS.from_documents(chunks, embeddings)
1546
+ retriever = db.as_retriever()
1547
+
1548
+ global session_retriever
1549
+ session_retriever = retriever
1550
+ global session_qa_chain
1551
+ session_qa_chain = ConversationalRetrievalChain.from_llm(
1552
+ llm=llm_gpt4,
1553
+ retriever=retriever,
1554
+ memory=ConversationBufferMemory(memory_key="chat_history", return_messages=True),
1555
+ )
1556
 
1557
  parsed = parse_query(query)
1558
  if (parsed.get("summarise") or parsed.get("compare")) and len(docs_by_file) > 0:
 
1563
  state = {"query": query, "file_names": file_names}
1564
  if retriever is not None:
1565
  state["retriever"] = retriever
1566
+
1567
  result = graph.invoke(state)
1568
  if "answer" in result:
1569
  return result["answer"]
1570
  if "summary" in result:
1571
  return result["summary"]
1572
  return "No answer."
1573
+
1574
  except Exception as e:
1575
+ print(f"Error in main function: {e}")
1576
  return f"[Tab6 Error] {str(e)}"
1577
 
1578
  # Gradio Interface Settings