Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -628,7 +628,7 @@ Your goal: Provide the EXACT answer in the EXACT format requested.
|
|
| 628 |
chat_llm = ChatGroq(
|
| 629 |
temperature=0, # Maximum determinism
|
| 630 |
groq_api_key=GROQ_API_KEY,
|
| 631 |
-
model_name="
|
| 632 |
max_tokens=4096,
|
| 633 |
timeout=60
|
| 634 |
)
|
|
@@ -641,77 +641,77 @@ Your goal: Provide the EXACT answer in the EXACT format requested.
|
|
| 641 |
print("✅ Tools bound to LLM")
|
| 642 |
|
| 643 |
# --- Agent Node ---
|
| 644 |
-
def agent_node(state: AgentState):
|
| 645 |
-
|
| 646 |
-
# --- Turn Counter Logic ---
|
| 647 |
-
# We need to check if this is a retry of a failed turn (e.g., Turn 1 violation)
|
| 648 |
-
# We identify a retry if the *last* message was our "Protocol Violation" message
|
| 649 |
-
last_msg = state['messages'][-1]
|
| 650 |
-
is_a_retry = False
|
| 651 |
-
if isinstance(last_msg, SystemMessage) and "Protocol Violation" in last_msg.content:
|
| 652 |
-
is_a_retry = True
|
| 653 |
-
|
| 654 |
-
# Get the state's current turn number
|
| 655 |
-
current_turn = state.get('turn', 0)
|
| 656 |
-
|
| 657 |
-
# If this is NOT a retry, increment the turn.
|
| 658 |
-
# If it IS a retry, we *stay on the same turn number*
|
| 659 |
-
if not is_a_retry:
|
| 660 |
-
current_turn += 1
|
| 661 |
-
|
| 662 |
-
# Handle the very first run (where state['turn'] is 0)
|
| 663 |
-
if current_turn == 0:
|
| 664 |
-
current_turn = 1
|
| 665 |
-
# --- End Turn Counter Logic ---
|
| 666 |
-
|
| 667 |
-
print(f"\n{'='*60}")
|
| 668 |
-
print(f"AGENT TURN {current_turn}/{MAX_TURNS}")
|
| 669 |
-
if is_a_retry:
|
| 670 |
-
print("--- (Re-trying after protocol violation) ---")
|
| 671 |
-
print('='*60)
|
| 672 |
-
|
| 673 |
-
messages_to_send = state["messages"]
|
| 674 |
-
|
| 675 |
-
# Retry logic with exponential backoff
|
| 676 |
-
max_retries = 3
|
| 677 |
-
ai_message = None
|
| 678 |
-
|
| 679 |
-
for attempt in range(max_retries):
|
| 680 |
-
try:
|
| 681 |
-
ai_message = self.llm_with_tools.invoke(messages_to_send)
|
| 682 |
-
break
|
| 683 |
-
except Exception as e:
|
| 684 |
-
print(f"⚠️ LLM attempt {attempt+1}/{max_retries} failed: {e}")
|
| 685 |
-
if attempt == max_retries - 1:
|
| 686 |
-
error_msg = AIMessage(
|
| 687 |
-
content=f"Error: LLM failed after {max_retries} attempts: {str(e)}"
|
| 688 |
-
)
|
| 689 |
-
return {"messages": [error_msg], "turn": current_turn}
|
| 690 |
-
time.sleep(2 ** attempt) # Exponential backoff
|
| 691 |
-
|
| 692 |
-
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
| 693 |
-
# --- (FIX #1) RULE ENFORCEMENT BLOCK ---
|
| 694 |
-
#
|
| 695 |
-
# If it's Turn 1 AND the agent tried to call tools, we reject it
|
| 696 |
-
# and force it to re-do Turn 1.
|
| 697 |
-
if current_turn == 1 and ai_message.tool_calls:
|
| 698 |
-
print("⚠️ AGENT VIOLATION: Tried to call tools on Turn 1. Forcing replan.")
|
| 699 |
|
| 700 |
-
#
|
| 701 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 702 |
|
| 703 |
-
#
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 709 |
|
| 710 |
-
#
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 715 |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
| 716 |
# --- FIX #2: REPLACE THE FALLBACK PARSING BLOCK ---
|
| 717 |
#
|
|
|
|
| 628 |
chat_llm = ChatGroq(
|
| 629 |
temperature=0, # Maximum determinism
|
| 630 |
groq_api_key=GROQ_API_KEY,
|
| 631 |
+
model_name="openai/gpt-oss-120b", # Best reasoning model
|
| 632 |
max_tokens=4096,
|
| 633 |
timeout=60
|
| 634 |
)
|
|
|
|
| 641 |
print("✅ Tools bound to LLM")
|
| 642 |
|
| 643 |
# --- Agent Node ---
|
| 644 |
+
def agent_node(state: AgentState):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 645 |
|
| 646 |
+
# --- Turn Counter Logic ---
|
| 647 |
+
# We need to check if this is a retry of a failed turn (e.g., Turn 1 violation)
|
| 648 |
+
# We identify a retry if the *last* message was our "Protocol Violation" message
|
| 649 |
+
last_msg = state['messages'][-1]
|
| 650 |
+
is_a_retry = False
|
| 651 |
+
if isinstance(last_msg, SystemMessage) and "Protocol Violation" in last_msg.content:
|
| 652 |
+
is_a_retry = True
|
| 653 |
|
| 654 |
+
# Get the state's current turn number
|
| 655 |
+
current_turn = state.get('turn', 0)
|
| 656 |
+
|
| 657 |
+
# If this is NOT a retry, increment the turn.
|
| 658 |
+
# If it IS a retry, we *stay on the same turn number*
|
| 659 |
+
if not is_a_retry:
|
| 660 |
+
current_turn += 1
|
| 661 |
+
|
| 662 |
+
# Handle the very first run (where state['turn'] is 0)
|
| 663 |
+
if current_turn == 0:
|
| 664 |
+
current_turn = 1
|
| 665 |
+
# --- End Turn Counter Logic ---
|
| 666 |
+
|
| 667 |
+
print(f"\n{'='*60}")
|
| 668 |
+
print(f"AGENT TURN {current_turn}/{MAX_TURNS}")
|
| 669 |
+
if is_a_retry:
|
| 670 |
+
print("--- (Re-trying after protocol violation) ---")
|
| 671 |
+
print('='*60)
|
| 672 |
+
|
| 673 |
+
messages_to_send = state["messages"]
|
| 674 |
|
| 675 |
+
# Retry logic with exponential backoff
|
| 676 |
+
max_retries = 3
|
| 677 |
+
ai_message = None
|
| 678 |
+
|
| 679 |
+
for attempt in range(max_retries):
|
| 680 |
+
try:
|
| 681 |
+
ai_message = self.llm_with_tools.invoke(messages_to_send)
|
| 682 |
+
break
|
| 683 |
+
except Exception as e:
|
| 684 |
+
print(f"⚠️ LLM attempt {attempt+1}/{max_retries} failed: {e}")
|
| 685 |
+
if attempt == max_retries - 1:
|
| 686 |
+
error_msg = AIMessage(
|
| 687 |
+
content=f"Error: LLM failed after {max_retries} attempts: {str(e)}"
|
| 688 |
+
)
|
| 689 |
+
return {"messages": [error_msg], "turn": current_turn}
|
| 690 |
+
time.sleep(2 ** attempt) # Exponential backoff
|
| 691 |
+
|
| 692 |
+
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
| 693 |
+
# --- (FIX #1) RULE ENFORCEMENT BLOCK ---
|
| 694 |
+
#
|
| 695 |
+
# If it's Turn 1 AND the agent tried to call tools, we reject it
|
| 696 |
+
# and force it to re-do Turn 1.
|
| 697 |
+
if current_turn == 1 and ai_message.tool_calls:
|
| 698 |
+
print("⚠️ AGENT VIOLATION: Tried to call tools on Turn 1. Forcing replan.")
|
| 699 |
+
|
| 700 |
+
# Strip the illegal tool call
|
| 701 |
+
ai_message.tool_calls = []
|
| 702 |
+
|
| 703 |
+
# Create the correction message that forces the plan
|
| 704 |
+
correction_message = SystemMessage(
|
| 705 |
+
content="SYSTEM: Protocol Violation. Your FIRST turn MUST be a plan with NO tool calls. "
|
| 706 |
+
"You are not allowed to call any tools on your first turn. "
|
| 707 |
+
"Re-read the protocol and provide your 2-3 sentence plan now."
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
# Return the messages.
|
| 711 |
+
# Critically, we set the state's turn counter back to 1.
|
| 712 |
+
# This ensures the *next* run of this node is *still* Turn 1.
|
| 713 |
+
return {"messages": [ai_message, correction_message], "turn": 1}
|
| 714 |
+
# --- END OF RULE ENFORCEMENT BLOCK ---
|
| 715 |
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
| 716 |
# --- FIX #2: REPLACE THE FALLBACK PARSING BLOCK ---
|
| 717 |
#
|