"""Retrieval Agent - Handles information gathering and search tasks""" import os import requests from typing import Dict, Any, List from langchain_core.messages import SystemMessage, HumanMessage, AIMessage, ToolMessage from langchain_core.tools import tool from langchain_groq import ChatGroq from langchain_community.tools.tavily_search import TavilySearchResults from langchain_community.document_loaders import WikipediaLoader, ArxivLoader from langchain.tools.retriever import create_retriever_tool from src.memory import memory_manager from src.tracing import get_langfuse_callback_handler # Tool definitions (same as original) @tool def wiki_search(input: str) -> str: """Search Wikipedia for a query and return maximum 2 results. Args: input: The search query.""" try: search_docs = WikipediaLoader(query=input, load_max_docs=2).load() if not search_docs: return "No Wikipedia results found for the query." formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content}\n' for doc in search_docs ]) return formatted_search_docs except Exception as e: print(f"Error in wiki_search: {e}") return f"Error searching Wikipedia: {e}" @tool def web_search(input: str) -> str: """Search Tavily for a query and return maximum 3 results. Args: input: The search query.""" try: search_docs = TavilySearchResults(max_results=3).invoke(input) if not search_docs: return "No web search results found for the query." formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.get("content", "No content")}\n' for doc in search_docs ]) return formatted_search_docs except Exception as e: print(f"Error in web_search: {e}") return f"Error searching web: {e}" @tool def arvix_search(input: str) -> str: """Search Arxiv for a query and return maximum 3 results. Args: input: The search query.""" try: search_docs = ArxivLoader(query=input, load_max_docs=3).load() if not search_docs: return "No Arxiv results found for the query." formatted_search_docs = "\n\n---\n\n".join( [ f'\n{doc.page_content[:1000]}\n' for doc in search_docs ]) return formatted_search_docs except Exception as e: print(f"Error in arvix_search: {e}") return f"Error searching Arxiv: {e}" def load_retrieval_prompt() -> str: """Load the retrieval prompt from file""" try: with open("./prompts/retrieval_prompt.txt", "r", encoding="utf-8") as f: return f.read().strip() except FileNotFoundError: return """You are a specialized retrieval agent. Use available tools to search for information and provide comprehensive answers.""" def get_retrieval_tools() -> List: """Get list of tools available to the retrieval agent""" tools = [wiki_search, web_search, arvix_search] # Add vector store retrieval tool if available if memory_manager.vector_store: try: retrieval_tool = create_retriever_tool( retriever=memory_manager.vector_store.as_retriever(), name="question_search", description="A tool to retrieve similar questions from a vector store.", ) tools.append(retrieval_tool) except Exception as e: print(f"Could not create retrieval tool: {e}") return tools def execute_tool_calls(tool_calls: list, tools: list) -> list: """Execute tool calls and return results""" tool_messages = [] # Create a mapping of tool names to tool functions tool_map = {tool.name: tool for tool in tools} for tool_call in tool_calls: tool_name = tool_call['name'] tool_args = tool_call['args'] tool_call_id = tool_call['id'] if tool_name in tool_map: try: print(f"Retrieval Agent: Executing {tool_name} with args: {tool_args}") result = tool_map[tool_name].invoke(tool_args) tool_messages.append( ToolMessage( content=str(result), tool_call_id=tool_call_id ) ) except Exception as e: print(f"Error executing {tool_name}: {e}") tool_messages.append( ToolMessage( content=f"Error executing {tool_name}: {e}", tool_call_id=tool_call_id ) ) else: tool_messages.append( ToolMessage( content=f"Unknown tool: {tool_name}", tool_call_id=tool_call_id ) ) return tool_messages def fetch_attachment_if_needed(query: str) -> str: """Fetch attachment content if the query matches a known task""" try: DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" resp = requests.get(f"{DEFAULT_API_URL}/questions", timeout=30) resp.raise_for_status() questions = resp.json() for q in questions: if str(q.get("question")).strip() == str(query).strip(): task_id = str(q.get("task_id")) print(f"Retrieval Agent: Downloading attachment for task {task_id}") file_resp = requests.get(f"{DEFAULT_API_URL}/files/{task_id}", timeout=60) if file_resp.status_code == 200 and file_resp.content: try: file_text = file_resp.content.decode("utf-8", errors="replace") except Exception: file_text = "(binary or non-UTF8 file omitted)" MAX_CHARS = 8000 if len(file_text) > MAX_CHARS: file_text = file_text[:MAX_CHARS] + "\n… (truncated)" return f"Attached file content for task {task_id}:\n```python\n{file_text}\n```" else: print(f"No attachment for task {task_id}") return "" return "" except Exception as e: print(f"Error fetching attachment: {e}") return "" def retrieval_agent(state: Dict[str, Any]) -> Dict[str, Any]: """ Retrieval agent that handles information gathering tasks """ print("Retrieval Agent: Processing information retrieval request") try: # Get retrieval prompt retrieval_prompt = load_retrieval_prompt() # Initialize LLM with tools llm = ChatGroq(model="qwen-qwq-32b", temperature=0.3) tools = get_retrieval_tools() llm_with_tools = llm.bind_tools(tools) # Get callback handler for tracing callback_handler = get_langfuse_callback_handler() callbacks = [callback_handler] if callback_handler else [] # Build messages messages = state.get("messages", []) # Add retrieval system prompt retrieval_messages = [SystemMessage(content=retrieval_prompt)] # Get user query for context and attachment fetching user_query = None for msg in reversed(messages): if msg.type == "human": user_query = msg.content break # Check for similar questions in memory if user_query: similar_qa = memory_manager.get_similar_qa(user_query) if similar_qa: context_msg = HumanMessage( content=f"Here is a similar question and answer for reference:\n\n{similar_qa}" ) retrieval_messages.append(context_msg) # Fetch attachment if needed attachment_content = fetch_attachment_if_needed(user_query) if attachment_content: attachment_msg = HumanMessage(content=attachment_content) retrieval_messages.append(attachment_msg) # Add original messages (excluding system messages to avoid duplicates) for msg in messages: if msg.type != "system": retrieval_messages.append(msg) # Get initial response from LLM and iterate tool calls if necessary response = llm_with_tools.invoke(retrieval_messages, config={"callbacks": callbacks}) max_tool_iterations = 3 # safeguard to prevent infinite loops iteration = 0 while response.tool_calls and iteration < max_tool_iterations: iteration += 1 print(f"Retrieval Agent: LLM requested {len(response.tool_calls)} tool calls (iteration {iteration})") # Execute the tool calls tool_messages = execute_tool_calls(response.tool_calls, tools) # Append the LLM response and tool results to the conversation retrieval_messages.extend([response] + tool_messages) # Ask the model again with the new information response = llm_with_tools.invoke(retrieval_messages, config={"callbacks": callbacks}) # After iterating (or if no tool calls), we have our final response retrieval_messages.append(response) return { **state, "messages": retrieval_messages, "agent_response": response, "current_step": "verification" } except Exception as e: print(f"Retrieval Agent Error: {e}") error_response = AIMessage(content=f"I encountered an error while processing your request: {e}") return { **state, "messages": state.get("messages", []) + [error_response], "agent_response": error_response, "current_step": "verification" }