ShoaibSSM commited on
Commit
a154e2b
·
verified ·
1 Parent(s): ee9487b

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +96 -141
agent.py CHANGED
@@ -1,22 +1,18 @@
1
- import os
 
2
  import time
3
- import json
4
- from dotenv import load_dotenv
5
- from typing import TypedDict, Annotated, List
6
-
7
- from langgraph.graph import StateGraph, START, END
8
- from langgraph.prebuilt import ToolNode
9
- from langgraph.graph.message import add_messages
10
- from langchain.chat_models import init_chat_model
11
  from langchain_core.rate_limiters import InMemoryRateLimiter
12
- from langchain_core.messages import trim_messages
13
-
14
  from tools import (
15
- get_rendered_html, download_file, post_request, run_code,
16
- add_dependencies, ocr_image_tool, transcribe_audio, encode_image_to_base64
17
  )
18
- from shared_store import url_time
19
-
 
 
 
 
20
  load_dotenv()
21
 
22
  EMAIL = os.getenv("EMAIL")
@@ -25,161 +21,110 @@ SECRET = os.getenv("SECRET")
25
  RECURSION_LIMIT = 5000
26
  MAX_TOKENS = 180000
27
 
28
- # ==============================================================
29
- # STATE
30
- # ==============================================================
31
 
 
 
 
32
  class AgentState(TypedDict):
33
  messages: Annotated[List, add_messages]
34
 
35
- TOOLS = [
36
- run_code, get_rendered_html, download_file, post_request,
37
- add_dependencies, ocr_image_tool, transcribe_audio, encode_image_to_base64
38
- ]
39
-
40
- # ==============================================================
41
- # FALLBACK LLM
42
- # ==============================================================
43
 
44
- FALLBACK_MODELS = [
45
- "gemini-2.5-flash",
46
- "gemini-2.5-flash-lite",
47
- "gemini-2.0-flash-lite",
48
- "gemini-2.0-flash",
49
-
50
  ]
51
 
52
- def init_llm_with_fallback(tools):
53
- """Initialize an LLM with automatic fallback selection."""
54
 
55
- rate_limiter = InMemoryRateLimiter(
56
- requests_per_second=7/60,
57
- check_every_n_seconds=1,
58
- max_bucket_size=7
59
- )
60
-
61
- for model_name in FALLBACK_MODELS:
62
- try:
63
- print(f"[LLM] Trying model: {model_name}")
64
- llm = init_chat_model(
65
- model_provider="google_genai",
66
- model=model_name,
67
- rate_limiter=rate_limiter
68
- ).bind_tools(tools)
69
-
70
- llm.invoke("ping") # probe
71
- print(f"[LLM] Model ready: {model_name}")
72
- return llm
73
-
74
- except Exception as e:
75
- print(f"[LLM] Model failed ({model_name}): {e}")
76
-
77
- raise RuntimeError("❌ No Gemini model available!")
78
-
79
- # global LLM
80
- llm = init_llm_with_fallback(TOOLS)
81
-
82
- # ==============================================================
83
- # SAFE INVOKE (fallback switcher)
84
- # ==============================================================
85
-
86
- def safe_llm_invoke(input_message):
87
- global llm
88
- try:
89
- return llm.invoke(input_message)
90
-
91
- except Exception as e:
92
- err = str(e).lower()
93
-
94
- trigger_fallback = any([
95
- "429" in err,
96
- "quota" in err,
97
- "exceeded" in err,
98
- "rate" in err,
99
- "unavailable" in err,
100
- "deadline" in err,
101
- "resourceexhausted" in err
102
- ])
103
 
104
- if trigger_fallback:
105
- print("\n⚠️ Quota/Rate error → switching LLM model...\n")
106
- llm = init_llm_with_fallback(TOOLS)
107
- return llm.invoke(input_message)
 
108
 
109
- raise e
110
 
111
- # ==============================================================
112
- # SYSTEM PROMPT
113
- # ==============================================================
114
 
 
 
 
115
  SYSTEM_PROMPT = f"""
116
  You are an autonomous quiz-solving agent.
117
 
118
- Your job:
119
- 1. Load each quiz page.
120
- 2. Extract instructions, parameters & submit endpoint.
121
- 3. Solve tasks EXACTLY.
122
  4. Submit answers ONLY to the correct endpoint.
123
- 5. Follow returned URLs until none remain, then output END.
124
 
125
  Rules:
126
- - NEVER generate your own base64—use encode_image_to_base64 tool.
127
- - NEVER hallucinate endpoints.
128
- - NEVER shorten URLs.
129
  - Always inspect server response.
130
- - Use tools for HTML, code execution, OCR, downloading, etc.
131
- Include in every submission:
132
- email = {EMAIL}
133
- secret = {SECRET}
 
134
  """
135
 
136
- # ==============================================================
137
- # AGENT NODE
138
- # ==============================================================
139
 
 
 
 
140
  def agent_node(state: AgentState):
141
-
142
- # ---- TIMEOUT ----
143
  cur_time = time.time()
144
  cur_url = os.getenv("url")
145
- prev_time = url_time.get(cur_url)
146
- offset = float(os.getenv("offset", "0"))
147
-
148
  if prev_time is not None:
149
  prev_time = float(prev_time)
150
  diff = cur_time - prev_time
151
 
152
- if diff >= 180 or (offset != 0 and (cur_time - offset) > 90):
153
- print("Timeout exceeded — forcing WRONG submission.", diff)
154
 
155
- instruction = """
156
- You exceeded allowed time.
157
- Immediately call post_request with a WRONG answer
158
- for the CURRENT quiz.
159
  """
160
 
161
- return {"messages": [safe_llm_invoke(instruction)]}
162
-
163
- # ---- NORMAL FLOW ----
 
 
164
 
165
- trimmed = trim_messages(
166
  messages=state["messages"],
167
  max_tokens=MAX_TOKENS,
168
  strategy="last",
169
  include_system=True,
170
  start_on="human",
171
- token_counter=llm
172
  )
 
 
173
 
174
- result = safe_llm_invoke(trimmed)
175
  return {"messages": [result]}
176
 
177
- # ==============================================================
178
- # ROUTING
179
- # ==============================================================
180
-
181
  def route(state):
182
  last = state["messages"][-1]
 
 
183
  tool_calls = getattr(last, "tool_calls", None)
184
 
185
  if tool_calls:
@@ -191,62 +136,72 @@ def route(state):
191
  if isinstance(content, str) and content.strip() == "END":
192
  return END
193
 
194
- if isinstance(content, list) and len(content):
195
  if content[0].get("text", "").strip() == "END":
196
  return END
197
 
198
  print("Route → agent")
199
  return "agent"
200
 
201
- # ==============================================================
202
- # GRAPH BUILD
203
- # ==============================================================
204
 
 
 
 
 
205
  graph = StateGraph(AgentState)
 
206
  graph.add_node("tools", ToolNode(TOOLS))
 
207
  graph.add_edge(START, "agent")
208
  graph.add_edge("tools", "agent")
209
  graph.add_conditional_edges("agent", route)
210
-
211
- graph.add_node("agent", agent_node, retry={
212
  "initial_interval": 1,
213
  "backoff_factor": 2,
214
  "max_interval": 60,
215
  "max_attempts": 10
216
- })
217
 
 
218
  app = graph.compile()
219
 
220
- # ==============================================================
221
- # RUN AGENT
222
- # ==============================================================
223
 
224
- def run_agent(url: str):
225
 
226
- initial = [
 
 
 
 
 
227
  {"role": "system", "content": SYSTEM_PROMPT},
228
  {"role": "user", "content": url}
229
  ]
230
 
 
231
  result = app.invoke(
232
- {"messages": initial},
233
  config={"recursion_limit": RECURSION_LIMIT}
234
  )
235
 
 
236
  try:
237
  last = result["messages"][-1]
