Kishor Ramanan
commited on
Commit
·
0a25329
1
Parent(s):
6e68a92
Base
Browse files- .python-version +1 -0
- README.md +44 -1
- agent.py +210 -0
- app.py +20 -4
- clients.py +52 -0
- config.py +11 -0
- pyproject.toml +15 -0
- utility.py +144 -0
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
README.md
CHANGED
|
@@ -11,4 +11,47 @@ license: mit
|
|
| 11 |
short_description: Storing Memories and Agentic Retrieval with MCP
|
| 12 |
---
|
| 13 |
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
short_description: Storing Memories and Agentic Retrieval with MCP
|
| 12 |
---
|
| 13 |
|
| 14 |
+
# Central Memory Agent
|
| 15 |
+
|
| 16 |
+
Central Memory Agent is a Gradio-based chatbot application designed to store and retrieve information. It provides a user-friendly interface and exposes tools as MCP (Model Context Protocol) endpoints for seamless integration with MCP clients.
|
| 17 |
+
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
## Features
|
| 21 |
+
|
| 22 |
+
- **Chatbot Interface**: Interact with the memory system to store and retrieve information.
|
| 23 |
+
- **Memory Storage**: Add content to memory with metadata (category, topic).
|
| 24 |
+
- **Memory Retrieval**: Search stored information using agentic retrieval.
|
| 25 |
+
- **MCP Endpoints**: Access `populate_memory` and `search_memory` tools via MCP clients.
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## Usage
|
| 30 |
+
|
| 31 |
+
1. **Run the Application**:
|
| 32 |
+
```bash
|
| 33 |
+
uv run main.py
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
2. **Interact with the Chatbot**:
|
| 37 |
+
- Use the chatbot interface to store and retrieve memories.
|
| 38 |
+
|
| 39 |
+
3. **Connect MCP Clients**:
|
| 40 |
+
- Access the `populate_memory` and `search_memory` tools via MCP endpoints.
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## Project Structure
|
| 45 |
+
|
| 46 |
+
- `main.py`: The main application file that launches the Gradio interface.
|
| 47 |
+
- `utility.py`: Contains the `populate_memory` and `search_memory` tools.
|
| 48 |
+
- `agent.py`: Manages retrieval states and builds retrieval graphs.
|
| 49 |
+
- `clients.py`: Defines the language model and vector store clients.
|
| 50 |
+
- `pyproject.toml`: Project configuration and dependencies.
|
| 51 |
+
|
| 52 |
+
---
|
| 53 |
+
|
| 54 |
+
## Acknowledgments
|
| 55 |
+
|
| 56 |
+
- Built with [Gradio](https://gradio.app/).
|
| 57 |
+
- Powered by [LangChain](https://langchain.com/) and [Qdrant](https://qdrant.tech/).
|
agent.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Literal, Optional, TypedDict
|
| 2 |
+
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 5 |
+
from langgraph.graph import END, START, StateGraph
|
| 6 |
+
from pydantic import BaseModel, Field
|
| 7 |
+
from qdrant_client.http.models import (
|
| 8 |
+
FieldCondition,
|
| 9 |
+
Filter,
|
| 10 |
+
MatchValue,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
from clients import LLM, VECTOR_STORE
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class RetrievalState(TypedDict):
|
| 17 |
+
"""State for the agentic retrieval graph."""
|
| 18 |
+
|
| 19 |
+
original_query: str
|
| 20 |
+
current_query: str
|
| 21 |
+
category: Optional[str]
|
| 22 |
+
topic: Optional[str]
|
| 23 |
+
documents: List[Document]
|
| 24 |
+
relevant_documents: List[Document]
|
| 25 |
+
generation: str
|
| 26 |
+
retry_count: int
|
| 27 |
+
max_retries: int
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class GradeDocuments(BaseModel):
|
| 31 |
+
"""Grade whether a document is relevant to the query."""
|
| 32 |
+
|
| 33 |
+
is_relevant: Literal["yes", "no"] = Field(
|
| 34 |
+
description="Is the document relevant to the query? 'yes' or 'no'"
|
| 35 |
+
)
|
| 36 |
+
reason: str = Field(description="Brief reason for the relevance decision")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def retrieve_documents(state: RetrievalState) -> RetrievalState:
|
| 40 |
+
"""Retrieve documents from vector store."""
|
| 41 |
+
query = state["current_query"]
|
| 42 |
+
category = state.get("category")
|
| 43 |
+
topic = state.get("topic")
|
| 44 |
+
|
| 45 |
+
# Build Qdrant filter
|
| 46 |
+
conditions = []
|
| 47 |
+
if category:
|
| 48 |
+
conditions.append(
|
| 49 |
+
FieldCondition(key="metadata.category", match=MatchValue(value=category))
|
| 50 |
+
)
|
| 51 |
+
if topic:
|
| 52 |
+
conditions.append(
|
| 53 |
+
FieldCondition(key="metadata.topic", match=MatchValue(value=topic))
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
qdrant_filter = Filter(must=conditions) if conditions else None
|
| 57 |
+
|
| 58 |
+
documents = VECTOR_STORE.similarity_search(
|
| 59 |
+
query,
|
| 60 |
+
k=5,
|
| 61 |
+
filter=qdrant_filter,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
return {**state, "documents": documents}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def grade_documents(state: RetrievalState) -> RetrievalState:
|
| 68 |
+
"""Grade documents for relevance using LLM."""
|
| 69 |
+
query = state["original_query"]
|
| 70 |
+
documents = state["documents"]
|
| 71 |
+
|
| 72 |
+
if not documents:
|
| 73 |
+
return {**state, "relevant_documents": []}
|
| 74 |
+
|
| 75 |
+
# Create grader with structured output
|
| 76 |
+
grader_llm = LLM.with_structured_output(GradeDocuments)
|
| 77 |
+
|
| 78 |
+
grading_prompt = ChatPromptTemplate.from_messages(
|
| 79 |
+
[
|
| 80 |
+
(
|
| 81 |
+
"system",
|
| 82 |
+
"""You are a grader assessing relevance of a retrieved document to a user query.
|
| 83 |
+
|
| 84 |
+
If the document contains keywords or semantic meaning related to the query, grade it as relevant.
|
| 85 |
+
Be lenient - even partial relevance should be marked as 'yes'.
|
| 86 |
+
Only mark 'no' if the document is completely unrelated.""",
|
| 87 |
+
),
|
| 88 |
+
(
|
| 89 |
+
"human",
|
| 90 |
+
"""Query: {query}
|
| 91 |
+
|
| 92 |
+
Document content: {document}
|
| 93 |
+
|
| 94 |
+
Is this document relevant to the query?""",
|
| 95 |
+
),
|
| 96 |
+
]
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
relevant_docs = []
|
| 100 |
+
for doc in documents:
|
| 101 |
+
try:
|
| 102 |
+
result = grader_llm.invoke(
|
| 103 |
+
grading_prompt.format_messages(
|
| 104 |
+
query=query,
|
| 105 |
+
document=doc.page_content[:1000], # Limit content length
|
| 106 |
+
)
|
| 107 |
+
)
|
| 108 |
+
if result.is_relevant == "yes":
|
| 109 |
+
relevant_docs.append(doc)
|
| 110 |
+
except Exception:
|
| 111 |
+
# If grading fails, include the document (fail-safe)
|
| 112 |
+
relevant_docs.append(doc)
|
| 113 |
+
|
| 114 |
+
return {**state, "relevant_documents": relevant_docs}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def rewrite_query(state: RetrievalState) -> RetrievalState:
|
| 118 |
+
"""Rewrite the query for better retrieval."""
|
| 119 |
+
original_query = state["original_query"]
|
| 120 |
+
retry_count = state["retry_count"]
|
| 121 |
+
|
| 122 |
+
rewrite_prompt = ChatPromptTemplate.from_messages(
|
| 123 |
+
[
|
| 124 |
+
(
|
| 125 |
+
"system",
|
| 126 |
+
"""You are an expert at reformulating search queries.
|
| 127 |
+
Given the original query, generate a better search query that might retrieve more relevant documents.
|
| 128 |
+
|
| 129 |
+
Focus on:
|
| 130 |
+
- Extracting key concepts and entities
|
| 131 |
+
- Using synonyms or related terms
|
| 132 |
+
- Being more specific or more general as appropriate
|
| 133 |
+
|
| 134 |
+
Return ONLY the rewritten query, nothing else.""",
|
| 135 |
+
),
|
| 136 |
+
("human", "Original query: {query}\n\nRewritten query:"),
|
| 137 |
+
]
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
response = LLM.invoke(rewrite_prompt.format_messages(query=original_query))
|
| 141 |
+
|
| 142 |
+
new_query = response.content.strip()
|
| 143 |
+
|
| 144 |
+
return {
|
| 145 |
+
**state,
|
| 146 |
+
"current_query": new_query,
|
| 147 |
+
"retry_count": retry_count + 1,
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def generate_response(state: RetrievalState) -> RetrievalState:
|
| 152 |
+
"""Generate final response from relevant documents."""
|
| 153 |
+
relevant_docs = state["relevant_documents"]
|
| 154 |
+
|
| 155 |
+
if not relevant_docs:
|
| 156 |
+
return {**state, "generation": "No relevant memories found."}
|
| 157 |
+
|
| 158 |
+
# Format documents
|
| 159 |
+
formatted = []
|
| 160 |
+
for i, doc in enumerate(relevant_docs, 1):
|
| 161 |
+
meta = doc.metadata
|
| 162 |
+
formatted.append(
|
| 163 |
+
f"{i}. [{meta.get('category', 'N/A')}/{meta.get('topic', 'N/A')}]: {doc.page_content}"
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
return {**state, "generation": "\n".join(formatted)}
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def should_retry(state: RetrievalState) -> Literal["rewrite", "generate"]:
|
| 170 |
+
"""Decide whether to retry with a rewritten query."""
|
| 171 |
+
relevant_docs = state["relevant_documents"]
|
| 172 |
+
retry_count = state["retry_count"]
|
| 173 |
+
max_retries = state["max_retries"]
|
| 174 |
+
|
| 175 |
+
# If we have relevant docs, generate response
|
| 176 |
+
if relevant_docs:
|
| 177 |
+
return "generate"
|
| 178 |
+
|
| 179 |
+
# If no relevant docs and we can still retry, rewrite query
|
| 180 |
+
if retry_count < max_retries:
|
| 181 |
+
return "rewrite"
|
| 182 |
+
|
| 183 |
+
# Max retries reached, generate (empty) response
|
| 184 |
+
return "generate"
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def build_retrieval_graph():
|
| 188 |
+
workflow = StateGraph(RetrievalState)
|
| 189 |
+
|
| 190 |
+
# Add nodes
|
| 191 |
+
workflow.add_node("retrieve", retrieve_documents)
|
| 192 |
+
workflow.add_node("grade", grade_documents)
|
| 193 |
+
workflow.add_node("rewrite", rewrite_query)
|
| 194 |
+
workflow.add_node("generate", generate_response)
|
| 195 |
+
|
| 196 |
+
# Add edges
|
| 197 |
+
workflow.add_edge(START, "retrieve")
|
| 198 |
+
workflow.add_edge("retrieve", "grade")
|
| 199 |
+
workflow.add_conditional_edges(
|
| 200 |
+
"grade",
|
| 201 |
+
should_retry,
|
| 202 |
+
{
|
| 203 |
+
"rewrite": "rewrite",
|
| 204 |
+
"generate": "generate",
|
| 205 |
+
},
|
| 206 |
+
)
|
| 207 |
+
workflow.add_edge("rewrite", "retrieve")
|
| 208 |
+
workflow.add_edge("generate", END)
|
| 209 |
+
|
| 210 |
+
return workflow.compile()
|
app.py
CHANGED
|
@@ -1,7 +1,23 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
|
| 3 |
-
|
| 4 |
-
return "Hello " + name + "!!"
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
|
| 3 |
+
from utility import chat, populate_memory, search_memory
|
|
|
|
| 4 |
|
| 5 |
+
with gr.Blocks(title="Central Memory") as app:
|
| 6 |
+
gr.ChatInterface(
|
| 7 |
+
fn=chat,
|
| 8 |
+
title="Central Memory ChatBot",
|
| 9 |
+
examples=[
|
| 10 |
+
"Remember that my favorite color is blue",
|
| 11 |
+
"Store this: I'm learning Rust to make an OS",
|
| 12 |
+
"Search the memorie about learning rust",
|
| 13 |
+
],
|
| 14 |
+
api_visibility="private",
|
| 15 |
+
)
|
| 16 |
+
gr.api(populate_memory.func)
|
| 17 |
+
gr.api(search_memory.func)
|
| 18 |
+
gr.Markdown("""---
|
| 19 |
+
**Note:** `search_memory` using agentic retrieval. This application exposes all tools as MCP endpoints.
|
| 20 |
+
Connect your MCP client to this server to access the `populate_memory` and `search_memory` tools.
|
| 21 |
+
""")
|
| 22 |
+
|
| 23 |
+
app.launch(mcp_server=True, share=False, theme=gr.themes.Soft())
|
clients.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
|
| 2 |
+
from langchain_qdrant import QdrantVectorStore
|
| 3 |
+
from qdrant_client import QdrantClient
|
| 4 |
+
from qdrant_client.http.models import (
|
| 5 |
+
Distance,
|
| 6 |
+
VectorParams,
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
from config import (
|
| 10 |
+
COLLECTION_NAME,
|
| 11 |
+
OPENAI_API_KEY,
|
| 12 |
+
OPENAI_BASE_URL,
|
| 13 |
+
QDRANT_API_KEY,
|
| 14 |
+
QDRANT_URL,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
EMBEDDING = OpenAIEmbeddings(
|
| 18 |
+
openai_api_key=OPENAI_API_KEY,
|
| 19 |
+
openai_api_base=OPENAI_BASE_URL,
|
| 20 |
+
model="Qwen/Qwen3-Embedding-8B",
|
| 21 |
+
check_embedding_ctx_length=False,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
QDRANT_CLIENT = QdrantClient(
|
| 25 |
+
url=QDRANT_URL,
|
| 26 |
+
api_key=QDRANT_API_KEY,
|
| 27 |
+
port=443,
|
| 28 |
+
https=True,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
if not QDRANT_CLIENT.collection_exists(COLLECTION_NAME):
|
| 32 |
+
QDRANT_CLIENT.create_collection(
|
| 33 |
+
collection_name=COLLECTION_NAME,
|
| 34 |
+
vectors_config=VectorParams(
|
| 35 |
+
size=4096,
|
| 36 |
+
distance=Distance.COSINE,
|
| 37 |
+
),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
VECTOR_STORE = QdrantVectorStore(
|
| 41 |
+
client=QDRANT_CLIENT,
|
| 42 |
+
collection_name=COLLECTION_NAME,
|
| 43 |
+
embedding=EMBEDDING,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
LLM = ChatOpenAI(
|
| 47 |
+
openai_api_key=OPENAI_API_KEY,
|
| 48 |
+
openai_api_base=OPENAI_BASE_URL,
|
| 49 |
+
model="openai/gpt-oss-120b",
|
| 50 |
+
temperature=0.3,
|
| 51 |
+
streaming=True,
|
| 52 |
+
)
|
config.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
|
| 5 |
+
load_dotenv()
|
| 6 |
+
|
| 7 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 8 |
+
OPENAI_BASE_URL = os.getenv("OPENAI_BASE_URL")
|
| 9 |
+
QDRANT_URL = os.getenv("QDRANT_URL")
|
| 10 |
+
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
| 11 |
+
COLLECTION_NAME = "memories"
|
pyproject.toml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "central-memory"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.12"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"gradio[mcp]>=6.0.1",
|
| 9 |
+
"langchain-openai>=1.1.0",
|
| 10 |
+
"langchain-qdrant>=1.1.0",
|
| 11 |
+
"langgraph>=1.0.4",
|
| 12 |
+
"mcp>=1.22.0",
|
| 13 |
+
"python-dotenv>=1.2.1",
|
| 14 |
+
"qdrant-client>=1.16.1",
|
| 15 |
+
]
|
utility.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Generator, Optional
|
| 2 |
+
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage, SystemMessage
|
| 5 |
+
from langchain_core.tools import tool
|
| 6 |
+
|
| 7 |
+
from agent import RetrievalState, build_retrieval_graph
|
| 8 |
+
from clients import LLM, VECTOR_STORE
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@tool
|
| 12 |
+
def populate_memory(
|
| 13 |
+
content: str,
|
| 14 |
+
category: str,
|
| 15 |
+
topic: str,
|
| 16 |
+
) -> str:
|
| 17 |
+
"""Add content with metadata to the memory for later retrieval. Use this to store important information the user wants to remember.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
content: The content to store in memory
|
| 21 |
+
category: Category of the memory (e.g., 'personal', 'work', 'learning')
|
| 22 |
+
topic: Specific topic of the memory
|
| 23 |
+
"""
|
| 24 |
+
VECTOR_STORE.add_documents(
|
| 25 |
+
documents=[
|
| 26 |
+
Document(
|
| 27 |
+
page_content=content, metadata={"category": category, "topic": topic}
|
| 28 |
+
)
|
| 29 |
+
]
|
| 30 |
+
)
|
| 31 |
+
return f"Successfully stored memory about '{topic}' in category '{category}'"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@tool
|
| 35 |
+
def search_memory(
|
| 36 |
+
query: str,
|
| 37 |
+
category: Optional[str] = None,
|
| 38 |
+
topic: Optional[str] = None,
|
| 39 |
+
) -> str:
|
| 40 |
+
"""Search and retrieve relevant information from memory using intelligent agentic retrieval.
|
| 41 |
+
|
| 42 |
+
This tool uses advanced retrieval with:
|
| 43 |
+
- Document relevance grading
|
| 44 |
+
- Automatic query rewriting if no relevant results found
|
| 45 |
+
- Self-correction with retry logic
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
query: The search query to find relevant memories
|
| 49 |
+
category: Optional category filter
|
| 50 |
+
topic: Optional topic filter
|
| 51 |
+
"""
|
| 52 |
+
try:
|
| 53 |
+
initial_state: RetrievalState = {
|
| 54 |
+
"original_query": query,
|
| 55 |
+
"current_query": query,
|
| 56 |
+
"category": category,
|
| 57 |
+
"topic": topic,
|
| 58 |
+
"documents": [],
|
| 59 |
+
"relevant_documents": [],
|
| 60 |
+
"generation": "",
|
| 61 |
+
"retry_count": 0,
|
| 62 |
+
"max_retries": 2, # Allow up to 2 query rewrites
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
final_state = _get_retrieval_agent().invoke(initial_state)
|
| 66 |
+
result = final_state["generation"]
|
| 67 |
+
|
| 68 |
+
return result
|
| 69 |
+
except Exception as e:
|
| 70 |
+
error_msg = f"Error in search_memory: {str(e)}"
|
| 71 |
+
print(f"DEBUG: {error_msg}")
|
| 72 |
+
return error_msg
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Create tools list and bound LLM
|
| 76 |
+
TOOLS = [search_memory, populate_memory]
|
| 77 |
+
CHAT_LLM = LLM.bind_tools(TOOLS)
|
| 78 |
+
|
| 79 |
+
# Lazy initialization to avoid circular imports
|
| 80 |
+
_retrieval_agent = None
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _get_retrieval_agent():
|
| 84 |
+
global _retrieval_agent
|
| 85 |
+
if _retrieval_agent is None:
|
| 86 |
+
_retrieval_agent = build_retrieval_graph()
|
| 87 |
+
return _retrieval_agent
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def chat(
|
| 91 |
+
message: str,
|
| 92 |
+
history: list[dict],
|
| 93 |
+
) -> Generator[str, None, None]:
|
| 94 |
+
messages = [
|
| 95 |
+
SystemMessage(content="Whenever the user asks you a question, you must always use the search_memory tool first to look for relevant information in your memory. If you find relevant information, use it to answer the user's question. if you don't find any relevant information, answer the question to the best of your ability.")
|
| 96 |
+
]
|
| 97 |
+
for msg in history:
|
| 98 |
+
if msg["role"] == "user":
|
| 99 |
+
messages.append(HumanMessage(content=msg["content"]))
|
| 100 |
+
elif msg["role"] == "assistant":
|
| 101 |
+
messages.append(AIMessage(content=msg["content"]))
|
| 102 |
+
|
| 103 |
+
messages.append(HumanMessage(content=message))
|
| 104 |
+
|
| 105 |
+
max_iterations = 10
|
| 106 |
+
iteration = 0
|
| 107 |
+
|
| 108 |
+
while iteration < max_iterations:
|
| 109 |
+
iteration += 1
|
| 110 |
+
|
| 111 |
+
response = CHAT_LLM.invoke(messages)
|
| 112 |
+
messages.append(response)
|
| 113 |
+
|
| 114 |
+
if not response.tool_calls:
|
| 115 |
+
if response.content:
|
| 116 |
+
yield response.content
|
| 117 |
+
else:
|
| 118 |
+
yield "Done!"
|
| 119 |
+
return
|
| 120 |
+
|
| 121 |
+
tool_map = {t.name: t for t in TOOLS}
|
| 122 |
+
|
| 123 |
+
for tool_call in response.tool_calls:
|
| 124 |
+
tool_name = tool_call["name"]
|
| 125 |
+
tool_args = tool_call["args"]
|
| 126 |
+
|
| 127 |
+
yield f"🔧 Using {tool_name}..."
|
| 128 |
+
|
| 129 |
+
if tool_name in tool_map:
|
| 130 |
+
try:
|
| 131 |
+
result = tool_map[tool_name].invoke(tool_args)
|
| 132 |
+
except Exception as e:
|
| 133 |
+
result = f"Error: {str(e)}"
|
| 134 |
+
else:
|
| 135 |
+
result = f"Unknown tool: {tool_name}"
|
| 136 |
+
|
| 137 |
+
messages.append(
|
| 138 |
+
ToolMessage(
|
| 139 |
+
content=str(result),
|
| 140 |
+
tool_call_id=tool_call["id"],
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
yield "I processed your request but couldn't generate a final response."
|