ChienChung commited on
Commit
ceb89b7
·
verified ·
1 Parent(s): 875807c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -42
app.py CHANGED
@@ -94,6 +94,9 @@ from langgraph.graph import StateGraph
94
  from langchain_core.runnables import RunnableLambda
95
  from langchain.chains import LLMChain
96
  from langchain.chains.combine_documents.stuff import StuffDocumentsChain
 
 
 
97
 
98
  try:
99
  from phoenix.trace.langchain import LangChainInstrumentor
@@ -1238,12 +1241,10 @@ def multi_agent_chat_advanced(query: str, file=None) -> str:
1238
 
1239
 
1240
  # LangGraph 使用的節點函數(會接續你的 Crew Agent)
1241
- # ✅ Tab6 最終設計:General + Search fallback 接在 DocQA 後面,Summarise 平行觸發
1242
- from sentence_transformers import SentenceTransformer
1243
 
1244
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
1245
 
1246
- # 用於 LangGraph 路由的意圖標籤
1247
  INTENT_LABELS = {
1248
  "DocQA": ["document", "file", "paper", "cb", "proposal"],
1249
  "Summarise": ["summarise", "summary", "abstract", "key points", "要點", "總結", "摘要"],
@@ -1252,8 +1253,9 @@ INTENT_LABELS = {
1252
 
1253
  def detect_intent_embedding(query, file_names=[]):
1254
  query_emb = embedding_model.encode(query, normalize_embeddings=True)
1255
- best_label, best_score = None, -1
1256
 
 
 
1257
  all_phrases = INTENT_LABELS.copy()
1258
  if file_names:
1259
  all_phrases["DocQA"] += [name.lower() for name in file_names]
@@ -1263,59 +1265,70 @@ def detect_intent_embedding(query, file_names=[]):
1263
  example_emb = embedding_model.encode(example, normalize_embeddings=True)
1264
  score = float(query_emb @ example_emb.T)
1265
  if score > best_score:
1266
- best_score, best_label = score, label
1267
- return best_label if best_label else "DocQA"
 
1268
 
1269
  def decide_next(state):
1270
  query = state.get("query", "")
1271
  file_names = state.get("file_names", [])
1272
- return detect_intent_embedding(query, file_names)
 
 
 
1273
 
1274
- # ⛓️ LangGraph 節點函數
 
 
 
 
 
 
1275
 
1276
- def docqa_run(state: dict) -> dict:
1277
- result = document_qa_agent.execute_task(router_task, inputs={"query": state["query"]})
1278
  output = result.output.lower()
1279
- if any(x in output for x in ["no relevant", "can't find", "not in document"]):
1280
- result = general_agent.execute_task(router_task, inputs={"query": state["query"]})
1281
- output = result.output.lower()
1282
- if any(x in output for x in ["i don't know", "no idea", "not sure", "can't answer", "no info"]):
1283
- result = search_agent.execute_task(router_task, inputs={"query": state["query"]})
1284
- return {"query": state["query"], "answer": result.output}
1285
 
1286
- def summarizer_run(state: dict) -> dict:
1287
- result = summarizer_agent.execute_task(router_task, inputs={"query": state["query"]})
1288
- return {"query": state["query"], "summary": result.output}
1289
 
1290
- # LangGraph 設計
1291
 
1292
- def build_langgraph_gpt_like():
1293
  graph = StateGraph(dict)
1294
 
 
1295
  graph.add_node("DocQA", docqa_run)
1296
- graph.add_node("Summarise", summarizer_run)
 
1297
 
1298
- graph.set_entry_point("DocQA")
1299
- graph.add_conditional_edges("DocQA", decide_next, {
 
 
1300
  "Summarise": "Summarise",
1301
- "DocQA": "DocQA" # 防止空跳轉
1302
  })
1303
 
1304
  graph.set_finish_point("DocQA")
 
1305
  graph.set_finish_point("Summarise")
 
1306
  return graph.compile()
1307
 
1308
- # 主入口函數
1309
 
1310
- def langgraph_tab6_main(query: str, file=None) -> str:
1311
  try:
1312
  files = file if isinstance(file, list) else [file] if file else []
1313
  all_docs, file_names = [], []
1314
 
1315
  for f in files:
1316
  path = get_file_path(f)
1317
- if not path:
1318
- continue
1319
  file_names.append(os.path.basename(path))
1320
  if path.lower().endswith(".pdf"):
1321
  loader = PyPDFLoader(path)
@@ -1325,29 +1338,29 @@ def langgraph_tab6_main(query: str, file=None) -> str:
1325
  loader = TextLoader(path)
1326
  all_docs.extend(loader.load())
1327
 
1328
- # 即使沒檔案也進入 QA,但內容會讓 QA 回答不了 → fallback 到 general → fallback 到 search
1329
- if all_docs:
 
1330
  chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(all_docs)
1331
  db = FAISS.from_documents(chunks, embeddings)
1332
  retriever = db.as_retriever()
1333
- # 替換 DocQA agent retriever
1334
- document_qa_agent.tools[0].retriever = retriever
1335
 
1336
- graph = build_langgraph_gpt_like()
1337
- result = graph.invoke({"query": query, "file_names": file_names})
 
 
 
 
1338
 
1339
- if "answer" in result and "summary" in result:
1340
- return f"\nAnswer: {result['answer']}\n\nSummary: {result['summary']}"
1341
- elif "answer" in result:
1342
  return result["answer"]
1343
- elif "summary" in result:
1344
  return result["summary"]
1345
- else:
1346
- return "No result."
1347
-
1348
  except Exception as e:
1349
  return f"[Tab6 Error] {e}"
1350
 
 
1351
 
1352
 
1353
  # Gradio Interface Settings
 
94
  from langchain_core.runnables import RunnableLambda
95
  from langchain.chains import LLMChain
96
  from langchain.chains.combine_documents.stuff import StuffDocumentsChain
97
+ from sentence_transformers import SentenceTransformer
98
+
99
+
100
 
101
  try:
102
  from phoenix.trace.langchain import LangChainInstrumentor
 
1241
 
1242
 
1243
  # LangGraph 使用的節點函數(會接續你的 Crew Agent)
 
 
1244
 
1245
  embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
1246
 
1247
+ # Intent Embedding 分類(支援檔名)
1248
  INTENT_LABELS = {
1249
  "DocQA": ["document", "file", "paper", "cb", "proposal"],
1250
  "Summarise": ["summarise", "summary", "abstract", "key points", "要點", "總結", "摘要"],
 
1253
 
1254
  def detect_intent_embedding(query, file_names=[]):
1255
  query_emb = embedding_model.encode(query, normalize_embeddings=True)
 
1256
 
1257
+ best_label = None
1258
+ best_score = -1
1259
  all_phrases = INTENT_LABELS.copy()
1260
  if file_names:
1261
  all_phrases["DocQA"] += [name.lower() for name in file_names]
 
1265
  example_emb = embedding_model.encode(example, normalize_embeddings=True)
1266
  score = float(query_emb @ example_emb.T)
1267
  if score > best_score:
1268
+ best_score = score
1269
+ best_label = label
1270
+ return best_label if best_label else "General"
1271
 
1272
  def decide_next(state):
1273
  query = state.get("query", "")
1274
  file_names = state.get("file_names", [])
1275
+ label = detect_intent_embedding(query, file_names)
1276
+ return label
1277
+
1278
+ # ✅ node functions
1279
 
1280
+ def docqa_run(state):
1281
+ result = document_qa_agent.execute_task(docqa_task, inputs={"query": state["query"]})
1282
+ output = result.output.lower()
1283
+ # fallback: no info found → general
1284
+ if any(x in output for x in ["no relevant info", "not found", "no answer"]):
1285
+ return general_run(state)
1286
+ return {"answer": result.output}
1287
 
1288
+ def general_run(state):
1289
+ result = general_agent.execute_task(general_task, inputs={"query": state["query"]})
1290
  output = result.output.lower()
1291
+ if any(x in output for x in ["i don't know", "no idea", "not sure", "can't answer"]):
1292
+ result = search_agent.execute_task(search_task, inputs={"query": state["query"]})
1293
+ return {"answer": result.output}
 
 
 
1294
 
1295
+ def summariser_run(state):
1296
+ result = summariser_agent.execute_task(summariser_task, inputs={"query": state["query"]})
1297
+ return {"summary": result.output}
1298
 
1299
+ # LangGraph 定義
1300
 
1301
+ def build_langgraph_pipeline():
1302
  graph = StateGraph(dict)
1303
 
1304
+ graph.add_node("Router", lambda state: state) # router 只傳入狀態即可
1305
  graph.add_node("DocQA", docqa_run)
1306
+ graph.add_node("General", general_run)
1307
+ graph.add_node("Summarise", summariser_run)
1308
 
1309
+ graph.set_entry_point("Router")
1310
+ graph.add_conditional_edges("Router", decide_next, {
1311
+ "DocQA": "DocQA",
1312
+ "General": "General",
1313
  "Summarise": "Summarise",
 
1314
  })
1315
 
1316
  graph.set_finish_point("DocQA")
1317
+ graph.set_finish_point("General")
1318
  graph.set_finish_point("Summarise")
1319
+
1320
  return graph.compile()
1321
 
1322
+ # 執行函數
1323
 
1324
+ def langgraph_tab6_main(query: str, file=None):
1325
  try:
1326
  files = file if isinstance(file, list) else [file] if file else []
1327
  all_docs, file_names = [], []
1328
 
1329
  for f in files:
1330
  path = get_file_path(f)
1331
+ if not path: continue
 
1332
  file_names.append(os.path.basename(path))
1333
  if path.lower().endswith(".pdf"):
1334
  loader = PyPDFLoader(path)
 
1338
  loader = TextLoader(path)
1339
  all_docs.extend(loader.load())
1340
 
1341
+ if not all_docs:
1342
+ retriever = None # 空 retriever
1343
+ else:
1344
  chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(all_docs)
1345
  db = FAISS.from_documents(chunks, embeddings)
1346
  retriever = db.as_retriever()
 
 
1347
 
1348
+ # 設定 retriever 到 global Agent,如果你需要可傳給 Agent
1349
+ # 可選:document_qa_agent.retriever = retriever
1350
+
1351
+ graph = build_langgraph_pipeline()
1352
+ state = {"query": query, "file_names": file_names}
1353
+ result = graph.invoke(state)
1354
 
1355
+ if "answer" in result:
 
 
1356
  return result["answer"]
1357
+ if "summary" in result:
1358
  return result["summary"]
1359
+ return "No answer."
 
 
1360
  except Exception as e:
1361
  return f"[Tab6 Error] {e}"
1362
 
1363
+
1364
 
1365
 
1366
  # Gradio Interface Settings