ChienChung commited on
Commit
bb18fa0
·
verified ·
1 Parent(s): ceb89b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -23
app.py CHANGED
@@ -1241,7 +1241,7 @@ def multi_agent_chat_advanced(query: str, file=None) -> str:
1241
 
1242
 
1243
  # LangGraph 使用的節點函數(會接續你的 Crew Agent)
1244
-
1245
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
1246
 
1247
  # Intent Embedding 分類(支援檔名)
@@ -1253,13 +1253,11 @@ INTENT_LABELS = {
1253
 
1254
  def detect_intent_embedding(query, file_names=[]):
1255
  query_emb = embedding_model.encode(query, normalize_embeddings=True)
1256
-
1257
  best_label = None
1258
  best_score = -1
1259
  all_phrases = INTENT_LABELS.copy()
1260
  if file_names:
1261
  all_phrases["DocQA"] += [name.lower() for name in file_names]
1262
-
1263
  for label, examples in all_phrases.items():
1264
  for example in examples:
1265
  example_emb = embedding_model.encode(example, normalize_embeddings=True)
@@ -1275,12 +1273,40 @@ def decide_next(state):
1275
  label = detect_intent_embedding(query, file_names)
1276
  return label
1277
 
1278
- # node functions
 
 
 
 
 
 
1279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1280
  def docqa_run(state):
1281
  result = document_qa_agent.execute_task(docqa_task, inputs={"query": state["query"]})
1282
  output = result.output.lower()
1283
- # fallback: no info found general
1284
  if any(x in output for x in ["no relevant info", "not found", "no answer"]):
1285
  return general_run(state)
1286
  return {"answer": result.output}
@@ -1288,47 +1314,42 @@ def docqa_run(state):
1288
  def general_run(state):
1289
  result = general_agent.execute_task(general_task, inputs={"query": state["query"]})
1290
  output = result.output.lower()
 
1291
  if any(x in output for x in ["i don't know", "no idea", "not sure", "can't answer"]):
1292
  result = search_agent.execute_task(search_task, inputs={"query": state["query"]})
1293
  return {"answer": result.output}
1294
 
1295
  def summariser_run(state):
1296
- result = summariser_agent.execute_task(summariser_task, inputs={"query": state["query"]})
1297
  return {"summary": result.output}
1298
 
1299
- # LangGraph 定義
1300
-
1301
  def build_langgraph_pipeline():
1302
  graph = StateGraph(dict)
1303
-
1304
- graph.add_node("Router", lambda state: state) # router 只傳入狀態即可
1305
  graph.add_node("DocQA", docqa_run)
1306
  graph.add_node("General", general_run)
1307
  graph.add_node("Summarise", summariser_run)
1308
-
1309
  graph.set_entry_point("Router")
1310
  graph.add_conditional_edges("Router", decide_next, {
1311
  "DocQA": "DocQA",
1312
  "General": "General",
1313
  "Summarise": "Summarise",
1314
  })
1315
-
1316
  graph.set_finish_point("DocQA")
1317
  graph.set_finish_point("General")
1318
  graph.set_finish_point("Summarise")
1319
-
1320
  return graph.compile()
1321
 
1322
- # 主執行函數
1323
-
1324
  def langgraph_tab6_main(query: str, file=None):
1325
  try:
1326
  files = file if isinstance(file, list) else [file] if file else []
1327
  all_docs, file_names = [], []
1328
-
1329
  for f in files:
1330
  path = get_file_path(f)
1331
- if not path: continue
 
1332
  file_names.append(os.path.basename(path))
1333
  if path.lower().endswith(".pdf"):
1334
  loader = PyPDFLoader(path)
@@ -1337,21 +1358,16 @@ def langgraph_tab6_main(query: str, file=None):
1337
  else:
1338
  loader = TextLoader(path)
1339
  all_docs.extend(loader.load())
1340
-
1341
  if not all_docs:
1342
  retriever = None # 空 retriever
1343
  else:
1344
  chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(all_docs)
1345
  db = FAISS.from_documents(chunks, embeddings)
1346
  retriever = db.as_retriever()
1347
-
1348
- # 設定 retriever 到 global Agent,如果你需要可傳給 Agent
1349
- # 可選:document_qa_agent.retriever = retriever
1350
-
1351
  graph = build_langgraph_pipeline()
1352
  state = {"query": query, "file_names": file_names}
1353
  result = graph.invoke(state)
1354
-
1355
  if "answer" in result:
1356
  return result["answer"]
1357
  if "summary" in result:
 
1241
 
1242
 
1243
  # LangGraph 使用的節點函數(會接續你的 Crew Agent)
1244
+ # 初始化 embedding model
1245
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
1246
 
1247
  # Intent Embedding 分類(支援檔名)
 
1253
 
1254
  def detect_intent_embedding(query, file_names=[]):
1255
  query_emb = embedding_model.encode(query, normalize_embeddings=True)
 
1256
  best_label = None
1257
  best_score = -1
1258
  all_phrases = INTENT_LABELS.copy()
1259
  if file_names:
1260
  all_phrases["DocQA"] += [name.lower() for name in file_names]
 
1261
  for label, examples in all_phrases.items():
1262
  for example in examples:
1263
  example_emb = embedding_model.encode(example, normalize_embeddings=True)
 
1273
  label = detect_intent_embedding(query, file_names)
1274
  return label
1275
 
1276
+ # === 定義 Task 物件 ===
1277
+ docqa_task = Task(
1278
+ description="Document QA Task: Answer questions based on the uploaded document.",
1279
+ expected_output="Answer from Document QA Agent.",
1280
+ agent=document_qa_agent,
1281
+ input_variables=["query"]
1282
+ )
1283
 
1284
+ general_task = Task(
1285
+ description="General Chat Task: Answer general queries.",
1286
+ expected_output="Answer from General Agent.",
1287
+ agent=general_agent,
1288
+ input_variables=["query"]
1289
+ )
1290
+
1291
+ summariser_task = Task(
1292
+ description="Summarisation Task: Summarise document content.",
1293
+ expected_output="Summary output.",
1294
+ agent=summarizer_agent, # 使用 summarizer_agent(注意字母 z)
1295
+ input_variables=["query"]
1296
+ )
1297
+
1298
+ search_task = Task(
1299
+ description="Search Task: Retrieve information from the web.",
1300
+ expected_output="Answer from Search Agent.",
1301
+ agent=search_agent,
1302
+ input_variables=["query"]
1303
+ )
1304
+
1305
+ # === LangGraph 節點函數 ===
1306
  def docqa_run(state):
1307
  result = document_qa_agent.execute_task(docqa_task, inputs={"query": state["query"]})
1308
  output = result.output.lower()
1309
+ # fallback: 若回答中含 "no relevant info", "not found", "no answer",則使用 General 模組
1310
  if any(x in output for x in ["no relevant info", "not found", "no answer"]):
1311
  return general_run(state)
1312
  return {"answer": result.output}
 
1314
  def general_run(state):
1315
  result = general_agent.execute_task(general_task, inputs={"query": state["query"]})
1316
  output = result.output.lower()
1317
+ # 若 General Agent 回答不佳,則改用 Search 模組
1318
  if any(x in output for x in ["i don't know", "no idea", "not sure", "can't answer"]):
1319
  result = search_agent.execute_task(search_task, inputs={"query": state["query"]})
1320
  return {"answer": result.output}
1321
 
1322
  def summariser_run(state):
1323
+ result = summarizer_agent.execute_task(summariser_task, inputs={"query": state["query"]})
1324
  return {"summary": result.output}
1325
 
1326
+ # === LangGraph 定義 ===
 
1327
  def build_langgraph_pipeline():
1328
  graph = StateGraph(dict)
1329
+ graph.add_node("Router", lambda state: state) # Router 僅傳遞狀態
 
1330
  graph.add_node("DocQA", docqa_run)
1331
  graph.add_node("General", general_run)
1332
  graph.add_node("Summarise", summariser_run)
 
1333
  graph.set_entry_point("Router")
1334
  graph.add_conditional_edges("Router", decide_next, {
1335
  "DocQA": "DocQA",
1336
  "General": "General",
1337
  "Summarise": "Summarise",
1338
  })
 
1339
  graph.set_finish_point("DocQA")
1340
  graph.set_finish_point("General")
1341
  graph.set_finish_point("Summarise")
 
1342
  return graph.compile()
1343
 
1344
+ # === 主執行函數 (Tab6) ===
 
1345
  def langgraph_tab6_main(query: str, file=None):
1346
  try:
1347
  files = file if isinstance(file, list) else [file] if file else []
1348
  all_docs, file_names = [], []
 
1349
  for f in files:
1350
  path = get_file_path(f)
1351
+ if not path:
1352
+ continue
1353
  file_names.append(os.path.basename(path))
1354
  if path.lower().endswith(".pdf"):
1355
  loader = PyPDFLoader(path)
 
1358
  else:
1359
  loader = TextLoader(path)
1360
  all_docs.extend(loader.load())
 
1361
  if not all_docs:
1362
  retriever = None # 空 retriever
1363
  else:
1364
  chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(all_docs)
1365
  db = FAISS.from_documents(chunks, embeddings)
1366
  retriever = db.as_retriever()
1367
+ # 可選:設定 retriever 到 global Agent(若需要)
 
 
 
1368
  graph = build_langgraph_pipeline()
1369
  state = {"query": query, "file_names": file_names}
1370
  result = graph.invoke(state)
 
1371
  if "answer" in result:
1372
  return result["answer"]
1373
  if "summary" in result: