Sandiago21 commited on
Commit
555d88a
·
verified ·
1 Parent(s): 152c460

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -8
app.py CHANGED
@@ -114,10 +114,10 @@ def reasoning_generate(prompt):
114
  tokens removed and leading/trailing whitespace stripped.
115
 
116
  """
117
- inputs = reasoning_tokenizer(prompt, return_tensors="pt").to(model.device)
118
 
119
  with torch.no_grad():
120
- outputs = reasoning_model.generate(
121
  **inputs,
122
  max_new_tokens=config.reasoning_max_len,
123
  temperature=config.temperature,
@@ -126,7 +126,7 @@ def reasoning_generate(prompt):
126
 
127
  generated = outputs[0][inputs["input_ids"].shape[-1]:]
128
 
129
- return reasoning_tokenizer.decode(generated, skip_special_tokens=True).strip()
130
 
131
 
132
  def reasoning_generate(prompt):
@@ -338,7 +338,7 @@ def visit_webpage(url: str) -> str:
338
  return [main_text[:1000],]
339
 
340
 
341
- def visit_webpage(url: str) -> str:
342
  headers = {
343
  "User-Agent": "Mozilla/5.0"
344
  }
@@ -374,6 +374,46 @@ def visit_webpage(url: str) -> str:
374
  return [main_text[:1000],]
375
 
376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
  def web_search(query: str, num_results: int = 10):
378
  """
379
  Search the internet for the query provided
@@ -708,6 +748,7 @@ def route(state: AgentState):
708
  else:
709
  return "allow"
710
 
 
711
  def tool_executor(state: AgentState):
712
  """
713
  Tool execution node for a risk-aware LLM agent.
@@ -792,7 +833,7 @@ def tool_executor(state: AgentState):
792
 
793
  for result in results:
794
  try:
795
- webpage_results = visit_webpage(result)
796
  webpage_result = " \n ".join(webpage_results)
797
 
798
  # for webpage_result in webpage_results:
@@ -810,15 +851,58 @@ def tool_executor(state: AgentState):
810
  if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
811
  best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
812
  best_webpage_information = webpage_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813
 
814
  except Exception as e:
815
  logger.info(f"Tool Executor - Exception: {e}")
816
 
817
  elif action.tool == "visit_webpage":
818
  try:
819
- webpage_results = visit_webpage(**action.args)
820
  webpage_result = " \n ".join(webpage_results)
 
 
 
 
 
821
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
822
  # for webpage_result in webpage_results:
823
  query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
824
  webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
@@ -936,8 +1020,8 @@ class BasicAgent:
936
 
937
  # agent_answer = str(df)
938
  # agent_answer = str(response.status_code) + " - " + task_id
939
- except:
940
- agent_answer = ""
941
 
942
  else:
943
  agent_answer = fixed_answer
@@ -948,6 +1032,7 @@ class BasicAgent:
948
 
949
  return agent_answer
950
 
 
951
  def run_and_submit_all( profile: gr.OAuthProfile | None):
952
  """
953
  Fetches all questions, runs the BasicAgent on them, submits all answers,
 
114
  tokens removed and leading/trailing whitespace stripped.
115
 