238
  content = getattr(last, "content", "")
239
 
 
240
  if isinstance(content, str) and content.strip() == "END":
241
  print("Tasks completed successfully!")
242
  return
243
 
 
 
244
  parsed = json.loads(content) if isinstance(content, str) else {}
245
  if parsed.get("url") is None:
246
  print("Tasks completed successfully!")
247
  return
248
 
249
  except Exception:
250
- pass
251
 
252
- print("Tasks completed successfully!")
 
 
1
+ from langgraph.graph import StateGraph, END, START
2
+ from shared_store import url_time
3
  import time
 
 
 
 
 
 
 
 
4
  from langchain_core.rate_limiters import InMemoryRateLimiter
5
+ from langgraph.prebuilt import ToolNode
 
6
  from tools import (
7
+ get_rendered_html, download_file, post_request,
8
+ run_code, add_dependencies, ocr_image_tool, transcribe_audio, encode_image_to_base64
9
  )
10
+ from typing import TypedDict, Annotated, List
11
+ from langchain_core.messages import trim_messages
12
+ from langchain.chat_models import init_chat_model
13
+ from langgraph.graph.message import add_messages
14
+ import os
15
+ from dotenv import load_dotenv
16
  load_dotenv()
17
 
18
  EMAIL = os.getenv("EMAIL")
 
21
  RECURSION_LIMIT = 5000
22
  MAX_TOKENS = 180000
23
 
 
 
 
24
 
25
+ # -------------------------------------------------
26
+ # STATE
27
+ # -------------------------------------------------
28
  class AgentState(TypedDict):
29
  messages: Annotated[List, add_messages]
30
 
 
 
 
 
 
 
 
 
31
 
32
+ TOOLS = [
33
+ run_code, get_rendered_html, download_file,
34
+ post_request, add_dependencies, ocr_image_tool, transcribe_audio, encode_image_to_base64
 
 
 
35
  ]
36
 
 
 
37
 
