ShoaibSSM commited on
Commit
ee9487b
·
verified ·
1 Parent(s): 782c24f

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +140 -96
agent.py CHANGED
@@ -1,18 +1,22 @@
1
- from langgraph.graph import StateGraph, END, START
2
- from shared_store import url_time
3
  import time
4
- from langchain_core.rate_limiters import InMemoryRateLimiter
 
 
 
 
5
  from langgraph.prebuilt import ToolNode
 
 
 
 
 
6
  from tools import (
7
- get_rendered_html, download_file, post_request,
8
- run_code, add_dependencies, ocr_image_tool, transcribe_audio, encode_image_to_base64
9
  )
10
- from typing import TypedDict, Annotated, List
11
- from langchain_core.messages import trim_messages
12
- from langchain.chat_models import init_chat_model
13
- from langgraph.graph.message import add_messages
14
- import os
15
- from dotenv import load_dotenv
16
  load_dotenv()
17
 
18
  EMAIL = os.getenv("EMAIL")
@@ -21,110 +25,161 @@ SECRET = os.getenv("SECRET")
21
  RECURSION_LIMIT = 5000
22
  MAX_TOKENS = 180000
23
 
 
 
 
24
 
25
- # -------------------------------------------------
26
- # STATE
27
- # -------------------------------------------------
28
  class AgentState(TypedDict):
29
  messages: Annotated[List, add_messages]
30
 
31
-
32
  TOOLS = [
33
- run_code, get_rendered_html, download_file,
34
- post_request, add_dependencies, ocr_image_tool, transcribe_audio, encode_image_to_base64
35
  ]
36
 
 
 
 
37
 
38
- # -------------------------------------------------
39
- # LLM INIT (NO SYSTEM PROMPT HERE)
40
- # -------------------------------------------------
41
- rate_limiter = InMemoryRateLimiter(
42
- requests_per_second=7 / 60,
43
- check_every_n_seconds=1,
44
- max_bucket_size=7
45
- )
 
 
46
 
47
- llm = init_chat_model(
48
- model_provider="google_genai",
49
- model="gemini-2.5-flash",
50
- rate_limiter=rate_limiter
51
- ).bind_tools(TOOLS)
52
 
 
 
 
 
 
 
 
 
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- # -------------------------------------------------
56
- # SYSTEM PROMPT (WILL BE INSERTED ONLY ONCE)
57
- # -------------------------------------------------
58
  SYSTEM_PROMPT = f"""
59
  You are an autonomous quiz-solving agent.
60
 
61
- Your job is to:
62
- 1. Load each quiz page from the given URL.
63
- 2. Extract instructions, parameters, and submit endpoint.
64
- 3. Solve tasks exactly.
65
  4. Submit answers ONLY to the correct endpoint.
66
- 5. Follow new URLs until none remain, then output END.
67
 
68
  Rules:
69
- - For base64 generation of an image NEVER use your own code, always use the "encode_image_to_base64" tool that's provided
70
- - Never hallucinate URLs or fields.
71
- - Never shorten endpoints.
72
  - Always inspect server response.
73
- - Never stop early.
74
- - Use tools for HTML, downloading, rendering, OCR, or running code.
75
- - Include:
76
- email = {EMAIL}
77
- secret = {SECRET}
78
  """
79
 
 
 
 
80
 
81
- # -------------------------------------------------
82
- # AGENT NODE
83
- # -------------------------------------------------
84
  def agent_node(state: AgentState):
85
- # time-handling
 
86
  cur_time = time.time()
87
  cur_url = os.getenv("url")
88
- prev_time = url_time[cur_url]
89
- offset = os.getenv("offset")
 
90
  if prev_time is not None:
91
  prev_time = float(prev_time)
92
  diff = cur_time - prev_time
93
 
94
- if diff >= 180 or (offset != "0" and (cur_time - float(offset)) > 90):
95
- print("Timeout exceeded — instructing LLM to purposely submit wrong answer.", diff, "Offset=", offset)
96
 
97
- fail_instruction = """
98
- You have exceeded the time limit for this task (over 130 seconds).
99
- Immediately call the `post_request` tool and submit a WRONG answer for the CURRENT quiz.
 
100
  """
101
 
102
- # LLM will figure out the right endpoint + JSON structure itself
103
- result = llm.invoke([
104
- {"role": "user", "content": fail_instruction}
105
- ])
106
- return {"messages": [result]}
107
 
108
- trimmed_messages = trim_messages(
109
  messages=state["messages"],
110
  max_tokens=MAX_TOKENS,
111
  strategy="last",
112
  include_system=True,
113
  start_on="human",
114
- token_counter=llm, # Use the LLM to count actual tokens, not just list length
115
  )
116
-
117
- result = llm.invoke(trimmed_messages)
118
 
 
119
  return {"messages": [result]}
120
 
121
- # -------------------------------------------------
122
- # ROUTE LOGIC (YOURS WITH MINOR SAFETY IMPROVES)
123
- # -------------------------------------------------
 
124
  def route(state):
125
  last = state["messages"][-1]
126
- # print("=== ROUTE DEBUG: last message type ===")
127
-
128
  tool_calls = getattr(last, "tool_calls", None)
129
 
130
  if tool_calls:
@@ -136,73 +191,62 @@ def route(state):
136
  if isinstance(content, str) and content.strip() == "END":
137
  return END
138
 
139
- if isinstance(content, list) and len(content) and isinstance(content[0], dict):
140
  if content[0].get("text", "").strip() == "END":
141
  return END
142
 
143
  print("Route → agent")
144
  return "agent"
145
 
 
 
 
146
 
147
-
148
- # -------------------------------------------------
149
- # GRAPH
150
- # -------------------------------------------------
151
  graph = StateGraph(AgentState)
152
-
153
  graph.add_node("tools", ToolNode(TOOLS))
154
-
155
  graph.add_edge(START, "agent")
156
  graph.add_edge("tools", "agent")
157
  graph.add_conditional_edges("agent", route)
158
- robust_retry = {
 
159
  "initial_interval": 1,
160
  "backoff_factor": 2,
161
  "max_interval": 60,
162
  "max_attempts": 10
163
- }
164
 
165
- graph.add_node("agent", agent_node, retry=robust_retry)
166
  app = graph.compile()
167
 
 
 
 
168
 
169
-
170
- # -------------------------------------------------
171
- # RUNNER
172
- # -------------------------------------------------
173
  def run_agent(url: str):
174
- # system message is seeded ONCE here
175
- initial_messages = [
176
  {"role": "system", "content": SYSTEM_PROMPT},
177
  {"role": "user", "content": url}
178
  ]
179
 
180
- # run agent and CAPTURE the output
181
  result = app.invoke(
182
- {"messages": initial_messages},
183
  config={"recursion_limit": RECURSION_LIMIT}
184
  )
185
 
186
- # Try to detect final server response if present
187
  try:
188
  last = result["messages"][-1]
189
  content = getattr(last, "content", "")
190
 
191
- # If LLM already output END – good
192
  if isinstance(content, str) and content.strip() == "END":
193
  print("Tasks completed successfully!")
194
  return
195
 
196
- # If the last content is JSON from server submission
197
- import json
198
  parsed = json.loads(content) if isinstance(content, str) else {}
199
  if parsed.get("url") is None:
200
  print("Tasks completed successfully!")
201
  return
202
 
203
  except Exception:
204
- pass # fallback below
205
 
206
- # Default fallback
207
  print("Tasks completed successfully!")
208
-
 
1
+ import os
 
2
  import time
3
+ import json
4
+ from dotenv import load_dotenv
5
+ from typing import TypedDict, Annotated, List
6
+
7
+ from langgraph.graph import StateGraph, START, END
8
  from langgraph.prebuilt import ToolNode
9
+ from langgraph.graph.message import add_messages
10
+ from langchain.chat_models import init_chat_model
11
+ from langchain_core.rate_limiters import InMemoryRateLimiter
12
+ from langchain_core.messages import trim_messages
13
+
14
  from tools import (
15
+ get_rendered_html, download_file, post_request, run_code,
16
+ add_dependencies, ocr_image_tool, transcribe_audio, encode_image_to_base64
17
  )
18
+ from shared_store import url_time
19
+
 
 
 
 
20
  load_dotenv()
21
 
22
  EMAIL = os.getenv("EMAIL")
 
25
  RECURSION_LIMIT = 5000
26
  MAX_TOKENS = 180000
27
 
28
+ # ==============================================================
29
+ # STATE
30
+ # ==============================================================
31
 
 
 
 
32
  class AgentState(TypedDict):
33
  messages: Annotated[List, add_messages]
34
 
 
35
  TOOLS = [
36
+ run_code, get_rendered_html, download_file, post_request,
37
+ add_dependencies, ocr_image_tool, transcribe_audio, encode_image_to_base64
38
  ]
39
 
40
+ # ==============================================================
41
+ # FALLBACK LLM
42
+ # ==============================================================
43
 
44
+ FALLBACK_MODELS = [
45
+ "gemini-2.5-flash",
46
+ "gemini-2.5-flash-lite",
47
+ "gemini-2.0-flash-lite",
48
+ "gemini-2.0-flash",
49
+
50
+ ]
51
+
52
+ def init_llm_with_fallback(tools):
53
+ """Initialize an LLM with automatic fallback selection."""
54
 
55
+ rate_limiter = InMemoryRateLimiter(
56
+ requests_per_second=7/60,
57
+ check_every_n_seconds=1,
58
+ max_bucket_size=7
59
+ )
60
 
61
+ for model_name in FALLBACK_MODELS:
62
+ try:
63
+ print(f"[LLM] Trying model: {model_name}")
64
+ llm = init_chat_model(
65
+ model_provider="google_genai",
66
+ model=model_name,
67
+ rate_limiter=rate_limiter
68
+ ).bind_tools(tools)
69
 
70
+ llm.invoke("ping") # probe
71
+ print(f"[LLM] Model ready: {model_name}")
72
+ return llm
73
+
74
+ except Exception as e:
75
+ print(f"[LLM] Model failed ({model_name}): {e}")
76
+
77
+ raise RuntimeError("❌ No Gemini model available!")
78
+
79
+ # global LLM
80
+ llm = init_llm_with_fallback(TOOLS)
81
+
82
+ # ==============================================================
83
+ # SAFE INVOKE (fallback switcher)
84
+ # ==============================================================
85
+
86
+ def safe_llm_invoke(input_message):
87
+ global llm
88
+ try:
89
+ return llm.invoke(input_message)
90
+
91
+ except Exception as e:
92
+ err = str(e).lower()
93
+
94
+ trigger_fallback = any([
95
+ "429" in err,
96
+ "quota" in err,
97
+ "exceeded" in err,
98
+ "rate" in err,
99
+ "unavailable" in err,
100
+ "deadline" in err,
101
+ "resourceexhausted" in err
102
+ ])
103
+
104
+ if trigger_fallback:
105
+ print("\n⚠️ Quota/Rate error → switching LLM model...\n")
106
+ llm = init_llm_with_fallback(TOOLS)
107
+ return llm.invoke(input_message)
108
+
109
+ raise e
110
+
111
+ # ==============================================================
112
+ # SYSTEM PROMPT
113
+ # ==============================================================
114
 
 
 
 
115
  SYSTEM_PROMPT = f"""
116
  You are an autonomous quiz-solving agent.
117
 
118
+ Your job:
119
+ 1. Load each quiz page.
120
+ 2. Extract instructions, parameters & submit endpoint.
121
+ 3. Solve tasks EXACTLY.
122
  4. Submit answers ONLY to the correct endpoint.
123
+ 5. Follow returned URLs until none remain, then output END.
124
 
125
  Rules:
126
+ - NEVER generate your own base64—use encode_image_to_base64 tool.
127
+ - NEVER hallucinate endpoints.
128
+ - NEVER shorten URLs.
129
  - Always inspect server response.
130
+ - Use tools for HTML, code execution, OCR, downloading, etc.
131
+ Include in every submission:
132
+ email = {EMAIL}
133
+ secret = {SECRET}
 
134
  """
