raaec commited on
Commit
5d520be
·
verified ·
1 Parent(s): 7ee22fe

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +380 -270
agent.py CHANGED
@@ -1,13 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import logging
3
- from typing import Tuple, List, Dict, Any, Optional
4
 
5
- import gradio as gr
6
- import requests
7
- import pandas as pd
8
- from langchain_core.messages import HumanMessage
9
-
10
- from agent import build_graph
 
 
 
 
 
 
 
11
 
12
  # Configure logging
13
  logging.basicConfig(
@@ -17,339 +42,424 @@ logging.basicConfig(
17
  )
18
  logger = logging.getLogger(__name__)
19
 
20
- # --- Constants ---
21
- DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
22
- REQUEST_TIMEOUT = 60 # seconds
23
 
24
 
25
- class BasicAgent:
26
- """
27
- A LangGraph-based agent that answers questions using a graph-based workflow.
28
-
29
- This agent takes natural language questions, processes them through a
30
- predefined graph workflow, and returns the answer.
31
-
32
- Attributes:
33
- graph: The LangGraph workflow that processes the questions
34
- """
35
 
36
- def __init__(self):
37
- """Initialize the agent with a graph-based workflow."""
38
- logger.info("Initializing BasicAgent")
39
- self.graph = build_graph()
40
-
41
- def __call__(self, question: str) -> str:
42
- """
43
- Process a question and return an answer.
44
-
45
- Args:
46
- question: The natural language question to process
47
-
48
- Returns:
49
- The agent's answer to the question
50
- """
51
- logger.info(f"Processing question (first 50 chars): {question[:50]}...")
52
-
53
- # Wrap the question in a HumanMessage from langchain_core
54
- messages = [HumanMessage(content=question)]
55
-
56
- # Process through the graph
57
- messages = self.graph.invoke({"messages": messages})
58
-
59
- # Extract and clean the answer
60
- answer = messages['messages'][-1].content
61
 
62
- # Remove the "FINAL ANSWER:" prefix if present
63
- return answer[14:] if answer.startswith("FINAL ANSWER:") else answer
 
 
64
 
65
 
66
- def fetch_questions(api_url: str) -> List[Dict[str, Any]]:
67
- """
68
- Fetch questions from the evaluation server.
69
 
70
  Args:
71
- api_url: Base URL of the evaluation API
 
72
 
73
  Returns:
74
- List of question data dictionaries
75
-
76
- Raises:
77
- requests.exceptions.RequestException: If there's an error fetching questions
78
  """
79
- questions_url = f"{api_url}/questions"
80
- logger.info(f"Fetching questions from: {questions_url}")
81
-
82
- response = requests.get(questions_url, timeout=REQUEST_TIMEOUT)
83
- response.raise_for_status()
84
-
85
- questions_data = response.json()
86
- if not questions_data:
87
- raise ValueError("Fetched questions list is empty or invalid format")
88
-
89
- logger.info(f"Successfully fetched {len(questions_data)} questions")
90
- return questions_data
91
 
92
 
93
- def run_agent_on_questions(
94
- agent: BasicAgent,
95
- questions_data: List[Dict[str, Any]]
96
- ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
97
- """
98
- Run the agent on a list of questions.
99
 
100
  Args:
101
- agent: The agent to run
102
- questions_data: List of question data dictionaries
103
 
104
  Returns:
105
- Tuple of (answers_payload, results_log)
106
  """
107
- results_log = []
108
- answers_payload = []
109
-
110
- logger.info(f"Running agent on {len(questions_data)} questions...")
 
 
111
 
112
- for item in questions_data:
113
- task_id = item.get("task_id")
114
- question_text = item.get("question")
115
 
116
- if not task_id or question_text is None:
117
- logger.warning(f"Skipping item with missing task_id or question: {item}")
118
- continue
119
-
120
- try:
121
- submitted_answer = agent(question_text)
122
-
123
- # Prepare answer for submission
124
- answers_payload.append({
125
- "task_id": task_id,
126
- "submitted_answer": submitted_answer
127
- })
128
-
129
- # Log result for display
130
- results_log.append({
131
- "Task ID": task_id,
132
- "Question": question_text,
133
- "Submitted Answer": submitted_answer
134
- })
135
-
136
- except Exception as e:
137
- logger.error(f"Error running agent on task {task_id}: {e}", exc_info=True)
138
-
139
- # Log error in results
140
- results_log.append({
141
- "Task ID": task_id,
142
- "Question": question_text,
143
- "Submitted Answer": f"AGENT ERROR: {e}"
144
- })
145
-
146
- return answers_payload, results_log
147
-
148
-
149
- def submit_answers(
150
- api_url: str,
151
- username: str,
152
- agent_code: str,
153
- answers_payload: List[Dict[str, Any]]
154
- ) -> Dict[str, Any]:
155
  """
156
- Submit answers to the evaluation server.
 
 
 
 
 
 
 
157
 
158
  Args:
159
- api_url: Base URL of the evaluation API
160
- username: Hugging Face username
161
- agent_code: URL to the agent code repository
162
- answers_payload: List of answer dictionaries
163
 
164
  Returns:
165
- Response data from the server
166
 
167
  Raises:
168
- requests.exceptions.RequestException: If there's an error during submission
169
  """
170
- submit_url = f"{api_url}/submit"
171
-
172
- # Prepare submission data
173
- submission_data = {
174
- "username": username.strip(),
175
- "agent_code": agent_code,
176
- "answers": answers_payload
177
- }
178
-
179
- logger.info(f"Submitting {len(answers_payload)} answers to: {submit_url}")
180
-
181
- # Submit answers
182
- response = requests.post(submit_url, json=submission_data, timeout=REQUEST_TIMEOUT)
183
- response.raise_for_status()
184
-
185
- result_data = response.json()
186
- logger.info("Submission successful")
187
-
188
- return result_data
189
 
190
 
191
- def run_and_submit_all(profile: Optional[gr.OAuthProfile] = None) -> Tuple[str, pd.DataFrame]:
192
- """
193
- Fetches all questions, runs the BasicAgent on them, submits all answers,
194
- and displays the results.
 
 
 
195
 
196
  Args:
197
- profile: Gradio OAuth profile containing user information
198
 
199
  Returns:
200
- Tuple of (status_message, results_dataframe)
201
  """
202
- # Check if user is logged in
203
- if not profile:
204
- logger.warning("User not logged in")
205
- return "Please Login to Hugging Face with the button.", None
206
-
207
- username = profile.username
208
- logger.info(f"User logged in: {username}")
209
-
210
- # Get the space ID for linking to code
211
- space_id = os.getenv("SPACE_ID")
212
- api_url = DEFAULT_API_URL
213
- agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
214
 
215
  try:
216
- # 1. Instantiate Agent
217
- agent = BasicAgent()
218
 
219
- # 2. Fetch Questions
220
- questions_data = fetch_questions(api_url)
221
 
222
- # 3. Run Agent on Questions
223
- answers_payload, results_log = run_agent_on_questions(agent, questions_data)
 
 
 
 
224
 
225
- if not answers_payload:
226
- logger.warning("Agent did not produce any answers to submit")
227
- return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
228
 
229
- # 4. Submit Answers
230
- result_data = submit_answers(api_url, username, agent_code, answers_payload)
 
 
 
 
 
 
 
 
 
231
 
232
- # 5. Format and Return Results
233
- final_status = (
234
- f"Submission Successful!\n"
235
- f"User: {result_data.get('username')}\n"
236
- f"Overall Score: {result_data.get('score', 'N/A')}% "
237
- f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
238
- f"Message: {result_data.get('message', 'No message received.')}"
 
 
 
 
 
 
 
 
 
239
  )
240
 
241
- results_df = pd.DataFrame(results_log)
242
- return final_status, results_df
243
 
244
- except requests.exceptions.HTTPError as e:
245
- # Handle HTTP errors with detailed error information
246
- error_detail = f"Server responded with status {e.response.status_code}."
247
- try:
248
- error_json = e.response.json()
249
- error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
250
- except requests.exceptions.JSONDecodeError:
251
- error_detail += f" Response: {e.response.text[:500]}"
252
-
253
- status_message = f"Submission Failed: {error_detail}"
254
- logger.error(status_message)
255
 
256
- results_df = pd.DataFrame(results_log if 'results_log' in locals() else [])
257
- return status_message, results_df
 
 
 
 
 
258
 
259
- except requests.exceptions.Timeout:
260
- status_message = f"Submission Failed: The request timed out after {REQUEST_TIMEOUT} seconds"
261
- logger.error(status_message)
262
 
263
- results_df = pd.DataFrame(results_log if 'results_log' in locals() else [])
264
- return status_message, results_df
 
 
 
 
265
 
266
- except Exception as e:
267
- status_message = f"An unexpected error occurred: {str(e)}"
268
- logger.error(status_message, exc_info=True)
269
 
270
- results_df = pd.DataFrame(results_log if 'results_log' in locals() else [])
271
- return status_message, results_df
 
272
 
273
 
274
- def create_gradio_interface() -> gr.Blocks:
 
 
 
 
275
  """
276
- Create and configure the Gradio interface.
277
 
278
  Returns:
279
- Configured Gradio Blocks interface
 
 
 
280
  """
281
- with gr.Blocks() as demo:
282
- gr.Markdown("# Agent Evaluation Runner")
283
- gr.Markdown(
284
- """
285
- ## Instructions
286
-
287
- 1. **Clone this space** and modify the code to define your agent's logic, tools, and dependencies
288
- 2. **Log in to your Hugging Face account** using the button below (required for submission)
289
- 3. **Run Evaluation** to fetch questions, run your agent, and submit answers
290
-
291
- ## Important Notes
292
-
293
- - The evaluation process may take several minutes to complete
294
- - This agent framework is intentionally minimal to allow for your own improvements
295
- - Consider implementing caching or async processing for better performance
296
- """
297
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
- gr.LoginButton()
300
 
301
- run_button = gr.Button("Run Evaluation & Submit All Answers", variant="primary")
 
 
302
 
303
- status_output = gr.Textbox(
304
- label="Run Status / Submission Result",
305
- lines=5,
306
- interactive=False
307
- )
 
308
 
309
- results_table = gr.DataFrame(
310
- label="Questions and Agent Answers",
311
- wrap=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  )
 
 
 
 
313
 
314
- run_button.click(
315
- fn=run_and_submit_all,
316
- outputs=[status_output, results_table]
317
- )
318
-
319
- return demo
320
 
 
 
 
321
 
322
- def check_environment() -> None:
323
  """
324
- Check and log environment variables at startup.
 
 
 
 
 
 
 
 
 
 
 
325
  """
326
- logger.info("-" * 30 + " App Starting " + "-" * 30)
327
 
328
- # Check for SPACE_HOST
329
- space_host = os.getenv("SPACE_HOST")
330
- if space_host:
331
- logger.info(f"✅ SPACE_HOST found: {space_host}")
332
- logger.info(f" Runtime URL should be: https://{space_host}.hf.space")
333
- else:
334
- logger.info("ℹ️ SPACE_HOST environment variable not found (running locally?).")
335
-
336
- # Check for SPACE_ID
337
- space_id = os.getenv("SPACE_ID")
338
- if space_id:
339
- logger.info(f"✅ SPACE_ID found: {space_id}")
340
- logger.info(f" Repo URL: https://huggingface.co/spaces/{space_id}")
341
- logger.info(f" Repo Tree URL: https://huggingface.co/spaces/{space_id}/tree/main")
342
- else:
343
- logger.info("ℹ️ SPACE_ID environment variable not found (running locally?).")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
344
 
345
- logger.info("-" * (60 + len(" App Starting ")) + "\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
 
 
 
 
347
 
348
  if __name__ == "__main__":
349
- # Check environment at startup
350
- check_environment()
 
 
 
 
 
 
 
 
351
 
352
- # Create and launch Gradio interface
353
- logger.info("Launching Gradio Interface for Agent Evaluation...")
354
- demo = create_gradio_interface()
355
- demo.launch(debug=True, share=False)
 
1
+ """
2
+ LLM Agent Graph Implementation
3
+ =============================
4
+ This module defines a graph-based LLM agent workflow with various tools and retrieval capabilities.
5
+
6
+ The agent can:
7
+ - Perform mathematical operations
8
+ - Search Wikipedia, web, and arXiv
9
+ - Retrieve similar questions from a vector database
10
+ - Process user queries using different LLM providers
11
+
12
+ Components:
13
+ - Tool definitions: Math operations, search tools
14
+ - Vector database retrieval
15
+ - Graph construction with different LLM options
16
+ - Workflow management with LangGraph
17
+ """
18
+
19
  import os
20
  import logging
21
+ from typing import Dict, List, Union, Optional, Any, Callable
22
 
23
+ from dotenv import load_dotenv
24
+ from langgraph.graph import START, StateGraph, MessagesState
25
+ from langgraph.prebuilt import tools_condition, ToolNode
26
+ from langchain_google_genai import ChatGoogleGenerativeAI
27
+ from langchain_groq import ChatGroq
28
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
29
+ from langchain_community.tools.tavily_search import TavilySearchResults
30
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
31
+ from langchain_community.vectorstores import SupabaseVectorStore
32
+ from langchain_core.messages import SystemMessage, HumanMessage
33
+ from langchain_core.tools import tool
34
+ from langchain.tools.retriever import create_retriever_tool
35
+ from supabase.client import Client, create_client
36
 
37
  # Configure logging
38
  logging.basicConfig(
 
42
  )
43
  logger = logging.getLogger(__name__)
44
 
45
+ # Load environment variables
46
+ load_dotenv()
 
47
 
48
 
49
+ # ===================
50
+ # Math Operation Tools
51
+ # ===================
52
+
53
+ @tool
54
+ def multiply(a: int, b: int) -> int:
55
+ """Multiply two integers and return the result.
 
 
 
56
 
57
+ Args:
58
+ a: First integer to multiply
59
+ b: Second integer to multiply
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ Returns:
62
+ The product of a and b
63
+ """
64
+ return a * b
65
 
66
 
67
+ @tool
68
+ def add(a: int, b: int) -> int:
69
+ """Add two integers and return the result.
70
 
71
  Args:
72
+ a: First integer to add
73
+ b: Second integer to add
74
 
75
  Returns:
76
+ The sum of a and b
 
 
 
77
  """
78
+ return a + b
 
 
 
 
 
 
 
 
 
 
 
79
 
80
 
81
+ @tool
82
+ def subtract(a: int, b: int) -> int:
83
+ """Subtract the second integer from the first and return the result.
 
 
 
84
 
85
  Args:
86
+ a: Integer to subtract from
87
+ b: Integer to subtract
88
 
89
  Returns:
90
+ The difference (a - b)
91
  """
92
+ return a - b
93
+
94
+
95
+ @tool
96
+ def divide(a: int, b: int) -> float:
97
+ """Divide the first integer by the second and return the result.
98
 
99
+ Args:
100
+ a: Numerator (dividend)
101
+ b: Denominator (divisor)
102
 
103
+ Returns:
104
+ The quotient (a / b) as a float
105
+
106
+ Raises:
107
+ ValueError: If b is zero (division by zero)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  """
109
+ if b == 0:
110
+ raise ValueError("Cannot divide by zero.")
111
+ return a / b
112
+
113
+
114
+ @tool
115
+ def modulus(a: int, b: int) -> int:
116
+ """Calculate the remainder when the first integer is divided by the second.
117
 
118
  Args:
119
+ a: Dividend
120
+ b: Divisor
 
 
121
 
122
  Returns:
123
+ The remainder of a divided by b
124
 
125
  Raises:
126
+ ValueError: If b is zero (modulo by zero)
127
  """
128
+ if b == 0:
129
+ raise ValueError("Cannot calculate modulus with divisor zero.")
130
+ return a % b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
 
133
+ # ===================
134
+ # Search Tools
135
+ # ===================
136
+
137
+ @tool
138
+ def wiki_search(query: str) -> Dict[str, str]:
139
+ """Search Wikipedia for a query and return formatted results.
140
 
141
  Args:
142
+ query: The search term to look up on Wikipedia
143
 
144
  Returns:
145
+ Dictionary with formatted Wikipedia search results
146
  """
147
+ logger.info(f"Searching Wikipedia for: {query}")
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  try:
150
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
 
151
 
152
+ if not search_docs:
153
+ return {"wiki_results": "No Wikipedia results found for this query."}
154
 
155
+ formatted_search_docs = "\n\n---\n\n".join(
156
+ [
157
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
158
+ for doc in search_docs
159
+ ]
160
+ )
161
 
162
+ logger.info(f"Found {len(search_docs)} Wikipedia results")
163
+ return {"wiki_results": formatted_search_docs}
 
164
 
165
+ except Exception as e:
166
+ logger.error(f"Error searching Wikipedia: {e}", exc_info=True)
167
+ return {"wiki_results": f"Error searching Wikipedia: {str(e)}"}
168
+
169
+
170
+ @tool
171
+ def web_search(query: str) -> Dict[str, str]:
172
+ """Search the web using Tavily for a query and return formatted results.
173
+
174
+ Args:
175
+ query: The search term to look up on the web
176
 
177
+ Returns:
178
+ Dictionary with formatted web search results
179
+ """
180
+ logger.info(f"Searching the web for: {query}")
181
+
182
+ try:
183
+ search_results = TavilySearchResults(max_results=3).invoke(query=query)
184
+
185
+ if not search_results:
186
+ return {"web_results": "No web results found for this query."}
187
+
188
+ formatted_search_docs = "\n\n---\n\n".join(
189
+ [
190
+ f'<Document source="{result["url"]}">\n{result["content"]}\n</Document>'
191
+ for result in search_results
192
+ ]
193
  )
194
 
195
+ logger.info(f"Found {len(search_results)} web search results")
196
+ return {"web_results": formatted_search_docs}
197
 
198
+ except Exception as e:
199
+ logger.error(f"Error searching the web: {e}", exc_info=True)
200
+ return {"web_results": f"Error searching the web: {str(e)}"}
201
+
202
+
203
+ @tool
204
+ def arxiv_search(query: str) -> Dict[str, str]:
205
+ """Search arXiv for academic papers and return formatted results.
206
+
207
+ Args:
208
+ query: The search term to look up on arXiv
209
 
210
+ Returns:
211
+ Dictionary with formatted arXiv search results
212
+ """
213
+ logger.info(f"Searching arXiv for: {query}")
214
+
215
+ try:
216
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
217
 
218
+ if not search_docs:
219
+ return {"arxiv_results": "No arXiv results found for this query."}
 
220
 
221
+ formatted_search_docs = "\n\n---\n\n".join(
222
+ [
223
+ f'<Document source="{doc.metadata["entry_id"]}" title="{doc.metadata.get("Title", "")}">\n{doc.page_content[:1000]}\n</Document>'
224
+ for doc in search_docs
225
+ ]
226
+ )
227
 
228
+ logger.info(f"Found {len(search_docs)} arXiv results")
229
+ return {"arxiv_results": formatted_search_docs}
 
230
 
231
+ except Exception as e:
232
+ logger.error(f"Error searching arXiv: {e}", exc_info=True)
233
+ return {"arxiv_results": f"Error searching arXiv: {str(e)}"}
234
 
235
 
236
+ # ===================
237
+ # Vector Store Setup
238
+ # ===================
239
+
240
+ def setup_vector_store() -> SupabaseVectorStore:
241
  """
242
+ Set up and configure the Supabase vector store for question retrieval.
243
 
244
  Returns:
245
+ Configured SupabaseVectorStore instance
246
+
247
+ Raises:
248
+ ValueError: If required environment variables are missing
249
  """
250
+ # Check for required environment variables
251
+ supabase_url = os.environ.get("SUPABASE_URL")
252
+ supabase_key = os.environ.get("SUPABASE_SERVICE_KEY")
253
+
254
+ if not supabase_url or not supabase_key:
255
+ raise ValueError(
256
+ "Missing required environment variables: SUPABASE_URL and/or SUPABASE_SERVICE_KEY"
 
 
 
 
 
 
 
 
 
257
  )
258
+
259
+ # Initialize embeddings model
260
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
261
+
262
+ # Initialize Supabase client
263
+ supabase_client: Client = create_client(supabase_url, supabase_key)
264
+
265
+ # Create vector store
266
+ vector_store = SupabaseVectorStore(
267
+ client=supabase_client,
268
+ embedding=embeddings,
269
+ table_name="documents",
270
+ query_name="match_documents_langchain",
271
+ )
272
+
273
+ logger.info("Vector store initialized successfully")
274
+ return vector_store
275
 
 
276
 
277
+ # ===================
278
+ # LLM Provider Setup
279
+ # ===================
280
 
281
+ def get_llm(provider: str = "groq"):
282
+ """
283
+ Initialize and return an LLM based on the specified provider.
284
+
285
+ Args:
286
+ provider: The LLM provider to use ('google', 'groq', or 'huggingface')
287
 
288
+ Returns:
289
+ Initialized LLM instance
290
+
291
+ Raises:
292
+ ValueError: If an invalid provider is specified
293
+ """
294
+ if provider == "google":
295
+ logger.info("Using Google Gemini as LLM provider")
296
+ return ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
297
+
298
+ elif provider == "groq":
299
+ logger.info("Using Groq as LLM provider with qwen-qwq-32b model")
300
+ return ChatGroq(model="qwen-qwq-32b", temperature=0)
301
+
302
+ elif provider == "huggingface":
303
+ logger.info("Using Hugging Face as LLM provider with llama-2-7b-chat-hf model")
304
+ return ChatHuggingFace(
305
+ llm=HuggingFaceEndpoint(
306
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
307
+ temperature=0,
308
+ ),
309
  )
310
+
311
+ else:
312
+ available_providers = ['google', 'groq', 'huggingface']
313
+ raise ValueError(f"Invalid provider: '{provider}'. Choose from {available_providers}")
314
 
 
 
 
 
 
 
315
 
316
+ # ===================
317
+ # Graph Building
318
+ # ===================
319
 
320
+ def build_graph(provider: str = "groq"):
321
  """
322
+ Build and compile the agent workflow graph.
323
+
324
+ This function creates a LangGraph workflow that includes:
325
+ - A retriever node to find similar questions
326
+ - An assistant node that uses an LLM to generate responses
327
+ - A tools node for executing various tools
328
+
329
+ Args:
330
+ provider: The LLM provider to use ('google', 'groq', or 'huggingface')
331
+
332
+ Returns:
333
+ Compiled StateGraph ready for execution
334
  """
335
+ logger.info(f"Building agent graph with {provider} as LLM provider")
336
 
337
+ # Load system prompt
338
+ try:
339
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
340
+ system_prompt = f.read()
341
+ logger.info("Loaded system prompt from file")
342
+ except FileNotFoundError:
343
+ system_prompt = """You are a helpful AI assistant that answers questions accurately and concisely.
344
+ Use the available tools when appropriate to find information or perform calculations.
345
+ Always cite your sources when you use search tools."""
346
+ logger.warning("system_prompt.txt not found, using default system prompt")
347
+
348
+ # Initialize system message
349
+ sys_msg = SystemMessage(content=system_prompt)
350
+
351
+ # Set up vector store and retriever tool
352
+ try:
353
+ vector_store = setup_vector_store()
354
+ retriever_tool = create_retriever_tool(
355
+ retriever=vector_store.as_retriever(),
356
+ name="Question Search",
357
+ description="A tool to retrieve similar questions from a vector store.",
358
+ )
359
+ logger.info("Vector store retrieval tool initialized")
360
+ except Exception as e:
361
+ logger.error(f"Failed to set up vector store: {e}", exc_info=True)
362
+ retriever_tool = None
363
+
364
+ # Define available tools
365
+ tools = [
366
+ multiply,
367
+ add,
368
+ subtract,
369
+ divide,
370
+ modulus,
371
+ wiki_search,
372
+ web_search,
373
+ arxiv_search,
374
+ ]
375
+
376
+ # Add retriever tool if available
377
+ if retriever_tool:
378
+ tools.append(retriever_tool)
379
 
380
+ # Get LLM and bind tools
381
+ llm = get_llm(provider)
382
+ llm_with_tools = llm.bind_tools(tools)
383
+
384
+ # Define graph nodes
385
+ def assistant(state: MessagesState) -> Dict[str, List]:
386
+ """
387
+ Assistant node that processes messages with the LLM.
388
+
389
+ Args:
390
+ state: Current message state
391
+
392
+ Returns:
393
+ Updated message state with LLM response
394
+ """
395
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
396
+
397
+ def retriever(state: MessagesState) -> Dict[str, List]:
398
+ """
399
+ Retriever node that finds similar questions from the vector store.
400
+
401
+ Args:
402
+ state: Current message state
403
+
404
+ Returns:
405
+ Updated message state with retrieved examples
406
+ """
407
+ # Only use retrieval if vector_store is available
408
+ if vector_store:
409
+ try:
410
+ similar_questions = vector_store.similarity_search(state["messages"][0].content)
411
+ if similar_questions:
412
+ example_msg = HumanMessage(
413
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_questions[0].page_content}",
414
+ )
415
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
416
+ except Exception as e:
417
+ logger.error(f"Error in retriever node: {e}", exc_info=True)
418
+
419
+ # If vector_store is unavailable or retrieval fails, just add system message
420
+ return {"messages": [sys_msg] + state["messages"]}
421
+
422
+ # Build graph
423
+ builder = StateGraph(MessagesState)
424
+
425
+ # Add nodes
426
+ builder.add_node("retriever", retriever)
427
+ builder.add_node("assistant", assistant)
428
+ builder.add_node("tools", ToolNode(tools))
429
+
430
+ # Add edges
431
+ builder.add_edge(START, "retriever")
432
+ builder.add_edge("retriever", "assistant")
433
+ builder.add_conditional_edges(
434
+ "assistant",
435
+ tools_condition,
436
+ )
437
+ builder.add_edge("tools", "assistant")
438
+
439
+ # Compile graph
440
+ compiled_graph = builder.compile()
441
+ logger.info("Agent graph compiled successfully")
442
+
443
+ return compiled_graph
444
+
445
 
446
+ # ===================
447
+ # Testing
448
+ # ===================
449
 
450
  if __name__ == "__main__":
451
+ test_question = "When was the wiki entry of Boethius on De Philosophiae Consolatione first added?"
452
+
453
+ # Build the graph
454
+ logger.info("Starting test run")
455
+ graph = build_graph(provider="groq")
456
+
457
+ # Run the graph
458
+ logger.info(f"Testing with question: {test_question}")
459
+ messages = [HumanMessage(content=test_question)]
460
+ result_messages = graph.invoke({"messages": messages})
461
 
462
+ # Display results
463
+ logger.info("Test completed, printing messages:")
464
+ for message in result_messages["messages"]:
465
+ message.pretty_print()