Spaces:
No application file
No application file
File size: 10,851 Bytes
4f24301 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 |
from typing import Dict, List, Any, Optional
import re
import time
import json
from deepforest_agent.models.smollm3_3b import SmolLM3ModelManager
from deepforest_agent.conf.config import Config
from deepforest_agent.prompts.prompt_templates import format_memory_prompt
from deepforest_agent.utils.state_manager import session_state_manager
from deepforest_agent.utils.logging_utils import multi_agent_logger
from deepforest_agent.utils.parsing_utils import parse_memory_agent_response
from deepforest_agent.utils.cache_utils import tool_call_cache
from deepforest_agent.conf.config import Config
class MemoryAgent:
"""
Memory agent responsible for analyzing conversation history in new format.
Uses SmolLM3-3B model for getting relevant context
"""
def __init__(self):
"""Initialize the Memory Agent with model manager and configuration."""
self.agent_config = Config.AGENT_CONFIGS["memory"]
self.model_manager = SmolLM3ModelManager(Config.AGENT_MODELS["memory"])
def _filter_conversation_history(self, conversation_history: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Filter conversation history to include user and assistant messages.
Args:
conversation_history: Full conversation history
Returns:
Filtered history with only user/assistant messages
"""
filtered_history = []
for message in conversation_history:
if message.get("role") in ["user", "assistant"]:
content = message.get("content", "")
if isinstance(content, list):
text_parts = [item.get("text", "") for item in content if item.get("type") == "text"]
content = " ".join(text_parts)
elif isinstance(content, str):
content = content
else:
content = str(content)
filtered_history.append({
"role": message["role"],
"content": content
})
return filtered_history
def _get_conversation_history_context(self, session_id: str) -> str:
"""
Get formatted conversation history with turn-based structure.
Args:
session_id: Session identifier
Returns:
Formatted conversation history with turn structure
"""
conversation_history = session_state_manager.get(session_id, "conversation_history", [])
print(f"Session {session_id} - Conversation length: {len(conversation_history)}")
if not conversation_history:
return "No previous conversation history available."
# Build turn-based history
formatted_history = []
turn_number = 1
# Process conversation in pairs (user -> assistant)
i = 0
while i < len(conversation_history):
if i + 1 < len(conversation_history):
user_msg = conversation_history[i]
assistant_msg = conversation_history[i + 1]
if user_msg.get("role") == "user" and assistant_msg.get("role") == "assistant":
# Extract user query
user_content = user_msg.get("content", "")
if isinstance(user_content, list):
text_parts = [item.get("text", "") for item in user_content if item.get("type") == "text"]
user_query = " ".join(text_parts)
else:
user_query = str(user_content)
# Get stored context data for this turn
visual_context = session_state_manager.get(session_id, f"turn_{turn_number}_visual_context", "No visual analysis available")
detection_narrative = session_state_manager.get(session_id, f"turn_{turn_number}_detection_narrative", "No detection narrative available")
tool_cache_id = session_state_manager.get(session_id, f"turn_{turn_number}_tool_cache_id", "No tool cache ID")
tool_call_info = "No tool call information available"
if tool_cache_id:
try:
if tool_cache_id in tool_call_cache.cache_data:
cached_entry = tool_call_cache.cache_data[tool_cache_id]
tool_name = cached_entry.get("tool_name", "unknown")
stored_arguments = cached_entry.get("arguments", {})
all_arguments = Config.DEEPFOREST_DEFAULTS.copy()
all_arguments.update(stored_arguments)
# Format tool call info with all arguments
args_str = ", ".join([f"{k}={v}" for k, v in all_arguments.items()])
tool_call_info = f"Tool: {tool_name} called with arguments: {args_str}"
except Exception as e:
tool_call_info = f"Error retrieving tool call info: {str(e)}"
turn_text = f"--- Turn {turn_number}: ---\n"
turn_text += f"Turn {turn_number} User query: {user_query}\n"
turn_text += f"Turn {turn_number} Visual analysis full image or per tile: {visual_context}\n"
turn_text += f"Turn {turn_number} Tool cache ID: {tool_cache_id}\n"
turn_text += f"Turn {turn_number} Tool call details: {tool_call_info}\n"
turn_text += f"Turn {turn_number} Detection Data Analysis: {detection_narrative}\n"
turn_text += f"--- Turn {turn_number} Completed ---\n"
formatted_history.append(turn_text)
turn_number += 1
i += 2
else:
i += 1
else:
i += 1
if not formatted_history:
return "No complete conversation turns available."
print(f"Formatted {len(formatted_history)} conversation turns")
return "\n\n".join(formatted_history)
def process_conversation_history_structured(
self,
conversation_history: List[Dict[str, Any]],
latest_message: str,
session_id: str
) -> Dict[str, Any]:
"""
Process conversation history and extract relevant context with structured output.
Args:
conversation_history: Full conversation history
latest_message: Current user message requiring context analysis
session_id: Unique session identifier for this user
Returns:
Dict with structured output including tool_cache_id and relevant context
"""
if not session_state_manager.session_exists(session_id):
return {
"answer_present": False,
"direct_answer": "NO",
"tool_cache_id": None,
"relevant_context": f"Session {session_id} not found. Current query: {latest_message}",
"raw_response": f"Session {session_id} not found"
}
filtered_history = self._filter_conversation_history(conversation_history)
conversation_context = self._get_conversation_history_context(session_id)
memory_prompt = format_memory_prompt(filtered_history, latest_message, conversation_context)
print(f"Memory Agent Prompt:\n{memory_prompt}\n")
messages = [
{"role": "system", "content": memory_prompt},
{"role": "user", "content": latest_message}
]
memory_execution_start = time.perf_counter()
try:
response = self.model_manager.generate_response(
messages=messages,
max_new_tokens=self.agent_config["max_new_tokens"],
temperature=self.agent_config["temperature"],
top_p=self.agent_config["top_p"]
)
memory_execution_time = time.perf_counter() - memory_execution_start
print(f"Session {session_id} - Memory Agent: Raw response received")
print(f"Raw Response: {response}")
parsed_result = parse_memory_agent_response(response)
multi_agent_logger.log_agent_execution(
session_id=session_id,
agent_name="memory",
agent_input=f"Latest message: {latest_message}",
agent_output=response,
execution_time=memory_execution_time
)
print(f"Session {session_id} - Memory Agent: Analysis completed")
print(f"Has Answer: {parsed_result['answer_present']}")
return parsed_result
except Exception as e:
memory_execution_time = time.perf_counter() - memory_execution_start
error_msg = f"Error processing conversation history in session {session_id}: {str(e)}"
print(f"Session {session_id} - Memory Agent Error: {e}")
multi_agent_logger.log_error(
session_id=session_id,
error_type="memory_agent_error",
error_message=f"Memory agent failed after {memory_execution_time:.2f}s: {str(e)}"
)
return {
"answer_present": False,
"direct_answer": "NO",
"tool_cache_id": None,
"relevant_context": f"{error_msg}. Current query: {latest_message}",
"raw_response": str(e)
}
def store_turn_context(self, session_id: str, turn_number: int, visual_context: str,
detection_narrative: str, tool_cache_id: Optional[str]) -> None:
"""
Store context data for a specific conversation turn.
Args:
session_id: Session identifier
turn_number: Turn number in conversation
visual_context: Visual analysis context
detection_narrative: Detection narrative
tool_cache_id: Tool cache identifier
"""
session_state_manager.set(session_id, f"turn_{turn_number}_visual_context", visual_context)
session_state_manager.set(session_id, f"turn_{turn_number}_detection_narrative", detection_narrative)
session_state_manager.set(session_id, f"turn_{turn_number}_tool_cache_id", tool_cache_id or "No tool cache ID")
print(f"Session {session_id} - Stored context for turn {turn_number}") |