ShoaibSSM commited on
Commit
44b08c9
·
verified ·
1 Parent(s): 85f0456

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +84 -56
agent.py CHANGED
@@ -8,7 +8,7 @@ from tools import (
8
  run_code, add_dependencies, ocr_image_tool, transcribe_audio, encode_image_to_base64
9
  )
10
  from typing import TypedDict, Annotated, List
11
- from langchain_core.messages import trim_messages
12
  from langchain.chat_models import init_chat_model
13
  from langgraph.graph.message import add_messages
14
  import os
@@ -19,7 +19,7 @@ EMAIL = os.getenv("EMAIL")
19
  SECRET = os.getenv("SECRET")
20
 
21
  RECURSION_LIMIT = 5000
22
- MAX_TOKENS = 180000
23
 
24
 
25
  # -------------------------------------------------
@@ -36,12 +36,12 @@ TOOLS = [
36
 
37
 
38
  # -------------------------------------------------
39
- # LLM INIT (NO SYSTEM PROMPT HERE)
40
  # -------------------------------------------------
41
  rate_limiter = InMemoryRateLimiter(
42
- requests_per_second=7 / 60,
43
  check_every_n_seconds=1,
44
- max_bucket_size=7
45
  )
46
 
47
  llm = init_chat_model(
@@ -51,9 +51,8 @@ llm = init_chat_model(
51
  ).bind_tools(TOOLS)
52
 
53
 
54
-
55
  # -------------------------------------------------
56
- # SYSTEM PROMPT (WILL BE INSERTED ONLY ONCE)
57
  # -------------------------------------------------
58
  SYSTEM_PROMPT = f"""
59
  You are an autonomous quiz-solving agent.
@@ -78,32 +77,56 @@ Rules:
78
  """
79
 
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # -------------------------------------------------
82
  # AGENT NODE
83
  # -------------------------------------------------
84
  def agent_node(state: AgentState):
85
- # time-handling
86
  cur_time = time.time()
87
  cur_url = os.getenv("url")
88
- prev_time = url_time[cur_url]
89
- offset = os.getenv("offset")
 
 
 
90
  if prev_time is not None:
91
  prev_time = float(prev_time)
92
  diff = cur_time - prev_time
93
 
94
  if diff >= 180 or (offset != "0" and (cur_time - float(offset)) > 90):
95
- print("Timeout exceeded — instructing LLM to purposely submit wrong answer.", diff, "Offset=", offset)
96
 
97
  fail_instruction = """
98
- You have exceeded the time limit for this task (over 130 seconds).
99
  Immediately call the `post_request` tool and submit a WRONG answer for the CURRENT quiz.
100
  """
101
 
102
- # LLM will figure out the right endpoint + JSON structure itself
103
- result = llm.invoke([
104
- {"role": "user", "content": fail_instruction}
105
- ])
 
106
  return {"messages": [result]}
 
107
 
108
  trimmed_messages = trim_messages(
109
  messages=state["messages"],
@@ -111,28 +134,48 @@ def agent_node(state: AgentState):
111
  strategy="last",
112
  include_system=True,
113
  start_on="human",
114
- token_counter=llm, # Use the LLM to count actual tokens, not just list length
115
  )
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  result = llm.invoke(trimmed_messages)
118
 
119
  return {"messages": [result]}
120
 
 
121
  # -------------------------------------------------
122
- # ROUTE LOGIC (YOURS WITH MINOR SAFETY IMPROVES)
123
  # -------------------------------------------------
124
  def route(state):
125
  last = state["messages"][-1]
126
- # print("=== ROUTE DEBUG: last message type ===")
 
 
 
 
127
 
 
128
  tool_calls = getattr(last, "tool_calls", None)
129
-
130
  if tool_calls:
131
  print("Route → tools")
132
  return "tools"
133
 
 
134
  content = getattr(last, "content", None)
135
-
136
  if isinstance(content, str) and content.strip() == "END":
137
  return END
138
 
@@ -144,27 +187,34 @@ def route(state):
144
  return "agent"
145
 
146
 
147
-
148
  # -------------------------------------------------
149
  # GRAPH
150
  # -------------------------------------------------
151
  graph = StateGraph(AgentState)
152
 
 
 
153
  graph.add_node("tools", ToolNode(TOOLS))
 
154
 
 
155
  graph.add_edge(START, "agent")
156
  graph.add_edge("tools", "agent")
157
- graph.add_conditional_edges("agent", route)
158
- robust_retry = {
159
- "initial_interval": 1,
160
- "backoff_factor": 2,
161
- "max_interval": 60,
162
- "max_attempts": 10
163
- }
164
-
165
- graph.add_node("agent", agent_node, retry=robust_retry)
166
- app = graph.compile()
 
 
 
167
 
 
168
 
169
 
170
  # -------------------------------------------------
@@ -177,31 +227,9 @@ def run_agent(url: str):
177
  {"role": "user", "content": url}
178
  ]
179
 
180
- # run agent and CAPTURE the output
181
- result = app.invoke(
182
  {"messages": initial_messages},
183
  config={"recursion_limit": RECURSION_LIMIT}
184
  )
185
 
186
- # Try to detect final server response if present
187
- try:
188
- last = result["messages"][-1]
189
- content = getattr(last, "content", "")
190
-
191
- # If LLM already output END – good
192
- if isinstance(content, str) and content.strip() == "END":
193
- print("Tasks completed successfully!")
194
- return
195
-
196
- # If the last content is JSON from server submission
197
- import json
198
- parsed = json.loads(content) if isinstance(content, str) else {}
199
- if parsed.get("url") is None:
200
- print("Tasks completed successfully!")
201
- return
202
-
203
- except Exception:
204
- pass # fallback below
205
-
206
- # Default fallback
207
- print("Tasks completed successfully!")
 
8
  run_code, add_dependencies, ocr_image_tool, transcribe_audio, encode_image_to_base64
9
  )
