Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -94,6 +94,9 @@ from langgraph.graph import StateGraph
|
|
| 94 |
from langchain_core.runnables import RunnableLambda
|
| 95 |
from langchain.chains import LLMChain
|
| 96 |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
try:
|
| 99 |
from phoenix.trace.langchain import LangChainInstrumentor
|
|
@@ -1238,12 +1241,10 @@ def multi_agent_chat_advanced(query: str, file=None) -> str:
|
|
| 1238 |
|
| 1239 |
|
| 1240 |
# LangGraph 使用的節點函數(會接續你的 Crew Agent)
|
| 1241 |
-
# ✅ Tab6 最終設計:General + Search fallback 接在 DocQA 後面,Summarise 平行觸發
|
| 1242 |
-
from sentence_transformers import SentenceTransformer
|
| 1243 |
|
| 1244 |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 1245 |
|
| 1246 |
-
#
|
| 1247 |
INTENT_LABELS = {
|
| 1248 |
"DocQA": ["document", "file", "paper", "cb", "proposal"],
|
| 1249 |
"Summarise": ["summarise", "summary", "abstract", "key points", "要點", "總結", "摘要"],
|
|
@@ -1252,8 +1253,9 @@ INTENT_LABELS = {
|
|
| 1252 |
|
| 1253 |
def detect_intent_embedding(query, file_names=[]):
|
| 1254 |
query_emb = embedding_model.encode(query, normalize_embeddings=True)
|
| 1255 |
-
best_label, best_score = None, -1
|
| 1256 |
|
|
|
|
|
|
|
| 1257 |
all_phrases = INTENT_LABELS.copy()
|
| 1258 |
if file_names:
|
| 1259 |
all_phrases["DocQA"] += [name.lower() for name in file_names]
|
|
@@ -1263,59 +1265,70 @@ def detect_intent_embedding(query, file_names=[]):
|
|
| 1263 |
example_emb = embedding_model.encode(example, normalize_embeddings=True)
|
| 1264 |
score = float(query_emb @ example_emb.T)
|
| 1265 |
if score > best_score:
|
| 1266 |
-
best_score
|
| 1267 |
-
|
|
|
|
| 1268 |
|
| 1269 |
def decide_next(state):
|
| 1270 |
query = state.get("query", "")
|
| 1271 |
file_names = state.get("file_names", [])
|
| 1272 |
-
|
|
|
|
|
|
|
|
|
|
| 1273 |
|
| 1274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1275 |
|
| 1276 |
-
def
|
| 1277 |
-
result =
|
| 1278 |
output = result.output.lower()
|
| 1279 |
-
if any(x in output for x in ["
|
| 1280 |
-
result =
|
| 1281 |
-
|
| 1282 |
-
if any(x in output for x in ["i don't know", "no idea", "not sure", "can't answer", "no info"]):
|
| 1283 |
-
result = search_agent.execute_task(router_task, inputs={"query": state["query"]})
|
| 1284 |
-
return {"query": state["query"], "answer": result.output}
|
| 1285 |
|
| 1286 |
-
def
|
| 1287 |
-
result =
|
| 1288 |
-
return {"
|
| 1289 |
|
| 1290 |
-
# LangGraph
|
| 1291 |
|
| 1292 |
-
def
|
| 1293 |
graph = StateGraph(dict)
|
| 1294 |
|
|
|
|
| 1295 |
graph.add_node("DocQA", docqa_run)
|
| 1296 |
-
graph.add_node("
|
|
|
|
| 1297 |
|
| 1298 |
-
graph.set_entry_point("
|
| 1299 |
-
graph.add_conditional_edges("
|
|
|
|
|
|
|
| 1300 |
"Summarise": "Summarise",
|
| 1301 |
-
"DocQA": "DocQA" # 防止空跳轉
|
| 1302 |
})
|
| 1303 |
|
| 1304 |
graph.set_finish_point("DocQA")
|
|
|
|
| 1305 |
graph.set_finish_point("Summarise")
|
|
|
|
| 1306 |
return graph.compile()
|
| 1307 |
|
| 1308 |
-
# 主
|
| 1309 |
|
| 1310 |
-
def langgraph_tab6_main(query: str, file=None)
|
| 1311 |
try:
|
| 1312 |
files = file if isinstance(file, list) else [file] if file else []
|
| 1313 |
all_docs, file_names = [], []
|
| 1314 |
|
| 1315 |
for f in files:
|
| 1316 |
path = get_file_path(f)
|
| 1317 |
-
if not path:
|
| 1318 |
-
continue
|
| 1319 |
file_names.append(os.path.basename(path))
|
| 1320 |
if path.lower().endswith(".pdf"):
|
| 1321 |
loader = PyPDFLoader(path)
|
|
@@ -1325,29 +1338,29 @@ def langgraph_tab6_main(query: str, file=None) -> str:
|
|
| 1325 |
loader = TextLoader(path)
|
| 1326 |
all_docs.extend(loader.load())
|
| 1327 |
|
| 1328 |
-
|
| 1329 |
-
|
|
|
|
| 1330 |
chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(all_docs)
|
| 1331 |
db = FAISS.from_documents(chunks, embeddings)
|
| 1332 |
retriever = db.as_retriever()
|
| 1333 |
-
# 替換 DocQA agent retriever
|
| 1334 |
-
document_qa_agent.tools[0].retriever = retriever
|
| 1335 |
|
| 1336 |
-
|
| 1337 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1338 |
|
| 1339 |
-
if "answer" in result
|
| 1340 |
-
return f"\nAnswer: {result['answer']}\n\nSummary: {result['summary']}"
|
| 1341 |
-
elif "answer" in result:
|
| 1342 |
return result["answer"]
|
| 1343 |
-
|
| 1344 |
return result["summary"]
|
| 1345 |
-
|
| 1346 |
-
return "No result."
|
| 1347 |
-
|
| 1348 |
except Exception as e:
|
| 1349 |
return f"[Tab6 Error] {e}"
|
| 1350 |
|
|
|
|
| 1351 |
|
| 1352 |
|
| 1353 |
# Gradio Interface Settings
|
|
|
|
| 94 |
from langchain_core.runnables import RunnableLambda
|
| 95 |
from langchain.chains import LLMChain
|
| 96 |
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
| 97 |
+
from sentence_transformers import SentenceTransformer
|
| 98 |
+
|
| 99 |
+
|
| 100 |
|
| 101 |
try:
|
| 102 |
from phoenix.trace.langchain import LangChainInstrumentor
|
|
|
|
| 1241 |
|
| 1242 |
|
| 1243 |
# LangGraph 使用的節點函數(會接續你的 Crew Agent)
|
|
|
|
|
|
|
| 1244 |
|
| 1245 |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 1246 |
|
| 1247 |
+
# Intent Embedding 分類(支援檔名)
|
| 1248 |
INTENT_LABELS = {
|
| 1249 |
"DocQA": ["document", "file", "paper", "cb", "proposal"],
|
| 1250 |
"Summarise": ["summarise", "summary", "abstract", "key points", "要點", "總結", "摘要"],
|
|
|
|
| 1253 |
|
| 1254 |
def detect_intent_embedding(query, file_names=[]):
|
| 1255 |
query_emb = embedding_model.encode(query, normalize_embeddings=True)
|
|
|
|
| 1256 |
|
| 1257 |
+
best_label = None
|
| 1258 |
+
best_score = -1
|
| 1259 |
all_phrases = INTENT_LABELS.copy()
|
| 1260 |
if file_names:
|
| 1261 |
all_phrases["DocQA"] += [name.lower() for name in file_names]
|
|
|
|
| 1265 |
example_emb = embedding_model.encode(example, normalize_embeddings=True)
|
| 1266 |
score = float(query_emb @ example_emb.T)
|
| 1267 |
if score > best_score:
|
| 1268 |
+
best_score = score
|
| 1269 |
+
best_label = label
|
| 1270 |
+
return best_label if best_label else "General"
|
| 1271 |
|
| 1272 |
def decide_next(state):
|
| 1273 |
query = state.get("query", "")
|
| 1274 |
file_names = state.get("file_names", [])
|
| 1275 |
+
label = detect_intent_embedding(query, file_names)
|
| 1276 |
+
return label
|
| 1277 |
+
|
| 1278 |
+
# ✅ node functions
|
| 1279 |
|
| 1280 |
+
def docqa_run(state):
|
| 1281 |
+
result = document_qa_agent.execute_task(docqa_task, inputs={"query": state["query"]})
|
| 1282 |
+
output = result.output.lower()
|
| 1283 |
+
# fallback: no info found → general
|
| 1284 |
+
if any(x in output for x in ["no relevant info", "not found", "no answer"]):
|
| 1285 |
+
return general_run(state)
|
| 1286 |
+
return {"answer": result.output}
|
| 1287 |
|
| 1288 |
+
def general_run(state):
|
| 1289 |
+
result = general_agent.execute_task(general_task, inputs={"query": state["query"]})
|
| 1290 |
output = result.output.lower()
|
| 1291 |
+
if any(x in output for x in ["i don't know", "no idea", "not sure", "can't answer"]):
|
| 1292 |
+
result = search_agent.execute_task(search_task, inputs={"query": state["query"]})
|
| 1293 |
+
return {"answer": result.output}
|
|
|
|
|
|
|
|
|
|
| 1294 |
|
| 1295 |
+
def summariser_run(state):
|
| 1296 |
+
result = summariser_agent.execute_task(summariser_task, inputs={"query": state["query"]})
|
| 1297 |
+
return {"summary": result.output}
|
| 1298 |
|
| 1299 |
+
# ✅ LangGraph 定義
|
| 1300 |
|
| 1301 |
+
def build_langgraph_pipeline():
|
| 1302 |
graph = StateGraph(dict)
|
| 1303 |
|
| 1304 |
+
graph.add_node("Router", lambda state: state) # router 只傳入狀態即可
|
| 1305 |
graph.add_node("DocQA", docqa_run)
|
| 1306 |
+
graph.add_node("General", general_run)
|
| 1307 |
+
graph.add_node("Summarise", summariser_run)
|
| 1308 |
|
| 1309 |
+
graph.set_entry_point("Router")
|
| 1310 |
+
graph.add_conditional_edges("Router", decide_next, {
|
| 1311 |
+
"DocQA": "DocQA",
|
| 1312 |
+
"General": "General",
|
| 1313 |
"Summarise": "Summarise",
|
|
|
|
| 1314 |
})
|
| 1315 |
|
| 1316 |
graph.set_finish_point("DocQA")
|
| 1317 |
+
graph.set_finish_point("General")
|
| 1318 |
graph.set_finish_point("Summarise")
|
| 1319 |
+
|
| 1320 |
return graph.compile()
|
| 1321 |
|
| 1322 |
+
# ✅ 主執行函數
|
| 1323 |
|
| 1324 |
+
def langgraph_tab6_main(query: str, file=None):
|
| 1325 |
try:
|
| 1326 |
files = file if isinstance(file, list) else [file] if file else []
|
| 1327 |
all_docs, file_names = [], []
|
| 1328 |
|
| 1329 |
for f in files:
|
| 1330 |
path = get_file_path(f)
|
| 1331 |
+
if not path: continue
|
|
|
|
| 1332 |
file_names.append(os.path.basename(path))
|
| 1333 |
if path.lower().endswith(".pdf"):
|
| 1334 |
loader = PyPDFLoader(path)
|
|
|
|
| 1338 |
loader = TextLoader(path)
|
| 1339 |
all_docs.extend(loader.load())
|
| 1340 |
|
| 1341 |
+
if not all_docs:
|
| 1342 |
+
retriever = None # 空 retriever
|
| 1343 |
+
else:
|
| 1344 |
chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(all_docs)
|
| 1345 |
db = FAISS.from_documents(chunks, embeddings)
|
| 1346 |
retriever = db.as_retriever()
|
|
|
|
|
|
|
| 1347 |
|
| 1348 |
+
# 設定 retriever 到 global Agent,如果你需要可傳給 Agent
|
| 1349 |
+
# 可選:document_qa_agent.retriever = retriever
|
| 1350 |
+
|
| 1351 |
+
graph = build_langgraph_pipeline()
|
| 1352 |
+
state = {"query": query, "file_names": file_names}
|
| 1353 |
+
result = graph.invoke(state)
|
| 1354 |
|
| 1355 |
+
if "answer" in result:
|
|
|
|
|
|
|
| 1356 |
return result["answer"]
|
| 1357 |
+
if "summary" in result:
|
| 1358 |
return result["summary"]
|
| 1359 |
+
return "No answer."
|
|
|
|
|
|
|
| 1360 |
except Exception as e:
|
| 1361 |
return f"[Tab6 Error] {e}"
|
| 1362 |
|
| 1363 |
+
|
| 1364 |
|
| 1365 |
|
| 1366 |
# Gradio Interface Settings
|