water3 / agent /core /contextual_memory.py
onewayto's picture
Upload 187 files
070daf8 verified
"""
Contextual Memory Engine - Persistent learning and knowledge base integration
"""
import json
import logging
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
@dataclass
class UserMemory:
"""Memory for a specific user"""
user_id: str
preferences: Dict[str, Any] = field(default_factory=dict)
successful_patterns: List[Dict[str, Any]] = field(default_factory=list)
quality_score: float = 0.0
def to_dict(self) -> Dict[str, Any]:
return {
"user_id": self.user_id,
"preferences": self.preferences,
"successful_patterns": self.successful_patterns,
"quality_score": self.quality_score
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'UserMemory':
return cls(
user_id=data.get("user_id", ""),
preferences=data.get("preferences", {}),
successful_patterns=data.get("successful_patterns", []),
quality_score=data.get("quality_score", 0.0)
)
@dataclass
class Context:
"""Retrieved context for a query"""
user_memory: Optional[UserMemory]
similar_examples: List[Dict[str, Any]]
domain_knowledge: List[Dict[str, Any]]
compressed_size: int
@dataclass
class ExecutionResult:
"""Result of an execution"""
query: str
success: bool
tools_used: List[str]
reasoning: List[str]
error: Optional[str] = None
class ContextualMemoryEngine:
"""
Maintains:
- Per-user knowledge base (what they care about)
- Global failure patterns (what goes wrong)
- Tool performance characteristics
- Domain-specific expertise
- Context compression (for long conversations)
"""
def __init__(self, storage_path: str = "/tmp/agent_memory"):
self.storage_path = storage_path
self.user_memories: Dict[str, UserMemory] = {}
self.failure_patterns: List[Dict[str, Any]] = []
self.domain_knowledge: Dict[str, List[Dict[str, Any]]] = {}
# Load existing memories
self._load_memories()
def _load_memories(self):
"""Load memories from storage"""
import os
try:
if os.path.exists(f"{self.storage_path}/user_memories.json"):
with open(f"{self.storage_path}/user_memories.json", "r") as f:
data = json.load(f)
self.user_memories = {
k: UserMemory.from_dict(v)
for k, v in data.items()
}
logger.info(f"Loaded {len(self.user_memories)} user memories")
except Exception as e:
logger.warning(f"Failed to load memories: {e}")
def _save_memories(self):
"""Save memories to storage"""
import os
try:
os.makedirs(self.storage_path, exist_ok=True)
with open(f"{self.storage_path}/user_memories.json", "w") as f:
json.dump(
{k: v.to_dict() for k, v in self.user_memories.items()},
f,
indent=2
)
except Exception as e:
logger.warning(f"Failed to save memories: {e}")
def _compute_similarity(self, query1: str, query2: str) -> float:
"""Compute similarity between two queries"""
# Simple word overlap similarity
words1 = set(query1.lower().split())
words2 = set(query2.lower().split())
if not words1 or not words2:
return 0.0
intersection = words1 & words2
union = words1 | words2
return len(intersection) / len(union)
def _classify_domain(self, query: str) -> str:
"""Classify the domain of a query"""
query_lower = query.lower()
domains = {
"web_development": ["react", "vue", "angular", "html", "css", "javascript", "frontend"],
"backend": ["api", "server", "database", "sql", "nodejs", "python", "backend"],
"devops": ["docker", "kubernetes", "ci/cd", "deployment", "aws", "cloud"],
"data_science": ["pandas", "numpy", "machine learning", "ml", "data", "analysis"],
"security": ["auth", "oauth", "security", "encryption", "vulnerability"],
"mobile": ["ios", "android", "flutter", "react native", "mobile"]
}
scores = {}
for domain, keywords in domains.items():
score = sum(1 for kw in keywords if kw in query_lower)
scores[domain] = score
best_domain = max(scores, key=scores.get)
return best_domain if scores[best_domain] > 0 else "general"
def _estimate_tokens(self, texts: List[Any]) -> int:
"""Estimate token count for texts"""
total_chars = sum(len(str(t)) for t in texts)
# Rough estimate: 1 token ≈ 4 characters
return total_chars // 4
async def retrieve_context(
self,
query: str,
user_id: str,
max_tokens: int = 2000
) -> Context:
"""Retrieve most relevant context for query"""
# Get user preferences and history
user_memory = self.user_memories.get(user_id)
# Find similar past queries from user's patterns
similar_examples = []
if user_memory:
for pattern in user_memory.successful_patterns:
similarity = self._compute_similarity(query, pattern.get("query", ""))
if similarity > 0.3: # Threshold for similarity
similar_examples.append({
"query": pattern.get("query"),
"similarity": similarity,
"tools_used": pattern.get("tools_used"),
"outcome": pattern.get("outcome")
})
# Sort by similarity
similar_examples.sort(key=lambda x: x["similarity"], reverse=True)
similar_examples = similar_examples[:5] # Top 5
# Get domain knowledge if applicable
domain = self._classify_domain(query)
domain_knowledge = self.domain_knowledge.get(domain, [])
# Compress if needed
current_tokens = self._estimate_tokens([similar_examples, domain_knowledge])
if current_tokens > max_tokens:
# Simple compression: reduce examples
target_examples = max(1, len(similar_examples) // 2)
similar_examples = similar_examples[:target_examples]
compressed_size = self._estimate_tokens([similar_examples, domain_knowledge])
else:
compressed_size = current_tokens
logger.info(
f"Context retrieved for user {user_id}: "
f"domain={domain}, similar_examples={len(similar_examples)}, "
f"tokens={compressed_size}"
)
return Context(
user_memory=user_memory,
similar_examples=similar_examples,
domain_knowledge=domain_knowledge,
compressed_size=compressed_size
)
async def learn_from_execution(
self,
user_id: str,
execution_result: ExecutionResult
):
"""Store learnings from this execution"""
# Get or create user memory
if user_id not in self.user_memories:
self.user_memories[user_id] = UserMemory(user_id=user_id)
user_memory = self.user_memories[user_id]
# If successful, remember the pattern
if execution_result.success:
pattern = {
"query": execution_result.query,
"tools_used": execution_result.tools_used,
"reasoning": execution_result.reasoning,
"outcome": "success",
"timestamp": json.dumps({}).__class__ # Placeholder for timestamp
}
user_memory.successful_patterns.append(pattern)
# Keep only last 50 patterns
if len(user_memory.successful_patterns) > 50:
user_memory.successful_patterns = user_memory.successful_patterns[-50:]
# Update quality score
user_memory.quality_score = min(
1.0,
user_memory.quality_score + 0.05
)
logger.info(f"Learned successful pattern for user {user_id}")
# If failed, record failure pattern for future avoidance
else:
failure_pattern = {
"query": execution_result.query,
"tools_attempted": execution_result.tools_used,
"failure_reason": execution_result.error,
"outcome": "failure"
}
self.failure_patterns.append(failure_pattern)
# Keep only last 100 failure patterns
if len(self.failure_patterns) > 100:
self.failure_patterns = self.failure_patterns[-100:]
# Update quality score
user_memory.quality_score = max(
0.0,
user_memory.quality_score - 0.02
)
logger.info(f"Learned failure pattern for user {user_id}")
# Save memories
self._save_memories()
def get_user_stats(self, user_id: str) -> Dict[str, Any]:
"""Get statistics for a user"""
user_memory = self.user_memories.get(user_id)
if not user_memory:
return {
"user_id": user_id,
"has_memory": False,
"quality_score": 0.0,
"successful_patterns": 0
}
return {
"user_id": user_id,
"has_memory": True,
"quality_score": user_memory.quality_score,
"successful_patterns": len(user_memory.successful_patterns),
"preferences": user_memory.preferences
}
# Global memory engine
memory_engine = ContextualMemoryEngine()