ChienChung commited on
Commit
1c79352
·
verified ·
1 Parent(s): ef1450e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -90
app.py CHANGED
@@ -1244,33 +1244,22 @@ INTENT_LABELS = {
1244
  "General": ["who are you", "tell me something", "what can you do", "fun fact"],
1245
  }
1246
 
1247
- def parse_query(query: str) -> dict:
1248
- prompt = """Analyze the following query and determine required subtasks. Return a JSON object containing:
1249
- - summarize_files: list of document indices to summarize
1250
- - qa_pairs: list of QA objects [{"question": "question", "doc_indices": [relevant doc indices]}]
1251
- - compare_files: list of document index pairs to compare [[doc1_idx, doc2_idx]]
1252
- - find_relations: boolean, whether to analyze document relationships
1253
-
1254
- For example, query "What are the differences between document A and B, and summarize A" should return:
1255
- {
1256
- "summarize_files": [0],
1257
- "qa_pairs": [],
1258
- "compare_files": [[0, 1]],
1259
- "find_relations": false
1260
- }
1261
-
1262
- Query: """ + query
1263
-
1264
- response = llm_gpt4.invoke(prompt)
1265
- try:
1266
- return json.loads(response.content)
1267
- except:
1268
- return {
1269
- "summarize_files": [],
1270
- "qa_pairs": [{"question": query, "doc_indices": [0]}],
1271
- "compare_files": [],
1272
- "find_relations": False
1273
- }
1274
 
1275
  def autogen_multi_document_analysis(query: str, docs: list, file_names: list) -> str:
1276
  try:
@@ -1397,22 +1386,7 @@ def autogen_multi_document_analysis(query: str, docs: list, file_names: list) ->
1397
  print(f"ERROR in AutoGen processing: {str(e)}")
1398
  return f"Error analyzing documents: {str(e)}"
1399
 
1400
- # AutoGen Multi-Agent Collaboration Logic
1401
- def detect_intent_embedding(query, file_names=[]):
1402
- query_emb = embedding_model.encode(query, normalize_embeddings=True)
1403
- best_label = None
1404
- best_score = -1
1405
- all_phrases = INTENT_LABELS.copy()
1406
- if file_names:
1407
- all_phrases["DocQA"] += [name.lower() for name in file_names]
1408
- for label, examples in all_phrases.items():
1409
- for example in examples:
1410
- example_emb = embedding_model.encode(example, normalize_embeddings=True)
1411
- score = float(query_emb @ example_emb.T)
1412
- if score > best_score:
1413
- best_score = score
1414
- best_label = label
1415
- return best_label if best_label else "General"
1416
 
1417
  def decide_next(state):
1418
  query = state.get("query", "")
@@ -1420,34 +1394,6 @@ def decide_next(state):
1420
  label = detect_intent_embedding(query, file_names)
1421
  return label
1422
 
1423
- # === Define Task objects ===
1424
- docqa_task = Task(
1425
- description="Document QA Task: Answer questions based on the uploaded document.",
1426
- expected_output="Answer from Document QA Agent.",
1427
- agent=document_qa_agent,
1428
- input_variables=["query"]
1429
- )
1430
-
1431
- general_task = Task(
1432
- description="General Chat Task: Answer general queries.",
1433
- expected_output="Answer from General Agent.",
1434
- agent=general_agent,
1435
- input_variables=["query"]
1436
- )
1437
-
1438
- summariser_task = Task(
1439
- description="Summarisation Task: Summarise document content.",
1440
- expected_output="Summary output.",
1441
- agent=summarizer_agent, # Note: The name must match the defined agent (using 'z' if applicable)
1442
- input_variables=["query"]
1443
- )
1444
-
1445
- search_task = Task(
1446
- description="Search Task: Retrieve information from the web.",
1447
- expected_output="Answer from Search Agent.",
1448
- agent=search_agent,
1449
- input_variables=["query"]
1450
- )
1451
 
1452
  # === LangGraph Node Functions ===
1453
 
@@ -1508,25 +1454,6 @@ def summariser_run(state):
1508
  print(f"ERROR in summariser_run: {str(e)}")
1509
  return {"summary": "Error generating summary."}
1510
 
1511
- # === LangGraph Definition ===
1512
- def build_langgraph_pipeline():
1513
- graph = StateGraph(dict)
1514
- graph.add_node("Router", lambda state: state) # Router: simply pass the state
1515
- graph.add_node("DocQA", docqa_run)
1516
- graph.add_node("General", general_run)
1517
- graph.add_node("Summarise", summariser_run)
1518
- graph.set_entry_point("Router")
1519
- graph.add_conditional_edges("Router", decide_next, {
1520
- "DocQA": "DocQA",
1521
- "General": "General",
1522
- "Summarise": "Summarise",
1523
- })
1524
- graph.set_finish_point("DocQA")
1525
- graph.set_finish_point("General")
1526
- graph.set_finish_point("Summarise")
1527
- return graph.compile()
1528
-
1529
-
1530
 
1531
  def get_file_path_tab6(file):
1532
  if isinstance(file, str):
 
1244
  "General": ["who are you", "tell me something", "what can you do", "fun fact"],
1245
  }
1246
 
1247
+ # AutoGen Multi-Agent Collaboration Logic
1248
+ def detect_intent_embedding(query, file_names=[]):
1249
+ query_emb = embedding_model.encode(query, normalize_embeddings=True)
1250
+ best_label = None
1251
+ best_score = -1
1252
+ all_phrases = INTENT_LABELS.copy()
1253
+ if file_names:
1254
+ all_phrases["DocQA"] += [name.lower() for name in file_names]
1255
+ for label, examples in all_phrases.items():
1256
+ for example in examples:
1257
+ example_emb = embedding_model.encode(example, normalize_embeddings=True)
1258
+ score = float(query_emb @ example_emb.T)
1259
+ if score > best_score:
1260
+ best_score = score
1261
+ best_label = label
1262
+ return best_label if best_label else "General"
 
 
 
 
 
 
 
 
 
 
 
1263
 
1264
  def autogen_multi_document_analysis(query: str, docs: list, file_names: list) -> str:
1265
  try:
 
1386
  print(f"ERROR in AutoGen processing: {str(e)}")
1387
  return f"Error analyzing documents: {str(e)}"
1388
 
1389
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1390
 
1391
  def decide_next(state):
1392
  query = state.get("query", "")
 
1394
  label = detect_intent_embedding(query, file_names)
1395
  return label
1396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1397
 
1398
  # === LangGraph Node Functions ===
1399
 
 
1454
  print(f"ERROR in summariser_run: {str(e)}")
1455
  return {"summary": "Error generating summary."}
1456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1457
 
1458
  def get_file_path_tab6(file):
1459
  if isinstance(file, str):