135
 
136
+ # ==============================================================
137
+ # AGENT NODE
138
+ # ==============================================================
139
 
 
 
 
140
  def agent_node(state: AgentState):
141
+
142
+ # ---- TIMEOUT ----
143
  cur_time = time.time()
144
  cur_url = os.getenv("url")
145
+ prev_time = url_time.get(cur_url)
146
+ offset = float(os.getenv("offset", "0"))
147
+
148
  if prev_time is not None:
149
  prev_time = float(prev_time)
150
  diff = cur_time - prev_time
151
 
152
+ if diff >= 180 or (offset != 0 and (cur_time - offset) > 90):
153
+ print("Timeout exceeded — forcing WRONG submission.", diff)
154
 
155
+ instruction = """
156
+ You exceeded allowed time.
157
+ Immediately call post_request with a WRONG answer
158
+ for the CURRENT quiz.
159
  """
160
 
161
+ return {"messages": [safe_llm_invoke(instruction)]}
162
+
163
+ # ---- NORMAL FLOW ----
 
 
164
 
165
+ trimmed = trim_messages(
166
  messages=state["messages"],
167
  max_tokens=MAX_TOKENS,
168
  strategy="last",
169
  include_system=True,
170
  start_on="human",
171
+ token_counter=llm
172
  )
 
 
173
 
174
+ result = safe_llm_invoke(trimmed)
175
  return {"messages": [result]}
176
 
177
+ # ==============================================================
178
+ # ROUTING
179
+ # ==============================================================
180
+
181
  def route(state):
182
  last = state["messages"][-1]
 
 
183
  tool_calls = getattr(last, "tool_calls", None)
184
 
185
  if tool_calls:
 
191
  if isinstance(content, str) and content.strip() == "END":
192
  return END
193
 
194
+ if isinstance(content, list) and len(content):
195
  if content[0].get("text", "").strip() == "END":
196
  return END
197
 
198
  print("Route → agent")
199
  return "agent"
200
 
201
+ # ==============================================================
202
+ # GRAPH BUILD
203
+ # ==============================================================
204
 
 
 
 
 
205
  graph = StateGraph(AgentState)
 
206
  graph.add_node("tools", ToolNode(TOOLS))
 
207
  graph.add_edge(START, "agent")
208
  graph.add_edge("tools", "agent")
209
  graph.add_conditional_edges("agent", route)
210
+
211
+ graph.add_node("agent", agent_node, retry={
212
  "initial_interval": 1,
213
  "backoff_factor": 2,
214
  "max_interval": 60,
215
  "max_attempts": 10
216
+ })
217
 
 
218
  app = graph.compile()
219
 
220
+ # ==============================================================
221
+ # RUN AGENT
222
+ # ==============================================================
223
 
 
 
 
 
224
  def run_agent(url: str):
225
+
226
+ initial = [
227
  {"role": "system", "content": SYSTEM_PROMPT},
228
  {"role": "user", "content": url}
229
  ]
230
 
 
231
  result = app.invoke(
232
+ {"messages": initial},
233
  config={"recursion_limit": RECURSION_LIMIT}
234
  )
235
 
 
236
  try:
237
  last = result["messages"][-1]
238
  content = getattr(last, "content", "")
239
 
 
240
  if isinstance(content, str) and content.strip() == "END":
241
  print("Tasks completed successfully!")
242
  return
243
 
 
 
244
  parsed = json.loads(content) if isinstance(content, str) else {}
245
  if parsed.get("url") is None:
246
  print("Tasks completed successfully!")
247
  return
248
 
249
  except Exception:
250
+ pass
251
 
 
252
  print("Tasks completed successfully!")