10
  from typing import TypedDict, Annotated, List
11
+ from langchain_core.messages import trim_messages, HumanMessage
12
  from langchain.chat_models import init_chat_model
13
  from langgraph.graph.message import add_messages
14
  import os
 
19
  SECRET = os.getenv("SECRET")
20
 
21
  RECURSION_LIMIT = 5000
22
+ MAX_TOKENS = 60000
23
 
24
 
25
  # -------------------------------------------------
 
36
 
37
 
38
  # -------------------------------------------------
39
+ # LLM INIT
40
  # -------------------------------------------------
41
  rate_limiter = InMemoryRateLimiter(
42
+ requests_per_second=4 / 60,
43
  check_every_n_seconds=1,
44
+ max_bucket_size=4
45
  )
46
 
47
  llm = init_chat_model(
 
51
  ).bind_tools(TOOLS)
52
 
53
 
 
54
  # -------------------------------------------------
55
+ # SYSTEM PROMPT
56
  # -------------------------------------------------
57
  SYSTEM_PROMPT = f"""
58
  You are an autonomous quiz-solving agent.
 
77
  """
78
 
79
 
80
+ # -------------------------------------------------
81
+ # NEW NODE: HANDLE MALFORMED JSON
82
+ # -------------------------------------------------
83
+ def handle_malformed_node(state: AgentState):
84
+ """
85
+ If the LLM generates invalid JSON, this node sends a correction message
86
+ so the LLM can try again.
87
+ """
88
+ print("--- DETECTED MALFORMED JSON. ASKING AGENT TO RETRY ---")
89
+ return {
90
+ "messages": [
91
+ {
92
+ "role": "user",
93
+ "content": "SYSTEM ERROR: Your last tool call was Malformed (Invalid JSON). Please rewrite the code and try again. Ensure you escape newlines and quotes correctly inside the JSON."
94
+ }
95
+ ]
96
+ }
97
+
98
+
99
  # -------------------------------------------------
100
  # AGENT NODE
101
  # -------------------------------------------------
102
  def agent_node(state: AgentState):
103
+ # --- TIME HANDLING START ---
104
  cur_time = time.time()
105
  cur_url = os.getenv("url")
106
+
107
+ # SAFE GET: Prevents crash if url is None or not in dict
108
+ prev_time = url_time.get(cur_url)
109
+ offset = os.getenv("offset", "0")
110
+
111
  if prev_time is not None:
112
  prev_time = float(prev_time)
