Spaces:
Running
Running
Commit ·
9d9d55a
1
Parent(s): b3d741d
Adding memory management using sentenceTransformer
Browse files- src/manager/manager.py +42 -1
src/manager/manager.py
CHANGED
|
@@ -8,6 +8,9 @@ from src.manager.tool_manager import ToolManager
|
|
| 8 |
from src.manager.utils.suppress_outputs import suppress_output
|
| 9 |
import logging
|
| 10 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
handler = logging.StreamHandler(sys.stdout)
|
|
@@ -35,6 +38,7 @@ class GeminiManager:
|
|
| 35 |
self.client = genai.Client(api_key=self.API_KEY)
|
| 36 |
self.toolsLoader.load_tools()
|
| 37 |
self.model_name = gemini_model
|
|
|
|
| 38 |
with open(system_prompt_file, 'r', encoding="utf8") as f:
|
| 39 |
self.system_prompt = f.read()
|
| 40 |
self.messages = []
|
|
@@ -131,6 +135,9 @@ class GeminiManager:
|
|
| 131 |
match message.get("role"):
|
| 132 |
case "user":
|
| 133 |
role = "user"
|
|
|
|
|
|
|
|
|
|
| 134 |
case "tool":
|
| 135 |
role = "tool"
|
| 136 |
formatted_history.append(
|
|
@@ -149,7 +156,41 @@ class GeminiManager:
|
|
| 149 |
))
|
| 150 |
return formatted_history
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
def run(self, messages):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
chat_history = self.format_chat_history(messages)
|
| 154 |
logger.debug(f"Chat history: {chat_history}")
|
| 155 |
try:
|
|
@@ -195,6 +236,6 @@ class GeminiManager:
|
|
| 195 |
if (call.get("role") == "tool"
|
| 196 |
or (call.get("role") == "assistant" and call.get("metadata", {}).get("status") == "done")):
|
| 197 |
messages.append(call)
|
| 198 |
-
yield from self.
|
| 199 |
return
|
| 200 |
yield messages
|
|
|
|
| 8 |
from src.manager.utils.suppress_outputs import suppress_output
|
| 9 |
import logging
|
| 10 |
import gradio as gr
|
| 11 |
+
from sentence_transformers import SentenceTransformer
|
| 12 |
+
import torch
|
| 13 |
+
from src.tools.default_tools.memory_manager import MemoryManager
|
| 14 |
|
| 15 |
logger = logging.getLogger(__name__)
|
| 16 |
handler = logging.StreamHandler(sys.stdout)
|
|
|
|
| 38 |
self.client = genai.Client(api_key=self.API_KEY)
|
| 39 |
self.toolsLoader.load_tools()
|
| 40 |
self.model_name = gemini_model
|
| 41 |
+
self.memory_manager = MemoryManager()
|
| 42 |
with open(system_prompt_file, 'r', encoding="utf8") as f:
|
| 43 |
self.system_prompt = f.read()
|
| 44 |
self.messages = []
|
|
|
|
| 135 |
match message.get("role"):
|
| 136 |
case "user":
|
| 137 |
role = "user"
|
| 138 |
+
case "memories":
|
| 139 |
+
role = "user"
|
| 140 |
+
parts = [types.Part.from_text(text="User memories: "+message.get("content", ""))]
|
| 141 |
case "tool":
|
| 142 |
role = "tool"
|
| 143 |
formatted_history.append(
|
|
|
|
| 156 |
))
|
| 157 |
return formatted_history
|
| 158 |
|
| 159 |
+
def get_k_memories(self, query, k=5, threshold=0.0):
|
| 160 |
+
memories = MemoryManager().get_memories()
|
| 161 |
+
if len(memories) == 0:
|
| 162 |
+
return []
|
| 163 |
+
top_k = min(k, len(memories))
|
| 164 |
+
# Semantic Retrieval with GPU
|
| 165 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 166 |
+
model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
|
| 167 |
+
doc_embeddings = model.encode(memories, convert_to_tensor=True, device=device)
|
| 168 |
+
query_embedding = model.encode(query, convert_to_tensor=True, device=device)
|
| 169 |
+
similarity_scores = model.similarity(query_embedding, doc_embeddings)[0]
|
| 170 |
+
scores, indices = torch.topk(similarity_scores, k=top_k)
|
| 171 |
+
results = []
|
| 172 |
+
for score, idx in zip(scores, indices):
|
| 173 |
+
print(memories[idx], f"(Score: {score:.4f})")
|
| 174 |
+
if score >= threshold:
|
| 175 |
+
results.append(memories[idx])
|
| 176 |
+
return results
|
| 177 |
+
|
| 178 |
def run(self, messages):
|
| 179 |
+
memories = self.get_k_memories(messages[-1]['content'], k=5, threshold=0.0)
|
| 180 |
+
if len(memories) > 0:
|
| 181 |
+
messages.append({
|
| 182 |
+
"role": "memories",
|
| 183 |
+
"content": f"{memories}",
|
| 184 |
+
})
|
| 185 |
+
messages.append({
|
| 186 |
+
"role": "assistant",
|
| 187 |
+
"content": f"Memories: {memories}",
|
| 188 |
+
"metadata": {"title": "Memories"}
|
| 189 |
+
})
|
| 190 |
+
yield messages
|
| 191 |
+
yield from self.invoke_manager(messages)
|
| 192 |
+
|
| 193 |
+
def invoke_manager(self, messages):
|
| 194 |
chat_history = self.format_chat_history(messages)
|
| 195 |
logger.debug(f"Chat history: {chat_history}")
|
| 196 |
try:
|
|
|
|
| 236 |
if (call.get("role") == "tool"
|
| 237 |
or (call.get("role") == "assistant" and call.get("metadata", {}).get("status") == "done")):
|
| 238 |
messages.append(call)
|
| 239 |
+
yield from self.invoke_manager(messages)
|
| 240 |
return
|
| 241 |
yield messages
|