Update app.py
Browse files
app.py
CHANGED
|
@@ -114,10 +114,10 @@ def reasoning_generate(prompt):
|
|
| 114 |
tokens removed and leading/trailing whitespace stripped.
|
| 115 |
|
| 116 |
"""
|
| 117 |
-
inputs =
|
| 118 |
|
| 119 |
with torch.no_grad():
|
| 120 |
-
outputs =
|
| 121 |
**inputs,
|
| 122 |
max_new_tokens=config.reasoning_max_len,
|
| 123 |
temperature=config.temperature,
|
|
@@ -126,7 +126,7 @@ def reasoning_generate(prompt):
|
|
| 126 |
|
| 127 |
generated = outputs[0][inputs["input_ids"].shape[-1]:]
|
| 128 |
|
| 129 |
-
return
|
| 130 |
|
| 131 |
|
| 132 |
def reasoning_generate(prompt):
|
|
@@ -338,7 +338,7 @@ def visit_webpage(url: str) -> str:
|
|
| 338 |
return [main_text[:1000],]
|
| 339 |
|
| 340 |
|
| 341 |
-
def
|
| 342 |
headers = {
|
| 343 |
"User-Agent": "Mozilla/5.0"
|
| 344 |
}
|
|
@@ -374,6 +374,46 @@ def visit_webpage(url: str) -> str:
|
|
| 374 |
return [main_text[:1000],]
|
| 375 |
|
| 376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
def web_search(query: str, num_results: int = 10):
|
| 378 |
"""
|
| 379 |
Search the internet for the query provided
|
|
@@ -708,6 +748,7 @@ def route(state: AgentState):
|
|
| 708 |
else:
|
| 709 |
return "allow"
|
| 710 |
|
|
|
|
| 711 |
def tool_executor(state: AgentState):
|
| 712 |
"""
|
| 713 |
Tool execution node for a risk-aware LLM agent.
|
|
@@ -792,7 +833,7 @@ def tool_executor(state: AgentState):
|
|
| 792 |
|
| 793 |
for result in results:
|
| 794 |
try:
|
| 795 |
-
webpage_results =
|
| 796 |
webpage_result = " \n ".join(webpage_results)
|
| 797 |
|
| 798 |
# for webpage_result in webpage_results:
|
|
@@ -810,15 +851,58 @@ def tool_executor(state: AgentState):
|
|
| 810 |
if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
|
| 811 |
best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
|
| 812 |
best_webpage_information = webpage_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 813 |
|
| 814 |
except Exception as e:
|
| 815 |
logger.info(f"Tool Executor - Exception: {e}")
|
| 816 |
|
| 817 |
elif action.tool == "visit_webpage":
|
| 818 |
try:
|
| 819 |
-
webpage_results =
|
| 820 |
webpage_result = " \n ".join(webpage_results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 821 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 822 |
# for webpage_result in webpage_results:
|
| 823 |
query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
|
| 824 |
webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
|
|
@@ -936,8 +1020,8 @@ class BasicAgent:
|
|
| 936 |
|
| 937 |
# agent_answer = str(df)
|
| 938 |
# agent_answer = str(response.status_code) + " - " + task_id
|
| 939 |
-
except:
|
| 940 |
-
agent_answer =
|
| 941 |
|
| 942 |
else:
|
| 943 |
agent_answer = fixed_answer
|
|
@@ -948,6 +1032,7 @@ class BasicAgent:
|
|
| 948 |
|
| 949 |
return agent_answer
|
| 950 |
|
|
|
|
| 951 |
def run_and_submit_all( profile: gr.OAuthProfile | None):
|
| 952 |
"""
|
| 953 |
Fetches all questions, runs the BasicAgent on them, submits all answers,
|
|
|
|
| 114 |
tokens removed and leading/trailing whitespace stripped.
|
| 115 |
|
| 116 |
"""
|
| 117 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 118 |
|
| 119 |
with torch.no_grad():
|
| 120 |
+
outputs = model.generate(
|
| 121 |
**inputs,
|
| 122 |
max_new_tokens=config.reasoning_max_len,
|
| 123 |
temperature=config.temperature,
|
|
|
|
| 126 |
|
| 127 |
generated = outputs[0][inputs["input_ids"].shape[-1]:]
|
| 128 |
|
| 129 |
+
return tokenizer.decode(generated, skip_special_tokens=True).strip()
|
| 130 |
|
| 131 |
|
| 132 |
def reasoning_generate(prompt):
|
|
|
|
| 338 |
return [main_text[:1000],]
|
| 339 |
|
| 340 |
|
| 341 |
+
def visit_webpage_wiki(url: str) -> str:
|
| 342 |
headers = {
|
| 343 |
"User-Agent": "Mozilla/5.0"
|
| 344 |
}
|
|
|
|
| 374 |
return [main_text[:1000],]
|
| 375 |
|
| 376 |
|
| 377 |
+
def visit_webpage_main(url: str):
|
| 378 |
+
headers = {"User-Agent": "Mozilla/5.0"}
|
| 379 |
+
|
| 380 |
+
response = requests.get(url, headers=headers, timeout=10)
|
| 381 |
+
response.raise_for_status()
|
| 382 |
+
|
| 383 |
+
soup = BeautifulSoup(response.text, "html.parser")
|
| 384 |
+
|
| 385 |
+
# Remove scripts/styles
|
| 386 |
+
for tag in soup(["script", "style"]):
|
| 387 |
+
tag.extract()
|
| 388 |
+
|
| 389 |
+
# 🔥 Try to focus on body (fallback if no clear container)
|
| 390 |
+
content = soup.find("body")
|
| 391 |
+
|
| 392 |
+
# ✅ Extract broader set of elements
|
| 393 |
+
elements = content.find_all(["p", "dd", "td", "div"])
|
| 394 |
+
|
| 395 |
+
texts = []
|
| 396 |
+
for el in elements:
|
| 397 |
+
text = el.get_text(strip=True)
|
| 398 |
+
if text and len(text) > 30: # filter noise
|
| 399 |
+
texts.append(text)
|
| 400 |
+
|
| 401 |
+
main_text = "\n".join(texts)
|
| 402 |
+
|
| 403 |
+
# ✅ Extract all tables (not just wikitable)
|
| 404 |
+
table_texts = []
|
| 405 |
+
for table in soup.find_all("table"):
|
| 406 |
+
for row in table.find_all("tr"):
|
| 407 |
+
cols = [c.get_text(strip=True) for c in row.find_all(["td", "th"])]
|
| 408 |
+
if cols:
|
| 409 |
+
table_texts.append(" | ".join(cols))
|
| 410 |
+
|
| 411 |
+
if table_texts:
|
| 412 |
+
return [main_text[:1500], "\n".join(table_texts)[:5000]]
|
| 413 |
+
else:
|
| 414 |
+
return [main_text[:1500]]
|
| 415 |
+
|
| 416 |
+
|
| 417 |
def web_search(query: str, num_results: int = 10):
|
| 418 |
"""
|
| 419 |
Search the internet for the query provided
|
|
|
|
| 748 |
else:
|
| 749 |
return "allow"
|
| 750 |
|
| 751 |
+
|
| 752 |
def tool_executor(state: AgentState):
|
| 753 |
"""
|
| 754 |
Tool execution node for a risk-aware LLM agent.
|
|
|
|
| 833 |
|
| 834 |
for result in results:
|
| 835 |
try:
|
| 836 |
+
webpage_results = visit_webpage_wiki(result)
|
| 837 |
webpage_result = " \n ".join(webpage_results)
|
| 838 |
|
| 839 |
# for webpage_result in webpage_results:
|
|
|
|
| 851 |
if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
|
| 852 |
best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
|
| 853 |
best_webpage_information = webpage_result
|
| 854 |
+
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
webpage_results = visit_webpage_main(result)
|
| 858 |
+
webpage_result = " \n ".join(webpage_results)
|
| 859 |
+
|
| 860 |
+
# for webpage_result in webpage_results:
|
| 861 |
+
query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
|
| 862 |
+
webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
|
| 863 |
+
query_webpage_information_similarity_score = float(cosine_similarity(query_embeddings, webpage_information_embeddings)[0][0])
|
| 864 |
+
|
| 865 |
+
# logger.info(f"Webpage Information and Similarity Score: {result} - {webpage_result} - {query_webpage_information_similarity_score}")
|
| 866 |
+
|
| 867 |
+
if query_webpage_information_similarity_score > 0.65:
|
| 868 |
+
webpage_information_complete += webpage_result
|
| 869 |
+
webpage_information_complete += " \n "
|
| 870 |
+
webpage_information_complete += " \n "
|
| 871 |
+
|
| 872 |
+
if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
|
| 873 |
+
best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
|
| 874 |
+
best_webpage_information = webpage_result
|
| 875 |
+
|
| 876 |
|
| 877 |
except Exception as e:
|
| 878 |
logger.info(f"Tool Executor - Exception: {e}")
|
| 879 |
|
| 880 |
elif action.tool == "visit_webpage":
|
| 881 |
try:
|
| 882 |
+
webpage_results = visit_webpage_wiki(result)
|
| 883 |
webpage_result = " \n ".join(webpage_results)
|
| 884 |
+
|
| 885 |
+
# for webpage_result in webpage_results:
|
| 886 |
+
query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
|
| 887 |
+
webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
|
| 888 |
+
query_webpage_information_similarity_score = float(cosine_similarity(query_embeddings, webpage_information_embeddings)[0][0])
|
| 889 |
|
| 890 |
+
# logger.info(f"Webpage Information and Similarity Score: {result} - {webpage_result} - {query_webpage_information_similarity_score}")
|
| 891 |
+
|
| 892 |
+
if query_webpage_information_similarity_score > 0.65:
|
| 893 |
+
webpage_information_complete += webpage_result
|
| 894 |
+
webpage_information_complete += " \n "
|
| 895 |
+
webpage_information_complete += " \n "
|
| 896 |
+
|
| 897 |
+
if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
|
| 898 |
+
best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
|
| 899 |
+
best_webpage_information = webpage_result
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
webpage_results = visit_webpage_main(result)
|
| 904 |
+
webpage_result = " \n ".join(webpage_results)
|
| 905 |
+
|
| 906 |
# for webpage_result in webpage_results:
|
| 907 |
query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
|
| 908 |
webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
|
|
|
|
| 1020 |
|
| 1021 |
# agent_answer = str(df)
|
| 1022 |
# agent_answer = str(response.status_code) + " - " + task_id
|
| 1023 |
+
except Exception as e:
|
| 1024 |
+
agent_answer = str(e)
|
| 1025 |
|
| 1026 |
else:
|
| 1027 |
agent_answer = fixed_answer
|
|
|
|
| 1032 |
|
| 1033 |
return agent_answer
|
| 1034 |
|
| 1035 |
+
|
| 1036 |
def run_and_submit_all( profile: gr.OAuthProfile | None):
|
| 1037 |
"""
|
| 1038 |
Fetches all questions, runs the BasicAgent on them, submits all answers,
|