113
  diff = cur_time - prev_time
114
 
115
  if diff >= 180 or (offset != "0" and (cur_time - float(offset)) > 90):
116
+ print(f"Timeout exceeded ({diff}s) — instructing LLM to purposely submit wrong answer.")
117
 
118
  fail_instruction = """
119
+ You have exceeded the time limit for this task (over 180 seconds).
120
  Immediately call the `post_request` tool and submit a WRONG answer for the CURRENT quiz.
121
  """
122
 
123
+ # Using HumanMessage (as you correctly implemented)
124
+ fail_msg = HumanMessage(content=fail_instruction)
125
+
126
+ # We invoke the LLM immediately with this new instruction
127
+ result = llm.invoke(state["messages"] + [fail_msg])
128
  return {"messages": [result]}
129
+ # --- TIME HANDLING END ---
130
 
131
  trimmed_messages = trim_messages(
132
  messages=state["messages"],
 
134
  strategy="last",
135
  include_system=True,
136
  start_on="human",
137
+ token_counter=llm,
138
  )
139
 
140
+ # Better check: Does it have a HumanMessage?
141
+ has_human = any(msg.type == "human" for msg in trimmed_messages)
142
+
143
+ if not has_human:
144
+ print("WARNING: Context was trimmed too far. Injecting state reminder.")
145
+ # We remind the agent of the current URL from the environment
146
+ current_url = os.getenv("url", "Unknown URL")
147
+ reminder = HumanMessage(content=f"Context cleared due to length. Continue processing URL: {current_url}")
148
+
149
+ # We append this to the trimmed list (temporarily for this invoke)
150
+ trimmed_messages.append(reminder)
151
+ # ----------------------------------------
152
+
153
+ print(f"--- INVOKING AGENT (Context: {len(trimmed_messages)} items) ---")
154
+
155
  result = llm.invoke(trimmed_messages)
156
 
157
  return {"messages": [result]}
158
 
159
+
160
  # -------------------------------------------------
161
+ # ROUTE LOGIC (UPDATED FOR MALFORMED CALLS)
162
  # -------------------------------------------------
163
  def route(state):
164
  last = state["messages"][-1]
165
+
166
+ # 1. CHECK FOR MALFORMED FUNCTION CALLS
167
+ if "finish_reason" in last.response_metadata:
168
+ if last.response_metadata["finish_reason"] == "MALFORMED_FUNCTION_CALL":
169
+ return "handle_malformed"
170
 
171
+ # 2. CHECK FOR VALID TOOLS
172
  tool_calls = getattr(last, "tool_calls", None)
 
173
  if tool_calls:
174
  print("Route → tools")
175
  return "tools"
176
 
177
+ # 3. CHECK FOR END
178
  content = getattr(last, "content", None)
 
179
  if isinstance(content, str) and content.strip() == "END":
180
  return END
181
 
 
187
  return "agent"
188
 
189
 
 
190
  # -------------------------------------------------
191
  # GRAPH
192
  # -------------------------------------------------
193
  graph = StateGraph(AgentState)
194
 
195
+ # Add Nodes
196
+ graph.add_node("agent", agent_node)
197
  graph.add_node("tools", ToolNode(TOOLS))
198
+ graph.add_node("handle_malformed", handle_malformed_node) # Add the repair node
199
 
200
+ # Add Edges
201
  graph.add_edge(START, "agent")
202
  graph.add_edge("tools", "agent")
203
+ graph.add_edge("handle_malformed", "agent") # Retry loop
204
+
205
+ # Conditional Edges
206
+ graph.add_conditional_edges(
207
+ "agent",
208
+ route,
209
+ {
210
+ "tools": "tools",
211
+ "agent": "agent",
212
+ "handle_malformed": "handle_malformed", # Map the new route
213
+ END: END
214
+ }
215
+ )
216
 
217
+ app = graph.compile()
218
 
219
 
220
  # -------------------------------------------------
 
227
  {"role": "user", "content": url}
228
  ]
229
 
230
+ app.invoke(
 
231
  {"messages": initial_messages},
232
  config={"recursion_limit": RECURSION_LIMIT}
233
  )
234
 
235
+ print("Tasks completed successfully!")