gabejavitt commited on
Commit
b59bc25
·
verified ·
1 Parent(s): 8430d9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -25
app.py CHANGED
@@ -614,7 +614,7 @@ Your goal: Provide the EXACT answer in the EXACT format requested.
614
  chat_llm = ChatGroq(
615
  temperature=0, # Maximum determinism
616
  groq_api_key=GROQ_API_KEY,
617
- model_name="llama-3.3-70b-versatile", # Best reasoning model
618
  max_tokens=4096,
619
  timeout=60
620
  )
@@ -651,35 +651,102 @@ Your goal: Provide the EXACT answer in the EXACT format requested.
651
  )
652
  return {"messages": [error_msg], "turn": current_turn}
653
  time.sleep(2 ** attempt) # Exponential backoff
654
-
655
- # --- Fallback JSON parsing ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
656
  if not ai_message.tool_calls and isinstance(ai_message.content, str) and ai_message.content.strip():
657
- json_match = re.search(
658
- r"```(?:json)?\s*(\{.*?\})\s*```|(\{.*?\})",
659
- ai_message.content,
 
 
 
 
 
 
660
  re.DOTALL | re.IGNORECASE
661
  )
662
 
663
- if json_match:
664
- json_str = json_match.group(1) or json_match.group(2)
665
  try:
666
- parsed_json = json.loads(json_str)
667
- if isinstance(parsed_json, dict) and "tool" in parsed_json and "tool_input" in parsed_json:
668
- tool_name = parsed_json.get("tool")
669
- tool_input = parsed_json.get("tool_input", {})
670
-
671
- if any(t.name == tool_name for t in self.tools):
672
- print(f"🔧 Fallback: Parsed tool call for '{tool_name}' from JSON in content")
673
- tool_call = ToolCall(
674
- name=tool_name,
675
- args=tool_input,
676
- id=str(uuid.uuid4())
677
- )
678
- ai_message.tool_calls = [tool_call]
679
- ai_message.content = ""
680
- except json.JSONDecodeError:
681
- pass
682
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  # --- Logging ---
684
  if ai_message.tool_calls:
685
  for tc in ai_message.tool_calls:
 
614
  chat_llm = ChatGroq(
615
  temperature=0, # Maximum determinism
616
  groq_api_key=GROQ_API_KEY,
617
+ model_name="qwen/qwen3-32b", # Best reasoning model
618
  max_tokens=4096,
619
  timeout=60
620
  )
 
651
  )
652
  return {"messages": [error_msg], "turn": current_turn}
653
  time.sleep(2 ** attempt) # Exponential backoff
654
+ # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
655
+ # --- ADD THIS RULE ENFORCEMENT BLOCK ---
656
+ #
657
+ # Force compliance with the "Plan on Turn 1" rule.
658
+ # If the agent tries to call a tool on turn 1, we
659
+ # strip the tool call and send a correction message.
660
+ if current_turn == 1 and ai_message.tool_calls:
661
+ print("⚠️ AGENT VIOLATION: Tried to call tools on Turn 1. Forcing replan.")
662
+
663
+ # Strip the illegal tool call
664
+ ai_message.tool_calls = []
665
+
666
+ # Create a correction message
667
+ correction_message = SystemMessage(
668
+ content="SYSTEM: Protocol Violation. Your FIRST turn MUST be a plan with NO tool calls. "
669
+ "You are not allowed to call any tools on your first turn. "
670
+ "Re-read the protocol and provide your 2-3 sentence plan now."
671
+ )
672
+
673
+ # Return the (now harmless) AI message + the correction
674
+ return {"messages": [ai_message, correction_message], "turn": current_turn}
675
+ # --- END OF RULE ENFORCEMENT BLOCK ---
676
+ # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
677
+ # --- FIX #2: REPLACE THE FALLBACK PARSING BLOCK ---
678
+ #
679
+ # --- Fallback Parsing ---
680
+ # Check if LLM failed to format tool call and put it in 'content'
681
  if not ai_message.tool_calls and isinstance(ai_message.content, str) and ai_message.content.strip():
682
+ content = ai_message.content
683
+ tool_name = None
684
+ tool_input = None
685
+
686
+ # 1. Try to parse the new <function(tool_name)>{json}</function> format
687
+ # Note: We look for </function> optionally, as it might be truncated
688
+ func_match = re.search(
689
+ r"<function\(([^)]+)\)>(\{.*?\})(?:</function>)?",
690
+ content,
691
  re.DOTALL | re.IGNORECASE
692
  )
693
 
694
+ if func_match:
 
695
  try:
696
+ tool_name = func_match.group(1).strip()
697
+ json_str = func_match.group(2)
698
+ tool_input = json.loads(json_str)
699
+ print(f"🔧 Fallback (Format 1): Parsed tool call for '{tool_name}'")
700
+ except json.JSONDecodeError as e:
701
+ print(f"⚠️ Fallback (Format 1): Failed to parse JSON: {e}")
702
+ tool_name = None # Reset
703
+
704
+ # 2. If Format 1 failed, try to parse bare JSON (old fallback)
705
+ if not tool_name:
706
+ json_match = re.search(
707
+ r"```(?:json)?\s*(\{.*?\})\s*```|(\{.*?\})",
708
+ content,
709
+ re.DOTALL | re.IGNORECASE
710
+ )
711
+ if json_match:
712
+ json_str = json_match.group(1) or json_match.group(2)
713
+ try:
714
+ parsed_json = json.loads(json_str)
715
+ # This format is less structured; we guess tool from keys
716
+ if isinstance(parsed_json, dict):
717
+ if "tool" in parsed_json and "tool_input" in parsed_json:
718
+ tool_name = parsed_json.get("tool")
719
+ tool_input = parsed_json.get("tool_input", {})
720
+ elif "code" in parsed_json: # Guess code_interpreter
721
+ tool_name = "code_interpreter"
722
+ tool_input = parsed_json
723
+ elif "answer" in parsed_json: # Guess final_answer
724
+ tool_name = "final_answer_tool"
725
+ tool_input = parsed_json
726
+
727
+ if tool_name:
728
+ print(f"🔧 Fallback (Format 2): Parsed tool call for '{tool_name}'")
729
+
730
+ except json.JSONDecodeError as e:
731
+ print(f"⚠️ Fallback (Format 2): Failed to parse JSON: {e}")
732
+
733
+ # --- If any fallback parser succeeded, build the tool call ---
734
+ if tool_name and tool_input is not None and any(t.name == tool_name for t in self.tools):
735
+ print(f"🔧 Fallback SUCCESS: Rebuilding tool call for '{tool_name}'")
736
+ tool_call = ToolCall(
737
+ name=tool_name,
738
+ args=tool_input,
739
+ id=str(uuid.uuid4())
740
+ )
741
+ ai_message.tool_calls = [tool_call]
742
+ ai_message.content = "" # Clear content field
743
+
744
+ elif not tool_name:
745
+ print(f"⚠️ Fallback FAILED: Could not parse any tool call from content:\n{content[:200]}...")
746
+ # --- END OF REPLACEMENT BLOCK ---
747
+ # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
748
+
749
+
750
  # --- Logging ---
751
  if ai_message.tool_calls:
752
  for tc in ai_message.tool_calls: