File size: 4,428 Bytes
0a25329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from typing import Generator, Optional

from langchain_core.documents import Document
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage, SystemMessage
from langchain_core.tools import tool

from agent import RetrievalState, build_retrieval_graph
from clients import LLM, VECTOR_STORE


@tool
def populate_memory(
    content: str,
    category: str,
    topic: str,
) -> str:
    """Add content with metadata to the memory for later retrieval. Use this to store important information the user wants to remember.

    Args:
        content: The content to store in memory
        category: Category of the memory (e.g., 'personal', 'work', 'learning')
        topic: Specific topic of the memory
    """
    VECTOR_STORE.add_documents(
        documents=[
            Document(
                page_content=content, metadata={"category": category, "topic": topic}
            )
        ]
    )
    return f"Successfully stored memory about '{topic}' in category '{category}'"


@tool
def search_memory(
    query: str,
    category: Optional[str] = None,
    topic: Optional[str] = None,
) -> str:
    """Search and retrieve relevant information from memory using intelligent agentic retrieval.

    This tool uses advanced retrieval with:
    - Document relevance grading
    - Automatic query rewriting if no relevant results found
    - Self-correction with retry logic

    Args:
        query: The search query to find relevant memories
        category: Optional category filter
        topic: Optional topic filter
    """
    try:
        initial_state: RetrievalState = {
            "original_query": query,
            "current_query": query,
            "category": category,
            "topic": topic,
            "documents": [],
            "relevant_documents": [],
            "generation": "",
            "retry_count": 0,
            "max_retries": 2,  # Allow up to 2 query rewrites
        }

        final_state = _get_retrieval_agent().invoke(initial_state)
        result = final_state["generation"]

        return result
    except Exception as e:
        error_msg = f"Error in search_memory: {str(e)}"
        print(f"DEBUG: {error_msg}")
        return error_msg


# Create tools list and bound LLM
TOOLS = [search_memory, populate_memory]
CHAT_LLM = LLM.bind_tools(TOOLS)

# Lazy initialization to avoid circular imports
_retrieval_agent = None


def _get_retrieval_agent():
    global _retrieval_agent
    if _retrieval_agent is None:
        _retrieval_agent = build_retrieval_graph()
    return _retrieval_agent


def chat(
    message: str,
    history: list[dict],
) -> Generator[str, None, None]:
    messages = [
        SystemMessage(content="Whenever the user asks you a question, you must always use the search_memory tool first to look for relevant information in your memory. If you find relevant information, use it to answer the user's question. if you don't find any relevant information, answer the question to the best of your ability.")
    ]
    for msg in history:
        if msg["role"] == "user":
            messages.append(HumanMessage(content=msg["content"]))
        elif msg["role"] == "assistant":
            messages.append(AIMessage(content=msg["content"]))

    messages.append(HumanMessage(content=message))

    max_iterations = 10
    iteration = 0

    while iteration < max_iterations:
        iteration += 1

        response = CHAT_LLM.invoke(messages)
        messages.append(response)

        if not response.tool_calls:
            if response.content:
                yield response.content
            else:
                yield "Done!"
            return

        tool_map = {t.name: t for t in TOOLS}

        for tool_call in response.tool_calls:
            tool_name = tool_call["name"]
            tool_args = tool_call["args"]

            yield f"🔧 Using {tool_name}..."

            if tool_name in tool_map:
                try:
                    result = tool_map[tool_name].invoke(tool_args)
                except Exception as e:
                    result = f"Error: {str(e)}"
            else:
                result = f"Unknown tool: {tool_name}"

            messages.append(
                ToolMessage(
                    content=str(result),
                    tool_call_id=tool_call["id"],
                )
            )

    yield "I processed your request but couldn't generate a final response."