File size: 4,428 Bytes
0a25329 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from typing import Generator, Optional
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage, SystemMessage
from langchain_core.tools import tool
from agent import RetrievalState, build_retrieval_graph
from clients import LLM, VECTOR_STORE
@tool
def populate_memory(
content: str,
category: str,
topic: str,
) -> str:
"""Add content with metadata to the memory for later retrieval. Use this to store important information the user wants to remember.
Args:
content: The content to store in memory
category: Category of the memory (e.g., 'personal', 'work', 'learning')
topic: Specific topic of the memory
"""
VECTOR_STORE.add_documents(
documents=[
Document(
page_content=content, metadata={"category": category, "topic": topic}
)
]
)
return f"Successfully stored memory about '{topic}' in category '{category}'"
@tool
def search_memory(
query: str,
category: Optional[str] = None,
topic: Optional[str] = None,
) -> str:
"""Search and retrieve relevant information from memory using intelligent agentic retrieval.
This tool uses advanced retrieval with:
- Document relevance grading
- Automatic query rewriting if no relevant results found
- Self-correction with retry logic
Args:
query: The search query to find relevant memories
category: Optional category filter
topic: Optional topic filter
"""
try:
initial_state: RetrievalState = {
"original_query": query,
"current_query": query,
"category": category,
"topic": topic,
"documents": [],
"relevant_documents": [],
"generation": "",
"retry_count": 0,
"max_retries": 2, # Allow up to 2 query rewrites
}
final_state = _get_retrieval_agent().invoke(initial_state)
result = final_state["generation"]
return result
except Exception as e:
error_msg = f"Error in search_memory: {str(e)}"
print(f"DEBUG: {error_msg}")
return error_msg
# Create tools list and bound LLM
TOOLS = [search_memory, populate_memory]
CHAT_LLM = LLM.bind_tools(TOOLS)
# Lazy initialization to avoid circular imports
_retrieval_agent = None
def _get_retrieval_agent():
global _retrieval_agent
if _retrieval_agent is None:
_retrieval_agent = build_retrieval_graph()
return _retrieval_agent
def chat(
message: str,
history: list[dict],
) -> Generator[str, None, None]:
messages = [
SystemMessage(content="Whenever the user asks you a question, you must always use the search_memory tool first to look for relevant information in your memory. If you find relevant information, use it to answer the user's question. if you don't find any relevant information, answer the question to the best of your ability.")
]
for msg in history:
if msg["role"] == "user":
messages.append(HumanMessage(content=msg["content"]))
elif msg["role"] == "assistant":
messages.append(AIMessage(content=msg["content"]))
messages.append(HumanMessage(content=message))
max_iterations = 10
iteration = 0
while iteration < max_iterations:
iteration += 1
response = CHAT_LLM.invoke(messages)
messages.append(response)
if not response.tool_calls:
if response.content:
yield response.content
else:
yield "Done!"
return
tool_map = {t.name: t for t in TOOLS}
for tool_call in response.tool_calls:
tool_name = tool_call["name"]
tool_args = tool_call["args"]
yield f"🔧 Using {tool_name}..."
if tool_name in tool_map:
try:
result = tool_map[tool_name].invoke(tool_args)
except Exception as e:
result = f"Error: {str(e)}"
else:
result = f"Unknown tool: {tool_name}"
messages.append(
ToolMessage(
content=str(result),
tool_call_id=tool_call["id"],
)
)
yield "I processed your request but couldn't generate a final response."
|