Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1241,7 +1241,7 @@ def multi_agent_chat_advanced(query: str, file=None) -> str:
|
|
| 1241 |
|
| 1242 |
|
| 1243 |
# LangGraph 使用的節點函數(會接續你的 Crew Agent)
|
| 1244 |
-
|
| 1245 |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 1246 |
|
| 1247 |
# Intent Embedding 分類(支援檔名)
|
|
@@ -1253,13 +1253,11 @@ INTENT_LABELS = {
|
|
| 1253 |
|
| 1254 |
def detect_intent_embedding(query, file_names=[]):
|
| 1255 |
query_emb = embedding_model.encode(query, normalize_embeddings=True)
|
| 1256 |
-
|
| 1257 |
best_label = None
|
| 1258 |
best_score = -1
|
| 1259 |
all_phrases = INTENT_LABELS.copy()
|
| 1260 |
if file_names:
|
| 1261 |
all_phrases["DocQA"] += [name.lower() for name in file_names]
|
| 1262 |
-
|
| 1263 |
for label, examples in all_phrases.items():
|
| 1264 |
for example in examples:
|
| 1265 |
example_emb = embedding_model.encode(example, normalize_embeddings=True)
|
|
@@ -1275,12 +1273,40 @@ def decide_next(state):
|
|
| 1275 |
label = detect_intent_embedding(query, file_names)
|
| 1276 |
return label
|
| 1277 |
|
| 1278 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1279 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1280 |
def docqa_run(state):
|
| 1281 |
result = document_qa_agent.execute_task(docqa_task, inputs={"query": state["query"]})
|
| 1282 |
output = result.output.lower()
|
| 1283 |
-
# fallback: no info found
|
| 1284 |
if any(x in output for x in ["no relevant info", "not found", "no answer"]):
|
| 1285 |
return general_run(state)
|
| 1286 |
return {"answer": result.output}
|
|
@@ -1288,47 +1314,42 @@ def docqa_run(state):
|
|
| 1288 |
def general_run(state):
|
| 1289 |
result = general_agent.execute_task(general_task, inputs={"query": state["query"]})
|
| 1290 |
output = result.output.lower()
|
|
|
|
| 1291 |
if any(x in output for x in ["i don't know", "no idea", "not sure", "can't answer"]):
|
| 1292 |
result = search_agent.execute_task(search_task, inputs={"query": state["query"]})
|
| 1293 |
return {"answer": result.output}
|
| 1294 |
|
| 1295 |
def summariser_run(state):
|
| 1296 |
-
result =
|
| 1297 |
return {"summary": result.output}
|
| 1298 |
|
| 1299 |
-
#
|
| 1300 |
-
|
| 1301 |
def build_langgraph_pipeline():
|
| 1302 |
graph = StateGraph(dict)
|
| 1303 |
-
|
| 1304 |
-
graph.add_node("Router", lambda state: state) # router 只傳入狀態即可
|
| 1305 |
graph.add_node("DocQA", docqa_run)
|
| 1306 |
graph.add_node("General", general_run)
|
| 1307 |
graph.add_node("Summarise", summariser_run)
|
| 1308 |
-
|
| 1309 |
graph.set_entry_point("Router")
|
| 1310 |
graph.add_conditional_edges("Router", decide_next, {
|
| 1311 |
"DocQA": "DocQA",
|
| 1312 |
"General": "General",
|
| 1313 |
"Summarise": "Summarise",
|
| 1314 |
})
|
| 1315 |
-
|
| 1316 |
graph.set_finish_point("DocQA")
|
| 1317 |
graph.set_finish_point("General")
|
| 1318 |
graph.set_finish_point("Summarise")
|
| 1319 |
-
|
| 1320 |
return graph.compile()
|
| 1321 |
|
| 1322 |
-
#
|
| 1323 |
-
|
| 1324 |
def langgraph_tab6_main(query: str, file=None):
|
| 1325 |
try:
|
| 1326 |
files = file if isinstance(file, list) else [file] if file else []
|
| 1327 |
all_docs, file_names = [], []
|
| 1328 |
-
|
| 1329 |
for f in files:
|
| 1330 |
path = get_file_path(f)
|
| 1331 |
-
if not path:
|
|
|
|
| 1332 |
file_names.append(os.path.basename(path))
|
| 1333 |
if path.lower().endswith(".pdf"):
|
| 1334 |
loader = PyPDFLoader(path)
|
|
@@ -1337,21 +1358,16 @@ def langgraph_tab6_main(query: str, file=None):
|
|
| 1337 |
else:
|
| 1338 |
loader = TextLoader(path)
|
| 1339 |
all_docs.extend(loader.load())
|
| 1340 |
-
|
| 1341 |
if not all_docs:
|
| 1342 |
retriever = None # 空 retriever
|
| 1343 |
else:
|
| 1344 |
chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(all_docs)
|
| 1345 |
db = FAISS.from_documents(chunks, embeddings)
|
| 1346 |
retriever = db.as_retriever()
|
| 1347 |
-
|
| 1348 |
-
# 設定 retriever 到 global Agent,如果你需要可傳給 Agent
|
| 1349 |
-
# 可選:document_qa_agent.retriever = retriever
|
| 1350 |
-
|
| 1351 |
graph = build_langgraph_pipeline()
|
| 1352 |
state = {"query": query, "file_names": file_names}
|
| 1353 |
result = graph.invoke(state)
|
| 1354 |
-
|
| 1355 |
if "answer" in result:
|
| 1356 |
return result["answer"]
|
| 1357 |
if "summary" in result:
|
|
|
|
| 1241 |
|
| 1242 |
|
| 1243 |
# LangGraph 使用的節點函數(會接續你的 Crew Agent)
|
| 1244 |
+
# 初始化 embedding model
|
| 1245 |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 1246 |
|
| 1247 |
# Intent Embedding 分類(支援檔名)
|
|
|
|
| 1253 |
|
| 1254 |
def detect_intent_embedding(query, file_names=[]):
|
| 1255 |
query_emb = embedding_model.encode(query, normalize_embeddings=True)
|
|
|
|
| 1256 |
best_label = None
|
| 1257 |
best_score = -1
|
| 1258 |
all_phrases = INTENT_LABELS.copy()
|
| 1259 |
if file_names:
|
| 1260 |
all_phrases["DocQA"] += [name.lower() for name in file_names]
|
|
|
|
| 1261 |
for label, examples in all_phrases.items():
|
| 1262 |
for example in examples:
|
| 1263 |
example_emb = embedding_model.encode(example, normalize_embeddings=True)
|
|
|
|
| 1273 |
label = detect_intent_embedding(query, file_names)
|
| 1274 |
return label
|
| 1275 |
|
| 1276 |
+
# === 定義 Task 物件 ===
|
| 1277 |
+
docqa_task = Task(
|
| 1278 |
+
description="Document QA Task: Answer questions based on the uploaded document.",
|
| 1279 |
+
expected_output="Answer from Document QA Agent.",
|
| 1280 |
+
agent=document_qa_agent,
|
| 1281 |
+
input_variables=["query"]
|
| 1282 |
+
)
|
| 1283 |
|
| 1284 |
+
general_task = Task(
|
| 1285 |
+
description="General Chat Task: Answer general queries.",
|
| 1286 |
+
expected_output="Answer from General Agent.",
|
| 1287 |
+
agent=general_agent,
|
| 1288 |
+
input_variables=["query"]
|
| 1289 |
+
)
|
| 1290 |
+
|
| 1291 |
+
summariser_task = Task(
|
| 1292 |
+
description="Summarisation Task: Summarise document content.",
|
| 1293 |
+
expected_output="Summary output.",
|
| 1294 |
+
agent=summarizer_agent, # 使用 summarizer_agent(注意字母 z)
|
| 1295 |
+
input_variables=["query"]
|
| 1296 |
+
)
|
| 1297 |
+
|
| 1298 |
+
search_task = Task(
|
| 1299 |
+
description="Search Task: Retrieve information from the web.",
|
| 1300 |
+
expected_output="Answer from Search Agent.",
|
| 1301 |
+
agent=search_agent,
|
| 1302 |
+
input_variables=["query"]
|
| 1303 |
+
)
|
| 1304 |
+
|
| 1305 |
+
# === LangGraph 節點函數 ===
|
| 1306 |
def docqa_run(state):
|
| 1307 |
result = document_qa_agent.execute_task(docqa_task, inputs={"query": state["query"]})
|
| 1308 |
output = result.output.lower()
|
| 1309 |
+
# fallback: 若回答中含 "no relevant info", "not found", 或 "no answer",則使用 General 模組
|
| 1310 |
if any(x in output for x in ["no relevant info", "not found", "no answer"]):
|
| 1311 |
return general_run(state)
|
| 1312 |
return {"answer": result.output}
|
|
|
|
| 1314 |
def general_run(state):
|
| 1315 |
result = general_agent.execute_task(general_task, inputs={"query": state["query"]})
|
| 1316 |
output = result.output.lower()
|
| 1317 |
+
# 若 General Agent 回答不佳,則改用 Search 模組
|
| 1318 |
if any(x in output for x in ["i don't know", "no idea", "not sure", "can't answer"]):
|
| 1319 |
result = search_agent.execute_task(search_task, inputs={"query": state["query"]})
|
| 1320 |
return {"answer": result.output}
|
| 1321 |
|
| 1322 |
def summariser_run(state):
|
| 1323 |
+
result = summarizer_agent.execute_task(summariser_task, inputs={"query": state["query"]})
|
| 1324 |
return {"summary": result.output}
|
| 1325 |
|
| 1326 |
+
# === LangGraph 定義 ===
|
|
|
|
| 1327 |
def build_langgraph_pipeline():
|
| 1328 |
graph = StateGraph(dict)
|
| 1329 |
+
graph.add_node("Router", lambda state: state) # Router 僅傳遞狀態
|
|
|
|
| 1330 |
graph.add_node("DocQA", docqa_run)
|
| 1331 |
graph.add_node("General", general_run)
|
| 1332 |
graph.add_node("Summarise", summariser_run)
|
|
|
|
| 1333 |
graph.set_entry_point("Router")
|
| 1334 |
graph.add_conditional_edges("Router", decide_next, {
|
| 1335 |
"DocQA": "DocQA",
|
| 1336 |
"General": "General",
|
| 1337 |
"Summarise": "Summarise",
|
| 1338 |
})
|
|
|
|
| 1339 |
graph.set_finish_point("DocQA")
|
| 1340 |
graph.set_finish_point("General")
|
| 1341 |
graph.set_finish_point("Summarise")
|
|
|
|
| 1342 |
return graph.compile()
|
| 1343 |
|
| 1344 |
+
# === 主執行函數 (Tab6) ===
|
|
|
|
| 1345 |
def langgraph_tab6_main(query: str, file=None):
|
| 1346 |
try:
|
| 1347 |
files = file if isinstance(file, list) else [file] if file else []
|
| 1348 |
all_docs, file_names = [], []
|
|
|
|
| 1349 |
for f in files:
|
| 1350 |
path = get_file_path(f)
|
| 1351 |
+
if not path:
|
| 1352 |
+
continue
|
| 1353 |
file_names.append(os.path.basename(path))
|
| 1354 |
if path.lower().endswith(".pdf"):
|
| 1355 |
loader = PyPDFLoader(path)
|
|
|
|
| 1358 |
else:
|
| 1359 |
loader = TextLoader(path)
|
| 1360 |
all_docs.extend(loader.load())
|
|
|
|
| 1361 |
if not all_docs:
|
| 1362 |
retriever = None # 空 retriever
|
| 1363 |
else:
|
| 1364 |
chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(all_docs)
|
| 1365 |
db = FAISS.from_documents(chunks, embeddings)
|
| 1366 |
retriever = db.as_retriever()
|
| 1367 |
+
# 可選:設定 retriever 到 global Agent(若需要)
|
|
|
|
|
|
|
|
|
|
| 1368 |
graph = build_langgraph_pipeline()
|
| 1369 |
state = {"query": query, "file_names": file_names}
|
| 1370 |
result = graph.invoke(state)
|
|
|
|
| 1371 |
if "answer" in result:
|
| 1372 |
return result["answer"]
|
| 1373 |
if "summary" in result:
|