debashis2007's picture
Upload folder using huggingface_hub
75bea1c verified
from __future__ import annotations
"""Context memory management for the agent."""
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
@dataclass
class MemoryEntry:
"""A single entry in memory."""
key: str
value: Any
timestamp: datetime = field(default_factory=datetime.now)
source: str = "unknown"
relevance: float = 1.0
class ContextMemory:
"""Manages context and working memory for the agent."""
def __init__(self, max_entries: int = 100):
"""Initialize memory.
Args:
max_entries: Maximum entries to keep
"""
self.max_entries = max_entries
self._short_term: dict[str, MemoryEntry] = {}
self._working: dict[str, Any] = {}
self._conversation: list[dict[str, str]] = []
def store(self, key: str, value: Any, source: str = "agent") -> None:
"""Store a value in short-term memory.
Args:
key: Memory key
value: Value to store
source: Source of the information
"""
self._short_term[key] = MemoryEntry(
key=key,
value=value,
source=source,
)
# Trim if over capacity
if len(self._short_term) > self.max_entries:
self._trim_oldest()
def retrieve(self, key: str) -> Any | None:
"""Retrieve a value from memory.
Args:
key: Memory key
Returns:
Stored value or None
"""
entry = self._short_term.get(key)
return entry.value if entry else None
def update_working(self, key: str, value: Any) -> None:
"""Update working memory.
Args:
key: Memory key
value: Value to store
"""
self._working[key] = value
def get_working(self, key: str, default: Any = None) -> Any:
"""Get from working memory.
Args:
key: Memory key
default: Default value if not found
Returns:
Stored value or default
"""
return self._working.get(key, default)
def add_conversation_turn(self, role: str, content: str) -> None:
"""Add a turn to conversation history.
Args:
role: Message role (user/assistant)
content: Message content
"""
self._conversation.append({
"role": role,
"content": content,
"timestamp": datetime.now().isoformat(),
})
def get_conversation_history(self, limit: int = 10) -> list[dict[str, str]]:
"""Get recent conversation history.
Args:
limit: Maximum turns to return
Returns:
List of conversation turns
"""
return self._conversation[-limit:]
def get_context_summary(self) -> dict[str, Any]:
"""Get a summary of current context.
Returns:
Dictionary with context summary
"""
return {
"short_term_keys": list(self._short_term.keys()),
"working_memory_keys": list(self._working.keys()),
"conversation_length": len(self._conversation),
}
def clear_working(self) -> None:
"""Clear working memory."""
self._working.clear()
def clear_all(self) -> None:
"""Clear all memory."""
self._short_term.clear()
self._working.clear()
self._conversation.clear()
def _trim_oldest(self) -> None:
"""Remove oldest entries to stay within capacity."""
if not self._short_term:
return
# Sort by timestamp and remove oldest
sorted_keys = sorted(
self._short_term.keys(),
key=lambda k: self._short_term[k].timestamp,
)
# Remove oldest 10%
to_remove = max(1, len(sorted_keys) // 10)
for key in sorted_keys[:to_remove]:
del self._short_term[key]
def search(self, query: str) -> list[MemoryEntry]:
"""Search memory for relevant entries.
Args:
query: Search query
Returns:
List of matching entries
"""
query_lower = query.lower()
results = []
for entry in self._short_term.values():
# Simple keyword matching
value_str = str(entry.value).lower()
if query_lower in value_str or query_lower in entry.key.lower():
results.append(entry)
# Sort by relevance (for now, just by timestamp)
results.sort(key=lambda e: e.timestamp, reverse=True)
return results