Sandiago21 commited on
Commit
c17b923
·
verified ·
1 Parent(s): c79ae07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -76
app.py CHANGED
@@ -40,7 +40,7 @@ class Config(object):
40
  def __init__(self):
41
  self.random_state = 42
42
  self.max_len = 256
43
- self.reasoning_max_len = 128
44
  self.temperature = 0.1
45
  self.repetition_penalty = 1.2
46
  self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
@@ -173,6 +173,7 @@ class AgentState(TypedDict):
173
  confidence: float
174
  judge_explanation: str
175
 
 
176
  ALL_TOOLS = {
177
  "web_search": ["query"],
178
  "visit_webpage": ["url"],
@@ -363,7 +364,7 @@ def visit_webpage(url: str) -> str:
363
  if content is not None:
364
  for table in content.find_all("table", {"class": "wikitable"}):
365
  for row in table.find_all("tr"):
366
- cols = [c.get_text(strip=True) for c in row.find_all(["td", "th"])]
367
  if cols:
368
  table_texts.append(" | ".join(cols))
369
 
@@ -512,12 +513,12 @@ Response: <answer>
512
 
513
  DO NOT add anything additional and return ONLY what is asked and in the format asked.
514
 
515
- If you output anything else, it is incorrect.
516
-
517
  ONLY return a response if you are confident about the answer, otherwise return empty string.
518
 
519
  If you output anything else, it is incorrect.
520
 
 
 
521
  Example of valid json response for user request: Who was the winner of 2025 World Snooker Championship:
522
  Response: Zhao Xintong.
523
 
@@ -531,8 +532,8 @@ Information:
531
  {information}
532
  """
533
 
534
- # raw_output = reasoning_generate(prompt)
535
- raw_output = generate(prompt)
536
 
537
  logger.info(f"Raw Output: {raw_output}")
538
 
@@ -550,16 +551,32 @@ Information:
550
 
551
 
552
  raw = raw_output.strip()
 
 
553
 
554
- # Find the first valid "Response: ..." occurrence
555
- match = re.search(r"Response:\s*([^\n\.]+)", raw)
556
-
557
- if match:
558
- output = match.group(1).strip()
559
  else:
560
- # fallback: take first line
561
- output = raw.split("\n")[0].strip()
562
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
563
  if "Response:" in output:
564
  output = output.split("Response:")[-1]
565
  elif "Response" in output:
@@ -569,7 +586,31 @@ Information:
569
  output = output.strip('"').strip()
570
  if output.endswith("."):
571
  output = output[:-1]
 
 
 
 
 
 
 
 
 
 
 
572
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
 
574
  state["output"] = output
575
 
@@ -655,7 +696,7 @@ Answer:
655
  state["confidence"] = data["confidence"]
656
  state["judge_explanation"] = data["explanation"]
657
 
658
- logger.info(f"State (Judge Agent): {state}")
659
 
660
  return state
661
 
@@ -725,34 +766,59 @@ def tool_executor(state: AgentState):
725
  responsible for translating structured LLM intent into real system actions.
726
  """
727
 
728
- webpage_result = ""
729
- action = Action.model_validate(state["proposed_action"])
730
-
731
- best_query_webpage_information_similarity_score = -1.0
732
- best_webpage_information = ""
733
-
734
- webpage_information_complete = ""
735
-
736
- if action.tool == "web_search":
737
- logger.info(f"action.tool: {action.tool}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
738
 
739
- query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
740
- query_arg_embeddings = sentence_transformer_model.encode_query(state["proposed_action"]["args"]["query"]).reshape(1, -1)
741
- score = float(cosine_similarity(query_embeddings, query_arg_embeddings)[0][0])
742
-
743
- if score > 0.80:
744
- results = web_search(**action.args)
745
- else:
746
- logger.info(f"Overwriting user query because the Agent suggested query had score: {state["proposed_action"]["args"]["query"]} - {score}")
747
- results = web_search(**{"query": state["messages"][-1].content})
748
-
749
- logger.info(f"Webpages - Results: {results}")
750
-
751
- for result in results:
 
 
752
  try:
753
- webpage_results = visit_webpage(result)
754
  webpage_result = " \n ".join(webpage_results)
755
-
756
  # for webpage_result in webpage_results:
757
  query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
758
  webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
@@ -768,43 +834,22 @@ def tool_executor(state: AgentState):
768
  if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
769
  best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
770
  best_webpage_information = webpage_result
771
-
772
- except Exception as e:
773
- logger.info(f"Tool Executor - Exception: {e}")
774
-
775
- elif action.tool == "visit_webpage":
776
- try:
777
- webpage_results = visit_webpage(**action.args)
778
- webpage_result = " \n ".join(webpage_results)
779
-
780
- # for webpage_result in webpage_results:
781
- query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
782
- webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
783
- query_webpage_information_similarity_score = float(cosine_similarity(query_embeddings, webpage_information_embeddings)[0][0])
784
-
785
- # logger.info(f"Webpage Information and Similarity Score: {result} - {webpage_result} - {query_webpage_information_similarity_score}")
786
-
787
- if query_webpage_information_similarity_score > 0.65:
788
- webpage_information_complete += webpage_result
789
- webpage_information_complete += " \n "
790
- webpage_information_complete += " \n "
791
-
792
- if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
793
- best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
794
- best_webpage_information = webpage_result
795
- except:
796
- pass
797
- else:
798
- result = "Unknown tool"
799
-
800
- if webpage_information_complete == "":
801
- webpage_information_complete = best_webpage_information
802
 
803
- state["information"] = webpage_information_complete[:3000]
804
- state["best_query_webpage_information_similarity_score"] = best_query_webpage_information_similarity_score
805
-
806
- logger.info(f"Information: {state['information']}")
807
- logger.info(f"Information: {state['best_query_webpage_information_similarity_score']}")
 
 
 
 
 
 
808
 
809
  return state
810
 
@@ -859,6 +904,8 @@ class BasicAgent:
859
  "messages": question,
860
  }
861
 
 
 
862
 
863
  try:
864
  response = self.safe_app.invoke(state)
 
40
  def __init__(self):
41
  self.random_state = 42
42
  self.max_len = 256
43
+ self.reasoning_max_len = 256
44
  self.temperature = 0.1
45
  self.repetition_penalty = 1.2
46
  self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
173
  confidence: float
174
  judge_explanation: str
175
 
176
+
177
  ALL_TOOLS = {
178
  "web_search": ["query"],
179
  "visit_webpage": ["url"],
 
364
  if content is not None:
365
  for table in content.find_all("table", {"class": "wikitable"}):
366
  for row in table.find_all("tr"):
367
+ cols = [c.get_text(strip=False) for c in row.find_all(["td", "th"])]
368
  if cols:
369
  table_texts.append(" | ".join(cols))
370
 
 
513
 
514
  DO NOT add anything additional and return ONLY what is asked and in the format asked.
515
 
 
 
516
  ONLY return a response if you are confident about the answer, otherwise return empty string.
517
 
518
  If you output anything else, it is incorrect.
519
 
520
+ If there is no information provided or the information is not relevant then answer as best based on your own knowledge.
521
+
522
  Example of valid json response for user request: Who was the winner of 2025 World Snooker Championship:
523
  Response: Zhao Xintong.
524
 
 
532
  {information}
533
  """
534
 
535
+ raw_output = reasoning_generate(prompt)
536
+ # raw_output = generate(prompt)
537
 
538
  logger.info(f"Raw Output: {raw_output}")
539
 
 
551
 
552
 
553
  raw = raw_output.strip()
554
+
555
+ matches = re.findall(r"Response:\s*([^\n]+)", raw)
556
 
557
+ if matches:
558
+ output = matches[-1].strip() # ✅ take LAST occurrence
 
 
 
559
  else:
560
+ # Find the first valid "Response: ..." occurrence
561
+ match = re.search(r"Response:\s*([^\n\.]+)", raw)
562
+
563
+ if match:
564
+ output = match.group(1).strip()
565
+ else:
566
+ # fallback: take first line
567
+ output = raw.split("\n")[0].strip()
568
+
569
+ if "Response:" in output:
570
+ output = output.split("Response:")[-1]
571
+ elif "Response" in output:
572
+ output = output.split("Response")[-1]
573
+
574
+ # Clean quotes / trailing punctuation
575
+ output = output.strip('"').strip()
576
+ if output.endswith("."):
577
+ output = output[:-1]
578
+
579
+ # Clean
580
  if "Response:" in output:
581
  output = output.split("Response:")[-1]
582
  elif "Response" in output:
 
586
  output = output.strip('"').strip()
587
  if output.endswith("."):
588
  output = output[:-1]
589
+
590
+
591
+ if output == "":
592
+ # Find the first valid "Response: ..." occurrence
593
+ match = re.search(r"Response:\s*([^\n\.]+)", raw)
594
+
595
+ if match:
596
+ output = match.group(1).strip()
597
+ else:
598
+ # fallback: take first line
599
+ output = raw.split("\n")[0].strip()
600
 
601
+ if "Response:" in output:
602
+ output = output.split("Response:")[-1]
603
+ elif "Response" in output:
604
+ output = output.split("Response")[-1]
605
+
606
+ # Clean quotes / trailing punctuation
607
+ output = output.strip('"').strip()
608
+ if output.endswith("."):
609
+ output = output[:-1]
610
+
611
+
612
+ output = output.split(".")[0]
613
+
614
 
615
  state["output"] = output
616
 
 
696
  state["confidence"] = data["confidence"]
697
  state["judge_explanation"] = data["explanation"]
698
 
699
+ # logger.info(f"State (Judge Agent): {state}")
700
 
701
  return state
702
 
 
766
  responsible for translating structured LLM intent into real system actions.
767
  """
768
 
769
+ try:
770
+ webpage_result = ""
771
+ action = Action.model_validate(state["proposed_action"])
772
+
773
+ best_query_webpage_information_similarity_score = -1.0
774
+ best_webpage_information = ""
775
+
776
+ webpage_information_complete = ""
777
+
778
+ if action.tool == "web_search":
779
+ logger.info(f"action.tool: {action.tool}")
780
+
781
+ query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
782
+ query_arg_embeddings = sentence_transformer_model.encode_query(state["proposed_action"]["args"]["query"]).reshape(1, -1)
783
+ score = float(cosine_similarity(query_embeddings, query_arg_embeddings)[0][0])
784
+
785
+ if score > 0.80:
786
+ results = web_search(**action.args)
787
+ else:
788
+ logger.info(f"Overwriting user query because the Agent suggested query had score: {state["proposed_action"]["args"]["query"]} - {score}")
789
+ results = web_search(**{"query": state["messages"][-1].content})
790
+
791
+ logger.info(f"Webpages - Results: {results}")
792
+
793
+ for result in results:
794
+ try:
795
+ webpage_results = visit_webpage(result)
796
+ webpage_result = " \n ".join(webpage_results)
797
+
798
+ # for webpage_result in webpage_results:
799
+ query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
800
+ webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
801
+ query_webpage_information_similarity_score = float(cosine_similarity(query_embeddings, webpage_information_embeddings)[0][0])
802
 
803
+ # logger.info(f"Webpage Information and Similarity Score: {result} - {webpage_result} - {query_webpage_information_similarity_score}")
804
+
805
+ if query_webpage_information_similarity_score > 0.65:
806
+ webpage_information_complete += webpage_result
807
+ webpage_information_complete += " \n "
808
+ webpage_information_complete += " \n "
809
+
810
+ if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
811
+ best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
812
+ best_webpage_information = webpage_result
813
+
814
+ except Exception as e:
815
+ logger.info(f"Tool Executor - Exception: {e}")
816
+
817
+ elif action.tool == "visit_webpage":
818
  try:
819
+ webpage_results = visit_webpage(**action.args)
820
  webpage_result = " \n ".join(webpage_results)
821
+
822
  # for webpage_result in webpage_results:
823
  query_embeddings = sentence_transformer_model.encode_query(state["messages"][-1].content).reshape(1, -1)
824
  webpage_information_embeddings = sentence_transformer_model.encode_query(webpage_result).reshape(1, -1)
 
834
  if query_webpage_information_similarity_score > best_query_webpage_information_similarity_score:
835
  best_query_webpage_information_similarity_score = query_webpage_information_similarity_score
836
  best_webpage_information = webpage_result
837
+ except:
838
+ pass
839
+ else:
840
+ result = "Unknown tool"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
841
 
842
+ if webpage_information_complete == "" and best_query_webpage_information_similarity_score > 0.30:
843
+ webpage_information_complete = best_webpage_information
844
+
845
+ state["information"] = webpage_information_complete[:3000]
846
+ state["best_query_webpage_information_similarity_score"] = best_query_webpage_information_similarity_score
847
+ except:
848
+ state["information"] = ""
849
+ state["best_query_webpage_information_similarity_score"] = -1.0
850
+
851
+ # logger.info(f"Information: {state['information']}")
852
+ # logger.info(f"Information: {state['best_query_webpage_information_similarity_score']}")
853
 
854
  return state
855
 
 
904
  "messages": question,
905
  }
906
 
907
+ if len(tokenizer.encode(state["messages"][::-1])) < len(tokenizer.encode(state["messages"])):
908
+ state["messages"] = state["messages"][::-1]
909
 
910
  try:
911
  response = self.safe_app.invoke(state)