from typing import Generator, Optional from langchain_core.documents import Document from langchain_core.messages import AIMessage, HumanMessage, ToolMessage, SystemMessage from langchain_core.tools import tool from agent import RetrievalState, build_retrieval_graph from clients import LLM, VECTOR_STORE @tool def populate_memory( content: str, category: str, topic: str, ) -> str: """Add content with metadata to the memory for later retrieval. Use this to store important information the user wants to remember. Args: content: The content to store in memory category: Category of the memory (e.g., 'personal', 'work', 'learning') topic: Specific topic of the memory """ VECTOR_STORE.add_documents( documents=[ Document( page_content=content, metadata={"category": category, "topic": topic} ) ] ) return f"Successfully stored memory about '{topic}' in category '{category}'" @tool def search_memory( query: str, category: Optional[str] = None, topic: Optional[str] = None, ) -> str: """Search and retrieve relevant information from memory using intelligent agentic retrieval. This tool uses advanced retrieval with: - Document relevance grading - Automatic query rewriting if no relevant results found - Self-correction with retry logic Args: query: The search query to find relevant memories category: Optional category filter topic: Optional topic filter """ try: initial_state: RetrievalState = { "original_query": query, "current_query": query, "category": category, "topic": topic, "documents": [], "relevant_documents": [], "generation": "", "retry_count": 0, "max_retries": 2, # Allow up to 2 query rewrites } final_state = _get_retrieval_agent().invoke(initial_state) result = final_state["generation"] return result except Exception as e: error_msg = f"Error in search_memory: {str(e)}" print(f"DEBUG: {error_msg}") return error_msg # Create tools list and bound LLM TOOLS = [search_memory, populate_memory] CHAT_LLM = LLM.bind_tools(TOOLS) # Lazy initialization to avoid circular imports _retrieval_agent = None def _get_retrieval_agent(): global _retrieval_agent if _retrieval_agent is None: _retrieval_agent = build_retrieval_graph() return _retrieval_agent def chat( message: str, history: list[dict], ) -> Generator[str, None, None]: messages = [ SystemMessage(content="Whenever the user asks you a question, you must always use the search_memory tool first to look for relevant information in your memory. If you find relevant information, use it to answer the user's question. if you don't find any relevant information, answer the question to the best of your ability.") ] for msg in history: if msg["role"] == "user": messages.append(HumanMessage(content=msg["content"])) elif msg["role"] == "assistant": messages.append(AIMessage(content=msg["content"])) messages.append(HumanMessage(content=message)) max_iterations = 10 iteration = 0 while iteration < max_iterations: iteration += 1 response = CHAT_LLM.invoke(messages) messages.append(response) if not response.tool_calls: if response.content: yield response.content else: yield "Done!" return tool_map = {t.name: t for t in TOOLS} for tool_call in response.tool_calls: tool_name = tool_call["name"] tool_args = tool_call["args"] yield f"🔧 Using {tool_name}..." if tool_name in tool_map: try: result = tool_map[tool_name].invoke(tool_args) except Exception as e: result = f"Error: {str(e)}" else: result = f"Unknown tool: {tool_name}" messages.append( ToolMessage( content=str(result), tool_call_id=tool_call["id"], ) ) yield "I processed your request but couldn't generate a final response."