Spaces:
Sleeping
Sleeping
Commit ·
6a0d993
1
Parent(s): fdab809
add chapter 6 files
Browse files- chapter_06_memory/01_session_agent.py +56 -0
- chapter_06_memory/02_core_memory_strategy.py +35 -0
- chapter_06_memory/03_core_memory_update.py +32 -0
- chapter_06_memory/04_sliding_window.py +43 -0
- chapter_06_memory/05_summarization.py +64 -0
- chapter_06_memory/06_conversation_search.py +59 -0
- chapter_06_memory/07_task_long_term.py +76 -0
- chapter_06_memory/08_user_long_term.py +98 -0
- pyproject.toml +1 -0
- scratch_agents/agents/execution_context_ch6.py +34 -0
- scratch_agents/agents/tool_calling_agent_ch4_base.py +2 -2
- scratch_agents/agents/tool_calling_agent_ch4_callback.py +1 -1
- scratch_agents/agents/tool_calling_agent_ch4_structured_output.py +2 -2
- scratch_agents/agents/tool_calling_agent_ch6.py +226 -0
- scratch_agents/memory/base_memory_strategy.py +13 -0
- scratch_agents/memory/core_memory_strategy.py +21 -0
- scratch_agents/memory/sliding_window_strategy.py +26 -0
- scratch_agents/memory/summarization_strategy.py +77 -0
- scratch_agents/models/openai.py +29 -3
- scratch_agents/sessions/base_cross_session_manager.py +297 -0
- scratch_agents/sessions/base_session_manager.py +28 -0
- scratch_agents/sessions/in_memory_session_manager.py +30 -0
- scratch_agents/sessions/session.py +23 -0
- scratch_agents/sessions/task_cross_session_manager.py +194 -0
- scratch_agents/sessions/user_cross_session_manager.py +185 -0
- scratch_agents/tools/base_tool.py +30 -3
- scratch_agents/tools/conversation_search.py +49 -0
- scratch_agents/tools/core_memory_upsert.py +33 -0
- scratch_agents/tools/function_tool.py +49 -9
- scratch_agents/types/contents.py +1 -1
- uv.lock +0 -0
chapter_06_memory/01_session_agent.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from scratch_agents.models.openai import OpenAILlm
|
| 3 |
+
from scratch_agents.agents.tool_calling_agent_ch6 import ToolCallingAgent
|
| 4 |
+
from scratch_agents.tools import calculator, search_web
|
| 5 |
+
from scratch_agents.sessions.in_memory_session_manager import InMemorySessionManager
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
async def main():
|
| 13 |
+
"""Demonstrate session memory functionality"""
|
| 14 |
+
user_id = "test_123"
|
| 15 |
+
# Initialize components
|
| 16 |
+
model = OpenAILlm(model='gpt-5-mini')
|
| 17 |
+
tools = [calculator, search_web]
|
| 18 |
+
|
| 19 |
+
# Create agent with session manager
|
| 20 |
+
agent = ToolCallingAgent(
|
| 21 |
+
name="session_assistant",
|
| 22 |
+
model=model,
|
| 23 |
+
instructions="You are a helpful assistant that remembers our conversations.",
|
| 24 |
+
tools=tools,
|
| 25 |
+
session_manager=InMemorySessionManager()
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# First interaction - session 1
|
| 29 |
+
print("=== First Interaction (Session 1) ===")
|
| 30 |
+
answer1 = await agent.run(
|
| 31 |
+
"My name is Alice and I'm working on Project Alpha. What's 123 * 456?",
|
| 32 |
+
session_id="session_1",
|
| 33 |
+
user_id=user_id
|
| 34 |
+
)
|
| 35 |
+
print(f"Assistant: {answer1}\n")
|
| 36 |
+
|
| 37 |
+
# Second interaction - continue session 1
|
| 38 |
+
print("=== Second Interaction (Session 1) ===")
|
| 39 |
+
answer2 = await agent.run(
|
| 40 |
+
"What project am I working on and what was the result of the multiplication I asked about?",
|
| 41 |
+
session_id="session_1",
|
| 42 |
+
user_id=user_id
|
| 43 |
+
)
|
| 44 |
+
print(f"Assistant: {answer2}\n")
|
| 45 |
+
|
| 46 |
+
# New session - session 2
|
| 47 |
+
print("=== New Session (Session 2) ===")
|
| 48 |
+
answer3 = await agent.run(
|
| 49 |
+
"Do you remember my name?",
|
| 50 |
+
session_id="session_2",
|
| 51 |
+
user_id=user_id
|
| 52 |
+
)
|
| 53 |
+
print(f"Assistant: {answer3}\n")
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
asyncio.run(main())
|
chapter_06_memory/02_core_memory_strategy.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from scratch_agents.agents.tool_calling_agent_ch6 import ToolCallingAgent
|
| 3 |
+
from scratch_agents.models.openai import OpenAILlm
|
| 4 |
+
from scratch_agents.sessions.in_memory_session_manager import InMemorySessionManager
|
| 5 |
+
from scratch_agents.memory.core_memory_strategy import CoreMemoryStrategy
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
async def test_core_memory_loading():
|
| 13 |
+
user_id = "test_123"
|
| 14 |
+
session_id = "test_session"
|
| 15 |
+
session_manager = InMemorySessionManager()
|
| 16 |
+
session = session_manager.get_or_create_session(session_id, user_id)
|
| 17 |
+
session.core_memory["user"] = "User's name is Alice"
|
| 18 |
+
|
| 19 |
+
agent = ToolCallingAgent(
|
| 20 |
+
name="memory_agent",
|
| 21 |
+
model=OpenAILlm(model="gpt-5-mini"),
|
| 22 |
+
instructions="You are a helpful assistant",
|
| 23 |
+
session_manager=session_manager,
|
| 24 |
+
before_llm_callbacks=[CoreMemoryStrategy()]
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
response = await agent.run(
|
| 28 |
+
"What's my name?",
|
| 29 |
+
session_id=session_id,
|
| 30 |
+
user_id=user_id
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
print(response)
|
| 34 |
+
|
| 35 |
+
asyncio.run(test_core_memory_loading())
|
chapter_06_memory/03_core_memory_update.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from scratch_agents.agents.tool_calling_agent_ch6 import ToolCallingAgent
|
| 2 |
+
from scratch_agents.models.openai import OpenAILlm
|
| 3 |
+
from scratch_agents.sessions.in_memory_session_manager import InMemorySessionManager
|
| 4 |
+
from scratch_agents.tools.core_memory_upsert import core_memory_upsert
|
| 5 |
+
import asyncio
|
| 6 |
+
from dotenv import load_dotenv
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
+
user_id = "test_123"
|
| 12 |
+
session_id = "test_session"
|
| 13 |
+
|
| 14 |
+
async def test_automatic_memory_update():
|
| 15 |
+
agent = ToolCallingAgent(
|
| 16 |
+
name="learning_agent",
|
| 17 |
+
model=OpenAILlm(model="gpt-5-mini"),
|
| 18 |
+
instructions="Remember important user info with core_memory_upsert",
|
| 19 |
+
tools=[core_memory_upsert],
|
| 20 |
+
session_manager=InMemorySessionManager(),
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
await agent.run(
|
| 24 |
+
"Hi! My name is Alice and I work as a data scientist.",
|
| 25 |
+
session_id=session_id,
|
| 26 |
+
user_id=user_id
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
session = agent.session_manager.get_session(session_id)
|
| 30 |
+
print(session.core_memory['user'])
|
| 31 |
+
|
| 32 |
+
asyncio.run(test_automatic_memory_update())
|
chapter_06_memory/04_sliding_window.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from scratch_agents.agents.tool_calling_agent_ch6 import ToolCallingAgent
|
| 3 |
+
from scratch_agents.models.openai import OpenAILlm
|
| 4 |
+
from scratch_agents.sessions.in_memory_session_manager import InMemorySessionManager
|
| 5 |
+
from scratch_agents.memory.sliding_window_strategy import SlidingWindowStrategy
|
| 6 |
+
from scratch_agents.types.contents import Message
|
| 7 |
+
from scratch_agents.types.events import Event
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
user_id = "test_123"
|
| 14 |
+
session_id = "test_session"
|
| 15 |
+
|
| 16 |
+
async def test_sliding_window():
|
| 17 |
+
|
| 18 |
+
session_manager = InMemorySessionManager()
|
| 19 |
+
session = session_manager.create_session(session_id, user_id)
|
| 20 |
+
|
| 21 |
+
session.events.append(Event(
|
| 22 |
+
execution_id="exec1",
|
| 23 |
+
author="user",
|
| 24 |
+
content=[Message(role="user", content="My name is Alice"),
|
| 25 |
+
Message(role="user", content="I live in Korea")]
|
| 26 |
+
))
|
| 27 |
+
|
| 28 |
+
agent = ToolCallingAgent(
|
| 29 |
+
name="window_agent",
|
| 30 |
+
model=OpenAILlm(model="gpt-5-mini"),
|
| 31 |
+
instructions="You are a helpful assistant",
|
| 32 |
+
session_manager=session_manager,
|
| 33 |
+
before_llm_callbacks=[SlidingWindowStrategy(max_messages=2)]
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
response = await agent.run(
|
| 37 |
+
"What's my name?",
|
| 38 |
+
session_id=session_id,
|
| 39 |
+
user_id=user_id
|
| 40 |
+
)
|
| 41 |
+
print(response)
|
| 42 |
+
|
| 43 |
+
asyncio.run(test_sliding_window())
|
chapter_06_memory/05_summarization.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from scratch_agents.agents.tool_calling_agent_ch6 import ToolCallingAgent
|
| 2 |
+
from scratch_agents.models.openai import OpenAILlm
|
| 3 |
+
from scratch_agents.sessions.in_memory_session_manager import InMemorySessionManager
|
| 4 |
+
from scratch_agents.memory.summarization_strategy import SummarizationStrategy
|
| 5 |
+
from scratch_agents.types.contents import Message
|
| 6 |
+
from scratch_agents.types.events import Event
|
| 7 |
+
import asyncio
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
load_dotenv()
|
| 12 |
+
|
| 13 |
+
user_id = "test_123"
|
| 14 |
+
session_id = "test_session"
|
| 15 |
+
|
| 16 |
+
async def test_summarization_strategy():
|
| 17 |
+
"""Demonstrate summarization strategy in action"""
|
| 18 |
+
|
| 19 |
+
model = OpenAILlm(model="gpt-5-mini")
|
| 20 |
+
session_manager = InMemorySessionManager()
|
| 21 |
+
session = session_manager.create_session(session_id, user_id)
|
| 22 |
+
|
| 23 |
+
messages = [
|
| 24 |
+
Message(role="user", content="Hi, I'm Bob"),
|
| 25 |
+
Message(role="assistant", content="Nice to meet you, Bob!"),
|
| 26 |
+
Message(role="user", content="I work as a teacher"),
|
| 27 |
+
Message(role="assistant", content="Wow! What subject?"),
|
| 28 |
+
Message(role="user", content="I teach math"),
|
| 29 |
+
Message(role="assistant", content="Math is important!"),
|
| 30 |
+
Message(role="user", content="I have 30 students"),
|
| 31 |
+
Message(role="assistant", content="That's a good class size"),
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
for msg in messages:
|
| 35 |
+
event = Event(
|
| 36 |
+
execution_id="test_exec",
|
| 37 |
+
author="test",
|
| 38 |
+
content=[msg]
|
| 39 |
+
)
|
| 40 |
+
session.events.append(event)
|
| 41 |
+
|
| 42 |
+
agent = ToolCallingAgent(
|
| 43 |
+
name="summary_agent",
|
| 44 |
+
model=model,
|
| 45 |
+
instructions="You are a helpful assistant",
|
| 46 |
+
session_manager=session_manager,
|
| 47 |
+
before_llm_callbacks=[
|
| 48 |
+
SummarizationStrategy(model=model, trigger_count=8, keep_recent=2)
|
| 49 |
+
]
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
response = await agent.run(
|
| 53 |
+
"What subject do I teach?",
|
| 54 |
+
session_id=session_id,
|
| 55 |
+
user_id=user_id
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
if "conversation_summary" in session.state:
|
| 59 |
+
print(f"Summary: {session.state['conversation_summary']}")
|
| 60 |
+
print(f"Summary Index: {session.state['last_summarized_index']}")
|
| 61 |
+
|
| 62 |
+
print(f"\nAgent response: {response}")
|
| 63 |
+
|
| 64 |
+
asyncio.run(test_summarization_strategy())
|
chapter_06_memory/06_conversation_search.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from scratch_agents.agents.tool_calling_agent_ch6 import ToolCallingAgent
|
| 2 |
+
from scratch_agents.models.openai import OpenAILlm
|
| 3 |
+
from scratch_agents.sessions.in_memory_session_manager import InMemorySessionManager
|
| 4 |
+
from scratch_agents.memory.sliding_window_strategy import SlidingWindowStrategy
|
| 5 |
+
from scratch_agents.tools.conversation_search import conversation_search
|
| 6 |
+
from scratch_agents.types.contents import Message
|
| 7 |
+
from scratch_agents.types.events import Event
|
| 8 |
+
import asyncio
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
user_id = "test_123"
|
| 15 |
+
session_id = "test_session"
|
| 16 |
+
|
| 17 |
+
async def test_search_with_sliding_window():
|
| 18 |
+
"""Demonstrate search recovering information lost to sliding window"""
|
| 19 |
+
|
| 20 |
+
model = OpenAILlm(model="gpt-5-mini")
|
| 21 |
+
session_manager = InMemorySessionManager()
|
| 22 |
+
session = session_manager.create_session(session_id, user_id)
|
| 23 |
+
|
| 24 |
+
conversation_history = [
|
| 25 |
+
("user", "My golden retriever puppy is named Max."),
|
| 26 |
+
("assistant", "Max is a lovely name for a golden retriever!"),
|
| 27 |
+
("user", "He loves playing fetch in the park."),
|
| 28 |
+
("assistant", "That's wonderful! Golden retrievers are great at fetch."),
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
for role, content in conversation_history:
|
| 32 |
+
event = Event(
|
| 33 |
+
execution_id="pre_loaded",
|
| 34 |
+
author=role,
|
| 35 |
+
content=[Message(role=role, content=content)]
|
| 36 |
+
)
|
| 37 |
+
session.events.append(event)
|
| 38 |
+
|
| 39 |
+
agent = ToolCallingAgent(
|
| 40 |
+
name="search_agent",
|
| 41 |
+
model=model,
|
| 42 |
+
instructions="""You are a helpful assistant. When asked about
|
| 43 |
+
information from earlier in our conversation, use the
|
| 44 |
+
conversation_search tool to find it.""",
|
| 45 |
+
tools=[conversation_search],
|
| 46 |
+
session_manager=session_manager,
|
| 47 |
+
before_llm_callbacks=[
|
| 48 |
+
SlidingWindowStrategy(max_messages=2)
|
| 49 |
+
]
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
response = await agent.run(
|
| 53 |
+
"What was my puppy's name?",
|
| 54 |
+
session_id=session_id,
|
| 55 |
+
user_id=user_id
|
| 56 |
+
)
|
| 57 |
+
print(f"Agent: {response}\n")
|
| 58 |
+
|
| 59 |
+
asyncio.run(test_search_with_sliding_window())
|
chapter_06_memory/07_task_long_term.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from scratch_agents.agents.execution_context_ch6 import ExecutionContext
|
| 2 |
+
from scratch_agents.agents.tool_calling_agent_ch6 import ToolCallingAgent
|
| 3 |
+
from scratch_agents.models.openai import OpenAILlm
|
| 4 |
+
from scratch_agents.sessions.task_cross_session_manager import TaskCrossSessionManager
|
| 5 |
+
from scratch_agents.sessions.in_memory_session_manager import InMemorySessionManager
|
| 6 |
+
from scratch_agents.tools.base_tool import BaseTool
|
| 7 |
+
from scratch_agents.models.llm_request import LlmRequest
|
| 8 |
+
from scratch_agents.tools.search_web import search_web
|
| 9 |
+
import asyncio
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
import os
|
| 12 |
+
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
user_id = "test_123"
|
| 16 |
+
|
| 17 |
+
async def long_term_memory_save_callback(context:ExecutionContext):
|
| 18 |
+
cross_session_manager = context.cross_session_manager
|
| 19 |
+
session = context.session
|
| 20 |
+
execution_id = context.execution_id
|
| 21 |
+
|
| 22 |
+
await cross_session_manager.process_session(session=session, execution_id=execution_id)
|
| 23 |
+
|
| 24 |
+
class MemorySearchTool(BaseTool):
|
| 25 |
+
async def execute(self, context, **kwargs):
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
async def process_llm_request(self, request: LlmRequest, context: ExecutionContext):
|
| 29 |
+
user_input = context.user_input
|
| 30 |
+
user_id = context.session.user_id
|
| 31 |
+
results = await context.cross_session_manager.search(user_input, user_id)
|
| 32 |
+
if results:
|
| 33 |
+
request.add_instructions(f"Use the following task memory to answer the user's question: {results}")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
async def test_long_term_memory_save():
|
| 37 |
+
"""Test long-term memory saving with a meaningful conversation"""
|
| 38 |
+
|
| 39 |
+
session_manager = InMemorySessionManager()
|
| 40 |
+
model = OpenAILlm(model="gpt-5-mini")
|
| 41 |
+
cross_session_manager = TaskCrossSessionManager(model=model)
|
| 42 |
+
|
| 43 |
+
memory_search_tool = MemorySearchTool()
|
| 44 |
+
|
| 45 |
+
agent = ToolCallingAgent(
|
| 46 |
+
name="memory_agent",
|
| 47 |
+
model=model,
|
| 48 |
+
instructions="You are a helpful assistant. Have a natural conversation and learn about the user's task. IMPORTANT: When the user asks about a specific term or technology, use the search results to provide a comprehensive answer. Do NOT ask for clarification if you find relevant search results. Only ask for clarification if search returns no results or the query is truly impossible to understand. If multiple meanings exist, provide information about the most common or relevant one based on the search results.",
|
| 49 |
+
tools=[search_web, memory_search_tool],
|
| 50 |
+
session_manager=session_manager,
|
| 51 |
+
cross_session_manager=cross_session_manager,
|
| 52 |
+
after_run_callbacks=[long_term_memory_save_callback]
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
print("=== Testing Long-term Memory Save ===\n")
|
| 56 |
+
|
| 57 |
+
test_conversations = [
|
| 58 |
+
"What is Mem0?",
|
| 59 |
+
"How does mem0 work?"
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
for i, message in enumerate(test_conversations, 1):
|
| 63 |
+
print(f"User: {message}")
|
| 64 |
+
session_id = f"test_session_{i}"
|
| 65 |
+
|
| 66 |
+
response = await agent.run(
|
| 67 |
+
message,
|
| 68 |
+
session_id=session_id,
|
| 69 |
+
user_id=user_id
|
| 70 |
+
)
|
| 71 |
+
print(response)
|
| 72 |
+
# print(cross_session_manager.collection.peek())
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
asyncio.run(test_long_term_memory_save())
|
| 76 |
+
|
chapter_06_memory/08_user_long_term.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from scratch_agents.agents.execution_context_ch6 import ExecutionContext
|
| 2 |
+
from scratch_agents.agents.tool_calling_agent_ch6 import ToolCallingAgent
|
| 3 |
+
from scratch_agents.models.openai import OpenAILlm
|
| 4 |
+
from scratch_agents.sessions.user_cross_session_manager import UserCrossSessionManager
|
| 5 |
+
from scratch_agents.sessions.in_memory_session_manager import InMemorySessionManager
|
| 6 |
+
from scratch_agents.tools.base_tool import BaseTool
|
| 7 |
+
from scratch_agents.models.llm_request import LlmRequest
|
| 8 |
+
import asyncio
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
user_id = "test_user_123"
|
| 15 |
+
|
| 16 |
+
async def user_memory_save_callback(context: ExecutionContext):
|
| 17 |
+
"""Callback to save user memories after each interaction"""
|
| 18 |
+
cross_session_manager = context.cross_session_manager
|
| 19 |
+
session = context.session
|
| 20 |
+
execution_id = context.execution_id
|
| 21 |
+
|
| 22 |
+
await cross_session_manager.process_session(session=session, execution_id=execution_id)
|
| 23 |
+
|
| 24 |
+
class UserMemorySearchTool(BaseTool):
|
| 25 |
+
"""Tool to search and retrieve user memories"""
|
| 26 |
+
async def execute(self, context, **kwargs):
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
async def process_llm_request(self, request: LlmRequest, context: ExecutionContext):
|
| 30 |
+
user_id = context.session.user_id
|
| 31 |
+
# Get all existing memories for the user
|
| 32 |
+
all_memories = await context.cross_session_manager.find_existing([], user_id)
|
| 33 |
+
if all_memories:
|
| 34 |
+
memory_contents = [mem['content'] for mem in all_memories]
|
| 35 |
+
memory_text = "\n".join(f"- {content}" for content in memory_contents)
|
| 36 |
+
request.add_instructions(f"You have the following memories about this user:\n{memory_text}\n\nUse these memories to personalize your responses.")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
async def test_user_long_term_memory():
|
| 40 |
+
"""Test user long-term memory with location updates"""
|
| 41 |
+
|
| 42 |
+
session_manager = InMemorySessionManager()
|
| 43 |
+
model = OpenAILlm(model="gpt-4o-mini")
|
| 44 |
+
cross_session_manager = UserCrossSessionManager(model=model)
|
| 45 |
+
|
| 46 |
+
memory_search_tool = UserMemorySearchTool()
|
| 47 |
+
|
| 48 |
+
agent = ToolCallingAgent(
|
| 49 |
+
name="user_memory_agent",
|
| 50 |
+
model=model,
|
| 51 |
+
instructions="You are a helpful assistant that remembers information about the user. Have natural conversations and acknowledge what you know about the user when relevant.",
|
| 52 |
+
tools=[memory_search_tool],
|
| 53 |
+
session_manager=session_manager,
|
| 54 |
+
cross_session_manager=cross_session_manager,
|
| 55 |
+
after_run_callbacks=[user_memory_save_callback]
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
print("=== Testing User Long-term Memory ===\n")
|
| 59 |
+
|
| 60 |
+
# Test conversation about location changes
|
| 61 |
+
test_conversations = [
|
| 62 |
+
"Hi! I'm living in New York City. I love the energy here!",
|
| 63 |
+
"Actually, I just moved to Los Angeles last month. The weather is so much better here.",
|
| 64 |
+
"What do you remember about where I live?"
|
| 65 |
+
]
|
| 66 |
+
|
| 67 |
+
for i, message in enumerate(test_conversations, 1):
|
| 68 |
+
print(f"\n--- Conversation {i} ---")
|
| 69 |
+
print(f"User: {message}")
|
| 70 |
+
session_id = f"user_session_{i}"
|
| 71 |
+
|
| 72 |
+
response = await agent.run(
|
| 73 |
+
message,
|
| 74 |
+
session_id=session_id,
|
| 75 |
+
user_id=user_id
|
| 76 |
+
)
|
| 77 |
+
print(f"Assistant: {response}")
|
| 78 |
+
|
| 79 |
+
# Show current memories in the database with timestamps
|
| 80 |
+
print("\n=> Current User Memories:")
|
| 81 |
+
memories = await cross_session_manager.find_existing([], user_id)
|
| 82 |
+
if memories:
|
| 83 |
+
for mem in memories:
|
| 84 |
+
created = mem.get('created_at', 'Unknown')[:19] if mem.get('created_at') != 'Unknown' else 'Unknown'
|
| 85 |
+
updated = mem.get('updated_at', 'Unknown')[:19] if mem.get('updated_at') != 'Unknown' else 'Unknown'
|
| 86 |
+
print(f" - {mem['content']}")
|
| 87 |
+
if created != updated:
|
| 88 |
+
print(f" (Created: {created}, Updated: {updated})")
|
| 89 |
+
else:
|
| 90 |
+
print(f" (Created: {created})")
|
| 91 |
+
else:
|
| 92 |
+
print(" (No memories yet)")
|
| 93 |
+
|
| 94 |
+
# Small delay to see the progression
|
| 95 |
+
await asyncio.sleep(1)
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
asyncio.run(test_user_long_term_memory())
|
pyproject.toml
CHANGED
|
@@ -5,6 +5,7 @@ description = "Add your description here"
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.11"
|
| 7 |
dependencies = [
|
|
|
|
| 8 |
"fastmcp>=2.11.3",
|
| 9 |
"mcp>=1.13.1",
|
| 10 |
"openai>=1.101.0",
|
|
|
|
| 5 |
readme = "README.md"
|
| 6 |
requires-python = ">=3.11"
|
| 7 |
dependencies = [
|
| 8 |
+
"chromadb>=1.0.20",
|
| 9 |
"fastmcp>=2.11.3",
|
| 10 |
"mcp>=1.13.1",
|
| 11 |
"openai>=1.101.0",
|
scratch_agents/agents/execution_context_ch6.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from ..sessions.session import Session
|
| 3 |
+
from ..sessions.in_memory_session_manager import InMemorySessionManager
|
| 4 |
+
from ..sessions.base_session_manager import BaseSessionManager
|
| 5 |
+
from dataclasses import field
|
| 6 |
+
import uuid
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
from typing import List, Dict, Any
|
| 9 |
+
from ..types.events import Event
|
| 10 |
+
from ..sessions.base_cross_session_manager import BaseCrossSessionManager
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class ExecutionContext:
|
| 14 |
+
session: Session
|
| 15 |
+
session_manager: BaseSessionManager
|
| 16 |
+
cross_session_manager: BaseCrossSessionManager
|
| 17 |
+
execution_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 18 |
+
current_step: int = 0
|
| 19 |
+
max_steps: int = 10
|
| 20 |
+
user_input: str = ""
|
| 21 |
+
final_result: str | BaseModel = ""
|
| 22 |
+
|
| 23 |
+
def add_event(self, event: Event) -> None:
|
| 24 |
+
self.session_manager.add_event(self.session, event)
|
| 25 |
+
@property
|
| 26 |
+
def events(self) -> List[Event]:
|
| 27 |
+
return self.session.events
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def state(self) -> Dict[str, Any]:
|
| 31 |
+
return self.session.state
|
| 32 |
+
|
| 33 |
+
def increment_step(self) -> None:
|
| 34 |
+
self.current_step += 1
|
scratch_agents/agents/tool_calling_agent_ch4_base.py
CHANGED
|
@@ -38,7 +38,7 @@ class ToolCallingAgent:
|
|
| 38 |
tool_call_id=tool_call.tool_call_id,
|
| 39 |
name=tool_call.name,
|
| 40 |
status="success",
|
| 41 |
-
content=
|
| 42 |
)
|
| 43 |
)
|
| 44 |
except Exception as e:
|
|
@@ -47,7 +47,7 @@ class ToolCallingAgent:
|
|
| 47 |
tool_call_id=tool_call.tool_call_id,
|
| 48 |
name=tool_call.name,
|
| 49 |
status="error",
|
| 50 |
-
content=
|
| 51 |
)
|
| 52 |
)
|
| 53 |
return tool_results
|
|
|
|
| 38 |
tool_call_id=tool_call.tool_call_id,
|
| 39 |
name=tool_call.name,
|
| 40 |
status="success",
|
| 41 |
+
content=str(result_output),
|
| 42 |
)
|
| 43 |
)
|
| 44 |
except Exception as e:
|
|
|
|
| 47 |
tool_call_id=tool_call.tool_call_id,
|
| 48 |
name=tool_call.name,
|
| 49 |
status="error",
|
| 50 |
+
content=str(e),
|
| 51 |
)
|
| 52 |
)
|
| 53 |
return tool_results
|
scratch_agents/agents/tool_calling_agent_ch4_callback.py
CHANGED
|
@@ -107,7 +107,7 @@ class ToolCallingAgent:
|
|
| 107 |
tool_call_id=tool_call.tool_call_id,
|
| 108 |
name=tool_call.name,
|
| 109 |
status=status,
|
| 110 |
-
content=
|
| 111 |
)
|
| 112 |
tool_results.append(tool_result)
|
| 113 |
|
|
|
|
| 107 |
tool_call_id=tool_call.tool_call_id,
|
| 108 |
name=tool_call.name,
|
| 109 |
status=status,
|
| 110 |
+
content=str(tool_response),
|
| 111 |
)
|
| 112 |
tool_results.append(tool_result)
|
| 113 |
|
scratch_agents/agents/tool_calling_agent_ch4_structured_output.py
CHANGED
|
@@ -47,7 +47,7 @@ class ToolCallingAgent:
|
|
| 47 |
tool_call_id=tool_call.tool_call_id,
|
| 48 |
name=tool_call.name,
|
| 49 |
status="success",
|
| 50 |
-
content=
|
| 51 |
)
|
| 52 |
)
|
| 53 |
except Exception as e:
|
|
@@ -56,7 +56,7 @@ class ToolCallingAgent:
|
|
| 56 |
tool_call_id=tool_call.tool_call_id,
|
| 57 |
name=tool_call.name,
|
| 58 |
status="error",
|
| 59 |
-
content=
|
| 60 |
)
|
| 61 |
)
|
| 62 |
return tool_results
|
|
|
|
| 47 |
tool_call_id=tool_call.tool_call_id,
|
| 48 |
name=tool_call.name,
|
| 49 |
status="success",
|
| 50 |
+
content=str(result_output),
|
| 51 |
)
|
| 52 |
)
|
| 53 |
except Exception as e:
|
|
|
|
| 56 |
tool_call_id=tool_call.tool_call_id,
|
| 57 |
name=tool_call.name,
|
| 58 |
status="error",
|
| 59 |
+
content=str(e),
|
| 60 |
)
|
| 61 |
)
|
| 62 |
return tool_results
|
scratch_agents/agents/tool_calling_agent_ch6.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Any, Optional
|
| 2 |
+
from ..models.base_llm import BaseLlm
|
| 3 |
+
from ..models.llm_request import LlmRequest
|
| 4 |
+
from ..models.llm_response import LlmResponse
|
| 5 |
+
from ..types.contents import Message, ToolCall
|
| 6 |
+
from ..types.events import Event
|
| 7 |
+
from .execution_context_ch6 import ExecutionContext
|
| 8 |
+
from ..tools.base_tool import BaseTool
|
| 9 |
+
from ..types.contents import ToolResult
|
| 10 |
+
from typing import Type
|
| 11 |
+
from pydantic import BaseModel
|
| 12 |
+
from ..tools.decorator import tool
|
| 13 |
+
import inspect
|
| 14 |
+
from ..sessions.base_session_manager import BaseSessionManager
|
| 15 |
+
from ..sessions.in_memory_session_manager import InMemorySessionManager
|
| 16 |
+
from ..sessions.base_cross_session_manager import BaseCrossSessionManager
|
| 17 |
+
|
| 18 |
+
class ToolCallingAgent:
|
| 19 |
+
def __init__(self, name: str, model: BaseLlm,
|
| 20 |
+
tools: List[BaseTool] = [],
|
| 21 |
+
instructions: str = "",
|
| 22 |
+
max_steps: int = 10,
|
| 23 |
+
output_type: Optional[Type[BaseModel]] = None,
|
| 24 |
+
before_llm_callbacks = None,
|
| 25 |
+
after_llm_callbacks = None,
|
| 26 |
+
before_tool_callbacks = None,
|
| 27 |
+
after_tool_callbacks = None,
|
| 28 |
+
after_run_callbacks = None,
|
| 29 |
+
session_manager: BaseSessionManager = None,
|
| 30 |
+
cross_session_manager: BaseCrossSessionManager = None):
|
| 31 |
+
self.name = name
|
| 32 |
+
self.model = model
|
| 33 |
+
self.max_steps = max_steps
|
| 34 |
+
self.instructions = instructions
|
| 35 |
+
self.output_type: Optional[Type[BaseModel]] = output_type
|
| 36 |
+
self.output_tool: Optional[str] = None
|
| 37 |
+
self.tools = self._setup_tools(tools)
|
| 38 |
+
self.before_llm_callbacks = before_llm_callbacks or []
|
| 39 |
+
self.after_llm_callbacks = after_llm_callbacks or []
|
| 40 |
+
self.before_tool_callbacks = before_tool_callbacks or []
|
| 41 |
+
self.after_tool_callbacks = after_tool_callbacks or []
|
| 42 |
+
self.after_run_callbacks = after_run_callbacks or []
|
| 43 |
+
self.session_manager = session_manager or InMemorySessionManager()
|
| 44 |
+
self.cross_session_manager = cross_session_manager
|
| 45 |
+
|
| 46 |
+
def _setup_tools(self, tools: List[BaseTool]):
|
| 47 |
+
if self.output_type is not None:
|
| 48 |
+
@tool(name="final_answer", description="Return the final structured answer matching the required schema.")
|
| 49 |
+
def final_answer(output: self.output_type) -> self.output_type:
|
| 50 |
+
return output
|
| 51 |
+
tools.append(final_answer)
|
| 52 |
+
self.output_tool = final_answer.name
|
| 53 |
+
return {t.name: t for t in tools}
|
| 54 |
+
|
| 55 |
+
async def think(self, context: ExecutionContext, llm_request: LlmRequest):
|
| 56 |
+
for callback in self.before_llm_callbacks:
|
| 57 |
+
result = callback(context, llm_request)
|
| 58 |
+
if inspect.isawaitable(result):
|
| 59 |
+
result = await result
|
| 60 |
+
if result is not None:
|
| 61 |
+
return result
|
| 62 |
+
|
| 63 |
+
llm_response = await self.model.generate(llm_request)
|
| 64 |
+
|
| 65 |
+
for callback in self.after_llm_callbacks:
|
| 66 |
+
result = callback(context, llm_response)
|
| 67 |
+
if inspect.isawaitable(result):
|
| 68 |
+
result = await result
|
| 69 |
+
if result is not None:
|
| 70 |
+
return result
|
| 71 |
+
|
| 72 |
+
return llm_response
|
| 73 |
+
|
| 74 |
+
async def _execute_tool(self, context: ExecutionContext, tool_name: str, tool_input: dict) -> Any:
|
| 75 |
+
"""Execute a tool with context injection if needed"""
|
| 76 |
+
tool = self.tools[tool_name]
|
| 77 |
+
|
| 78 |
+
# All tools now handle context properly in their execute method
|
| 79 |
+
return await tool.execute(context, **tool_input)
|
| 80 |
+
|
| 81 |
+
async def act(self, context: ExecutionContext, tool_calls: List[ToolCall]):
|
| 82 |
+
tool_results = []
|
| 83 |
+
for tool_call in tool_calls:
|
| 84 |
+
tool_name = tool_call.name
|
| 85 |
+
tool_input = tool_call.arguments
|
| 86 |
+
print(f" → Calling {tool_name} with {tool_input}")
|
| 87 |
+
|
| 88 |
+
# Step 1: before_tool_callbacks - can skip tool execution
|
| 89 |
+
tool_response = None
|
| 90 |
+
for callback in self.before_tool_callbacks:
|
| 91 |
+
result = callback(context, tool_call)
|
| 92 |
+
if inspect.isawaitable(result):
|
| 93 |
+
result = await result
|
| 94 |
+
if result is not None:
|
| 95 |
+
tool_response = result
|
| 96 |
+
break
|
| 97 |
+
|
| 98 |
+
# Step 2: Execute tool if no callback provided result
|
| 99 |
+
status = "success"
|
| 100 |
+
if tool_response is None:
|
| 101 |
+
try:
|
| 102 |
+
tool_response = await self._execute_tool(context, tool_name, tool_input)
|
| 103 |
+
except Exception as e:
|
| 104 |
+
tool_response = str(e)
|
| 105 |
+
status = "error"
|
| 106 |
+
|
| 107 |
+
# Step 3: after_tool_callbacks - only after actual tool execution
|
| 108 |
+
for callback in self.after_tool_callbacks:
|
| 109 |
+
result = callback(context, tool_response)
|
| 110 |
+
if inspect.isawaitable(result):
|
| 111 |
+
result = await result
|
| 112 |
+
if result is not None:
|
| 113 |
+
tool_response = result
|
| 114 |
+
break
|
| 115 |
+
|
| 116 |
+
# Step 4: Wrap in ToolResult at the end
|
| 117 |
+
if tool_response is not None:
|
| 118 |
+
tool_result = ToolResult(
|
| 119 |
+
tool_call_id=tool_call.tool_call_id,
|
| 120 |
+
name=tool_call.name,
|
| 121 |
+
status=status,
|
| 122 |
+
content=str(tool_response),
|
| 123 |
+
)
|
| 124 |
+
tool_results.append(tool_result)
|
| 125 |
+
|
| 126 |
+
return tool_results
|
| 127 |
+
|
| 128 |
+
async def step(self, context: ExecutionContext):
|
| 129 |
+
print(f"[Step {context.current_step + 1}]")
|
| 130 |
+
llm_request = await self._prepare_llm_request(context)
|
| 131 |
+
llm_response = await self.think(context, llm_request)
|
| 132 |
+
if llm_response.error_message:
|
| 133 |
+
raise RuntimeError(f"LLM error: {llm_response.error_message}")
|
| 134 |
+
response_event = Event(
|
| 135 |
+
execution_id=context.execution_id,
|
| 136 |
+
author=self.name,
|
| 137 |
+
required_output_tool=self.output_tool or None,
|
| 138 |
+
**llm_response.model_dump(),
|
| 139 |
+
)
|
| 140 |
+
context.add_event(response_event)
|
| 141 |
+
|
| 142 |
+
if tool_calls := response_event.get_tool_calls():
|
| 143 |
+
tool_results = await self.act(context, tool_calls)
|
| 144 |
+
tool_results_event = Event(
|
| 145 |
+
execution_id=context.execution_id,
|
| 146 |
+
author=self.name,
|
| 147 |
+
required_output_tool=self.output_tool or None,
|
| 148 |
+
content=tool_results,
|
| 149 |
+
)
|
| 150 |
+
context.add_event(tool_results_event)
|
| 151 |
+
|
| 152 |
+
context.increment_step()
|
| 153 |
+
|
| 154 |
+
async def run(self, user_input: str,
|
| 155 |
+
user_id: str = None,
|
| 156 |
+
session_id: str = None):
|
| 157 |
+
session = self.session_manager.get_or_create_session(session_id, user_id)
|
| 158 |
+
context = ExecutionContext(
|
| 159 |
+
user_input=user_input,
|
| 160 |
+
session=session,
|
| 161 |
+
session_manager=self.session_manager,
|
| 162 |
+
cross_session_manager=self.cross_session_manager,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
user_input_event = Event(
|
| 166 |
+
execution_id=context.execution_id,
|
| 167 |
+
author="user",
|
| 168 |
+
content=[
|
| 169 |
+
Message(
|
| 170 |
+
role="user",
|
| 171 |
+
content=user_input,
|
| 172 |
+
)
|
| 173 |
+
],
|
| 174 |
+
)
|
| 175 |
+
context.add_event(user_input_event)
|
| 176 |
+
|
| 177 |
+
while not context.final_result and context.current_step < self.max_steps:
|
| 178 |
+
await self.step(context)
|
| 179 |
+
|
| 180 |
+
last_event = context.events[-1]
|
| 181 |
+
if last_event.is_final_response():
|
| 182 |
+
context.final_result = self._extract_final_result(last_event)
|
| 183 |
+
|
| 184 |
+
for callback in self.after_run_callbacks:
|
| 185 |
+
result = callback(context)
|
| 186 |
+
if inspect.isawaitable(result):
|
| 187 |
+
await result
|
| 188 |
+
|
| 189 |
+
return context.final_result
|
| 190 |
+
|
| 191 |
+
async def _prepare_llm_request(self, context: ExecutionContext):
|
| 192 |
+
flat_contents = []
|
| 193 |
+
for event in context.events:
|
| 194 |
+
flat_contents.extend(event.content)
|
| 195 |
+
|
| 196 |
+
llm_request = LlmRequest(
|
| 197 |
+
instructions=[self.instructions] if self.instructions else [],
|
| 198 |
+
contents=flat_contents,
|
| 199 |
+
tools_dict={tool.name:tool for tool in self.tools.values() if tool.tool_definition},
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
for tool in self.tools.values():
|
| 203 |
+
await tool.process_llm_request(llm_request, context)
|
| 204 |
+
|
| 205 |
+
if self.output_tool:
|
| 206 |
+
llm_request.tool_choice = "required"
|
| 207 |
+
elif llm_request.tools_dict:
|
| 208 |
+
llm_request.tool_choice = "auto"
|
| 209 |
+
else:
|
| 210 |
+
llm_request.tool_choice = None
|
| 211 |
+
|
| 212 |
+
return llm_request
|
| 213 |
+
|
| 214 |
+
def _extract_final_result(self, event: Event):
|
| 215 |
+
if event.required_output_tool:
|
| 216 |
+
for item in event.content:
|
| 217 |
+
if (
|
| 218 |
+
isinstance(item, ToolResult)
|
| 219 |
+
and item.status == "success"
|
| 220 |
+
and item.name == event.required_output_tool
|
| 221 |
+
and item.content
|
| 222 |
+
):
|
| 223 |
+
return item.content[0]
|
| 224 |
+
for item in event.content:
|
| 225 |
+
if isinstance(item, Message) and item.role == "assistant":
|
| 226 |
+
return item.content
|
scratch_agents/memory/base_memory_strategy.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
|
| 3 |
+
class MemoryStrategy(ABC):
|
| 4 |
+
"""Base class for memory management strategies"""
|
| 5 |
+
|
| 6 |
+
@abstractmethod
|
| 7 |
+
async def apply(self, context, llm_request): #A
|
| 8 |
+
"""Apply memory management strategy to the request"""
|
| 9 |
+
pass
|
| 10 |
+
|
| 11 |
+
async def __call__(self, context, llm_request): #B
|
| 12 |
+
"""Make strategy callable as a before_llm_callback"""
|
| 13 |
+
return await self.apply(context, llm_request)
|
scratch_agents/memory/core_memory_strategy.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_memory_strategy import MemoryStrategy
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class CoreMemoryStrategy(MemoryStrategy):
|
| 5 |
+
"""Automatically add core memory to LLM context"""
|
| 6 |
+
|
| 7 |
+
async def apply(self, context, llm_request):
|
| 8 |
+
"""Add core memory as instructions if it exists"""
|
| 9 |
+
core_memory = context.session.core_memory
|
| 10 |
+
|
| 11 |
+
memory_parts = []
|
| 12 |
+
if core_memory.get("agent"):
|
| 13 |
+
memory_parts.append(f"[Your Persona]\n{core_memory['agent']}")
|
| 14 |
+
if core_memory.get("user"):
|
| 15 |
+
memory_parts.append(f"[User Info]\n{core_memory['user']}")
|
| 16 |
+
|
| 17 |
+
if memory_parts:
|
| 18 |
+
memory_text = "\n\n".join(memory_parts)
|
| 19 |
+
llm_request.add_instructions([memory_text])
|
| 20 |
+
|
| 21 |
+
return None
|
scratch_agents/memory/sliding_window_strategy.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_memory_strategy import MemoryStrategy
|
| 2 |
+
from ..models.llm_request import LlmRequest
|
| 3 |
+
from ..agents.execution_context_ch6 import ExecutionContext
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SlidingWindowStrategy(MemoryStrategy):
|
| 7 |
+
"""Keep only the most recent N messages in context"""
|
| 8 |
+
|
| 9 |
+
def __init__(self, max_messages: int = 20):
|
| 10 |
+
self.max_messages = max_messages
|
| 11 |
+
|
| 12 |
+
async def apply(self, context: ExecutionContext, llm_request: LlmRequest):
|
| 13 |
+
"""Apply sliding window to conversation history"""
|
| 14 |
+
contents = llm_request.contents
|
| 15 |
+
|
| 16 |
+
if len(contents) <= self.max_messages:
|
| 17 |
+
return None
|
| 18 |
+
|
| 19 |
+
# Keep only recent messages
|
| 20 |
+
recent_contents = contents[-self.max_messages:]
|
| 21 |
+
llm_request.contents = recent_contents
|
| 22 |
+
|
| 23 |
+
print(f"Trimmed messages")
|
| 24 |
+
print(f"from {len(contents)} to {self.max_messages}")
|
| 25 |
+
|
| 26 |
+
return None
|
scratch_agents/memory/summarization_strategy.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_memory_strategy import MemoryStrategy
|
| 2 |
+
from ..models.llm_request import LlmRequest
|
| 3 |
+
from ..types.contents import Message
|
| 4 |
+
|
| 5 |
+
class SummarizationStrategy(MemoryStrategy):
|
| 6 |
+
"""Summarize old messages to preserve information while reducing tokens"""
|
| 7 |
+
|
| 8 |
+
def __init__(self, model, trigger_count: int = 10, keep_recent: int = 3):
|
| 9 |
+
self.model = model
|
| 10 |
+
self.trigger_count = trigger_count #A
|
| 11 |
+
self.keep_recent = keep_recent #B
|
| 12 |
+
|
| 13 |
+
async def _generate_summary(self, messages_text: str):
|
| 14 |
+
request = LlmRequest(
|
| 15 |
+
instructions=[ #A
|
| 16 |
+
"Summarize the following conversation concisely.", #A
|
| 17 |
+
"Preserve key facts, decisions, and important context.", #A
|
| 18 |
+
"Keep the summary under 200 words." #A
|
| 19 |
+
],
|
| 20 |
+
contents=[Message(role="user", content=messages_text)] #B
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
response = await self.model.generate(request) #C
|
| 24 |
+
|
| 25 |
+
for item in response.content: #D
|
| 26 |
+
if isinstance(item, Message) and item.role == "assistant": #D
|
| 27 |
+
return item.content #D
|
| 28 |
+
|
| 29 |
+
return "Summary generation failed" #E
|
| 30 |
+
|
| 31 |
+
async def apply(self, context, llm_request):
|
| 32 |
+
"""Apply summarization when new messages since last summary exceed threshold"""
|
| 33 |
+
contents = llm_request.contents
|
| 34 |
+
|
| 35 |
+
messages_only = [item for item in contents if isinstance(item, Message)] #A
|
| 36 |
+
last_summarized = context.state.get("last_summarized_index", 0)
|
| 37 |
+
|
| 38 |
+
total_messages = len(messages_only) #B
|
| 39 |
+
new_messages_count = total_messages - last_summarized #B
|
| 40 |
+
|
| 41 |
+
if new_messages_count < self.trigger_count:
|
| 42 |
+
return None
|
| 43 |
+
|
| 44 |
+
summarize_until = total_messages - self.keep_recent #C
|
| 45 |
+
to_summarize = messages_only[last_summarized:summarize_until] #C
|
| 46 |
+
to_keep = contents[-self.keep_recent:] if len(contents) >= self.keep_recent else contents #C
|
| 47 |
+
|
| 48 |
+
if not to_summarize:
|
| 49 |
+
return None
|
| 50 |
+
|
| 51 |
+
existing_summary = context.state.get("conversation_summary")
|
| 52 |
+
|
| 53 |
+
summary_input = [] #D
|
| 54 |
+
if existing_summary: #D
|
| 55 |
+
summary_input.append(f"Previous Summary:\n{existing_summary}\n") #D
|
| 56 |
+
|
| 57 |
+
summary_input.append("New Messages to Summarize:\n") #D
|
| 58 |
+
for msg in to_summarize: #D
|
| 59 |
+
summary_input.append(f"{msg.role}: {msg.content}") #D
|
| 60 |
+
|
| 61 |
+
messages_text = "\n".join(summary_input) #D
|
| 62 |
+
|
| 63 |
+
new_summary = await self._generate_summary(messages_text) #E
|
| 64 |
+
|
| 65 |
+
context.state["conversation_summary"] = new_summary
|
| 66 |
+
context.state["last_summarized_index"] = summarize_until
|
| 67 |
+
|
| 68 |
+
if new_summary:
|
| 69 |
+
summary_instruction = f"[Previous Conversation Summary]\n{new_summary}"
|
| 70 |
+
llm_request.add_instructions([summary_instruction]) #F
|
| 71 |
+
|
| 72 |
+
llm_request.contents = to_keep #G
|
| 73 |
+
|
| 74 |
+
print(f"Compressed {len(to_summarize)} messages")
|
| 75 |
+
print(f"Keeping {len(to_keep)} recent items")
|
| 76 |
+
|
| 77 |
+
return None
|
scratch_agents/models/openai.py
CHANGED
|
@@ -4,7 +4,8 @@ from .llm_request import LlmRequest
|
|
| 4 |
from .llm_response import LlmResponse
|
| 5 |
from ..types.contents import Message, ToolCall, ToolResult
|
| 6 |
import json
|
| 7 |
-
from pydantic import Field
|
|
|
|
| 8 |
|
| 9 |
class OpenAILlm(BaseLlm):
|
| 10 |
"""OpenAI LLM implementation"""
|
|
@@ -136,7 +137,7 @@ class OpenAILlm(BaseLlm):
|
|
| 136 |
messages.append({
|
| 137 |
"role": "tool",
|
| 138 |
"tool_call_id": item.tool_call_id,
|
| 139 |
-
"content": str(item.content
|
| 140 |
})
|
| 141 |
|
| 142 |
# Flush any remaining assistant message
|
|
@@ -145,4 +146,29 @@ class OpenAILlm(BaseLlm):
|
|
| 145 |
# Extract model parameters
|
| 146 |
model_params = {**self.llm_config}
|
| 147 |
|
| 148 |
-
return messages, model_params
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from .llm_response import LlmResponse
|
| 5 |
from ..types.contents import Message, ToolCall, ToolResult
|
| 6 |
import json
|
| 7 |
+
from pydantic import Field, BaseModel
|
| 8 |
+
from typing import Dict, Any, List
|
| 9 |
|
| 10 |
class OpenAILlm(BaseLlm):
|
| 11 |
"""OpenAI LLM implementation"""
|
|
|
|
| 137 |
messages.append({
|
| 138 |
"role": "tool",
|
| 139 |
"tool_call_id": item.tool_call_id,
|
| 140 |
+
"content": str(item.content) if item.content else ""
|
| 141 |
})
|
| 142 |
|
| 143 |
# Flush any remaining assistant message
|
|
|
|
| 146 |
# Extract model parameters
|
| 147 |
model_params = {**self.llm_config}
|
| 148 |
|
| 149 |
+
return messages, model_params
|
| 150 |
+
|
| 151 |
+
async def generate_structured(self, messages: List[Dict[str, Any]], response_format: BaseModel):
|
| 152 |
+
"""Generate structured output using OpenAI's response_format"""
|
| 153 |
+
try:
|
| 154 |
+
response = await self.openai_client.chat.completions.parse(
|
| 155 |
+
model=self.model,
|
| 156 |
+
messages=messages,
|
| 157 |
+
response_format=response_format,
|
| 158 |
+
**self.llm_config
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
return response.choices[0].message.parsed
|
| 162 |
+
except Exception as e:
|
| 163 |
+
return {"error": str(e)}
|
| 164 |
+
|
| 165 |
+
async def embed(self, model, texts: List[str]) -> List[List[float]]:
|
| 166 |
+
"""Get embeddings using OpenAI API"""
|
| 167 |
+
try:
|
| 168 |
+
response = await self.openai_client.embeddings.create(
|
| 169 |
+
model=model,
|
| 170 |
+
input=texts
|
| 171 |
+
)
|
| 172 |
+
return [embedding.embedding for embedding in response.data]
|
| 173 |
+
except Exception as e:
|
| 174 |
+
return {"error": str(e)}
|
scratch_agents/sessions/base_cross_session_manager.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Base class for cross-session memory management."""
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import List, Dict, Optional, Any
|
| 5 |
+
import chromadb
|
| 6 |
+
from chromadb.utils import embedding_functions
|
| 7 |
+
from chromadb.config import Settings
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import uuid
|
| 12 |
+
from .session import Session
|
| 13 |
+
|
| 14 |
+
logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BaseCrossSessionManager(ABC):
|
| 18 |
+
"""Abstract base class for cross-session memory management."""
|
| 19 |
+
|
| 20 |
+
def __init__(
|
| 21 |
+
self,
|
| 22 |
+
model,
|
| 23 |
+
collection_name: str,
|
| 24 |
+
persist_directory: str = "./cross_session_db",
|
| 25 |
+
embedding_model: str = "text-embedding-3-small"
|
| 26 |
+
):
|
| 27 |
+
"""Initialize the base cross-session manager.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
model: LLM model for memory processing
|
| 31 |
+
collection_name: Name of the ChromaDB collection
|
| 32 |
+
persist_directory: Directory to persist ChromaDB data
|
| 33 |
+
embedding_model: Optional custom embedding model
|
| 34 |
+
"""
|
| 35 |
+
self.model = model
|
| 36 |
+
self.collection_name = collection_name
|
| 37 |
+
self.persist_directory = persist_directory
|
| 38 |
+
self.embedding_model = embedding_model
|
| 39 |
+
|
| 40 |
+
self.client = chromadb.PersistentClient(
|
| 41 |
+
path=persist_directory,
|
| 42 |
+
)
|
| 43 |
+
embedding_function = embedding_functions.OpenAIEmbeddingFunction(
|
| 44 |
+
api_key=os.getenv("OPENAI_API_KEY"),
|
| 45 |
+
model_name=self.embedding_model
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Get or create collection
|
| 49 |
+
try:
|
| 50 |
+
self.collection = self.client.get_or_create_collection(
|
| 51 |
+
name=collection_name,
|
| 52 |
+
metadata={"hnsw:space": "cosine"},
|
| 53 |
+
embedding_function=embedding_function
|
| 54 |
+
)
|
| 55 |
+
logger.info(f"Using existing collection: {collection_name}")
|
| 56 |
+
except Exception:
|
| 57 |
+
logger.error(f"Error getting or creating collection: {collection_name}")
|
| 58 |
+
raise
|
| 59 |
+
|
| 60 |
+
@abstractmethod
|
| 61 |
+
async def extract_memories(
|
| 62 |
+
self,
|
| 63 |
+
events: List[Dict[str, Any]],
|
| 64 |
+
) -> List[str]:
|
| 65 |
+
"""Extract memories from session events.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
events: List of session events
|
| 69 |
+
user_id: User identifier
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
List of extracted memory strings
|
| 73 |
+
"""
|
| 74 |
+
pass
|
| 75 |
+
|
| 76 |
+
async def process_session(
|
| 77 |
+
self,
|
| 78 |
+
session: Session,
|
| 79 |
+
execution_id: str
|
| 80 |
+
) -> None:
|
| 81 |
+
"""Process a completed session and extract/merge memories.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
session: Session data containing events
|
| 85 |
+
execution_id: Unique execution identifier
|
| 86 |
+
"""
|
| 87 |
+
try:
|
| 88 |
+
user_id = session.user_id
|
| 89 |
+
events = session.events
|
| 90 |
+
|
| 91 |
+
events = [event for event in events if event.execution_id == execution_id]
|
| 92 |
+
|
| 93 |
+
memories = await self.extract_memories(events)
|
| 94 |
+
|
| 95 |
+
if memories:
|
| 96 |
+
existing = await self.find_existing(memories, user_id)
|
| 97 |
+
actions = await self.decide_actions(memories, existing, user_id)
|
| 98 |
+
await self.execute_memory_actions(actions)
|
| 99 |
+
else:
|
| 100 |
+
logger.info(f"No memories extracted for user {user_id}")
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
logger.error(f"Error processing session: {e}")
|
| 104 |
+
|
| 105 |
+
async def find_existing(
|
| 106 |
+
self,
|
| 107 |
+
memories: List[str],
|
| 108 |
+
user_id: str
|
| 109 |
+
) -> List[Dict[str, Any]]:
|
| 110 |
+
"""Find existing memories.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
memories: List of new memory strings to merge
|
| 114 |
+
user_id: User identifier
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
List of existing memories with metadata
|
| 118 |
+
"""
|
| 119 |
+
existing_memories = []
|
| 120 |
+
for memory in memories:
|
| 121 |
+
existing = await self.search(memory, user_id)
|
| 122 |
+
if existing:
|
| 123 |
+
existing_memories.append(existing)
|
| 124 |
+
return existing_memories
|
| 125 |
+
|
| 126 |
+
@abstractmethod
|
| 127 |
+
async def decide_actions(
|
| 128 |
+
self,
|
| 129 |
+
memories: List[str],
|
| 130 |
+
existing: List[Dict[str, Any]],
|
| 131 |
+
user_id: str
|
| 132 |
+
) -> List[Dict[str, Any]]:
|
| 133 |
+
"""Decide actions for new memories."""
|
| 134 |
+
pass
|
| 135 |
+
|
| 136 |
+
async def execute_memory_actions(
|
| 137 |
+
self,
|
| 138 |
+
actions: List[Dict[str, Any]]
|
| 139 |
+
) -> None:
|
| 140 |
+
"""Execute memory actions."""
|
| 141 |
+
for action in actions:
|
| 142 |
+
if action["action"] == "ADD":
|
| 143 |
+
metadata = action.get("metadata", {})
|
| 144 |
+
await self.add(action["memory"], action["user_id"], action.get("embedding"), metadata)
|
| 145 |
+
elif action["action"] == "UPDATE":
|
| 146 |
+
metadata = action.get("metadata", {})
|
| 147 |
+
await self.update(action["memory_id"], action["memory"], action.get("embedding"), metadata)
|
| 148 |
+
elif action["action"] == "DELETE":
|
| 149 |
+
await self.delete(action["memory_id"])
|
| 150 |
+
elif action["action"] == "NOOP":
|
| 151 |
+
pass
|
| 152 |
+
|
| 153 |
+
async def search(
|
| 154 |
+
self,
|
| 155 |
+
query: str,
|
| 156 |
+
user_id: str,
|
| 157 |
+
limit: int = 5
|
| 158 |
+
) -> List[Dict[str, Any]]:
|
| 159 |
+
"""Search for relevant memories.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
query: Search query
|
| 163 |
+
user_id: User identifier
|
| 164 |
+
limit: Maximum number of results
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
List of relevant memories with metadata
|
| 168 |
+
"""
|
| 169 |
+
try:
|
| 170 |
+
# Filter by user_id in metadata
|
| 171 |
+
where = {"user_id": user_id}
|
| 172 |
+
|
| 173 |
+
results = self.collection.query(
|
| 174 |
+
query_texts=[query],
|
| 175 |
+
n_results=limit,
|
| 176 |
+
where=where
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
memories = []
|
| 180 |
+
if results["documents"] and results["documents"][0]:
|
| 181 |
+
for i, doc in enumerate(results["documents"][0]):
|
| 182 |
+
memory = {
|
| 183 |
+
"id": results["ids"][0][i] if results["ids"] and results["ids"][0] else None,
|
| 184 |
+
"content": doc,
|
| 185 |
+
"metadata": results["metadatas"][0][i] if results["metadatas"] else {},
|
| 186 |
+
"distance": results["distances"][0][i] if results["distances"] else 0
|
| 187 |
+
}
|
| 188 |
+
memories.append(memory)
|
| 189 |
+
|
| 190 |
+
return memories
|
| 191 |
+
|
| 192 |
+
except Exception as e:
|
| 193 |
+
logger.error(f"Error searching memories: {e}")
|
| 194 |
+
return []
|
| 195 |
+
|
| 196 |
+
async def add(
|
| 197 |
+
self,
|
| 198 |
+
memory: str,
|
| 199 |
+
user_id: str,
|
| 200 |
+
embedding: Optional[List[float]] = None,
|
| 201 |
+
additional_metadata: Optional[Dict[str, Any]] = None,
|
| 202 |
+
) -> str:
|
| 203 |
+
"""Add a new memory.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
memory: Memory content (as string for ChromaDB)
|
| 207 |
+
user_id: User identifier
|
| 208 |
+
embedding: Optional embedding vector
|
| 209 |
+
additional_metadata: Additional metadata to store
|
| 210 |
+
|
| 211 |
+
Returns:
|
| 212 |
+
Memory ID
|
| 213 |
+
"""
|
| 214 |
+
memory_id = f"{uuid.uuid4()}"
|
| 215 |
+
|
| 216 |
+
final_metadata = {
|
| 217 |
+
"user_id": user_id,
|
| 218 |
+
"created_at": datetime.now().isoformat(),
|
| 219 |
+
"updated_at": datetime.now().isoformat()
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
# Add any additional metadata (like the original structured data)
|
| 223 |
+
if additional_metadata:
|
| 224 |
+
final_metadata.update(additional_metadata)
|
| 225 |
+
|
| 226 |
+
if embedding:
|
| 227 |
+
self.collection.upsert(
|
| 228 |
+
documents=[memory],
|
| 229 |
+
ids=[memory_id],
|
| 230 |
+
embeddings=[embedding],
|
| 231 |
+
metadatas=[final_metadata]
|
| 232 |
+
)
|
| 233 |
+
else:
|
| 234 |
+
self.collection.add(
|
| 235 |
+
documents=[memory],
|
| 236 |
+
ids=[memory_id],
|
| 237 |
+
metadatas=[final_metadata]
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
return memory_id
|
| 241 |
+
|
| 242 |
+
async def update(
|
| 243 |
+
self,
|
| 244 |
+
memory_id: str,
|
| 245 |
+
memory: str,
|
| 246 |
+
embedding: Optional[List[float]] = None,
|
| 247 |
+
additional_metadata: Optional[Dict[str, Any]] = None,
|
| 248 |
+
) -> None:
|
| 249 |
+
"""Update an existing memory.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
memory_id: ID of memory to update
|
| 253 |
+
memory: New memory content (as string for ChromaDB)
|
| 254 |
+
embedding: Optional embedding of the memory
|
| 255 |
+
additional_metadata: Additional metadata to update
|
| 256 |
+
"""
|
| 257 |
+
if not memory_id:
|
| 258 |
+
logger.error("Cannot update memory: memory_id is None")
|
| 259 |
+
return
|
| 260 |
+
|
| 261 |
+
# Get existing metadata
|
| 262 |
+
existing = self.collection.get(ids=[memory_id])
|
| 263 |
+
if existing["metadatas"] and existing["metadatas"][0]:
|
| 264 |
+
final_metadata = existing["metadatas"][0].copy()
|
| 265 |
+
final_metadata["updated_at"] = datetime.now().isoformat()
|
| 266 |
+
else:
|
| 267 |
+
final_metadata = {}
|
| 268 |
+
final_metadata["updated_at"] = datetime.now().isoformat()
|
| 269 |
+
|
| 270 |
+
# Update with any additional metadata
|
| 271 |
+
if additional_metadata:
|
| 272 |
+
final_metadata.update(additional_metadata)
|
| 273 |
+
|
| 274 |
+
if embedding:
|
| 275 |
+
self.collection.upsert(
|
| 276 |
+
ids=[memory_id],
|
| 277 |
+
documents=[memory],
|
| 278 |
+
embeddings=[embedding],
|
| 279 |
+
metadatas=[final_metadata]
|
| 280 |
+
)
|
| 281 |
+
else:
|
| 282 |
+
self.collection.upsert(
|
| 283 |
+
ids=[memory_id],
|
| 284 |
+
documents=[memory],
|
| 285 |
+
metadatas=[final_metadata]
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
async def delete(
|
| 289 |
+
self,
|
| 290 |
+
memory_id: str
|
| 291 |
+
) -> None:
|
| 292 |
+
"""Delete a memory.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
memory_id: ID of memory to delete
|
| 296 |
+
"""
|
| 297 |
+
self.collection.delete(ids=[memory_id])
|
scratch_agents/sessions/base_session_manager.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Optional, Dict, List
|
| 3 |
+
from scratch_agents.types.events import Event
|
| 4 |
+
from scratch_agents.sessions.session import Session
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BaseSessionManager(ABC):
|
| 8 |
+
"""Abstract base class for session management"""
|
| 9 |
+
|
| 10 |
+
@abstractmethod
|
| 11 |
+
def create_session(self, session_id: Optional[str] = None, user_id: str = None) -> Session:
|
| 12 |
+
"""Create a new session"""
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def get_session(self, session_id: str) -> Optional[Session]:
|
| 17 |
+
"""Load a session from storage"""
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
@abstractmethod
|
| 21 |
+
def get_or_create_session(self, session_id: str, user_id: str = None) -> Session:
|
| 22 |
+
"""Get an existing session or create a new one"""
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def add_event(self, session: Session, event: Event) -> None:
|
| 27 |
+
"""Add an event to the session"""
|
| 28 |
+
pass
|
scratch_agents/sessions/in_memory_session_manager.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base_session_manager import BaseSessionManager
|
| 2 |
+
from .session import Session
|
| 3 |
+
from scratch_agents.types.events import Event
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
class InMemorySessionManager(BaseSessionManager):
|
| 7 |
+
"""In-memory session manager"""
|
| 8 |
+
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.sessions = {}
|
| 11 |
+
|
| 12 |
+
def create_session(self, session_id: str, user_id: str = None) -> Session:
|
| 13 |
+
if session_id in self.sessions:
|
| 14 |
+
raise ValueError(f"Session with id {session_id} already exists")
|
| 15 |
+
self.sessions[session_id] = Session(session_id=session_id, user_id=user_id)
|
| 16 |
+
return self.sessions[session_id]
|
| 17 |
+
|
| 18 |
+
def get_session(self, session_id: str) -> Session:
|
| 19 |
+
if session_id not in self.sessions:
|
| 20 |
+
raise ValueError(f"Session with id {session_id} does not exist")
|
| 21 |
+
return self.sessions[session_id]
|
| 22 |
+
|
| 23 |
+
def get_or_create_session(self, session_id: str, user_id: str = None) -> Session:
|
| 24 |
+
if session_id not in self.sessions:
|
| 25 |
+
return self.create_session(session_id, user_id)
|
| 26 |
+
return self.sessions[session_id]
|
| 27 |
+
|
| 28 |
+
def add_event(self, session: Session, event: Event) -> None:
|
| 29 |
+
session.events.append(event)
|
| 30 |
+
session.last_updated_at = datetime.now()
|
scratch_agents/sessions/session.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
from typing import List, Dict, Any
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from ..types.contents import ContentItem
|
| 6 |
+
|
| 7 |
+
class Session(BaseModel):
|
| 8 |
+
"""Container for short-term memory during a conversation session"""
|
| 9 |
+
user_id: str
|
| 10 |
+
session_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
| 11 |
+
events: List[ContentItem] = Field(default_factory=list)
|
| 12 |
+
state: Dict[str, Any] = Field(default_factory=dict)
|
| 13 |
+
last_updated_at: datetime = Field(default_factory=datetime.now)
|
| 14 |
+
|
| 15 |
+
@property
|
| 16 |
+
def core_memory(self) -> Dict[str, str]:
|
| 17 |
+
"""Access core memory with automatic initialization"""
|
| 18 |
+
if "core_memory" not in self.state:
|
| 19 |
+
self.state["core_memory"] = {
|
| 20 |
+
"persona": "You are a helpful AI assistant",
|
| 21 |
+
"human": ""
|
| 22 |
+
}
|
| 23 |
+
return self.state["core_memory"]
|
scratch_agents/sessions/task_cross_session_manager.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task-specific cross-session memory management."""
|
| 2 |
+
|
| 3 |
+
from typing import List, Dict, Any, Optional, Literal
|
| 4 |
+
import logging
|
| 5 |
+
from pydantic import BaseModel, Field
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
from .base_cross_session_manager import BaseCrossSessionManager
|
| 9 |
+
from ..types.events import Event
|
| 10 |
+
from ..types.contents import Message, ToolCall, ToolResult
|
| 11 |
+
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
MEMORY_EXTRACT_PROMPT = """
|
| 15 |
+
You are a Task Memory Extractor specializing in tracking agent actions and problem-solving attempts.
|
| 16 |
+
Extract ONLY information about what the agent ACTUALLY DID in this conversation.
|
| 17 |
+
|
| 18 |
+
Focus on:
|
| 19 |
+
1. **Problem Identification**: What issue or challenge was the agent trying to address?
|
| 20 |
+
2. **Actions Taken**: What specific actions did the agent perform? (tools used, searches made, code written, etc.)
|
| 21 |
+
3. **Key Discoveries**: What important facts or information did the agent discover during the process?
|
| 22 |
+
4. **Success Status**: Was the task completed successfully?
|
| 23 |
+
|
| 24 |
+
DO NOT extract:
|
| 25 |
+
- Personal user information (name, preferences, etc.)
|
| 26 |
+
- General conversation or greetings
|
| 27 |
+
- User opinions or feelings
|
| 28 |
+
- Future plans or what should be done
|
| 29 |
+
|
| 30 |
+
Format each task as a structured memory with:
|
| 31 |
+
- problem: Clear description of what the agent was asked to do or investigate
|
| 32 |
+
- actions_taken: Specific actions the agent performed (not what it should do)
|
| 33 |
+
- key_discoveries: Important information discovered during the task
|
| 34 |
+
- success: true/false indicating if the task was completed
|
| 35 |
+
|
| 36 |
+
Examples of GOOD task memories:
|
| 37 |
+
{
|
| 38 |
+
"problem": "User asked about React component not rendering",
|
| 39 |
+
"actions_taken": "Examined useEffect hook, identified missing dependency in array, added state variable to dependency array",
|
| 40 |
+
"key_discoveries": "useEffect was missing 'count' state variable in dependency array causing stale closure",
|
| 41 |
+
"success": true
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
{
|
| 45 |
+
"problem": "User reported database connection timeouts in production",
|
| 46 |
+
"actions_taken": "Checked connection pool configuration, analyzed production logs, increased pool size from 10 to 50, implemented retry logic with exponential backoff",
|
| 47 |
+
"key_discoveries": "Production load peaked at 45 concurrent connections, default pool size was only 10",
|
| 48 |
+
"success": true
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
{
|
| 52 |
+
"problem": "User asked 'What is Mem0 and how does it work?'",
|
| 53 |
+
"actions_taken": "Performed multiple web searches with different query variations to find information about Mem0",
|
| 54 |
+
"key_discoveries": "Found that Mem0 is an open-source memory layer for LLM applications, has a GitHub repo (mem0ai/mem0), provides hybrid data storage and intelligent retrieval",
|
| 55 |
+
"success": false
|
| 56 |
+
}
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
MEMORY_ACTION_PROMPT = """
|
| 60 |
+
You are a Task Memory Action Decider specializing in tracking agent actions and problem-solving attempts.
|
| 61 |
+
You are given a list of new task memories and a list of existing task memories.
|
| 62 |
+
You need to decide whether to ADD, UPDATE, DELETE, or NOOP the new task memories.
|
| 63 |
+
|
| 64 |
+
Format your response as a list of actions with:
|
| 65 |
+
- action: ADD, UPDATE, DELETE, or NOOP
|
| 66 |
+
- memory_id: The id of the memory to update or delete
|
| 67 |
+
|
| 68 |
+
Action:
|
| 69 |
+
- ADD: Add the new task memory if it describes a different problem or significantly different approach
|
| 70 |
+
- UPDATE: Update the existing task memory if it's the same problem but with better/more complete actions or discoveries
|
| 71 |
+
- DELETE: Delete the existing task memory if it's outdated or no longer relevant
|
| 72 |
+
- NOOP: Do not add if it's essentially the same problem with similar actions and discoveries
|
| 73 |
+
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class TaskMemory(BaseModel):
|
| 78 |
+
"""Structured task memory."""
|
| 79 |
+
problem: str = Field(description="The problem or task the agent was asked to address")
|
| 80 |
+
actions_taken: str = Field(description="The specific actions the agent performed")
|
| 81 |
+
success: bool = Field(description="Whether the task was completed successfully")
|
| 82 |
+
key_discoveries: Optional[str] = Field(default=None, description="Important information discovered during the task")
|
| 83 |
+
|
| 84 |
+
class MemoryAction(BaseModel):
|
| 85 |
+
"""Memory action."""
|
| 86 |
+
action: Literal["ADD", "UPDATE", "DELETE", "NOOP"] = Field(description="The action to take with the memory")
|
| 87 |
+
memory_id: Optional[str] = Field(description="The id of the memory to update or delete")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class TaskCrossSessionManager(BaseCrossSessionManager):
|
| 91 |
+
"""Manage task-specific memories across sessions."""
|
| 92 |
+
|
| 93 |
+
def __init__(self, model,
|
| 94 |
+
collection_name="task_memories",
|
| 95 |
+
persist_directory="./cross_session_db",
|
| 96 |
+
):
|
| 97 |
+
"""Initialize task cross-session manager.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
model: LLM model for memory extraction
|
| 101 |
+
collection_name: Name of the ChromaDB collection
|
| 102 |
+
persist_directory: Directory to persist ChromaDB data
|
| 103 |
+
"""
|
| 104 |
+
super().__init__(model, collection_name, persist_directory)
|
| 105 |
+
|
| 106 |
+
async def extract_memories(self, events: List[Event]):
|
| 107 |
+
conversation_parts = []
|
| 108 |
+
|
| 109 |
+
for event in events:
|
| 110 |
+
for item in event.content:
|
| 111 |
+
if isinstance(item, Message):
|
| 112 |
+
conversation_parts.append(f"{item.role}: {item.content}")
|
| 113 |
+
elif isinstance(item, ToolCall):
|
| 114 |
+
conversation_parts.append(f"{item.tool_call_id}: {item.name}")
|
| 115 |
+
elif isinstance(item, ToolResult):
|
| 116 |
+
conversation_parts.append(f"{item.tool_call_id}: {item.name} {item.content}")
|
| 117 |
+
|
| 118 |
+
conversation = "\n".join(conversation_parts)
|
| 119 |
+
|
| 120 |
+
user_prompt = f"""Conversation:
|
| 121 |
+
{conversation}
|
| 122 |
+
"""
|
| 123 |
+
messages = [
|
| 124 |
+
{"role": "system", "content": MEMORY_EXTRACT_PROMPT},
|
| 125 |
+
{"role": "user", "content": user_prompt}
|
| 126 |
+
]
|
| 127 |
+
|
| 128 |
+
try:
|
| 129 |
+
response = await self.model.generate_structured(messages, TaskMemory)
|
| 130 |
+
task_memory = TaskMemory.model_validate(response)
|
| 131 |
+
return [task_memory.model_dump()]
|
| 132 |
+
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.error(f"Error extracting task memories: {e}")
|
| 135 |
+
return []
|
| 136 |
+
|
| 137 |
+
async def find_existing(self, memories: List[Dict], user_id: str) -> List[Dict[str, Any]]:
|
| 138 |
+
existing_memories = []
|
| 139 |
+
for memory in memories:
|
| 140 |
+
query = memory["problem"]
|
| 141 |
+
results = await self.search(query, user_id)
|
| 142 |
+
if results:
|
| 143 |
+
existing_memories.append(results[0])
|
| 144 |
+
return existing_memories
|
| 145 |
+
|
| 146 |
+
async def decide_actions(self, new_memory: List[Dict], existing: List[Dict[str, Any]], user_id: str) -> List[Dict[str, Any]]:
|
| 147 |
+
system_prompt = MEMORY_ACTION_PROMPT
|
| 148 |
+
user_prompt = f"""
|
| 149 |
+
Existing memory: {existing}
|
| 150 |
+
New memory: {new_memory}
|
| 151 |
+
"""
|
| 152 |
+
messages = [
|
| 153 |
+
{"role": "system", "content": system_prompt},
|
| 154 |
+
{"role": "user", "content": user_prompt}
|
| 155 |
+
]
|
| 156 |
+
action = await self.model.generate_structured(messages, MemoryAction)
|
| 157 |
+
result = []
|
| 158 |
+
if action.action == "UPDATE":
|
| 159 |
+
|
| 160 |
+
memory_id = action.memory_id
|
| 161 |
+
if not memory_id:
|
| 162 |
+
logger.error("Cannot update memory: no memory_id available")
|
| 163 |
+
return []
|
| 164 |
+
embeddings = await self.model.embed(self.embedding_model, [new_memory[0]["problem"]])
|
| 165 |
+
# Convert dict to string for ChromaDB document field
|
| 166 |
+
memory_str = json.dumps(new_memory[0], ensure_ascii=False)
|
| 167 |
+
result.append({
|
| 168 |
+
"action": "UPDATE",
|
| 169 |
+
"memory_id": memory_id,
|
| 170 |
+
"memory": memory_str,
|
| 171 |
+
"embedding": embeddings[0],
|
| 172 |
+
"metadata": new_memory[0] # Store original dict in metadata
|
| 173 |
+
})
|
| 174 |
+
elif action.action == "ADD":
|
| 175 |
+
embeddings = await self.model.embed(self.embedding_model, [new_memory[0]["problem"]])
|
| 176 |
+
# Convert dict to string for ChromaDB document field
|
| 177 |
+
memory_str = json.dumps(new_memory[0], ensure_ascii=False)
|
| 178 |
+
result.append({
|
| 179 |
+
"action": "ADD",
|
| 180 |
+
"memory": memory_str,
|
| 181 |
+
"user_id": user_id,
|
| 182 |
+
"embedding": embeddings[0],
|
| 183 |
+
"metadata": new_memory[0] # Store original dict in metadata
|
| 184 |
+
})
|
| 185 |
+
elif action.action == "DELETE":
|
| 186 |
+
result.append({
|
| 187 |
+
"action": "DELETE",
|
| 188 |
+
"memory_id": action.memory_id
|
| 189 |
+
})
|
| 190 |
+
elif action.action == "NOOP":
|
| 191 |
+
result.append({
|
| 192 |
+
"action": "NOOP"
|
| 193 |
+
})
|
| 194 |
+
return result
|
scratch_agents/sessions/user_cross_session_manager.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import chromadb
|
| 3 |
+
from chromadb.utils import embedding_functions
|
| 4 |
+
from typing import List, Optional, Literal, Dict, Any
|
| 5 |
+
from enum import Enum
|
| 6 |
+
import uuid
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
import os
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
from .session import Session
|
| 13 |
+
from .base_cross_session_manager import BaseCrossSessionManager
|
| 14 |
+
from ..types.contents import Message
|
| 15 |
+
from ..types.events import Event
|
| 16 |
+
from ..models.llm_request import LlmRequest
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
MEMORY_EXTRACT_PROMPT = """
|
| 21 |
+
You are a User Memory Extractor specializing in accurately storing ONLY facts about the USER from their messages.
|
| 22 |
+
|
| 23 |
+
CRITICAL RULES:
|
| 24 |
+
1. ONLY extract factual information that the user explicitly states about themselves
|
| 25 |
+
2. NEVER extract questions the user asks
|
| 26 |
+
3. NEVER extract hypothetical scenarios or wishes
|
| 27 |
+
4. NEVER create memories from assistant responses
|
| 28 |
+
5. If the user is only asking questions, return an empty list
|
| 29 |
+
|
| 30 |
+
Types of Information to Remember:
|
| 31 |
+
|
| 32 |
+
1. **Personal Identity & Details**: Names, relationships, family information, important dates
|
| 33 |
+
2. **Professional Information**: Current job title, company name, work responsibilities, career goals, past work experience
|
| 34 |
+
3. **Personal Preferences**: Likes, dislikes, preferences in food, activities, entertainment, brands
|
| 35 |
+
4. **Goals & Plans**: Future intentions, upcoming events, trips, personal objectives
|
| 36 |
+
5. **Health & Wellness**: Dietary restrictions, fitness routines, health conditions
|
| 37 |
+
6. **Lifestyle & Activities**: Hobbies, regular activities, service preferences
|
| 38 |
+
7. **Location & Living Situation**: Where they live, recent moves, living arrangements
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
MEMORY_ACTION_PROMPT = """
|
| 42 |
+
You are a User Memory Action Decider specializing in accurately managing user facts and preferences.
|
| 43 |
+
|
| 44 |
+
CRITICAL RULES FOR CONFLICTING INFORMATION:
|
| 45 |
+
1. When new information CONTRADICTS or UPDATES existing information, you MUST use UPDATE action
|
| 46 |
+
2. Location changes: If user moves from Place A to Place B, UPDATE the existing location memory
|
| 47 |
+
3. Status changes: If user changes jobs, relationships, or any status, UPDATE the relevant memory
|
| 48 |
+
4. Preference changes: If user's preferences change, UPDATE the existing preference
|
| 49 |
+
5. Look for semantic conflicts, not just exact text matches
|
| 50 |
+
|
| 51 |
+
Examples of when to UPDATE:
|
| 52 |
+
- Existing: "User works at Company A" + New: "User works at Company B" → UPDATE existing memory
|
| 53 |
+
- Existing: "User likes coffee" + New: "User doesn't like coffee anymore" → UPDATE existing memory
|
| 54 |
+
|
| 55 |
+
Format your response as a list of actions with:
|
| 56 |
+
- action: ADD, UPDATE, DELETE, or NOOP
|
| 57 |
+
- memory_id: The id of the memory to update or delete (required for UPDATE/DELETE)
|
| 58 |
+
- content: The content of the memory to add or update (required for ADD/UPDATE)
|
| 59 |
+
|
| 60 |
+
Actions:
|
| 61 |
+
- ADD: Add new information that doesn't conflict with existing memories
|
| 62 |
+
- UPDATE: Replace existing memory when there's conflicting or updated information
|
| 63 |
+
- DELETE: Remove outdated or incorrect memory (use sparingly)
|
| 64 |
+
- NOOP: Skip if the information is already stored or not relevant
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
class MemoryAction(BaseModel):
|
| 68 |
+
"""Structured output for memory action decision"""
|
| 69 |
+
action: Literal["ADD", "UPDATE", "DELETE", "NOOP"] = Field(
|
| 70 |
+
description="The action to take with the memory"
|
| 71 |
+
)
|
| 72 |
+
memory_id: Optional[str] = Field(
|
| 73 |
+
description="The id of the memory to update or delete"
|
| 74 |
+
)
|
| 75 |
+
content: Optional[str] = Field(
|
| 76 |
+
description="The content of the memory to add or update"
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
class MemoryActions(BaseModel):
|
| 80 |
+
"""A list of memory actions"""
|
| 81 |
+
actions: List[MemoryAction] = Field(
|
| 82 |
+
description="A list of memory actions"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
class MemoryFacts(BaseModel):
|
| 86 |
+
"""A list of facts about the user"""
|
| 87 |
+
facts: List[str] = Field(
|
| 88 |
+
description="A list of facts about the user"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
class UserCrossSessionManager(BaseCrossSessionManager):
|
| 92 |
+
"""Manage memories across sessions using ChromaDB"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, model, collection_name="user_memory", persist_directory="./cross_session_db", embedding_model="text-embedding-3-small"):
|
| 95 |
+
# Initialize base class first
|
| 96 |
+
super().__init__(model, collection_name, persist_directory, embedding_model)
|
| 97 |
+
|
| 98 |
+
async def extract_memories(self, events: List[Any]) -> List[str]:
|
| 99 |
+
"""Extract important information from execution events using LLM"""
|
| 100 |
+
|
| 101 |
+
conversation_parts = []
|
| 102 |
+
for event in events:
|
| 103 |
+
for item in event.content:
|
| 104 |
+
if hasattr(item, 'role') and hasattr(item, 'content'):
|
| 105 |
+
if item.role == 'user':
|
| 106 |
+
conversation_parts.append(f"User: {item.content}")
|
| 107 |
+
|
| 108 |
+
conversation = "\n".join(conversation_parts)
|
| 109 |
+
|
| 110 |
+
if not conversation.strip():
|
| 111 |
+
return []
|
| 112 |
+
|
| 113 |
+
user_prompt = f"""Conversation:
|
| 114 |
+
{conversation}
|
| 115 |
+
"""
|
| 116 |
+
|
| 117 |
+
messages = [
|
| 118 |
+
{"role": "system", "content": MEMORY_EXTRACT_PROMPT},
|
| 119 |
+
{"role": "user", "content": user_prompt}
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
response = await self.model.generate_structured(
|
| 123 |
+
messages,
|
| 124 |
+
MemoryFacts
|
| 125 |
+
)
|
| 126 |
+
logger.debug(f"Extracted facts: {response}")
|
| 127 |
+
try:
|
| 128 |
+
return response.facts
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.error(f"Error extracting facts: {e}")
|
| 131 |
+
return []
|
| 132 |
+
|
| 133 |
+
async def find_existing(
|
| 134 |
+
self,
|
| 135 |
+
memories: List[str],
|
| 136 |
+
user_id: str
|
| 137 |
+
) -> List[Dict[str, Any]]:
|
| 138 |
+
"""Find existing memories.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
user_id: User identifier
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
List of existing memories with metadata including timestamps
|
| 145 |
+
"""
|
| 146 |
+
existing_memories = []
|
| 147 |
+
results = self.collection.get(
|
| 148 |
+
where={"user_id": user_id},
|
| 149 |
+
include=["documents", "metadatas"]
|
| 150 |
+
)
|
| 151 |
+
if results["documents"]:
|
| 152 |
+
for i, doc in enumerate(results["documents"]):
|
| 153 |
+
metadata = results["metadatas"][i] if results["metadatas"] else {}
|
| 154 |
+
existing_memories.append({
|
| 155 |
+
"id": results["ids"][i],
|
| 156 |
+
"content": doc,
|
| 157 |
+
"metadata": metadata,
|
| 158 |
+
"created_at": metadata.get("created_at", "Unknown"),
|
| 159 |
+
"updated_at": metadata.get("updated_at", "Unknown")
|
| 160 |
+
})
|
| 161 |
+
return existing_memories
|
| 162 |
+
|
| 163 |
+
async def decide_actions(self, new_memories: List[str], existing: List[Dict[str, Any]], user_id: str) -> List[Dict[str, Any]]:
|
| 164 |
+
"""Decide actions for new memories."""
|
| 165 |
+
system_prompt = MEMORY_ACTION_PROMPT
|
| 166 |
+
|
| 167 |
+
user_prompt = f"""
|
| 168 |
+
Existing memory: {existing}
|
| 169 |
+
New memory: {new_memories}
|
| 170 |
+
"""
|
| 171 |
+
messages = [
|
| 172 |
+
{"role": "system", "content": system_prompt},
|
| 173 |
+
{"role": "user", "content": user_prompt}
|
| 174 |
+
]
|
| 175 |
+
actions = await self.model.generate_structured(messages, MemoryActions)
|
| 176 |
+
result = []
|
| 177 |
+
for action in actions.actions:
|
| 178 |
+
action_dict = action.model_dump()
|
| 179 |
+
if action_dict["action"] == "ADD":
|
| 180 |
+
action_dict["user_id"] = user_id
|
| 181 |
+
action_dict["memory"] = action_dict.pop("content", None)
|
| 182 |
+
elif action_dict["action"] == "UPDATE":
|
| 183 |
+
action_dict["memory"] = action_dict.pop("content", None)
|
| 184 |
+
result.append(action_dict)
|
| 185 |
+
return result
|
scratch_agents/tools/base_tool.py
CHANGED
|
@@ -2,6 +2,8 @@ from typing import Any, Dict, Type, Union, Optional
|
|
| 2 |
from abc import ABC, abstractmethod
|
| 3 |
import json
|
| 4 |
from .schema_utils import format_tool_definition
|
|
|
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class BaseTool(ABC):
|
|
@@ -11,11 +13,13 @@ class BaseTool(ABC):
|
|
| 11 |
name: str = None,
|
| 12 |
description: str = None,
|
| 13 |
tool_definition: Optional[Union[Dict[str, Any], str]] = None,
|
| 14 |
-
pydantic_input_model: Type = None
|
|
|
|
| 15 |
):
|
| 16 |
self.name = name or self.__class__.__name__
|
| 17 |
self.description = description or self.__doc__ or ""
|
| 18 |
self.pydantic_input_model = pydantic_input_model
|
|
|
|
| 19 |
|
| 20 |
if isinstance(tool_definition, str):
|
| 21 |
self._tool_definition = json.loads(tool_definition)
|
|
@@ -48,7 +52,30 @@ class BaseTool(ABC):
|
|
| 48 |
return await self.execute(**kwargs)
|
| 49 |
|
| 50 |
@abstractmethod
|
| 51 |
-
async def execute(self, **kwargs) -> Any:
|
| 52 |
raise NotImplementedError(
|
| 53 |
f"{self.__class__.__name__} must implement the execute method"
|
| 54 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from abc import ABC, abstractmethod
|
| 3 |
import json
|
| 4 |
from .schema_utils import format_tool_definition
|
| 5 |
+
from ..agents.execution_context_ch6 import ExecutionContext
|
| 6 |
+
from ..models.llm_request import LlmRequest
|
| 7 |
|
| 8 |
|
| 9 |
class BaseTool(ABC):
|
|
|
|
| 13 |
name: str = None,
|
| 14 |
description: str = None,
|
| 15 |
tool_definition: Optional[Union[Dict[str, Any], str]] = None,
|
| 16 |
+
pydantic_input_model: Type = None,
|
| 17 |
+
output_type: str = "str"
|
| 18 |
):
|
| 19 |
self.name = name or self.__class__.__name__
|
| 20 |
self.description = description or self.__doc__ or ""
|
| 21 |
self.pydantic_input_model = pydantic_input_model
|
| 22 |
+
self.output_type = output_type
|
| 23 |
|
| 24 |
if isinstance(tool_definition, str):
|
| 25 |
self._tool_definition = json.loads(tool_definition)
|
|
|
|
| 52 |
return await self.execute(**kwargs)
|
| 53 |
|
| 54 |
@abstractmethod
|
| 55 |
+
async def execute(self, context: ExecutionContext, **kwargs) -> Any:
|
| 56 |
raise NotImplementedError(
|
| 57 |
f"{self.__class__.__name__} must implement the execute method"
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
async def process_llm_request(self, request: LlmRequest, context: ExecutionContext):
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
def to_code_prompt(self) -> str:
|
| 64 |
+
"""Generate tool description for code execution environment"""
|
| 65 |
+
params_desc = ""
|
| 66 |
+
if self._tool_definition and "function" in self._tool_definition:
|
| 67 |
+
func_def = self._tool_definition["function"]
|
| 68 |
+
if "parameters" in func_def and "properties" in func_def["parameters"]:
|
| 69 |
+
params = []
|
| 70 |
+
for param_name, param_info in func_def["parameters"]["properties"].items():
|
| 71 |
+
param_type = param_info.get("type", "Any")
|
| 72 |
+
param_desc = param_info.get("description", "")
|
| 73 |
+
required = param_name in func_def["parameters"].get("required", [])
|
| 74 |
+
req_str = " (required)" if required else " (optional)"
|
| 75 |
+
params.append(f" - {param_name}: {param_type}{req_str} - {param_desc}")
|
| 76 |
+
if params:
|
| 77 |
+
params_desc = "\n Parameters:\n" + "\n".join(params)
|
| 78 |
+
|
| 79 |
+
return f"""Tool: {self.name}
|
| 80 |
+
Description: {self.description}
|
| 81 |
+
Output Type: {self.output_type}{params_desc}"""
|
scratch_agents/tools/conversation_search.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .decorator import tool
|
| 2 |
+
from ..agents.execution_context_ch6 import ExecutionContext
|
| 3 |
+
from ..types.contents import Message
|
| 4 |
+
|
| 5 |
+
@tool
|
| 6 |
+
async def conversation_search(
|
| 7 |
+
query: str,
|
| 8 |
+
limit: int = 5,
|
| 9 |
+
context: ExecutionContext = None
|
| 10 |
+
):
|
| 11 |
+
"""Search through current session's conversation history using exact keyword matching
|
| 12 |
+
|
| 13 |
+
IMPORTANT: Use SHORT, SPECIFIC KEYWORDS that likely appear in the conversation.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
query: Short keyword to search for (use simple words that might appear in messages)
|
| 17 |
+
limit: Maximum number of results to return
|
| 18 |
+
context: Execution context with session access
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
Formatted string with search results or message if none found
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
query_lower = query.lower()
|
| 25 |
+
results = []
|
| 26 |
+
|
| 27 |
+
for event in context.session.events:
|
| 28 |
+
for item in event.content:
|
| 29 |
+
if isinstance(item, Message) and item.content:
|
| 30 |
+
if query_lower in item.content.lower():
|
| 31 |
+
results.append({
|
| 32 |
+
"role": item.role,
|
| 33 |
+
"content": item.content,
|
| 34 |
+
"event_id": event.id,
|
| 35 |
+
"timestamp": event.timestamp
|
| 36 |
+
})
|
| 37 |
+
break
|
| 38 |
+
|
| 39 |
+
results = results[-limit:]
|
| 40 |
+
|
| 41 |
+
if not results:
|
| 42 |
+
return f"No messages found containing '{query}'"
|
| 43 |
+
|
| 44 |
+
formatted = f"Found {len(results)} message(s) containing '{query}':\n\n"
|
| 45 |
+
for i, result in enumerate(results, 1):
|
| 46 |
+
formatted += f"{i}. [{result['role']}]: {result['content']}"
|
| 47 |
+
formatted += "\n\n"
|
| 48 |
+
|
| 49 |
+
return formatted
|
scratch_agents/tools/core_memory_upsert.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .decorator import tool
|
| 2 |
+
|
| 3 |
+
@tool
|
| 4 |
+
async def core_memory_upsert(
|
| 5 |
+
block: str,
|
| 6 |
+
content: str,
|
| 7 |
+
update_content: str = None,
|
| 8 |
+
context = None
|
| 9 |
+
) -> str:
|
| 10 |
+
"""Update or insert content in core memory blocks
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
block: Must be 'agent' or 'user'
|
| 14 |
+
content: Text to find or full replacement
|
| 15 |
+
update_content: New text for partial update
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
memory = context.session.core_memory
|
| 19 |
+
current = memory.get(block, "")
|
| 20 |
+
|
| 21 |
+
if update_content:
|
| 22 |
+
if content in current:
|
| 23 |
+
memory[block] = current.replace(content, update_content)
|
| 24 |
+
return f"Updated {block}"
|
| 25 |
+
else:
|
| 26 |
+
if current:
|
| 27 |
+
memory[block] = f"{current}\n{update_content}"
|
| 28 |
+
else:
|
| 29 |
+
memory[block] = update_content
|
| 30 |
+
return f"Added to {block}: {update_content}"
|
| 31 |
+
else:
|
| 32 |
+
memory[block] = content
|
| 33 |
+
return f"Set {block}"
|
scratch_agents/tools/function_tool.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from typing import Any, Dict, Type, Union, Callable, Optional
|
| 2 |
import inspect
|
| 3 |
import asyncio
|
| 4 |
from .base_tool import BaseTool
|
|
@@ -11,28 +11,38 @@ class FunctionTool(BaseTool):
|
|
| 11 |
func: Callable,
|
| 12 |
name: str = None,
|
| 13 |
description: str = None,
|
| 14 |
-
tool_definition: Union[Dict[str, Any], str] = None
|
|
|
|
| 15 |
):
|
| 16 |
self.func = func
|
| 17 |
-
self.pydantic_input_model = self._detect_pydantic_model(func)
|
| 18 |
|
| 19 |
-
name = name or func.__name__
|
| 20 |
-
description = description or (func.__doc__ or "").strip()
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
super().__init__(
|
| 23 |
name=name,
|
| 24 |
description=description,
|
| 25 |
tool_definition=tool_definition,
|
| 26 |
-
pydantic_input_model=self.pydantic_input_model
|
|
|
|
| 27 |
)
|
| 28 |
|
| 29 |
-
async def execute(self, **kwargs) -> Any:
|
|
|
|
|
|
|
|
|
|
| 30 |
if self.pydantic_input_model:
|
| 31 |
args = (self.pydantic_input_model.model_validate(kwargs),)
|
| 32 |
-
call_kwargs = {}
|
| 33 |
else:
|
| 34 |
args = ()
|
| 35 |
call_kwargs = kwargs
|
|
|
|
|
|
|
| 36 |
|
| 37 |
if inspect.iscoroutinefunction(self.func):
|
| 38 |
return await self.func(*args, **call_kwargs)
|
|
@@ -61,4 +71,34 @@ class FunctionTool(BaseTool):
|
|
| 61 |
return param_type
|
| 62 |
except ImportError:
|
| 63 |
pass
|
| 64 |
-
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Type, Union, Callable, Optional, get_type_hints
|
| 2 |
import inspect
|
| 3 |
import asyncio
|
| 4 |
from .base_tool import BaseTool
|
|
|
|
| 11 |
func: Callable,
|
| 12 |
name: str = None,
|
| 13 |
description: str = None,
|
| 14 |
+
tool_definition: Union[Dict[str, Any], str] = None,
|
| 15 |
+
output_type: str = None
|
| 16 |
):
|
| 17 |
self.func = func
|
| 18 |
+
self.pydantic_input_model = self._detect_pydantic_model(func)
|
| 19 |
|
| 20 |
+
name = name or func.__name__
|
| 21 |
+
description = description or (func.__doc__ or "").strip()
|
| 22 |
+
|
| 23 |
+
if output_type is None:
|
| 24 |
+
output_type = self._detect_output_type(func)
|
| 25 |
|
| 26 |
super().__init__(
|
| 27 |
name=name,
|
| 28 |
description=description,
|
| 29 |
tool_definition=tool_definition,
|
| 30 |
+
pydantic_input_model=self.pydantic_input_model,
|
| 31 |
+
output_type=output_type
|
| 32 |
)
|
| 33 |
|
| 34 |
+
async def execute(self, context, **kwargs) -> Any:
|
| 35 |
+
sig = inspect.signature(self.func)
|
| 36 |
+
expects_context = 'context' in sig.parameters
|
| 37 |
+
|
| 38 |
if self.pydantic_input_model:
|
| 39 |
args = (self.pydantic_input_model.model_validate(kwargs),)
|
| 40 |
+
call_kwargs = {'context': context} if expects_context else {}
|
| 41 |
else:
|
| 42 |
args = ()
|
| 43 |
call_kwargs = kwargs
|
| 44 |
+
if expects_context and 'context' not in call_kwargs:
|
| 45 |
+
call_kwargs['context'] = context
|
| 46 |
|
| 47 |
if inspect.iscoroutinefunction(self.func):
|
| 48 |
return await self.func(*args, **call_kwargs)
|
|
|
|
| 71 |
return param_type
|
| 72 |
except ImportError:
|
| 73 |
pass
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
def _detect_output_type(self, func: Callable) -> str:
|
| 77 |
+
"""Detect the output type from function's return type hint"""
|
| 78 |
+
try:
|
| 79 |
+
type_hints = get_type_hints(func)
|
| 80 |
+
return_type = type_hints.get('return', None)
|
| 81 |
+
|
| 82 |
+
if return_type is None:
|
| 83 |
+
return "str"
|
| 84 |
+
|
| 85 |
+
type_mapping = {
|
| 86 |
+
str: "str",
|
| 87 |
+
int: "int",
|
| 88 |
+
float: "float",
|
| 89 |
+
bool: "bool",
|
| 90 |
+
list: "list",
|
| 91 |
+
dict: "dict",
|
| 92 |
+
tuple: "tuple",
|
| 93 |
+
type(None): "None"
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
if return_type in type_mapping:
|
| 97 |
+
return type_mapping[return_type]
|
| 98 |
+
|
| 99 |
+
raise ValueError(f"Unsupported return type: {return_type}. Only basic types are supported.")
|
| 100 |
+
|
| 101 |
+
except ValueError:
|
| 102 |
+
raise
|
| 103 |
+
except Exception:
|
| 104 |
+
return "str"
|
scratch_agents/types/contents.py
CHANGED
|
@@ -13,7 +13,7 @@ class ToolResult(BaseModel):
|
|
| 13 |
tool_call_id: str
|
| 14 |
name: str
|
| 15 |
status: Literal["success", "error"]
|
| 16 |
-
content:
|
| 17 |
|
| 18 |
class Message(BaseModel):
|
| 19 |
type: Literal["message"] = "message"
|
|
|
|
| 13 |
tool_call_id: str
|
| 14 |
name: str
|
| 15 |
status: Literal["success", "error"]
|
| 16 |
+
content: str
|
| 17 |
|
| 18 |
class Message(BaseModel):
|
| 19 |
type: Literal["message"] = "message"
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|