Frazer2810 commited on
Commit
ff95a9a
·
verified ·
1 Parent(s): aef7231

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +167 -110
agent.py CHANGED
@@ -1,114 +1,171 @@
1
- """ Basic Agent Evaluation Runner invia sempre tutte le risposte """
2
-
3
- import os
4
- import requests
5
- import gradio as gr
6
- import pandas as pd
7
- from langchain_core.messages import HumanMessage
8
- from agent import build_graph
9
-
10
-
11
- # --- Constants ------------------------------------------------------------ #
12
- DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
13
-
14
-
15
- # --- Agent wrapper -------------------------------------------------------- #
16
- class BasicAgent:
17
- """LangGraph agent ready for evaluation."""
18
- def __init__(self):
19
- print("BasicAgent initialized (provider=groq).")
20
- self.graph = build_graph(provider="groq")
21
-
22
- def __call__(self, question: str) -> str:
23
- print(f"Agent received question (first 50 chars): {question[:50]}...")
24
- msgs = [HumanMessage(content=question)]
25
- result = self.graph.invoke({"messages": msgs})
26
- answer = result["messages"][-1].content
27
- # rimuove la parte "FINAL ANSWER: "
28
- return answer[14:]
29
-
30
-
31
- # --- Main evaluation logic ------------------------------------------------ #
32
- def run_and_submit_all(profile: gr.OAuthProfile | None):
33
- # 0. Check login
34
- if not profile:
35
- return "Please Login to Hugging Face with the button.", None
36
- username = profile.username
37
- print(f"User logged in: {username}")
38
-
39
- # 1. Instantiate agent
40
- try:
41
- agent = BasicAgent()
42
- except Exception as e:
43
- return f"Error initializing agent: {e}", None
44
-
45
- # 2. Fetch questions
46
- try:
47
- resp = requests.get(f"{DEFAULT_API_URL}/questions", timeout=15)
48
- resp.raise_for_status()
49
- questions_data = resp.json()
50
- if not questions_data:
51
- return "Fetched questions list is empty.", None
52
- except Exception as e:
53
- return f"Error fetching questions: {e}", None
54
-
55
- # 3. Run agent and build payload
56
- answers_payload = []
57
- results_log = []
58
-
59
- for item in questions_data:
60
- task_id = item.get("task_id")
61
- q_text = item.get("question")
62
-
63
- submitted_answer = "errore" # default in caso di failure
64
- try:
65
- submitted_answer = agent(q_text)
66
- except Exception as e:
67
- print(f"Error running agent on task {task_id}: {e}")
68
-
69
- # in ogni caso inseriamo la risposta (successo o errore)
70
- answers_payload.append(
71
- {"task_id": task_id, "submitted_answer": submitted_answer}
72
- )
73
- results_log.append(
74
- {
75
- "Task ID": task_id,
76
- "Question": q_text,
77
- "Submitted Answer": submitted_answer,
78
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  )
80
 
81
- # 4. Submit answers
82
- submission = {
83
- "username": username,
84
- "agent_code": f"https://huggingface.co/spaces/{os.getenv('SPACE_ID', '')}/tree/main",
85
- "answers": answers_payload,
86
- }
87
-
88
- try:
89
- resp = requests.post(f"{DEFAULT_API_URL}/submit", json=submission, timeout=60)
90
- resp.raise_for_status()
91
- data = resp.json()
92
- status_msg = (
93
- f"Submission Successful!\nUser: {data.get('username')}\n"
94
- f"Overall Score: {data.get('score', 'N/A')}% "
95
- f"({data.get('correct_count', '?')}/{data.get('total_attempted', '?')} correct)\n"
96
- f"Message: {data.get('message', 'No message received.')}"
97
  )
98
- except Exception as e:
99
- status_msg = f"Submission Failed: {e}"
100
-
101
- return status_msg, pd.DataFrame(results_log)
102
-
103
-
104
- # --- Gradio UI ------------------------------------------------------------ #
105
- with gr.Blocks() as demo:
106
- gr.Markdown("# Basic Agent Evaluation Runner (retry & error-safe)")
107
- gr.LoginButton()
108
- run_btn = gr.Button("Run Evaluation & Submit All Answers")
109
- status_box = gr.Textbox(lines=5, label="Run Status / Submission Result")
110
- results_tbl = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
111
- run_btn.click(fn=run_and_submit_all, outputs=[status_box, results_tbl])
112
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  if __name__ == "__main__":
114
- demo.launch(debug=True, share=False)
 
 
 
 
 
 
1
+ """LangGraph Agent – retry 5s, 30s, 60s; senza Supabase"""
2
+
3
+ import os, time
4
+ from dotenv import load_dotenv
5
+ from langgraph.graph import START, StateGraph, MessagesState
6
+ from langgraph.prebuilt import ToolNode, tools_condition
7
+
8
+ # LLM providers
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
+ from langchain_groq import ChatGroq
11
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
12
+
13
+ # Tools & loaders
14
+ from langchain_community.tools.tavily_search import TavilySearchResults
15
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
16
+ from langchain_core.messages import SystemMessage, HumanMessage
17
+ from langchain_core.tools import tool
18
+
19
+ load_dotenv()
20
+
21
+ # --------------------------------------------------------------------------- #
22
+ # TOOLS #
23
+ # --------------------------------------------------------------------------- #
24
+ @tool
25
+ def multiply(a: int, b: int) -> int:
26
+ """Multiply two integers and return the product."""
27
+ return a * b
28
+
29
+ @tool
30
+ def add(a: int, b: int) -> int:
31
+ """Add two integers and return the sum."""
32
+ return a + b
33
+
34
+ @tool
35
+ def subtract(a: int, b: int) -> int:
36
+ """Subtract the second integer from the first and return the difference."""
37
+ return a - b
38
+
39
+ @tool
40
+ def divide(a: int, b: int) -> float:
41
+ """Divide a by b and return the quotient (error if b == 0)."""
42
+ if b == 0:
43
+ raise ValueError("Cannot divide by zero.")
44
+ return a / b
45
+
46
+ @tool
47
+ def modulus(a: int, b: int) -> int:
48
+ """Return the remainder of the division of a by b."""
49
+ return a % b
50
+
51
+ @tool
52
+ def wiki_search(query: str) -> str:
53
+ """Search Wikipedia (max 2 docs) and return formatted content."""
54
+ docs = WikipediaLoader(query=query, load_max_docs=2).load()
55
+ return "\n\n---\n\n".join(
56
+ f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
57
+ f"{d.page_content}\n</Document>"
58
+ for d in docs
59
+ )
60
+
61
+ @tool
62
+ def web_search(query: str) -> str:
63
+ """Perform a web search with Tavily (max 3 docs) and return formatted content."""
64
+ docs = TavilySearchResults(max_results=3).invoke(query=query)
65
+ return "\n\n---\n\n".join(
66
+ f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
67
+ f"{d.page_content}\n</Document>"
68
+ for d in docs
69
+ )
70
+
71
+ @tool
72
+ def arxiv_search(query: str) -> str:
73
+ """Search ArXiv (max 3 docs) and return first 1000 characters per paper."""
74
+ docs = ArxivLoader(query=query, load_max_docs=3).load()
75
+ return "\n\n---\n\n".join(
76
+ f'<Document source="{d.metadata["source"]}" page="{d.metadata.get("page","")}"/>\n'
77
+ f"{d.page_content[:1000]}\n</Document>"
78
+ for d in docs
79
+ )
80
+
81
+ # --------------------------------------------------------------------------- #
82
+ # System prompt #
83
+ # --------------------------------------------------------------------------- #
84
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
85
+ system_prompt = f.read()
86
+ sys_msg = SystemMessage(content=system_prompt)
87
+
88
+ tools = [
89
+ multiply, add, subtract, divide, modulus,
90
+ wiki_search, web_search, arxiv_search,
91
+ ]
92
+
93
+ # --------------------------------------------------------------------------- #
94
+ # Retry parameters #
95
+ # --------------------------------------------------------------------------- #
96
+ RETRY_DELAYS = [0, 5, 30, 60] # secondi: tentativo 0, 1, 2, 3
97
+ MAX_ATTEMPTS = len(RETRY_DELAYS)
98
+
99
+ # --------------------------------------------------------------------------- #
100
+ # Build LangGraph #
101
+ # --------------------------------------------------------------------------- #
102
+ def build_graph(provider: str = "groq"):
103
+ """Return a LangGraph graph with explicit retry logic (5s, 30s, 60s)."""
104
+
105
+ # ----------- LLM selection -------------------------------------------- #
106
+ if provider == "google":
107
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
108
+
109
+ elif provider == "groq":
110
+ llm = ChatGroq(
111
+ model="qwen-qwq-32b",
112
+ temperature=0,
113
+ max_retries=0, # disabilitiamo i retry interni
114
  )
115
 
116
+ elif provider == "huggingface":
117
+ llm = ChatHuggingFace(
118
+ llm=HuggingFaceEndpoint(
119
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
120
+ temperature=0,
121
+ )
 
 
 
 
 
 
 
 
 
 
122
  )
123
+ else:
124
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
125
+
126
+ llm_with_tools = llm.bind_tools(tools)
127
+
128
+ # ---------------- Retry wrapper -------------------------------------- #
129
+ def invoke_with_retry(messages):
130
+ last_err = None
131
+ for attempt, delay in enumerate(RETRY_DELAYS):
132
+ if delay > 0:
133
+ print(f"[Retry {attempt}/{MAX_ATTEMPTS-1}] waiting {delay}s")
134
+ time.sleep(delay)
135
+ try:
136
+ return llm_with_tools.invoke(messages)
137
+ except Exception as e:
138
+ err_text = str(e)
139
+ if ("503" in err_text or "Service Unavailable" in err_text) and attempt < MAX_ATTEMPTS - 1:
140
+ last_err = e
141
+ continue # passa al prossimo tentativo
142
+ raise # altro errore o ultimi tentativo esaurito
143
+ # se per qualche motivo esce dal loop senza raise
144
+ raise last_err or RuntimeError("Unknown error during LLM invocation")
145
+
146
+ # ---------------- Nodes ---------------------------------------------- #
147
+ def assistant(state: MessagesState):
148
+ messages = [sys_msg] + state["messages"]
149
+ return {"messages": [invoke_with_retry(messages)]}
150
+
151
+ # ---------------- Graph ---------------------------------------------- #
152
+ builder = StateGraph(MessagesState)
153
+ builder.add_node("assistant", assistant)
154
+ builder.add_node("tools", ToolNode(tools))
155
+
156
+ builder.add_edge(START, "assistant")
157
+ builder.add_conditional_edges("assistant", tools_condition)
158
+ builder.add_edge("tools", "assistant")
159
+
160
+ return builder.compile()
161
+
162
+ # --------------------------------------------------------------------------- #
163
+ # Test rapido #
164
+ # --------------------------------------------------------------------------- #
165
  if __name__ == "__main__":
166
+ g = build_graph(provider="groq")
167
+ q = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
168
+ msgs = [HumanMessage(content=q)]
169
+ res = g.invoke({"messages": msgs})
170
+ for m in res["messages"]:
171
+ m.pretty_print()