Frazer2810 commited on
Commit
378c064
Β·
verified Β·
1 Parent(s): 833c04d

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +59 -167
agent.py CHANGED
@@ -1,173 +1,65 @@
1
- """LangGraph Agent – retry 5s, 30s, 60s β€’ no Supabase"""
2
-
3
- import os
4
- import time
5
- from dotenv import load_dotenv
6
- from langgraph.graph import START, StateGraph, MessagesState
7
- from langgraph.prebuilt import ToolNode, tools_condition
8
-
9
- # LLM providers
10
- from langchain_google_genai import ChatGoogleGenerativeAI
11
- from langchain_groq import ChatGroq
12
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
13
-
14
- # Tools & loaders
15
- from langchain_community.tools.tavily_search import TavilySearchResults
16
- from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
17
- from langchain_core.messages import SystemMessage, HumanMessage
18
- from langchain_core.tools import tool
19
-
20
- load_dotenv()
21
-
22
- # --------------------------------------------------------------------------- #
23
- # TOOL DEFINITIONS #
24
- # --------------------------------------------------------------------------- #
25
- @tool
26
- def multiply(a: int, b: int) -> int:
27
- """Multiply two integers and return the product."""
28
- return a * b
29
-
30
- @tool
31
- def add(a: int, b: int) -> int:
32
- """Add two integers and return the sum."""
33
- return a + b
34
-
35
- @tool
36
- def subtract(a: int, b: int) -> int:
37
- """Subtract the second integer from the first and return the difference."""
38
- return a - b
39
-
40
- @tool
41
- def divide(a: int, b: int) -> float:
42
- """Divide a by b and return the quotient (error if b == 0)."""
43
- if b == 0:
44
- raise ValueError("Cannot divide by zero.")
45
- return a / b
46
-
47
- @tool
48
- def modulus(a: int, b: int) -> int:
49
- """Return the remainder of the division of a by b."""
50
- return a % b
51
-
52
- @tool
53
- def wiki_search(query: str) -> str:
54
- """Search Wikipedia (max 2 docs) and return formatted content."""
55
- docs = WikipediaLoader(query=query, load_max_docs=2).load()
56
- return "\n\n---\n\n".join(
57
- f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
58
- f"{d.page_content}\n</Document>"
59
- for d in docs
60
- )
61
-
62
- @tool
63
- def web_search(query: str) -> str:
64
- """Perform a web search with Tavily (max 3 docs) and return formatted content."""
65
- docs = TavilySearchResults(max_results=3).invoke(query=query)
66
- return "\n\n---\n\n".join(
67
- f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
68
- f"{d.page_content}\n</Document>"
69
- for d in docs
70
- )
71
-
72
- @tool
73
- def arxiv_search(query: str) -> str:
74
- """Search ArXiv (max 3 docs) and return first 1000 characters per paper."""
75
- docs = ArxivLoader(query=query, load_max_docs=3).load()
76
- return "\n\n---\n\n".join(
77
- f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
78
- f"{d.page_content[:1000]}\n</Document>"
79
- for d in docs
80
  )
81
 
82
- # --------------------------------------------------------------------------- #
83
- # System prompt #
84
- # --------------------------------------------------------------------------- #
85
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
86
- system_prompt = f.read()
87
- sys_msg = SystemMessage(content=system_prompt)
 
 
 
 
 
 
 
 
 
88
 
89
  tools = [
90
- multiply, add, subtract, divide, modulus,
91
- wiki_search, web_search, arxiv_search,
 
92
  ]
93
 