38
+ # -------------------------------------------------
39
+ # LLM INIT (NO SYSTEM PROMPT HERE)
40
+ # -------------------------------------------------
41
+ rate_limiter = InMemoryRateLimiter(
42
+ requests_per_second=7 / 60,
43
+ check_every_n_seconds=1,
44
+ max_bucket_size=7
45
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
+ llm = init_chat_model(
48
+ model_provider="google_genai",
49
+ model="gemini-2.5-flash",
50
+ rate_limiter=rate_limiter
51
+ ).bind_tools(TOOLS)
52
 
 
53
 
 
 
 
54
 
55
+ # -------------------------------------------------
56
+ # SYSTEM PROMPT (WILL BE INSERTED ONLY ONCE)
57
+ # -------------------------------------------------
58
  SYSTEM_PROMPT = f"""
59
  You are an autonomous quiz-solving agent.
60
 
61
+ Your job is to:
62
+ 1. Load each quiz page from the given URL.
63
+ 2. Extract instructions, parameters, and submit endpoint.
64
+ 3. Solve tasks exactly.
65
  4. Submit answers ONLY to the correct endpoint.
66
+ 5. Follow new URLs until none remain, then output END.
67
 
68
  Rules:
69
+ - For base64 generation of an image NEVER use your own code, always use the "encode_image_to_base64" tool that's provided
70
+ - Never hallucinate URLs or fields.
71
+ - Never shorten endpoints.
72
  - Always inspect server response.
73
+ - Never stop early.
74
+ - Use tools for HTML, downloading, rendering, OCR, or running code.
75
+ - Include:
76
+ email = {EMAIL}
77
+ secret = {SECRET}
78
  """
79
 
 
 
 
80
 
81
+ # -------------------------------------------------
82
+ # AGENT NODE
83
+ # -------------------------------------------------
84
  def agent_node(state: AgentState):
85
+ # time-handling
 
86
  cur_time = time.time()
87
  cur_url = os.getenv("url")
88
+ prev_time = url_time[cur_url]
89
+ offset = os.getenv("offset")
 
90
  if prev_time is not None:
91
  prev_time = float(prev_time)
92
  diff = cur_time - prev_time
93
 
94
+ if diff >= 180 or (offset != "0" and (cur_time - float(offset)) > 90):
95
+ print("Timeout exceeded — instructing LLM to purposely submit wrong answer.", diff, "Offset=", offset)
96
 
97
+ fail_instruction = """
98
+ You have exceeded the time limit for this task (over 130 seconds).
99
+ Immediately call the `post_request` tool and submit a WRONG answer for the CURRENT quiz.
 
100
  """
101
 
102
+ # LLM will figure out the right endpoint + JSON structure itself
103
+ result = llm.invoke([
104
+ {"role": "user", "content": fail_instruction}
105
+ ])
106
+ return {"messages": [result]}
107
 
108
+ trimmed_messages = trim_messages(
109
  messages=state["messages"],
110
  max_tokens=MAX_TOKENS,
111
  strategy="last",
112
  include_system=True,
113
  start_on="human",
114
+ token_counter=llm, # Use the LLM to count actual tokens, not just list length
115
  )
116
+
117
+ result = llm.invoke(trimmed_messages)
118
 
 
119
  return {"messages": [result]}
120
 
121
+ # -------------------------------------------------
122
+ # ROUTE LOGIC (YOURS WITH MINOR SAFETY IMPROVES)
123
+ # -------------------------------------------------
 
124
  def route(state):
125
  last = state["messages"][-1]
126
+ # print("=== ROUTE DEBUG: last message type ===")
127
+
128
  tool_calls = getattr(last, "tool_calls", None)
129
 
130
  if tool_calls:
 
136
  if isinstance(content, str) and content.strip() == "END":
137
  return END
138
 
139
+ if isinstance(content, list) and len(content) and isinstance(content[0], dict):
140
  if content[0].get("text", "").strip() == "END":
141
  return END
142
 
143
  print("Route → agent")
144
  return "agent"
145
 
 
 
 
146
 
147
+
148
+ # -------------------------------------------------
149
+ # GRAPH
150
+ # -------------------------------------------------
151
  graph = StateGraph(AgentState)
152
+
153
  graph.add_node("tools", ToolNode(TOOLS))
154
+
155
  graph.add_edge(START, "agent")
156
  graph.add_edge("tools", "agent")
157
  graph.add_conditional_edges("agent", route)
158
+ robust_retry = {
 
159
  "initial_interval": 1,
160
  "backoff_factor": 2,
161
  "max_interval": 60,
162
  "max_attempts": 10
163
+ }
164
 
165
+ graph.add_node("agent", agent_node, retry=robust_retry)
166
  app = graph.compile()
167
 
 
 
 
168
 
 
169
 
170
+ # -------------------------------------------------
171
+ # RUNNER
172
+ # -------------------------------------------------
173
+ def run_agent(url: str):
174
+ # system message is seeded ONCE here
175
+ initial_messages = [
176
  {"role": "system", "content": SYSTEM_PROMPT},
177
  {"role": "user", "content": url}
178
  ]
179
 
180
+ # run agent and CAPTURE the output
181
  result = app.invoke(
182
+ {"messages": initial_messages},
183
  config={"recursion_limit": RECURSION_LIMIT}
184
  )
185
 
186
+ # Try to detect final server response if present
187
  try:
188
  last = result["messages"][-1]
189
  content = getattr(last, "content", "")
190
 
191
+ # If LLM already output END – good
192
  if isinstance(content, str) and content.strip() == "END":
193
  print("Tasks completed successfully!")
194
  return
195
 
196
+ # If the last content is JSON from server submission
197
+ import json
198
  parsed = json.loads(content) if isinstance(content, str) else {}
199
  if parsed.get("url") is None:
200
  print("Tasks completed successfully!")
201
  return
202
 
203
  except Exception:
204
+ pass # fallback below
205
 
206
+ # Default fallback
207
+ print("Tasks completed successfully!")