116
  """
117
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
118
 
119
  with torch.no_grad():
120
+ outputs = model.generate(
121
  **inputs,
122
  max_new_tokens=config.reasoning_max_len,
123
  temperature=config.temperature,
 
126
 
127
  generated = outputs[0][inputs["input_ids"].shape[-1]:]
128
 
129
+ return tokenizer.decode(generated, skip_special_tokens=True).strip()
130
 
131
 
132
  def reasoning_generate(prompt):
 
338
  return [main_text[:1000],]
339
 
340
 
341
+ def visit_webpage_wiki(url: str) -> str:
342
  headers = {
343
  "User-Agent": "Mozilla/5.0"
344
  }
 
374
  return [main_text[:1000],]
375
 
376
 
377
+ def visit_webpage_main(url: str):
378
+ headers = {"User-Agent": "Mozilla/5.0"}
379
+
380
+ response = requests.get(url, headers=headers, timeout=10)
381
+ response.raise_for_status()
382
+
383
+ soup = BeautifulSoup(response.text, "html.parser")
384
+
385
+ # Remove scripts/styles
386
+ for tag in soup(["script", "style"]):
387
+ tag.extract()
388
+
389
+ # 🔥 Try to focus on body (fallback if no clear container)
390
+ content = soup.find("body")
391
+
392
+ # ✅ Extract broader set of elements
393
+ elements = content.find_all(["p", "dd", "td", "div"])
394
+
395
+ texts = []
396
+ for el in elements:
397
+ text = el.get_text(strip=True)
398
+ if text and len(text) > 30: # filter noise
399
+ texts.append(text)
400
+
401
+ main_text = "\n".join(texts)
402
+
403
+ # ✅ Extract all tables (not just wikitable)
404
+ table_texts = []
405
+ for table in soup.find_all("table"):
406
+ for row in table.find_all("tr"):
407
+ cols = [c.get_text(strip=True) for c in row.find_all(["td", "th"])]
408
+ if cols:
409
+ table_texts.append(" | ".join(cols))
410
+
411
+ if table_texts:
412
+ return [main_text[:1500], "\n".join(table_texts)[:5000]]
413
+ else:
414
+ return [main_text[:1500]]
415
+
416
+
417
  def web_search(query: str, num_results: int = 10):
418
  """
419
  Search the internet for the query provided
 
748
  else:
749
  return "allow"
750
 
751
+
752
  def tool_executor(state: AgentState):
753
  """
754
  Tool execution node for a risk-aware LLM agent.
 
833
 
834
  for result in results:
835
  try:
836
+ webpage_results = visit_webpage_wiki(result)
837
  webpage_result = " \n ".join(webpage_results)
838
 
839
  # for webpage_result in webpage_results:
 
851
  if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
852
  best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
853
  best_webpage_information = webpage_result
854
+
855
+
856
+
857
+ webpage_results = visit_webpage_main(result)
858
+ webpage_result = " \n ".join(webpage_results)
859
+
860
+ # for webpage_result in webpage_results:
861
+ query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
862
+ webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
863
+ query_webpage_information_similarity_score = float(cosine_similarity(query_embeddings, webpage_information_embeddings)[0][0])
864
+
865
+ # logger.info(f"Webpage Information and Similarity Score: {result} - {webpage_result} - {query_webpage_information_similarity_score}")
866
+
867
+ if query_webpage_information_similarity_score > 0.65:
868
+ webpage_information_complete += webpage_result
869
+ webpage_information_complete += " \n "
870
+ webpage_information_complete += " \n "
871
+
872
+ if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
873
+ best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
874
+ best_webpage_information = webpage_result
875
+
876
 
877
  except Exception as e:
878
  logger.info(f"Tool Executor - Exception: {e}")
879
 
880
  elif action.tool == "visit_webpage":
881
  try:
882
+ webpage_results = visit_webpage_wiki(result)
883
  webpage_result = " \n ".join(webpage_results)
884
+
885
+ # for webpage_result in webpage_results:
886
+ query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
887
+ webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
888
+ query_webpage_information_similarity_score = float(cosine_similarity(query_embeddings, webpage_information_embeddings)[0][0])
889
 
890
+ # logger.info(f"Webpage Information and Similarity Score: {result} - {webpage_result} - {query_webpage_information_similarity_score}")
891
+
892
+ if query_webpage_information_similarity_score > 0.65:
893
+ webpage_information_complete += webpage_result
894
+ webpage_information_complete += " \n "
895
+ webpage_information_complete += " \n "
896
+
897
+ if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
898
+ best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
899
+ best_webpage_information = webpage_result
900
+
901
+
902
+
903
+ webpage_results = visit_webpage_main(result)
904
+ webpage_result = " \n ".join(webpage_results)
905
+
906
  # for webpage_result in webpage_results:
907
  query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
908
  webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
 
1020
 
1021
  # agent_answer = str(df)
1022
  # agent_answer = str(response.status_code) + " - " + task_id
1023
+ except Exception as e:
1024
+ agent_answer = str(e)
1025
 
1026
  else:
1027
  agent_answer = fixed_answer
 
1032
 
1033
  return agent_answer
1034
 
1035
+
1036
  def run_and_submit_all( profile: gr.OAuthProfile | None):
1037
  """
1038
  Fetches all questions, runs the BasicAgent on them, submits all answers,