94
- # --------------------------------------------------------------------------- #
95
- # Retry parameters #
96
- # --------------------------------------------------------------------------- #
97
- RETRY_DELAYS = [0, 5, 30, 60] # seconds for attempts 0-3
98
- MAX_ATTEMPTS = len(RETRY_DELAYS)
99
-
100
- # --------------------------------------------------------------------------- #
101
- # Build LangGraph #
102
- # --------------------------------------------------------------------------- #
103
- def build_graph(provider: str = "groq"):
104
- """Return a LangGraph graph with explicit retry logic."""
105
-
106
- # ----------- LLM selection -------------------------------------------- #
107
- if provider == "google":
108
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
109
-
110
- elif provider == "groq":
111
- llm = ChatGroq(
112
- model="qwen-qwq-32b",
113
- temperature=0,
114
- max_retries=0, # we handle retries manually
115
- )
116
-
117
- elif provider == "huggingface":
118
- llm = ChatHuggingFace(
119
- llm=HuggingFaceEndpoint(
120
- url="https://api-inference.huggingface.co/models/"
121
- "Meta-DeepLearning/llama-2-7b-chat-hf",
122
- temperature=0,
123
- )
124
- )
125
- else:
126
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
127
-
128
- llm_with_tools = llm.bind_tools(tools)
129
-
130
- # ---------------- Retry wrapper -------------------------------------- #
131
- def invoke_with_retry(messages):
132
- last_err = None
133
- for attempt, delay in enumerate(RETRY_DELAYS):
134
- if delay:
135
- print(f"[Retry {attempt}/{MAX_ATTEMPTS-1}] waiting {delay}s")
136
- time.sleep(delay)
137
- try:
138
- return llm_with_tools.invoke(messages)
139
- except Exception as e:
140
- if ("503" in str(e) or "Service Unavailable" in str(e) or "429" in str(e)) and attempt < MAX_ATTEMPTS - 1:
141
- last_err = e
142
- continue
143
- raise
144
- raise last_err or RuntimeError("Unknown error during LLM invocation")
145
-
146
- # ---------------- Nodes ---------------------------------------------- #
147
- def assistant(state: MessagesState):
148
- messages = [sys_msg] + state["messages"]
149
- return {"messages": [invoke_with_retry(messages)]}
150
-
151
- # ---------------- Graph ---------------------------------------------- #
152
- builder = StateGraph(MessagesState)
153
- builder.add_node("assistant", assistant)
154
- builder.add_node("tools", ToolNode(tools))
155
-
156
- builder.add_edge(START, "assistant")
157
- builder.add_conditional_edges("assistant", tools_condition)
158
- builder.add_edge("tools", "assistant")
159
-
160
- return builder.compile()
161
-
162
-
163
- # --------------------------------------------------------------------------- #
164
- # Stand-alone test #
165
- # --------------------------------------------------------------------------- #
166
- if __name__ == "__main__":
167
- g = build_graph(provider="groq")
168
- q = ("When was a picture of St. Thomas Aquinas first added to the Wikipedia "
169
- "page on the Principle of double effect?")
170
- msgs = [HumanMessage(content=q)]
171
- res = g.invoke({"messages": msgs})
172
- for m in res["messages"]:
173
- m.pretty_print()
 
1
+ """
2
+ GAIA Level-1 agent powered by smolagents.
3
+
4
+ * Planner/esecutore: CodeAgent (smolagents)
5
+ * LLM backend : GPT-4.1 via OpenAI
6
+ * Tools : DuckDuckGo (builtin), WikipediaTool, ArxivTool
7
+ * Output : UNA sola riga (exact-match)
8
+ """
9
+ from __future__ import annotations
10
+ import os, textwrap
11
+ from smolagents import CodeAgent, DuckDuckGoSearchTool, OpenAIModel
12
+ from tools import WikipediaTool, ArxivTool
13
+ import openai
14
+
15
+ # ─── API key check ──────────────────────────────────────────────
16
+ openai.api_key = os.getenv("OPENAI_API_KEY") or ""
17
+ if not openai.api_key:
18
+ raise EnvironmentError(
19
+ "OPENAI_API_KEY non impostata: aggiungila nei Secrets dello Space "
20
+ "o in un file .env locale."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  )
22
 
23
+ # ─── Prompt di sistema rigido (exact-match) ─────────────────────
24
+ SYSTEM_PROMPT = textwrap.dedent("""
25
+ You are a helpful assistant tasked with answering questions using a set of tools.
26
+ Your final answer should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
27
+ If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
28
+ If you are asked for a string, don't use articles, neither abbreviations, and write digits in plain text unless specified otherwise.
29
+ Return ONLY the final answer line.
30
+ """).strip()
31
+
32
+ # ─── Costruzione del β€œcore” CodeAgent ───────────────────────────
33
+ model = OpenAIModel(
34
+ model_id="gpt-4.1",
35
+ temperature=0,
36
+ system_prompt=SYSTEM_PROMPT
37
+ )
38
 
39
  tools = [
40
+ DuckDuckGoSearchTool(), # incorporato in smolagents
41
+ WikipediaTool(),
42
+ ArxivTool(),
43
  ]
44
 
45
+ core_agent = CodeAgent(
46
+ model=model,
47
+ tools=tools,
48
+ max_steps=6, # previene loop infiniti
49
+ scratchpad="minimal" # log conciso
50
+ )
51
+
52
+ # ─── Thin wrapper usato da app.py ───────────────────────────────
53
+ class BasicAgent: # (mantiene lo stesso nome giΓ  importato in app.py)
54
+ def __init__(self):
55
+ print("βœ… smolagents BasicAgent inizializzato")
56
+
57
+ def __call__(self, question: str) -> str:
58
+ """
59
+ Esegue CodeAgent e restituisce SOLO la prima riga,
60
+ così il grader riceve una stringa exact-match.
61
+ """
62
+ raw_answer: str = core_agent.run(question)
63
+ answer = raw_answer.strip().split("\n", 1)[0]
64
+ print(f"[ANSWER] {answer}")
65
+ return answer