ShoaibSSM commited on
Commit
4026086
·
verified ·
1 Parent(s): fab790b

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +208 -185
agent.py CHANGED
@@ -1,185 +1,208 @@
1
- from langgraph.graph import StateGraph, END, START
2
- from shared_store import url_time
3
- import time
4
- from langchain_core.rate_limiters import InMemoryRateLimiter
5
- from langgraph.prebuilt import ToolNode
6
- from tools import (
7
- get_rendered_html, download_file, post_request,
8
- run_code, add_dependencies, ocr_image_tool, transcribe_audio, encode_image_to_base64
9
- )
10
- from typing import TypedDict, Annotated, List
11
- from langchain_core.messages import trim_messages
12
- from langchain.chat_models import init_chat_model
13
- from langgraph.graph.message import add_messages
14
- import os
15
- from dotenv import load_dotenv
16
- load_dotenv()
17
-
18
- EMAIL = os.getenv("EMAIL")
19
- SECRET = os.getenv("SECRET")
20
-
21
- RECURSION_LIMIT = 5000
22
- MAX_TOKENS = 180000
23
-
24
-
25
- # -------------------------------------------------
26
- # STATE
27
- # -------------------------------------------------
28
- class AgentState(TypedDict):
29
- messages: Annotated[List, add_messages]
30
-
31
-
32
- TOOLS = [
33
- run_code, get_rendered_html, download_file,
34
- post_request, add_dependencies, ocr_image_tool, transcribe_audio, encode_image_to_base64
35
- ]
36
-
37
-
38
- # -------------------------------------------------
39
- # LLM INIT (NO SYSTEM PROMPT HERE)
40
- # -------------------------------------------------
41
- rate_limiter = InMemoryRateLimiter(
42
- requests_per_second=7 / 60,
43
- check_every_n_seconds=1,
44
- max_bucket_size=7
45
- )
46
-
47
- llm = init_chat_model(
48
- model_provider="google_genai",
49
- model="gemini-2.5-flash",
50
- rate_limiter=rate_limiter
51
- ).bind_tools(TOOLS)
52
-
53
-
54
-
55
- # -------------------------------------------------
56
- # SYSTEM PROMPT (WILL BE INSERTED ONLY ONCE)
57
- # -------------------------------------------------
58
- SYSTEM_PROMPT = f"""
59
- You are an autonomous quiz-solving agent.
60
-
61
- Your job is to:
62
- 1. Load each quiz page from the given URL.
63
- 2. Extract instructions, parameters, and submit endpoint.
64
- 3. Solve tasks exactly.
65
- 4. Submit answers ONLY to the correct endpoint.
66
- 5. Follow new URLs until none remain, then output END.
67
-
68
- Rules:
69
- - For base64 generation of an image NEVER use your own code, always use the "encode_image_to_base64" tool that's provided
70
- - Never hallucinate URLs or fields.
71
- - Never shorten endpoints.
72
- - Always inspect server response.
73
- - Never stop early.
74
- - Use tools for HTML, downloading, rendering, OCR, or running code.
75
- - Include:
76
- email = {EMAIL}
77
- secret = {SECRET}
78
- """
79
-
80
-
81
- # -------------------------------------------------
82
- # AGENT NODE
83
- # -------------------------------------------------
84
- def agent_node(state: AgentState):
85
- # time-handling
86
- cur_time = time.time()
87
- cur_url = os.getenv("url")
88
- prev_time = url_time[cur_url]
89
- offset = os.getenv("offset")
90
- if prev_time is not None:
91
- prev_time = float(prev_time)
92
- diff = cur_time - prev_time
93
-
94
- if diff >= 180 or (offset != "0" and (cur_time - float(offset)) > 90):
95
- print("Timeout exceeded — instructing LLM to purposely submit wrong answer.", diff, "Offset=", offset)
96
-
97
- fail_instruction = """
98
- You have exceeded the time limit for this task (over 130 seconds).
99
- Immediately call the `post_request` tool and submit a WRONG answer for the CURRENT quiz.
100
- """
101
-
102
- # LLM will figure out the right endpoint + JSON structure itself
103
- result = llm.invoke([
104
- {"role": "user", "content": fail_instruction}
105
- ])
106
- return {"messages": [result]}
107
-
108
- trimmed_messages = trim_messages(
109
- messages=state["messages"],
110
- max_tokens=MAX_TOKENS,
111
- strategy="last",
112
- include_system=True,
113
- start_on="human",
114
- token_counter=llm, # Use the LLM to count actual tokens, not just list length
115
- )
116
-
117
- result = llm.invoke(trimmed_messages)
118
-
119
- return {"messages": [result]}
120
-
121
- # -------------------------------------------------
122
- # ROUTE LOGIC (YOURS WITH MINOR SAFETY IMPROVES)
123
- # -------------------------------------------------
124
- def route(state):
125
- last = state["messages"][-1]
126
- # print("=== ROUTE DEBUG: last message type ===")
127
-
128
- tool_calls = getattr(last, "tool_calls", None)
129
-
130
- if tool_calls:
131
- print("Route → tools")
132
- return "tools"
133
-
134
- content = getattr(last, "content", None)
135
-
136
- if isinstance(content, str) and content.strip() == "END":
137
- return END
138
-
139
- if isinstance(content, list) and len(content) and isinstance(content[0], dict):
140
- if content[0].get("text", "").strip() == "END":
141
- return END
142
-
143
- print("Route → agent")
144
- return "agent"
145
-
146
-
147
-
148
- # -------------------------------------------------
149
- # GRAPH
150
- # -------------------------------------------------
151
- graph = StateGraph(AgentState)
152
-
153
- graph.add_node("tools", ToolNode(TOOLS))
154
-
155
- graph.add_edge(START, "agent")
156
- graph.add_edge("tools", "agent")
157
- graph.add_conditional_edges("agent", route)
158
- robust_retry = {
159
- "initial_interval": 1,
160
- "backoff_factor": 2,
161
- "max_interval": 60,
162
- "max_attempts": 10
163
- }
164
-
165
- graph.add_node("agent", agent_node, retry=robust_retry)
166
- app = graph.compile()
167
-
168
-
169
-
170
- # -------------------------------------------------
171
- # RUNNER
172
- # -------------------------------------------------
173
- def run_agent(url: str):
174
- # system message is seeded ONCE here
175
- initial_messages = [
176
- {"role": "system", "content": SYSTEM_PROMPT},
177
- {"role": "user", "content": url}
178
- ]
179
-
180
- app.invoke(
181
- {"messages": initial_messages},
182
- config={"recursion_limit": RECURSION_LIMIT}
183
- )
184
-
185
- print("Tasks completed successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.graph import StateGraph, END, START
2
+ from shared_store import url_time
3
+ import time
4
+ from langchain_core.rate_limiters import InMemoryRateLimiter
5
+ from langgraph.prebuilt import ToolNode
6
+ from tools import (
7
+ get_rendered_html, download_file, post_request,
8
+ run_code, add_dependencies, ocr_image_tool, transcribe_audio, encode_image_to_base64
9
+ )
10
+ from typing import TypedDict, Annotated, List
11
+ from langchain_core.messages import trim_messages
12
+ from langchain.chat_models import init_chat_model
13
+ from langgraph.graph.message import add_messages
14
+ import os
15
+ from dotenv import load_dotenv
16
+ load_dotenv()
17
+
18
+ EMAIL = os.getenv("EMAIL")
19
+ SECRET = os.getenv("SECRET")
20
+
21
+ RECURSION_LIMIT = 5000
22
+ MAX_TOKENS = 180000
23
+
24
+
25
+ # -------------------------------------------------
26
+ # STATE
27
+ # -------------------------------------------------
28
+ class AgentState(TypedDict):
29
+ messages: Annotated[List, add_messages]
30
+
31
+
32
+ TOOLS = [
33
+ run_code, get_rendered_html, download_file,
34
+ post_request, add_dependencies, ocr_image_tool, transcribe_audio, encode_image_to_base64
35
+ ]
36
+
37
+
38
+ # -------------------------------------------------
39
+ # LLM INIT (NO SYSTEM PROMPT HERE)
40
+ # -------------------------------------------------
41
+ rate_limiter = InMemoryRateLimiter(
42
+ requests_per_second=7 / 60,
43
+ check_every_n_seconds=1,
44
+ max_bucket_size=7
45
+ )
46
+
47
+ llm = init_chat_model(
48
+ model_provider="google_genai",
49
+ model="gemini-2.5-flash",
50
+ rate_limiter=rate_limiter
51
+ ).bind_tools(TOOLS)
52
+
53
+
54
+
55
+ # -------------------------------------------------
56
+ # SYSTEM PROMPT (WILL BE INSERTED ONLY ONCE)
57
+ # -------------------------------------------------
58
+ SYSTEM_PROMPT = f"""
59
+ You are an autonomous quiz-solving agent.
60
+
61
+ Your job is to:
62
+ 1. Load each quiz page from the given URL.
63
+ 2. Extract instructions, parameters, and submit endpoint.
64
+ 3. Solve tasks exactly.
65
+ 4. Submit answers ONLY to the correct endpoint.
66
+ 5. Follow new URLs until none remain, then output END.
67
+
68
+ Rules:
69
+ - For base64 generation of an image NEVER use your own code, always use the "encode_image_to_base64" tool that's provided
70
+ - Never hallucinate URLs or fields.
71
+ - Never shorten endpoints.
72
+ - Always inspect server response.
73
+ - Never stop early.
74
+ - Use tools for HTML, downloading, rendering, OCR, or running code.
75
+ - Include:
76
+ email = {EMAIL}
77
+ secret = {SECRET}
78
+ """
79
+
80
+
81
+ # -------------------------------------------------
82
+ # AGENT NODE
83
+ # -------------------------------------------------
84
+ def agent_node(state: AgentState):
85
+ # time-handling
86
+ cur_time = time.time()
87
+ cur_url = os.getenv("url")
88
+ prev_time = url_time[cur_url]
89
+ offset = os.getenv("offset")
90
+ if prev_time is not None:
91
+ prev_time = float(prev_time)
92
+ diff = cur_time - prev_time
93
+
94
+ if diff >= 180 or (offset != "0" and (cur_time - float(offset)) > 90):
95
+ print("Timeout exceeded — instructing LLM to purposely submit wrong answer.", diff, "Offset=", offset)
96
+
97
+ fail_instruction = """
98
+ You have exceeded the time limit for this task (over 130 seconds).
99
+ Immediately call the `post_request` tool and submit a WRONG answer for the CURRENT quiz.
100
+ """
101
+
102
+ # LLM will figure out the right endpoint + JSON structure itself
103
+ result = llm.invoke([
104
+ {"role": "user", "content": fail_instruction}
105
+ ])
106
+ return {"messages": [result]}
107
+
108
+ trimmed_messages = trim_messages(
109
+ messages=state["messages"],
110
+ max_tokens=MAX_TOKENS,
111
+ strategy="last",
112
+ include_system=True,
113
+ start_on="human",
114
+ token_counter=llm, # Use the LLM to count actual tokens, not just list length
115
+ )
116
+
117
+ result = llm.invoke(trimmed_messages)
118
+
119
+ return {"messages": [result]}
120
+
121
+ # -------------------------------------------------
122
+ # ROUTE LOGIC (YOURS WITH MINOR SAFETY IMPROVES)
123
+ # -------------------------------------------------
124
+ def route(state):
125
+ last = state["messages"][-1]
126
+ # print("=== ROUTE DEBUG: last message type ===")
127
+
128
+ tool_calls = getattr(last, "tool_calls", None)
129
+
130
+ if tool_calls:
131
+ print("Route → tools")
132
+ return "tools"
133
+
134
+ content = getattr(last, "content", None)
135
+
136
+ if isinstance(content, str) and content.strip() == "END":
137
+ return END
138
+
139
+ if isinstance(content, list) and len(content) and isinstance(content[0], dict):
140
+ if content[0].get("text", "").strip() == "END":
141
+ return END
142
+
143
+ print("Route → agent")
144
+ return "agent"
145
+
146
+
147
+
148
+ # -------------------------------------------------
149
+ # GRAPH
150
+ # -------------------------------------------------
151
+ graph = StateGraph(AgentState)
152
+
153
+ graph.add_node("tools", ToolNode(TOOLS))
154
+
155
+ graph.add_edge(START, "agent")
156
+ graph.add_edge("tools", "agent")
157
+ graph.add_conditional_edges("agent", route)
158
+ robust_retry = {
159
+ "initial_interval": 1,
160
+ "backoff_factor": 2,
161
+ "max_interval": 60,
162
+ "max_attempts": 10
163
+ }
164
+
165
+ graph.add_node("agent", agent_node, retry=robust_retry)
166
+ app = graph.compile()
167
+
168
+
169
+
170
+ # -------------------------------------------------
171
+ # RUNNER
172
+ # -------------------------------------------------
173
+ def run_agent(url: str):
174
+ # system message is seeded ONCE here
175
+ initial_messages = [
176
+ {"role": "system", "content": SYSTEM_PROMPT},
177
+ {"role": "user", "content": url}
178
+ ]
179
+
180
+ # run agent and CAPTURE the output
181
+ result = app.invoke(
182
+ {"messages": initial_messages},
183
+ config={"recursion_limit": RECURSION_LIMIT}
184
+ )
185
+
186
+ # Try to detect final server response if present
187
+ try:
188
+ last = result["messages"][-1]
189
+ content = getattr(last, "content", "")
190
+
191
+ # If LLM already output END – good
192
+ if isinstance(content, str) and content.strip() == "END":
193
+ print("Tasks completed successfully!")
194
+ return
195
+
196
+ # If the last content is JSON from server submission
197
+ import json
198
+ parsed = json.loads(content) if isinstance(content, str) else {}
199
+ if parsed.get("url") is None:
200
+ print("Tasks completed successfully!")
201
+ return
202
+
203
+ except Exception:
204
+ pass # fallback below
205
+
206
+ # Default fallback
207
+ print("Tasks completed successfully!")
208
+