Update app.py
Browse files
app.py
CHANGED
|
@@ -40,7 +40,7 @@ class Config(object):
|
|
| 40 |
def __init__(self):
|
| 41 |
self.random_state = 42
|
| 42 |
self.max_len = 256
|
| 43 |
-
self.reasoning_max_len =
|
| 44 |
self.temperature = 0.1
|
| 45 |
self.repetition_penalty = 1.2
|
| 46 |
self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -173,6 +173,7 @@ class AgentState(TypedDict):
|
|
| 173 |
confidence: float
|
| 174 |
judge_explanation: str
|
| 175 |
|
|
|
|
| 176 |
ALL_TOOLS = {
|
| 177 |
"web_search": ["query"],
|
| 178 |
"visit_webpage": ["url"],
|
|
@@ -363,7 +364,7 @@ def visit_webpage(url: str) -> str:
|
|
| 363 |
if content is not None:
|
| 364 |
for table in content.find_all("table", {"class": "wikitable"}):
|
| 365 |
for row in table.find_all("tr"):
|
| 366 |
-
cols = [c.get_text(strip=
|
| 367 |
if cols:
|
| 368 |
table_texts.append(" | ".join(cols))
|
| 369 |
|
|
@@ -512,12 +513,12 @@ Response: <answer>
|
|
| 512 |
|
| 513 |
DO NOT add anything additional and return ONLY what is asked and in the format asked.
|
| 514 |
|
| 515 |
-
If you output anything else, it is incorrect.
|
| 516 |
-
|
| 517 |
ONLY return a response if you are confident about the answer, otherwise return empty string.
|
| 518 |
|
| 519 |
If you output anything else, it is incorrect.
|
| 520 |
|
|
|
|
|
|
|
| 521 |
Example of valid json response for user request: Who was the winner of 2025 World Snooker Championship:
|
| 522 |
Response: Zhao Xintong.
|
| 523 |
|
|
@@ -531,8 +532,8 @@ Information:
|
|
| 531 |
{information}
|
| 532 |
"""
|
| 533 |
|
| 534 |
-
|
| 535 |
-
raw_output = generate(prompt)
|
| 536 |
|
| 537 |
logger.info(f"Raw Output: {raw_output}")
|
| 538 |
|
|
@@ -550,16 +551,32 @@ Information:
|
|
| 550 |
|
| 551 |
|
| 552 |
raw = raw_output.strip()
|
|
|
|
|
|
|
| 553 |
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
if match:
|
| 558 |
-
output = match.group(1).strip()
|
| 559 |
else:
|
| 560 |
-
#
|
| 561 |
-
|
| 562 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 563 |
if "Response:" in output:
|
| 564 |
output = output.split("Response:")[-1]
|
| 565 |
elif "Response" in output:
|
|
@@ -569,7 +586,31 @@ Information:
|
|
| 569 |
output = output.strip('"').strip()
|
| 570 |
if output.endswith("."):
|
| 571 |
output = output[:-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 572 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
|
| 574 |
state["output"] = output
|
| 575 |
|
|
@@ -655,7 +696,7 @@ Answer:
|
|
| 655 |
state["confidence"] = data["confidence"]
|
| 656 |
state["judge_explanation"] = data["explanation"]
|
| 657 |
|
| 658 |
-
logger.info(f"State (Judge Agent): {state}")
|
| 659 |
|
| 660 |
return state
|
| 661 |
|
|
@@ -725,34 +766,59 @@ def tool_executor(state: AgentState):
|
|
| 725 |
responsible for translating structured LLM intent into real system actions.
|
| 726 |
"""
|
| 727 |
|
| 728 |
-
|
| 729 |
-
|
| 730 |
-
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
-
|
| 736 |
-
|
| 737 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 738 |
|
| 739 |
-
|
| 740 |
-
|
| 741 |
-
|
| 742 |
-
|
| 743 |
-
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
| 750 |
-
|
| 751 |
-
|
|
|
|
|
|
|
| 752 |
try:
|
| 753 |
-
webpage_results = visit_webpage(
|
| 754 |
webpage_result = " \n ".join(webpage_results)
|
| 755 |
-
|
| 756 |
# for webpage_result in webpage_results:
|
| 757 |
query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
|
| 758 |
webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
|
|
@@ -768,43 +834,22 @@ def tool_executor(state: AgentState):
|
|
| 768 |
if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
|
| 769 |
best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
|
| 770 |
best_webpage_information = webpage_result
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
elif action.tool == "visit_webpage":
|
| 776 |
-
try:
|
| 777 |
-
webpage_results = visit_webpage(**action.args)
|
| 778 |
-
webpage_result = " \n ".join(webpage_results)
|
| 779 |
-
|
| 780 |
-
# for webpage_result in webpage_results:
|
| 781 |
-
query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
|
| 782 |
-
webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
|
| 783 |
-
query_webpage_information_similarity_score = float(cosine_similarity(query_embeddings, webpage_information_embeddings)[0][0])
|
| 784 |
-
|
| 785 |
-
# logger.info(f"Webpage Information and Similarity Score: {result} - {webpage_result} - {query_webpage_information_similarity_score}")
|
| 786 |
-
|
| 787 |
-
if query_webpage_information_similarity_score > 0.65:
|
| 788 |
-
webpage_information_complete += webpage_result
|
| 789 |
-
webpage_information_complete += " \n "
|
| 790 |
-
webpage_information_complete += " \n "
|
| 791 |
-
|
| 792 |
-
if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
|
| 793 |
-
best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
|
| 794 |
-
best_webpage_information = webpage_result
|
| 795 |
-
except:
|
| 796 |
-
pass
|
| 797 |
-
else:
|
| 798 |
-
result = "Unknown tool"
|
| 799 |
-
|
| 800 |
-
if webpage_information_complete == "":
|
| 801 |
-
webpage_information_complete = best_webpage_information
|
| 802 |
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
|
| 806 |
-
|
| 807 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 808 |
|
| 809 |
return state
|
| 810 |
|
|
@@ -859,6 +904,8 @@ class BasicAgent:
|
|
| 859 |
"messages": question,
|
| 860 |
}
|
| 861 |
|
|
|
|
|
|
|
| 862 |
|
| 863 |
try:
|
| 864 |
response = self.safe_app.invoke(state)
|
|
|
|
| 40 |
def __init__(self):
|
| 41 |
self.random_state = 42
|
| 42 |
self.max_len = 256
|
| 43 |
+
self.reasoning_max_len = 256
|
| 44 |
self.temperature = 0.1
|
| 45 |
self.repetition_penalty = 1.2
|
| 46 |
self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 173 |
confidence: float
|
| 174 |
judge_explanation: str
|
| 175 |
|
| 176 |
+
|
| 177 |
ALL_TOOLS = {
|
| 178 |
"web_search": ["query"],
|
| 179 |
"visit_webpage": ["url"],
|
|
|
|
| 364 |
if content is not None:
|
| 365 |
for table in content.find_all("table", {"class": "wikitable"}):
|
| 366 |
for row in table.find_all("tr"):
|
| 367 |
+
cols = [c.get_text(strip=False) for c in row.find_all(["td", "th"])]
|
| 368 |
if cols:
|
| 369 |
table_texts.append(" | ".join(cols))
|
| 370 |
|
|
|
|
| 513 |
|
| 514 |
DO NOT add anything additional and return ONLY what is asked and in the format asked.
|
| 515 |
|
|
|
|
|
|
|
| 516 |
ONLY return a response if you are confident about the answer, otherwise return empty string.
|
| 517 |
|
| 518 |
If you output anything else, it is incorrect.
|
| 519 |
|
| 520 |
+
If there is no information provided or the information is not relevant then answer as best based on your own knowledge.
|
| 521 |
+
|
| 522 |
Example of valid json response for user request: Who was the winner of 2025 World Snooker Championship:
|
| 523 |
Response: Zhao Xintong.
|
| 524 |
|
|
|
|
| 532 |
{information}
|
| 533 |
"""
|
| 534 |
|
| 535 |
+
raw_output = reasoning_generate(prompt)
|
| 536 |
+
# raw_output = generate(prompt)
|
| 537 |
|
| 538 |
logger.info(f"Raw Output: {raw_output}")
|
| 539 |
|
|
|
|
| 551 |
|
| 552 |
|
| 553 |
raw = raw_output.strip()
|
| 554 |
+
|
| 555 |
+
matches = re.findall(r"Response:\s*([^\n]+)", raw)
|
| 556 |
|
| 557 |
+
if matches:
|
| 558 |
+
output = matches[-1].strip() # ✅ take LAST occurrence
|
|
|
|
|
|
|
|
|
|
| 559 |
else:
|
| 560 |
+
# Find the first valid "Response: ..." occurrence
|
| 561 |
+
match = re.search(r"Response:\s*([^\n\.]+)", raw)
|
| 562 |
+
|
| 563 |
+
if match:
|
| 564 |
+
output = match.group(1).strip()
|
| 565 |
+
else:
|
| 566 |
+
# fallback: take first line
|
| 567 |
+
output = raw.split("\n")[0].strip()
|
| 568 |
+
|
| 569 |
+
if "Response:" in output:
|
| 570 |
+
output = output.split("Response:")[-1]
|
| 571 |
+
elif "Response" in output:
|
| 572 |
+
output = output.split("Response")[-1]
|
| 573 |
+
|
| 574 |
+
# Clean quotes / trailing punctuation
|
| 575 |
+
output = output.strip('"').strip()
|
| 576 |
+
if output.endswith("."):
|
| 577 |
+
output = output[:-1]
|
| 578 |
+
|
| 579 |
+
# Clean
|
| 580 |
if "Response:" in output:
|
| 581 |
output = output.split("Response:")[-1]
|
| 582 |
elif "Response" in output:
|
|
|
|
| 586 |
output = output.strip('"').strip()
|
| 587 |
if output.endswith("."):
|
| 588 |
output = output[:-1]
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
if output == "":
|
| 592 |
+
# Find the first valid "Response: ..." occurrence
|
| 593 |
+
match = re.search(r"Response:\s*([^\n\.]+)", raw)
|
| 594 |
+
|
| 595 |
+
if match:
|
| 596 |
+
output = match.group(1).strip()
|
| 597 |
+
else:
|
| 598 |
+
# fallback: take first line
|
| 599 |
+
output = raw.split("\n")[0].strip()
|
| 600 |
|
| 601 |
+
if "Response:" in output:
|
| 602 |
+
output = output.split("Response:")[-1]
|
| 603 |
+
elif "Response" in output:
|
| 604 |
+
output = output.split("Response")[-1]
|
| 605 |
+
|
| 606 |
+
# Clean quotes / trailing punctuation
|
| 607 |
+
output = output.strip('"').strip()
|
| 608 |
+
if output.endswith("."):
|
| 609 |
+
output = output[:-1]
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
output = output.split(".")[0]
|
| 613 |
+
|
| 614 |
|
| 615 |
state["output"] = output
|
| 616 |
|
|
|
|
| 696 |
state["confidence"] = data["confidence"]
|
| 697 |
state["judge_explanation"] = data["explanation"]
|
| 698 |
|
| 699 |
+
# logger.info(f"State (Judge Agent): {state}")
|
| 700 |
|
| 701 |
return state
|
| 702 |
|
|
|
|
| 766 |
responsible for translating structured LLM intent into real system actions.
|
| 767 |
"""
|
| 768 |
|
| 769 |
+
try:
|
| 770 |
+
webpage_result = ""
|
| 771 |
+
action = Action.model_validate(state["proposed_action"])
|
| 772 |
+
|
| 773 |
+
best_query_webpage_information_similarity_score = -1.0
|
| 774 |
+
best_webpage_information = ""
|
| 775 |
+
|
| 776 |
+
webpage_information_complete = ""
|
| 777 |
+
|
| 778 |
+
if action.tool == "web_search":
|
| 779 |
+
logger.info(f"action.tool: {action.tool}")
|
| 780 |
+
|
| 781 |
+
query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
|
| 782 |
+
query_arg_embeddings = sentence_transformer_model.encode_query(state["proposed_action"]["args"]["query"]).reshape(1, -1)
|
| 783 |
+
score = float(cosine_similarity(query_embeddings, query_arg_embeddings)[0][0])
|
| 784 |
+
|
| 785 |
+
if score > 0.80:
|
| 786 |
+
results = web_search(**action.args)
|
| 787 |
+
else:
|
| 788 |
+
logger.info(f"Overwriting user query because the Agent suggested query had score: {state["proposed_action"]["args"]["query"]} - {score}")
|
| 789 |
+
results = web_search(**{"query": state["messages"][-1].content})
|
| 790 |
+
|
| 791 |
+
logger.info(f"Webpages - Results: {results}")
|
| 792 |
+
|
| 793 |
+
for result in results:
|
| 794 |
+
try:
|
| 795 |
+
webpage_results = visit_webpage(result)
|
| 796 |
+
webpage_result = " \n ".join(webpage_results)
|
| 797 |
+
|
| 798 |
+
# for webpage_result in webpage_results:
|
| 799 |
+
query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
|
| 800 |
+
webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
|
| 801 |
+
query_webpage_information_similarity_score = float(cosine_similarity(query_embeddings, webpage_information_embeddings)[0][0])
|
| 802 |
|
| 803 |
+
# logger.info(f"Webpage Information and Similarity Score: {result} - {webpage_result} - {query_webpage_information_similarity_score}")
|
| 804 |
+
|
| 805 |
+
if query_webpage_information_similarity_score > 0.65:
|
| 806 |
+
webpage_information_complete += webpage_result
|
| 807 |
+
webpage_information_complete += " \n "
|
| 808 |
+
webpage_information_complete += " \n "
|
| 809 |
+
|
| 810 |
+
if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
|
| 811 |
+
best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
|
| 812 |
+
best_webpage_information = webpage_result
|
| 813 |
+
|
| 814 |
+
except Exception as e:
|
| 815 |
+
logger.info(f"Tool Executor - Exception: {e}")
|
| 816 |
+
|
| 817 |
+
elif action.tool == "visit_webpage":
|
| 818 |
try:
|
| 819 |
+
webpage_results = visit_webpage(**action.args)
|
| 820 |
webpage_result = " \n ".join(webpage_results)
|
| 821 |
+
|
| 822 |
# for webpage_result in webpage_results:
|
| 823 |
query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
|
| 824 |
webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
|
|
|
|
| 834 |
if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
|
| 835 |
best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
|
| 836 |
best_webpage_information = webpage_result
|
| 837 |
+
except:
|
| 838 |
+
pass
|
| 839 |
+
else:
|
| 840 |
+
result = "Unknown tool"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 841 |
|
| 842 |
+
if webpage_information_complete == "" and best_query_webpage_information_similarity_score > 0.30:
|
| 843 |
+
webpage_information_complete = best_webpage_information
|
| 844 |
+
|
| 845 |
+
state["information"] = webpage_information_complete[:3000]
|
| 846 |
+
state["best_query_webpage_information_similarity_score"] = best_query_webpage_information_similarity_score
|
| 847 |
+
except:
|
| 848 |
+
state["information"] = ""
|
| 849 |
+
state["best_query_webpage_information_similarity_score"] = -1.0
|
| 850 |
+
|
| 851 |
+
# logger.info(f"Information: {state['information']}")
|
| 852 |
+
# logger.info(f"Information: {state['best_query_webpage_information_similarity_score']}")
|
| 853 |
|
| 854 |
return state
|
| 855 |
|
|
|
|
| 904 |
"messages": question,
|
| 905 |
}
|
| 906 |
|
| 907 |
+
if len(tokenizer.encode(state["messages"][::-1])) < len(tokenizer.encode(state["messages"])):
|
| 908 |
+
state["messages"] = state["messages"][::-1]
|
| 909 |
|
| 910 |
try:
|
| 911 |
response = self.safe_app.invoke(state)
|