ShoaibSSM commited on
Commit
85f0456
·
verified ·
1 Parent(s): 6962466

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +38 -49
agent.py CHANGED
@@ -11,7 +11,6 @@ from typing import TypedDict, Annotated, List
11
  from langchain_core.messages import trim_messages
12
  from langchain.chat_models import init_chat_model
13
  from langgraph.graph.message import add_messages
14
- from langgraph.pregel.retry import RetryPolicy
15
  import os
16
  from dotenv import load_dotenv
17
  load_dotenv()
@@ -37,7 +36,7 @@ TOOLS = [
37
 
38
 
39
  # -------------------------------------------------
40
- # LLM INIT
41
  # -------------------------------------------------
42
  rate_limiter = InMemoryRateLimiter(
43
  requests_per_second=7 / 60,
@@ -47,17 +46,18 @@ rate_limiter = InMemoryRateLimiter(
47
 
48
  llm = init_chat_model(
49
  model_provider="google_genai",
50
- model="gemini-2.5-flash-lite",
51
  rate_limiter=rate_limiter
52
  ).bind_tools(TOOLS)
53
 
54
 
55
 
56
  # -------------------------------------------------
57
- # SYSTEM PROMPT
58
  # -------------------------------------------------
59
  SYSTEM_PROMPT = f"""
60
  You are an autonomous quiz-solving agent.
 
61
  Your job is to:
62
  1. Load each quiz page from the given URL.
63
  2. Extract instructions, parameters, and submit endpoint.
@@ -66,46 +66,40 @@ Your job is to:
66
  5. Follow new URLs until none remain, then output END.
67
 
68
  Rules:
69
- - For base64 generation NEVER use your own code use "encode_image_to_base64"
70
  - Never hallucinate URLs or fields.
71
  - Never shorten endpoints.
72
  - Always inspect server response.
73
  - Never stop early.
74
- - Use tools for HTML, downloading, rendering, OCR, running code.
75
  - Include:
76
  email = {EMAIL}
77
  secret = {SECRET}
78
  """
79
 
 
80
  # -------------------------------------------------
81
  # AGENT NODE
82
  # -------------------------------------------------
83
  def agent_node(state: AgentState):
84
- """Fixes: KeyError, offset None, float(None)"""
85
-
86
  cur_time = time.time()
87
- cur_url = os.getenv("url") or "" # FIX 1: safe load
88
- prev_time = url_time.get(cur_url) # FIX 2: no KeyError
89
- offset = os.getenv("offset") or "0" # FIX 3: no None issues
90
-
91
  if prev_time is not None:
92
  prev_time = float(prev_time)
93
  diff = cur_time - prev_time
94
 
95
- # timeout logic unchanged, only made safe
96
- try:
97
- offset_f = float(offset)
98
- except:
99
- offset_f = 0.0
100
-
101
- if diff >= 180 or (offset_f != 0 and (cur_time - offset_f) > 90):
102
  print("Timeout exceeded — instructing LLM to purposely submit wrong answer.", diff, "Offset=", offset)
103
 
104
  fail_instruction = """
105
- You exceeded the time limit (130s).
106
- Immediately call `post_request` and submit a WRONG answer for the CURRENT quiz.
107
  """
108
 
 
109
  result = llm.invoke([
110
  {"role": "user", "content": fail_instruction}
111
  ])
@@ -117,18 +111,20 @@ def agent_node(state: AgentState):
117
  strategy="last",
118
  include_system=True,
119
  start_on="human",
120
- token_counter=llm,
121
  )
122
-
123
  result = llm.invoke(trimmed_messages)
124
- return {"messages": [result]}
125
 
 
126
 
127
  # -------------------------------------------------
128
- # ROUTE
129
  # -------------------------------------------------
130
  def route(state):
131
  last = state["messages"][-1]
 
 
132
  tool_calls = getattr(last, "tool_calls", None)
133
 
134
  if tool_calls:
@@ -137,17 +133,6 @@ def route(state):
137
 
138
  content = getattr(last, "content", None)
139
 
140
- # allow message dicts (post_request returns dicts)
141
- if isinstance(content, dict):
142
- if content.get("url") == "" or content.get("correct") is False:
143
- # not final, continue agent
144
- print("Route → agent (dict content)")
145
- return "agent"
146
-
147
- if content is None:
148
- print("Content is None → END")
149
- return END
150
-
151
  if isinstance(content, str) and content.strip() == "END":
152
  return END
153
 
@@ -159,6 +144,7 @@ def route(state):
159
  return "agent"
160
 
161
 
 
162
  # -------------------------------------------------
163
  # GRAPH
164
  # -------------------------------------------------
@@ -166,20 +152,17 @@ graph = StateGraph(AgentState)
166
 
167
  graph.add_node("tools", ToolNode(TOOLS))
168
 
169
- # FIX 4 — LangGraph retry policy MUST be RetryPolicy(...)
170
- retry_policy = RetryPolicy(
171
- max_attempts=10,
172
- initial_interval=1,
173
- backoff_factor=2,
174
- max_interval=60
175
- )
176
-
177
- graph.add_node("agent", agent_node, retry=retry_policy)
178
-
179
  graph.add_edge(START, "agent")
180
  graph.add_edge("tools", "agent")
181
  graph.add_conditional_edges("agent", route)
182
-
 
 
 
 
 
 
 
183
  app = graph.compile()
184
 
185
 
@@ -188,31 +171,37 @@ app = graph.compile()
188
  # RUNNER
189
  # -------------------------------------------------
190
  def run_agent(url: str):
 
191
  initial_messages = [
192
  {"role": "system", "content": SYSTEM_PROMPT},
193
  {"role": "user", "content": url}
194
  ]
195
 
 
196
  result = app.invoke(
197
  {"messages": initial_messages},
198
  config={"recursion_limit": RECURSION_LIMIT}
199
  )
200
 
 
201
  try:
202
  last = result["messages"][-1]
203
  content = getattr(last, "content", "")
204
 
 
205
  if isinstance(content, str) and content.strip() == "END":
206
  print("Tasks completed successfully!")
207
  return
208
 
 
209
  import json
210
- parsed = json.loads(content) if isinstance(content, str) else content
211
  if parsed.get("url") is None:
212
  print("Tasks completed successfully!")
213
  return
214
 
215
  except Exception:
216
- pass
217
 
 
218
  print("Tasks completed successfully!")
 
11
  from langchain_core.messages import trim_messages
12
  from langchain.chat_models import init_chat_model
13
  from langgraph.graph.message import add_messages
 
14
  import os
15
  from dotenv import load_dotenv
16
  load_dotenv()
 
36
 
37
 
38
  # -------------------------------------------------
39
+ # LLM INIT (NO SYSTEM PROMPT HERE)
40
  # -------------------------------------------------
41
  rate_limiter = InMemoryRateLimiter(
42
  requests_per_second=7 / 60,
 
46
 
47
  llm = init_chat_model(
48
  model_provider="google_genai",
49
+ model="gemini-2.5-flash",
50
  rate_limiter=rate_limiter
51
  ).bind_tools(TOOLS)
52
 
53
 
54
 
55
  # -------------------------------------------------
56
+ # SYSTEM PROMPT (WILL BE INSERTED ONLY ONCE)
57
  # -------------------------------------------------
58
  SYSTEM_PROMPT = f"""
59
  You are an autonomous quiz-solving agent.
60
+
61
  Your job is to:
62
  1. Load each quiz page from the given URL.
63
  2. Extract instructions, parameters, and submit endpoint.
 
66
  5. Follow new URLs until none remain, then output END.
67
 
68
  Rules:
69
+ - For base64 generation of an image NEVER use your own code, always use the "encode_image_to_base64" tool that's provided
70
  - Never hallucinate URLs or fields.
71
  - Never shorten endpoints.
72
  - Always inspect server response.
73
  - Never stop early.
74
+ - Use tools for HTML, downloading, rendering, OCR, or running code.
75
  - Include:
76
  email = {EMAIL}
77
  secret = {SECRET}
78
  """
79
 
80
+
81
  # -------------------------------------------------
82
  # AGENT NODE
83
  # -------------------------------------------------
84
  def agent_node(state: AgentState):
85
+ # time-handling
 
86
  cur_time = time.time()
87
+ cur_url = os.getenv("url")
88
+ prev_time = url_time[cur_url]
89
+ offset = os.getenv("offset")
 
90
  if prev_time is not None:
91
  prev_time = float(prev_time)
92
  diff = cur_time - prev_time
93
 
94
+ if diff >= 180 or (offset != "0" and (cur_time - float(offset)) > 90):
 
 
 
 
 
 
95
  print("Timeout exceeded — instructing LLM to purposely submit wrong answer.", diff, "Offset=", offset)
96
 
97
  fail_instruction = """
98
+ You have exceeded the time limit for this task (over 130 seconds).
99
+ Immediately call the `post_request` tool and submit a WRONG answer for the CURRENT quiz.
100
  """
101
 
102
+ # LLM will figure out the right endpoint + JSON structure itself
103
  result = llm.invoke([
104
  {"role": "user", "content": fail_instruction}
105
  ])
 
111
  strategy="last",
112
  include_system=True,
113
  start_on="human",
114
+ token_counter=llm, # Use the LLM to count actual tokens, not just list length
115
  )
116
+
117
  result = llm.invoke(trimmed_messages)
 
118
 
119
+ return {"messages": [result]}
120
 
121
  # -------------------------------------------------
122
+ # ROUTE LOGIC (YOURS WITH MINOR SAFETY IMPROVES)
123
  # -------------------------------------------------
124
  def route(state):
125
  last = state["messages"][-1]
126
+ # print("=== ROUTE DEBUG: last message type ===")
127
+
128
  tool_calls = getattr(last, "tool_calls", None)
129
 
130
  if tool_calls:
 
133
 
134
  content = getattr(last, "content", None)
135
 
 
 
 
 
 
 
 
 
 
 
 
136
  if isinstance(content, str) and content.strip() == "END":
137
  return END
138
 
 
144
  return "agent"
145
 
146
 
147
+
148
  # -------------------------------------------------
149
  # GRAPH
150
  # -------------------------------------------------
 
152
 
153
  graph.add_node("tools", ToolNode(TOOLS))
154
 
 
 
 
 
 
 
 
 
 
 
155
  graph.add_edge(START, "agent")
156
  graph.add_edge("tools", "agent")
157
  graph.add_conditional_edges("agent", route)
158
+ robust_retry = {
159
+ "initial_interval": 1,
160
+ "backoff_factor": 2,
161
+ "max_interval": 60,
162
+ "max_attempts": 10
163
+ }
164
+
165
+ graph.add_node("agent", agent_node, retry=robust_retry)
166
  app = graph.compile()
167
 
168
 
 
171
  # RUNNER
172
  # -------------------------------------------------
173
  def run_agent(url: str):
174
+ # system message is seeded ONCE here
175
  initial_messages = [
176
  {"role": "system", "content": SYSTEM_PROMPT},
177
  {"role": "user", "content": url}
178
  ]
179
 
180
+ # run agent and CAPTURE the output
181
  result = app.invoke(
182
  {"messages": initial_messages},
183
  config={"recursion_limit": RECURSION_LIMIT}
184
  )
185
 
186
+ # Try to detect final server response if present
187
  try:
188
  last = result["messages"][-1]
189
  content = getattr(last, "content", "")
190
 
191
+ # If LLM already output END – good
192
  if isinstance(content, str) and content.strip() == "END":
193
  print("Tasks completed successfully!")
194
  return
195
 
196
+ # If the last content is JSON from server submission
197
  import json
198
+ parsed = json.loads(content) if isinstance(content, str) else {}
199
  if parsed.get("url") is None:
200
  print("Tasks completed successfully!")
201
  return
202
 
203
  except Exception:
204
+ pass # fallback below
205
 
206
+ # Default fallback
207
  print("Tasks completed successfully!")