Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1452,122 +1452,76 @@ def build_langgraph_pipeline():
|
|
| 1452 |
from tempfile import mkdtemp
|
| 1453 |
|
| 1454 |
|
| 1455 |
-
|
| 1456 |
-
from pathlib import Path
|
| 1457 |
-
import logging
|
| 1458 |
-
|
| 1459 |
-
logger = logging.getLogger(__name__)
|
| 1460 |
-
|
| 1461 |
def get_file_path_tab6(file):
|
| 1462 |
-
"""改進的文件路徑處理函數"""
|
| 1463 |
try:
|
| 1464 |
-
|
| 1465 |
-
|
| 1466 |
-
# 如果輸入為 None
|
| 1467 |
if file is None:
|
| 1468 |
-
logger.warning("File input is None")
|
| 1469 |
return None
|
| 1470 |
|
| 1471 |
-
# 處理
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1472 |
if isinstance(file, str):
|
| 1473 |
-
#
|
| 1474 |
-
|
| 1475 |
-
|
| 1476 |
-
|
| 1477 |
-
|
| 1478 |
-
|
| 1479 |
-
os.path.join("/tmp/gradio/", safe_path),
|
| 1480 |
-
os.path.join(os.getcwd(), safe_path)
|
| 1481 |
]
|
| 1482 |
|
| 1483 |
-
for path in
|
| 1484 |
if os.path.exists(path):
|
| 1485 |
-
logger.info(f"Found file at: {path}")
|
| 1486 |
return path
|
| 1487 |
|
| 1488 |
-
|
| 1489 |
-
return
|
| 1490 |
-
|
| 1491 |
-
# 處理字典輸入(Gradio 上傳文件的常見格式)
|
| 1492 |
-
elif isinstance(file, dict):
|
| 1493 |
-
logger.debug("Processing dictionary input")
|
| 1494 |
-
|
| 1495 |
-
# 獲取文件數據和名稱
|
| 1496 |
-
data = file.get("data")
|
| 1497 |
-
name = file.get("name", "uploaded_file")
|
| 1498 |
-
|
| 1499 |
-
if data:
|
| 1500 |
-
# 創建臨時目錄
|
| 1501 |
-
temp_dir = mkdtemp()
|
| 1502 |
-
file_path = os.path.join(temp_dir, name)
|
| 1503 |
-
|
| 1504 |
-
# 寫入文件
|
| 1505 |
-
try:
|
| 1506 |
-
with open(file_path, "wb") as f:
|
| 1507 |
-
if isinstance(data, str):
|
| 1508 |
-
f.write(data.encode("utf-8"))
|
| 1509 |
-
else:
|
| 1510 |
-
f.write(data)
|
| 1511 |
-
|
| 1512 |
-
if os.path.exists(file_path):
|
| 1513 |
-
logger.info(f"Successfully created file at: {file_path}")
|
| 1514 |
-
return file_path
|
| 1515 |
-
except Exception as e:
|
| 1516 |
-
logger.error(f"Error writing file: {e}")
|
| 1517 |
-
return None
|
| 1518 |
-
|
| 1519 |
-
logger.warning("No data found in file dictionary")
|
| 1520 |
-
return None
|
| 1521 |
-
|
| 1522 |
-
# 處理具有 save 方法的對象(如 UploadedFile)
|
| 1523 |
-
elif hasattr(file, "save"):
|
| 1524 |
-
logger.debug("Processing file object with save method")
|
| 1525 |
-
try:
|
| 1526 |
-
temp_dir = mkdtemp()
|
| 1527 |
-
file_name = getattr(file, "name", "uploaded_file")
|
| 1528 |
-
file_path = os.path.join(temp_dir, file_name)
|
| 1529 |
-
file.save(file_path)
|
| 1530 |
-
|
| 1531 |
-
if os.path.exists(file_path):
|
| 1532 |
-
logger.info(f"Successfully saved file to: {file_path}")
|
| 1533 |
-
return file_path
|
| 1534 |
-
except Exception as e:
|
| 1535 |
-
logger.error(f"Error saving file: {e}")
|
| 1536 |
-
return None
|
| 1537 |
-
|
| 1538 |
-
# 處理其他情況
|
| 1539 |
-
else:
|
| 1540 |
-
logger.warning(f"Unsupported file type: {type(file)}")
|
| 1541 |
-
return None
|
| 1542 |
|
|
|
|
| 1543 |
except Exception as e:
|
| 1544 |
-
|
| 1545 |
return None
|
| 1546 |
|
| 1547 |
-
|
| 1548 |
-
|
| 1549 |
def langgraph_tab6_main(query: str, file=None):
|
| 1550 |
try:
|
|
|
|
| 1551 |
files = file if isinstance(file, list) else [file] if file else []
|
| 1552 |
all_docs = []
|
| 1553 |
file_names = []
|
| 1554 |
docs_by_file = []
|
| 1555 |
|
|
|
|
| 1556 |
for f in files:
|
| 1557 |
-
path = get_file_path_tab6(f)
|
| 1558 |
-
if not path:
|
| 1559 |
-
logger.warning(f"Could not process file: {f}")
|
| 1560 |
-
continue
|
| 1561 |
-
|
| 1562 |
-
logger.info(f"Processing file: {path}")
|
| 1563 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1564 |
if path.lower().endswith(".pdf"):
|
|
|
|
| 1565 |
loader = PyPDFLoader(path)
|
| 1566 |
elif path.lower().endswith(".docx"):
|
|
|
|
| 1567 |
loader = UnstructuredWordDocumentLoader(path)
|
| 1568 |
else:
|
|
|
|
| 1569 |
loader = TextLoader(path)
|
| 1570 |
-
|
|
|
|
| 1571 |
docs = loader.load()
|
| 1572 |
if docs:
|
| 1573 |
file_names.append(os.path.basename(path))
|
|
@@ -1577,26 +1531,28 @@ def langgraph_tab6_main(query: str, file=None):
|
|
| 1577 |
text = "\n".join(docs)
|
| 1578 |
docs_by_file.append(text)
|
| 1579 |
all_docs.extend(docs)
|
| 1580 |
-
|
| 1581 |
except Exception as e:
|
| 1582 |
-
|
| 1583 |
continue
|
| 1584 |
|
| 1585 |
-
#
|
| 1586 |
if not all_docs:
|
| 1587 |
return "No valid documents could be processed. Please check your file and try again."
|
| 1588 |
-
|
| 1589 |
-
|
| 1590 |
-
|
| 1591 |
-
|
| 1592 |
-
|
| 1593 |
-
|
| 1594 |
-
|
| 1595 |
-
|
| 1596 |
-
|
| 1597 |
-
|
| 1598 |
-
|
| 1599 |
-
|
|
|
|
|
|
|
| 1600 |
|
| 1601 |
parsed = parse_query(query)
|
| 1602 |
if (parsed.get("summarise") or parsed.get("compare")) and len(docs_by_file) > 0:
|
|
@@ -1607,14 +1563,16 @@ def langgraph_tab6_main(query: str, file=None):
|
|
| 1607 |
state = {"query": query, "file_names": file_names}
|
| 1608 |
if retriever is not None:
|
| 1609 |
state["retriever"] = retriever
|
|
|
|
| 1610 |
result = graph.invoke(state)
|
| 1611 |
if "answer" in result:
|
| 1612 |
return result["answer"]
|
| 1613 |
if "summary" in result:
|
| 1614 |
return result["summary"]
|
| 1615 |
return "No answer."
|
|
|
|
| 1616 |
except Exception as e:
|
| 1617 |
-
|
| 1618 |
return f"[Tab6 Error] {str(e)}"
|
| 1619 |
|
| 1620 |
# Gradio Interface Settings
|
|
|
|
| 1452 |
from tempfile import mkdtemp
|
| 1453 |
|
| 1454 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1455 |
def get_file_path_tab6(file):
|
| 1456 |
+
"""改進的文件路徑處理函數,專門處理 Gradio 上傳的文件"""
|
| 1457 |
try:
|
| 1458 |
+
# 如果是 None
|
|
|
|
|
|
|
| 1459 |
if file is None:
|
|
|
|
| 1460 |
return None
|
| 1461 |
|
| 1462 |
+
# 處理 Gradio 文件對象
|
| 1463 |
+
if hasattr(file, 'name'):
|
| 1464 |
+
return file.name
|
| 1465 |
+
|
| 1466 |
+
# 如果是字典(Gradio 文件上傳的另一種格式)
|
| 1467 |
+
if isinstance(file, dict):
|
| 1468 |
+
if 'name' in file:
|
| 1469 |
+
return file['name']
|
| 1470 |
+
return None
|
| 1471 |
+
|
| 1472 |
+
# 如果是字符串路徑
|
| 1473 |
if isinstance(file, str):
|
| 1474 |
+
# 檢查常見的上傳路徑
|
| 1475 |
+
possible_paths = [
|
| 1476 |
+
file,
|
| 1477 |
+
os.path.join('/tmp/gradio/', file),
|
| 1478 |
+
os.path.join(os.getcwd(), file),
|
| 1479 |
+
os.path.abspath(file)
|
|
|
|
|
|
|
| 1480 |
]
|
| 1481 |
|
| 1482 |
+
for path in possible_paths:
|
| 1483 |
if os.path.exists(path):
|
|
|
|
| 1484 |
return path
|
| 1485 |
|
| 1486 |
+
# 如果找不到文件,返回原始路徑
|
| 1487 |
+
return file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1488 |
|
| 1489 |
+
return None
|
| 1490 |
except Exception as e:
|
| 1491 |
+
print(f"Error in get_file_path_tab6: {e}")
|
| 1492 |
return None
|
| 1493 |
|
|
|
|
|
|
|
| 1494 |
def langgraph_tab6_main(query: str, file=None):
|
| 1495 |
try:
|
| 1496 |
+
# 初始化文件處理
|
| 1497 |
files = file if isinstance(file, list) else [file] if file else []
|
| 1498 |
all_docs = []
|
| 1499 |
file_names = []
|
| 1500 |
docs_by_file = []
|
| 1501 |
|
| 1502 |
+
# 處理每個文件
|
| 1503 |
for f in files:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1504 |
try:
|
| 1505 |
+
# 獲取文件路徑
|
| 1506 |
+
path = get_file_path_tab6(f)
|
| 1507 |
+
if not path:
|
| 1508 |
+
print(f"Could not get valid path for file: {f}")
|
| 1509 |
+
continue
|
| 1510 |
+
|
| 1511 |
+
print(f"Attempting to process file: {path}")
|
| 1512 |
+
|
| 1513 |
+
# 根據文件類型選擇加載器
|
| 1514 |
if path.lower().endswith(".pdf"):
|
| 1515 |
+
from langchain.document_loaders import PyPDFLoader
|
| 1516 |
loader = PyPDFLoader(path)
|
| 1517 |
elif path.lower().endswith(".docx"):
|
| 1518 |
+
from langchain.document_loaders import UnstructuredWordDocumentLoader
|
| 1519 |
loader = UnstructuredWordDocumentLoader(path)
|
| 1520 |
else:
|
| 1521 |
+
from langchain.document_loaders import TextLoader
|
| 1522 |
loader = TextLoader(path)
|
| 1523 |
+
|
| 1524 |
+
# 加載文件
|
| 1525 |
docs = loader.load()
|
| 1526 |
if docs:
|
| 1527 |
file_names.append(os.path.basename(path))
|
|
|
|
| 1531 |
text = "\n".join(docs)
|
| 1532 |
docs_by_file.append(text)
|
| 1533 |
all_docs.extend(docs)
|
| 1534 |
+
print(f"Successfully processed file: {path}")
|
| 1535 |
except Exception as e:
|
| 1536 |
+
print(f"Error processing file {f}: {e}")
|
| 1537 |
continue
|
| 1538 |
|
| 1539 |
+
# 檢查是否有成功處理的文件
|
| 1540 |
if not all_docs:
|
| 1541 |
return "No valid documents could be processed. Please check your file and try again."
|
| 1542 |
+
|
| 1543 |
+
# 其餘代碼保持不變...
|
| 1544 |
+
chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(all_docs)
|
| 1545 |
+
db = FAISS.from_documents(chunks, embeddings)
|
| 1546 |
+
retriever = db.as_retriever()
|
| 1547 |
+
|
| 1548 |
+
global session_retriever
|
| 1549 |
+
session_retriever = retriever
|
| 1550 |
+
global session_qa_chain
|
| 1551 |
+
session_qa_chain = ConversationalRetrievalChain.from_llm(
|
| 1552 |
+
llm=llm_gpt4,
|
| 1553 |
+
retriever=retriever,
|
| 1554 |
+
memory=ConversationBufferMemory(memory_key="chat_history", return_messages=True),
|
| 1555 |
+
)
|
| 1556 |
|
| 1557 |
parsed = parse_query(query)
|
| 1558 |
if (parsed.get("summarise") or parsed.get("compare")) and len(docs_by_file) > 0:
|
|
|
|
| 1563 |
state = {"query": query, "file_names": file_names}
|
| 1564 |
if retriever is not None:
|
| 1565 |
state["retriever"] = retriever
|
| 1566 |
+
|
| 1567 |
result = graph.invoke(state)
|
| 1568 |
if "answer" in result:
|
| 1569 |
return result["answer"]
|
| 1570 |
if "summary" in result:
|
| 1571 |
return result["summary"]
|
| 1572 |
return "No answer."
|
| 1573 |
+
|
| 1574 |
except Exception as e:
|
| 1575 |
+
print(f"Error in main function: {e}")
|
| 1576 |
return f"[Tab6 Error] {str(e)}"
|
| 1577 |
|
| 1578 |
# Gradio Interface Settings
|