Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1232,270 +1232,417 @@ def multi_agent_chat_advanced(query: str, file=None) -> str:
|
|
| 1232 |
except Exception as e:
|
| 1233 |
return f"Multi-Agent Error: {e}"
|
| 1234 |
|
| 1235 |
-
#
|
| 1236 |
-
|
| 1237 |
-
# Initialize the
|
| 1238 |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 1239 |
|
| 1240 |
-
#
|
| 1241 |
INTENT_LABELS = {
|
| 1242 |
-
"DocQA":
|
| 1243 |
-
"Summarise":["summarise", "summary", "abstract", "key points", "overview", "main points"],
|
| 1244 |
-
"General":
|
| 1245 |
}
|
| 1246 |
|
|
|
|
| 1247 |
def detect_intent_embedding(query, file_names=[]):
|
| 1248 |
-
"""
|
| 1249 |
-
Compute embedding of the user query, compare against each intent's example embeddings,
|
| 1250 |
-
and return the label with highest cosine similarity.
|
| 1251 |
-
"""
|
| 1252 |
query_emb = embedding_model.encode(query, normalize_embeddings=True)
|
| 1253 |
-
best_label
|
| 1254 |
-
|
| 1255 |
-
|
| 1256 |
if file_names:
|
| 1257 |
-
|
| 1258 |
-
|
| 1259 |
-
|
| 1260 |
-
|
| 1261 |
-
|
| 1262 |
-
score = float(query_emb @ ex_emb.T)
|
| 1263 |
if score > best_score:
|
| 1264 |
-
|
| 1265 |
-
|
| 1266 |
-
|
| 1267 |
-
def decide_next(state):
|
| 1268 |
-
"""
|
| 1269 |
-
LangGraph router node: choose next node based on detected intent label.
|
| 1270 |
-
"""
|
| 1271 |
-
label = detect_intent_embedding(state["query"], state.get("file_names", []))
|
| 1272 |
-
return label
|
| 1273 |
|
| 1274 |
def autogen_multi_document_analysis(query: str, docs: list, file_names: list) -> str:
|
| 1275 |
-
|
| 1276 |
-
|
| 1277 |
-
|
| 1278 |
-
|
| 1279 |
-
|
| 1280 |
-
|
| 1281 |
-
|
| 1282 |
-
|
| 1283 |
-
|
| 1284 |
-
|
| 1285 |
-
|
| 1286 |
-
|
| 1287 |
-
|
| 1288 |
-
|
| 1289 |
-
|
| 1290 |
-
|
| 1291 |
-
|
| 1292 |
-
|
| 1293 |
-
|
| 1294 |
-
llm_config = {
|
| 1295 |
-
"config_list": [{"model":"gpt-4o-mini", "api_key": openai_api_key}],
|
| 1296 |
-
"temperature": 0
|
| 1297 |
-
}
|
| 1298 |
-
|
| 1299 |
-
# instantiate agents
|
| 1300 |
-
user_proxy = UserProxyAgent( name="User",
|
| 1301 |
-
system_message="User seeking cross-document analysis.",
|
| 1302 |
-
human_input_mode="NEVER",
|
| 1303 |
-
code_execution_config={"use_docker":False},
|
| 1304 |
-
llm_config=llm_config
|
| 1305 |
-
)
|
| 1306 |
-
doc_analyzer = AssistantAgent( name="DocumentAnalyzer",
|
| 1307 |
-
system_message="Expert on comparing document content and structure.",
|
| 1308 |
-
llm_config=llm_config
|
| 1309 |
-
)
|
| 1310 |
-
qa_expert = AssistantAgent( name="QAExpert",
|
| 1311 |
-
system_message="Expert at extracting precise answers from text.",
|
| 1312 |
-
llm_config=llm_config
|
| 1313 |
-
)
|
| 1314 |
-
summarizer = AssistantAgent( name="Summarizer",
|
| 1315 |
-
system_message="Expert at generating concise summaries.",
|
| 1316 |
-
llm_config=llm_config
|
| 1317 |
-
)
|
| 1318 |
|
| 1319 |
-
|
| 1320 |
-
|
| 1321 |
-
|
| 1322 |
-
|
| 1323 |
-
|
| 1324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1325 |
|
| 1326 |
-
|
| 1327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1328 |
|
| 1329 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1330 |
|
| 1331 |
-
|
| 1332 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1333 |
|
| 1334 |
-
|
| 1335 |
-
|
| 1336 |
-
|
| 1337 |
-
|
| 1338 |
-
|
| 1339 |
-
|
| 1340 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1341 |
|
| 1342 |
def general_run(state):
|
| 1343 |
-
"""
|
| 1344 |
-
|
| 1345 |
-
|
| 1346 |
-
|
| 1347 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1348 |
|
| 1349 |
def docqa_run(state):
|
| 1350 |
-
"""
|
| 1351 |
-
|
| 1352 |
-
|
| 1353 |
-
|
| 1354 |
-
|
| 1355 |
-
|
| 1356 |
-
|
| 1357 |
-
|
| 1358 |
-
|
| 1359 |
-
|
| 1360 |
-
|
| 1361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1362 |
|
| 1363 |
def summariser_run(state):
|
| 1364 |
-
"""
|
| 1365 |
-
|
| 1366 |
-
|
| 1367 |
-
|
| 1368 |
-
|
| 1369 |
-
|
| 1370 |
-
|
| 1371 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1372 |
def build_langgraph_pipeline():
|
| 1373 |
-
"""
|
| 1374 |
-
Assemble the LangGraph state graph: Router -> {DocQA, Summarise, General}.
|
| 1375 |
-
"""
|
| 1376 |
graph = StateGraph(dict)
|
| 1377 |
-
graph.add_node("Router", lambda state: state)
|
| 1378 |
graph.add_node("DocQA", docqa_run)
|
| 1379 |
-
graph.add_node("Summarise", summariser_run)
|
| 1380 |
graph.add_node("General", general_run)
|
|
|
|
| 1381 |
graph.set_entry_point("Router")
|
| 1382 |
graph.add_conditional_edges("Router", decide_next, {
|
| 1383 |
"DocQA": "DocQA",
|
|
|
|
| 1384 |
"Summarise": "Summarise",
|
| 1385 |
-
"General": "General"
|
| 1386 |
})
|
| 1387 |
graph.set_finish_point("DocQA")
|
| 1388 |
-
graph.set_finish_point("Summarise")
|
| 1389 |
graph.set_finish_point("General")
|
|
|
|
| 1390 |
return graph.compile()
|
| 1391 |
|
| 1392 |
def get_file_path_tab6(file):
|
| 1393 |
if isinstance(file, str):
|
|
|
|
| 1394 |
if os.path.exists(file):
|
|
|
|
| 1395 |
return file
|
| 1396 |
else:
|
|
|
|
| 1397 |
return None
|
| 1398 |
elif isinstance(file, dict):
|
|
|
|
| 1399 |
data = file.get("data")
|
| 1400 |
name = file.get("name")
|
|
|
|
| 1401 |
if data:
|
| 1402 |
if isinstance(data, str) and os.path.exists(data):
|
|
|
|
| 1403 |
return data
|
| 1404 |
else:
|
| 1405 |
temp_dir = mkdtemp()
|
| 1406 |
file_path = os.path.join(temp_dir, name if name else "uploaded_file")
|
|
|
|
| 1407 |
with open(file_path, "wb") as f:
|
| 1408 |
if isinstance(data, str):
|
| 1409 |
f.write(data.encode("utf-8"))
|
| 1410 |
else:
|
| 1411 |
f.write(data)
|
| 1412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1413 |
else:
|
|
|
|
| 1414 |
return None
|
| 1415 |
elif hasattr(file, "save"):
|
|
|
|
| 1416 |
temp_dir = mkdtemp()
|
| 1417 |
file_path = os.path.join(temp_dir, file.name)
|
| 1418 |
file.save(file_path)
|
| 1419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1420 |
else:
|
| 1421 |
-
|
| 1422 |
-
|
|
|
|
|
|
|
| 1423 |
return None
|
| 1424 |
|
| 1425 |
-
@traceable(name="multi_doc")
|
| 1426 |
def langgraph_tab6_main(query: str, file=None):
|
| 1427 |
-
"""
|
| 1428 |
-
Main entrypoint for Tab 6.
|
| 1429 |
-
1. If no file: call general_run.
|
| 1430 |
-
2. Load one or more docs, chunk them.
|
| 1431 |
-
3. Initialize Pinecone index 'Rag_Docs' with dimension=768, metric=cosine.
|
| 1432 |
-
4. Upsert chunks into Pinecone under namespace 'Rag_Docs'.
|
| 1433 |
-
5. Build retriever and ConversationalRetrievalChain.
|
| 1434 |
-
6. If multi‐doc or comparison query → autogen_multi_document_analysis.
|
| 1435 |
-
7. Else route through LangGraph pipeline.
|
| 1436 |
-
"""
|
| 1437 |
try:
|
|
|
|
|
|
|
|
|
|
| 1438 |
if not file:
|
| 1439 |
return general_run({"query": query})["answer"]
|
| 1440 |
-
|
| 1441 |
-
#
|
| 1442 |
files = file if isinstance(file, list) else [file]
|
| 1443 |
-
all_docs
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1444 |
for f in files:
|
| 1445 |
-
|
| 1446 |
-
|
| 1447 |
-
|
| 1448 |
-
|
| 1449 |
-
|
| 1450 |
-
|
| 1451 |
-
|
| 1452 |
-
|
| 1453 |
-
|
| 1454 |
-
|
| 1455 |
-
|
| 1456 |
-
|
| 1457 |
-
|
| 1458 |
-
|
| 1459 |
-
|
| 1460 |
-
|
| 1461 |
-
|
| 1462 |
-
|
| 1463 |
-
|
| 1464 |
-
|
| 1465 |
-
|
| 1466 |
-
|
| 1467 |
-
|
| 1468 |
-
chunks = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50).split_documents(all_docs)
|
| 1469 |
-
vectorstore = Pinecone.from_documents(
|
| 1470 |
-
documents=chunks,
|
| 1471 |
-
embedding=embeddings,
|
| 1472 |
-
index_name=index_name,
|
| 1473 |
-
namespace="Rag_Docs"
|
| 1474 |
-
)
|
| 1475 |
-
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
|
| 1476 |
-
|
| 1477 |
-
# set up conversational chain
|
| 1478 |
-
global session_retriever, session_qa_chain
|
| 1479 |
-
session_retriever = retriever
|
| 1480 |
-
session_qa_chain = ConversationalRetrievalChain.from_llm(
|
| 1481 |
-
llm=llm_gpt4,
|
| 1482 |
-
retriever=retriever,
|
| 1483 |
-
memory=ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
| 1484 |
-
)
|
| 1485 |
|
| 1486 |
-
|
| 1487 |
-
|
| 1488 |
-
return autogen_multi_document_analysis(query, docs_text, file_names)
|
| 1489 |
|
| 1490 |
-
#
|
| 1491 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1492 |
pipeline = build_langgraph_pipeline()
|
| 1493 |
-
|
| 1494 |
-
|
| 1495 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1496 |
except Exception as e:
|
| 1497 |
-
print(f"ERROR in
|
| 1498 |
-
return f"
|
| 1499 |
|
| 1500 |
# Gradio Interface Settings
|
| 1501 |
demo_description = """
|
|
|
|
| 1232 |
except Exception as e:
|
| 1233 |
return f"Multi-Agent Error: {e}"
|
| 1234 |
|
| 1235 |
+
# Tab 6
|
| 1236 |
+
# LangGraph node functions
|
| 1237 |
+
# Initialize the embedding model
|
| 1238 |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 1239 |
|
| 1240 |
+
# Intent embedding classification (supports file names)
|
| 1241 |
INTENT_LABELS = {
|
| 1242 |
+
"DocQA": ["document", "file", "paper", "cb", "proposal", "project"],
|
| 1243 |
+
"Summarise": ["summarise", "summary", "abstract", "key points", "overview", "main points"],
|
| 1244 |
+
"General": ["who are you", "tell me something", "what can you do", "fun fact"],
|
| 1245 |
}
|
| 1246 |
|
| 1247 |
+
# AutoGen Multi-Agent Collaboration Logic
|
| 1248 |
def detect_intent_embedding(query, file_names=[]):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1249 |
query_emb = embedding_model.encode(query, normalize_embeddings=True)
|
| 1250 |
+
best_label = None
|
| 1251 |
+
best_score = -1
|
| 1252 |
+
all_phrases = INTENT_LABELS.copy()
|
| 1253 |
if file_names:
|
| 1254 |
+
all_phrases["DocQA"] += [name.lower() for name in file_names]
|
| 1255 |
+
for label, examples in all_phrases.items():
|
| 1256 |
+
for example in examples:
|
| 1257 |
+
example_emb = embedding_model.encode(example, normalize_embeddings=True)
|
| 1258 |
+
score = float(query_emb @ example_emb.T)
|
|
|
|
| 1259 |
if score > best_score:
|
| 1260 |
+
best_score = score
|
| 1261 |
+
best_label = label
|
| 1262 |
+
return best_label if best_label else "General"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1263 |
|
| 1264 |
def autogen_multi_document_analysis(query: str, docs: list, file_names: list) -> str:
|
| 1265 |
+
try:
|
| 1266 |
+
# Create a temporary working directory
|
| 1267 |
+
temp_dir = tempfile.mkdtemp(dir="/tmp")
|
| 1268 |
+
os.environ["OPENAI_CACHE_DIR"] = temp_dir
|
| 1269 |
+
|
| 1270 |
+
# Set AutoGen's working directory
|
| 1271 |
+
os.environ["AUTOGEN_CACHE_PATH"] = temp_dir
|
| 1272 |
+
os.environ["AUTOGEN_CACHEDIR"] = temp_dir
|
| 1273 |
+
os.environ["OPENAI_CACHE_PATH"] = temp_dir
|
| 1274 |
+
|
| 1275 |
+
# Force AutoGen to use our temporary directory instead of ./.cache
|
| 1276 |
+
if hasattr(autogen, "set_cache_dir"):
|
| 1277 |
+
autogen.set_cache_dir(temp_dir)
|
| 1278 |
+
|
| 1279 |
+
# Prepare document context
|
| 1280 |
+
context = "\n\n".join(
|
| 1281 |
+
f"Document {name}:\n{doc[:2000]}..."
|
| 1282 |
+
for name, doc in zip(file_names, docs)
|
| 1283 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1284 |
|
| 1285 |
+
# Configure LLM
|
| 1286 |
+
config_list = [{
|
| 1287 |
+
"model": "gpt-4o-mini",
|
| 1288 |
+
"api_key": openai_api_key
|
| 1289 |
+
}]
|
| 1290 |
+
|
| 1291 |
+
# Base configuration (without any cache-related parameters)
|
| 1292 |
+
llm_config = {
|
| 1293 |
+
"config_list": config_list,
|
| 1294 |
+
"temperature": 0
|
| 1295 |
+
}
|
| 1296 |
+
|
| 1297 |
+
# Switch to temporary directory before AutoGen processing
|
| 1298 |
+
original_dir = os.getcwd()
|
| 1299 |
+
os.chdir(temp_dir)
|
| 1300 |
+
|
| 1301 |
+
try:
|
| 1302 |
+
# AutoGen processing code
|
| 1303 |
+
user_proxy = UserProxyAgent(
|
| 1304 |
+
name="User",
|
| 1305 |
+
system_message="A user seeking information from multiple documents.",
|
| 1306 |
+
human_input_mode="NEVER",
|
| 1307 |
+
code_execution_config={"use_docker": False},
|
| 1308 |
+
llm_config=llm_config
|
| 1309 |
+
)
|
| 1310 |
|
| 1311 |
+
# Define document analysis expert
|
| 1312 |
+
doc_analyzer = AssistantAgent(
|
| 1313 |
+
name="DocumentAnalyzer",
|
| 1314 |
+
system_message="""You are an expert at analyzing and comparing documents. Focus on:
|
| 1315 |
+
1. Key similarities and differences
|
| 1316 |
+
2. Main themes and topics
|
| 1317 |
+
3. Relationships between documents
|
| 1318 |
+
4. Evidence-based analysis""",
|
| 1319 |
+
llm_config=llm_config
|
| 1320 |
+
)
|
| 1321 |
|
| 1322 |
+
# Define Q&A expert
|
| 1323 |
+
qa_expert = AssistantAgent(
|
| 1324 |
+
name="QAExpert",
|
| 1325 |
+
system_message="""You are an expert at extracting specific information. Focus on:
|
| 1326 |
+
1. Finding relevant details
|
| 1327 |
+
2. Answering specific questions
|
| 1328 |
+
3. Cross-referencing information
|
| 1329 |
+
4. Providing evidence""",
|
| 1330 |
+
llm_config=llm_config
|
| 1331 |
+
)
|
| 1332 |
|
| 1333 |
+
# Define summarisation expert
|
| 1334 |
+
summarizer = AssistantAgent(
|
| 1335 |
+
name="Summarizer",
|
| 1336 |
+
system_message="""You are an expert at summarizing content. Focus on:
|
| 1337 |
+
1. Key points and findings
|
| 1338 |
+
2. Important relationships
|
| 1339 |
+
3. Critical conclusions
|
| 1340 |
+
4. Comprehensive overview""",
|
| 1341 |
+
llm_config=llm_config
|
| 1342 |
+
)
|
| 1343 |
|
| 1344 |
+
# Create group chat
|
| 1345 |
+
groupchat = GroupChat(
|
| 1346 |
+
agents=[user_proxy, doc_analyzer, qa_expert, summarizer],
|
| 1347 |
+
messages=[],
|
| 1348 |
+
max_round=5
|
| 1349 |
+
)
|
| 1350 |
+
|
| 1351 |
+
# Create manager
|
| 1352 |
+
manager = GroupChatManager(
|
| 1353 |
+
groupchat=groupchat,
|
| 1354 |
+
llm_config=llm_config
|
| 1355 |
+
)
|
| 1356 |
+
|
| 1357 |
+
# Prepare task prompt
|
| 1358 |
+
task_prompt = f"""Analyze these documents and answer the query:
|
| 1359 |
+
|
| 1360 |
+
Query: {query}
|
| 1361 |
+
|
| 1362 |
+
Documents Context:
|
| 1363 |
+
{context}
|
| 1364 |
+
|
| 1365 |
+
Requirements:
|
| 1366 |
+
1. Provide a direct and clear answer
|
| 1367 |
+
2. Support all claims with evidence from the documents
|
| 1368 |
+
3. Consider relationships between all documents
|
| 1369 |
+
4. If comparing, analyze all relevant aspects
|
| 1370 |
+
5. If summarizing, cover all important points
|
| 1371 |
+
6. If looking for specific content, search thoroughly
|
| 1372 |
+
7. If analyzing relationships, consider all connections
|
| 1373 |
+
|
| 1374 |
+
Please provide a comprehensive and well-structured answer."""
|
| 1375 |
+
|
| 1376 |
+
# Execute the group discussion
|
| 1377 |
+
user_proxy.initiate_chat(manager, message=task_prompt)
|
| 1378 |
+
return user_proxy.last_message()["content"]
|
| 1379 |
+
finally:
|
| 1380 |
+
# After processing, change back to the original directory
|
| 1381 |
+
os.chdir(original_dir)
|
| 1382 |
+
|
| 1383 |
+
return result
|
| 1384 |
+
|
| 1385 |
+
except Exception as e:
|
| 1386 |
+
print(f"ERROR in AutoGen processing: {str(e)}")
|
| 1387 |
+
return f"Error analyzing documents: {str(e)}"
|
| 1388 |
+
|
| 1389 |
+
|
| 1390 |
+
|
| 1391 |
+
def decide_next(state):
|
| 1392 |
+
query = state.get("query", "")
|
| 1393 |
+
file_names = state.get("file_names", [])
|
| 1394 |
+
label = detect_intent_embedding(query, file_names)
|
| 1395 |
+
return label
|
| 1396 |
+
|
| 1397 |
+
|
| 1398 |
+
# === LangGraph Node Functions ===
|
| 1399 |
|
| 1400 |
def general_run(state):
|
| 1401 |
+
"""Use direct LLM response instead of General Agent."""
|
| 1402 |
+
try:
|
| 1403 |
+
prompt = f"""You are a helpful AI assistant. Please answer the following question:
|
| 1404 |
+
{state["query"]}
|
| 1405 |
+
|
| 1406 |
+
Provide a clear and informative answer."""
|
| 1407 |
+
|
| 1408 |
+
response = llm_gpt4.invoke(prompt)
|
| 1409 |
+
answer = response.content if hasattr(response, 'content') else str(response)
|
| 1410 |
+
return {"answer": answer}
|
| 1411 |
+
except Exception as e:
|
| 1412 |
+
print(f"ERROR in general_run: {str(e)}")
|
| 1413 |
+
return {"answer": "I apologize, but I'm having trouble processing your request."}
|
| 1414 |
|
| 1415 |
def docqa_run(state):
|
| 1416 |
+
"""Document Q&A processing."""
|
| 1417 |
+
try:
|
| 1418 |
+
# If a retriever exists, use it to get relevant documents; otherwise, use provided docs
|
| 1419 |
+
if "retriever" in state:
|
| 1420 |
+
relevant_docs = state["retriever"].get_relevant_documents(state["query"])
|
| 1421 |
+
context = "\n".join(d.page_content for d in relevant_docs)
|
| 1422 |
+
else:
|
| 1423 |
+
context = "\n".join(state["docs"])
|
| 1424 |
+
|
| 1425 |
+
prompt = f"""Based on the following context, please answer the question:
|
| 1426 |
+
Question: {state["query"]}
|
| 1427 |
+
|
| 1428 |
+
Context:
|
| 1429 |
+
{context[:3000]}
|
| 1430 |
+
|
| 1431 |
+
Provide a detailed and accurate answer based on the context."""
|
| 1432 |
+
|
| 1433 |
+
response = llm_gpt4.invoke(prompt)
|
| 1434 |
+
return {"answer": response.content if hasattr(response, 'content') else str(response)}
|
| 1435 |
+
except Exception as e:
|
| 1436 |
+
print(f"ERROR in docqa_run: {str(e)}")
|
| 1437 |
+
return general_run(state)
|
| 1438 |
|
| 1439 |
def summariser_run(state):
|
| 1440 |
+
"""Document summarisation processing."""
|
| 1441 |
+
try:
|
| 1442 |
+
context = "\n".join(state["docs"])
|
| 1443 |
+
prompt = f"""Please provide a comprehensive summary of the following document:
|
| 1444 |
+
{context[:3000]}
|
| 1445 |
+
|
| 1446 |
+
Focus on:
|
| 1447 |
+
1. Main topics and key points
|
| 1448 |
+
2. Important findings or conclusions
|
| 1449 |
+
3. Significant details"""
|
| 1450 |
+
|
| 1451 |
+
response = llm_gpt4.invoke(prompt)
|
| 1452 |
+
return {"summary": response.content if hasattr(response, 'content') else str(response)}
|
| 1453 |
+
except Exception as e:
|
| 1454 |
+
print(f"ERROR in summariser_run: {str(e)}")
|
| 1455 |
+
return {"summary": "Error generating summary."}
|
| 1456 |
+
|
| 1457 |
def build_langgraph_pipeline():
|
|
|
|
|
|
|
|
|
|
| 1458 |
graph = StateGraph(dict)
|
| 1459 |
+
graph.add_node("Router", lambda state: state) # Router 僅傳遞狀態
|
| 1460 |
graph.add_node("DocQA", docqa_run)
|
|
|
|
| 1461 |
graph.add_node("General", general_run)
|
| 1462 |
+
graph.add_node("Summarise", summariser_run)
|
| 1463 |
graph.set_entry_point("Router")
|
| 1464 |
graph.add_conditional_edges("Router", decide_next, {
|
| 1465 |
"DocQA": "DocQA",
|
| 1466 |
+
"General": "General",
|
| 1467 |
"Summarise": "Summarise",
|
|
|
|
| 1468 |
})
|
| 1469 |
graph.set_finish_point("DocQA")
|
|
|
|
| 1470 |
graph.set_finish_point("General")
|
| 1471 |
+
graph.set_finish_point("Summarise")
|
| 1472 |
return graph.compile()
|
| 1473 |
|
| 1474 |
def get_file_path_tab6(file):
|
| 1475 |
if isinstance(file, str):
|
| 1476 |
+
print("DEBUG: File is a string:", file)
|
| 1477 |
if os.path.exists(file):
|
| 1478 |
+
print("DEBUG: File exists:", file)
|
| 1479 |
return file
|
| 1480 |
else:
|
| 1481 |
+
print("DEBUG: File does not exist:", file)
|
| 1482 |
return None
|
| 1483 |
elif isinstance(file, dict):
|
| 1484 |
+
print("DEBUG: File is a dict:", file)
|
| 1485 |
data = file.get("data")
|
| 1486 |
name = file.get("name")
|
| 1487 |
+
print("DEBUG: Data:", data, "Name:", name)
|
| 1488 |
if data:
|
| 1489 |
if isinstance(data, str) and os.path.exists(data):
|
| 1490 |
+
print("DEBUG: Data is a valid file path:", data)
|
| 1491 |
return data
|
| 1492 |
else:
|
| 1493 |
temp_dir = mkdtemp()
|
| 1494 |
file_path = os.path.join(temp_dir, name if name else "uploaded_file")
|
| 1495 |
+
print("DEBUG: Writing data to temporary file:", file_path)
|
| 1496 |
with open(file_path, "wb") as f:
|
| 1497 |
if isinstance(data, str):
|
| 1498 |
f.write(data.encode("utf-8"))
|
| 1499 |
else:
|
| 1500 |
f.write(data)
|
| 1501 |
+
if os.path.exists(file_path):
|
| 1502 |
+
print("DEBUG: Temporary file created:", file_path)
|
| 1503 |
+
return file_path
|
| 1504 |
+
else:
|
| 1505 |
+
print("ERROR: Temporary file not created:", file_path)
|
| 1506 |
+
return None
|
| 1507 |
else:
|
| 1508 |
+
print("DEBUG: No data in dict, returning None")
|
| 1509 |
return None
|
| 1510 |
elif hasattr(file, "save"):
|
| 1511 |
+
print("DEBUG: File has save attribute")
|
| 1512 |
temp_dir = mkdtemp()
|
| 1513 |
file_path = os.path.join(temp_dir, file.name)
|
| 1514 |
file.save(file_path)
|
| 1515 |
+
if os.path.exists(file_path):
|
| 1516 |
+
print("DEBUG: File saved to:", file_path)
|
| 1517 |
+
return file_path
|
| 1518 |
+
else:
|
| 1519 |
+
print("ERROR: File not saved properly:", file_path)
|
| 1520 |
+
return None
|
| 1521 |
else:
|
| 1522 |
+
print("DEBUG: File type unrecognized")
|
| 1523 |
+
if hasattr(file, "name"):
|
| 1524 |
+
if os.path.exists(file.name):
|
| 1525 |
+
return file.name
|
| 1526 |
return None
|
| 1527 |
|
| 1528 |
+
@traceable(name="multi_doc")
|
| 1529 |
def langgraph_tab6_main(query: str, file=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1530 |
try:
|
| 1531 |
+
print(f"DEBUG: Starting processing with query: {query}")
|
| 1532 |
+
|
| 1533 |
+
# If no file is uploaded, directly use general_run
|
| 1534 |
if not file:
|
| 1535 |
return general_run({"query": query})["answer"]
|
| 1536 |
+
|
| 1537 |
+
# Process list of files
|
| 1538 |
files = file if isinstance(file, list) else [file]
|
| 1539 |
+
all_docs = []
|
| 1540 |
+
file_names = []
|
| 1541 |
+
docs_by_file = []
|
| 1542 |
+
|
| 1543 |
+
# Process each uploaded file
|
| 1544 |
for f in files:
|
| 1545 |
+
try:
|
| 1546 |
+
path = get_file_path_tab6(f)
|
| 1547 |
+
if not path:
|
| 1548 |
+
continue
|
| 1549 |
+
|
| 1550 |
+
file_names.append(os.path.basename(path))
|
| 1551 |
+
|
| 1552 |
+
# Choose loader based on file type
|
| 1553 |
+
if path.lower().endswith('.pdf'):
|
| 1554 |
+
loader = PyPDFLoader(path)
|
| 1555 |
+
elif path.lower().endswith('.docx'):
|
| 1556 |
+
loader = UnstructuredWordDocumentLoader(path)
|
| 1557 |
+
else:
|
| 1558 |
+
loader = TextLoader(path)
|
| 1559 |
+
|
| 1560 |
+
docs = loader.load()
|
| 1561 |
+
if docs:
|
| 1562 |
+
text = "\n".join(doc.page_content for doc in docs if hasattr(doc, 'page_content'))
|
| 1563 |
+
docs_by_file.append(text)
|
| 1564 |
+
all_docs.extend(docs)
|
| 1565 |
+
except Exception as e:
|
| 1566 |
+
print(f"ERROR processing file: {str(e)}")
|
| 1567 |
+
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1568 |
|
| 1569 |
+
if not docs_by_file:
|
| 1570 |
+
return general_run({"query": query})["answer"]
|
|
|
|
| 1571 |
|
| 1572 |
+
# Build the retriever using Pinecone
|
| 1573 |
+
try:
|
| 1574 |
+
import pinecone
|
| 1575 |
+
# Initialize Pinecone
|
| 1576 |
+
pinecone.init(
|
| 1577 |
+
api_key=os.getenv("PINECONE_API_KEY"),
|
| 1578 |
+
environment=os.getenv("PINECONE_ENVIRONMENT")
|
| 1579 |
+
)
|
| 1580 |
+
|
| 1581 |
+
index_name = "Rag_Docs"
|
| 1582 |
+
if index_name not in pinecone.list_indexes():
|
| 1583 |
+
pinecone.create_index(
|
| 1584 |
+
name=index_name,
|
| 1585 |
+
dimension=768,
|
| 1586 |
+
metric="cosine"
|
| 1587 |
+
)
|
| 1588 |
+
|
| 1589 |
+
chunks = RecursiveCharacterTextSplitter(
|
| 1590 |
+
chunk_size=500,
|
| 1591 |
+
chunk_overlap=50
|
| 1592 |
+
).split_documents(all_docs)
|
| 1593 |
+
|
| 1594 |
+
vectorstore = Pinecone.from_documents(
|
| 1595 |
+
documents=chunks,
|
| 1596 |
+
embedding=embeddings,
|
| 1597 |
+
index_name=index_name,
|
| 1598 |
+
namespace="Rag_Docs"
|
| 1599 |
+
)
|
| 1600 |
+
|
| 1601 |
+
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
|
| 1602 |
+
|
| 1603 |
+
global session_retriever, session_qa_chain
|
| 1604 |
+
session_retriever = retriever
|
| 1605 |
+
session_qa_chain = ConversationalRetrievalChain.from_llm(
|
| 1606 |
+
llm=llm_gpt4,
|
| 1607 |
+
retriever=retriever,
|
| 1608 |
+
memory=ConversationBufferMemory(
|
| 1609 |
+
memory_key="chat_history",
|
| 1610 |
+
return_messages=True
|
| 1611 |
+
),
|
| 1612 |
+
)
|
| 1613 |
+
except Exception as e:
|
| 1614 |
+
print(f"ERROR setting up Pinecone retriever: {str(e)}")
|
| 1615 |
+
retriever = None
|
| 1616 |
+
|
| 1617 |
+
# If the query is a multi-document query or a complex query, use AutoGen collaboration
|
| 1618 |
+
if len(docs_by_file) > 1 or "compare" in query.lower() or "relation" in query.lower():
|
| 1619 |
+
return autogen_multi_document_analysis(query, docs_by_file, file_names)
|
| 1620 |
+
|
| 1621 |
+
# 使用 LangGraph 處理單文檔查詢 — 修改這部分來使用 LangGraph
|
| 1622 |
+
state = {
|
| 1623 |
+
"query": query,
|
| 1624 |
+
"file_names": file_names,
|
| 1625 |
+
"docs": docs_by_file,
|
| 1626 |
+
"retriever": retriever
|
| 1627 |
+
}
|
| 1628 |
+
|
| 1629 |
+
# 獲取 LangGraph 編譯後的管道
|
| 1630 |
pipeline = build_langgraph_pipeline()
|
| 1631 |
+
|
| 1632 |
+
# 調用 LangGraph 處理狀態
|
| 1633 |
+
result = pipeline.invoke(state)
|
| 1634 |
+
|
| 1635 |
+
# 從結果中提取答案或摘要
|
| 1636 |
+
if "answer" in result:
|
| 1637 |
+
return result["answer"]
|
| 1638 |
+
elif "summary" in result:
|
| 1639 |
+
return result["summary"]
|
| 1640 |
+
else:
|
| 1641 |
+
return "Processing completed but no specific answer or summary was generated."
|
| 1642 |
+
|
| 1643 |
except Exception as e:
|
| 1644 |
+
print(f"ERROR in main function: {str(e)}")
|
| 1645 |
+
return f"I apologize, but I encountered an error: {str(e)}"
|
| 1646 |
|
| 1647 |
# Gradio Interface Settings
|
| 1648 |
demo_description = """
|