AIDA / app /core /context_manager.py
destinyebuka's picture
dora
4c9881b
# ============================================================
# app/core/context_manager.py - Context Window Management
# ============================================================
import logging
import json
from typing import List, Dict, Optional, Tuple
from datetime import datetime, timedelta
import tiktoken
from app.core.error_handling import LojizError
logger = logging.getLogger(__name__)
# ============================================================
# Token Counter
# ============================================================
class TokenCounter:
"""Count tokens using tiktoken (OpenAI's tokenizer)"""
def __init__(self, encoding_name: str = "cl100k_base"):
try:
self.encoding = tiktoken.get_encoding(encoding_name)
except Exception as e:
logger.warning(f"⚠️ Failed to load tiktoken: {e}, using fallback")
self.encoding = None
def count_tokens(self, text: str) -> int:
"""Count tokens in text"""
if not self.encoding:
# Fallback: rough estimate (4 chars ≈ 1 token)
return len(text) // 4
return len(self.encoding.encode(text))
def count_messages_tokens(self, messages: List[Dict[str, str]]) -> int:
"""Count tokens in message list"""
total = 0
for msg in messages:
# Add overhead per message (role + content markers)
total += 4
if msg.get("role"):
total += self.count_tokens(msg["role"])
if msg.get("content"):
total += self.count_tokens(msg["content"])
# Add overhead for message framing
total += 2
return total
# ============================================================
# Context Manager
# ============================================================
class ContextManager:
"""Manage context window to prevent overflow"""
# Model limits (tokens)
MODEL_LIMITS = {
"deepseek-chat": 4096,
"mistralai/mistral-7b-instruct": 8192,
"xai-org/grok-beta": 8192,
"meta-llama/llama-2-70b-chat": 4096,
}
# Reserve space for response
RESPONSE_RESERVE = 600
def __init__(self, model: str = "deepseek-chat"):
self.model = model
self.token_counter = TokenCounter()
self.context_limit = self.MODEL_LIMITS.get(model, 4096)
self.usable_limit = self.context_limit - self.RESPONSE_RESERVE
def get_available_context(self, current_tokens: int) -> int:
"""Get available context space"""
return max(0, self.usable_limit - current_tokens)
def is_context_full(self, messages: List[Dict[str, str]]) -> bool:
"""Check if context is full"""
tokens = self.token_counter.count_messages_tokens(messages)
return tokens >= self.usable_limit
async def manage_context(
self,
messages: List[Dict[str, str]],
max_history_messages: int = 20,
) -> List[Dict[str, str]]:
"""
Manage context by summarizing if needed
Strategy:
1. Keep system message
2. Keep last message (current user input)
3. Summarize older messages if needed
"""
if not messages:
return messages
tokens = self.token_counter.count_messages_tokens(messages)
if tokens <= self.usable_limit:
logger.debug(
f"✅ Context OK: {tokens}/{self.usable_limit} tokens, "
f"{len(messages)} messages"
)
return messages
logger.warning(
f"⚠️ Context overflow: {tokens}/{self.usable_limit} tokens, "
f"{len(messages)} messages"
)
# Keep system message + last message, summarize the rest
system_msg = [m for m in messages if m.get("role") == "system"]
user_msg = [m for m in messages if m.get("role") == "user"][-1:] if messages else []
history = [
m for m in messages
if m.get("role") not in ["system"] and m not in user_msg
]
# Trim history to most recent max_history_messages
if len(history) > max_history_messages:
logger.info(f"📦 Trimming history from {len(history)} to {max_history_messages}")
history = history[-max_history_messages:]
# Rebuild messages
managed_messages = system_msg + history + user_msg
final_tokens = self.token_counter.count_messages_tokens(managed_messages)
logger.info(
f"📦 Context managed: {final_tokens}/{self.usable_limit} tokens, "
f"{len(managed_messages)} messages"
)
return managed_messages
async def summarize_conversation(
self,
messages: List[Dict[str, str]],
summarizer_fn = None,
) -> str:
"""
Summarize conversation history
Args:
messages: Message history
summarizer_fn: Optional async function to summarize
Returns:
Summary of conversation
"""
if not messages or len(messages) < 3:
return ""
# Extract conversation content (skip system message)
conversation = [
m for m in messages
if m.get("role") != "system"
]
conversation_text = "\n".join([
f"{m.get('role', 'unknown').upper()}: {m.get('content', '')[:200]}"
for m in conversation
])
# If no custom summarizer, use basic extraction
if not summarizer_fn:
return self._basic_summary(conversation)
# Use custom summarizer
try:
summary = await summarizer_fn(conversation_text)
return summary
except Exception as e:
logger.warning(f"⚠️ Summarization failed: {e}, using basic summary")
return self._basic_summary(conversation)
def _basic_summary(self, messages: List[Dict[str, str]]) -> str:
"""Basic summary extraction"""
summaries = []
for msg in messages[-10:]: # Last 10 messages
content = msg.get("content", "")
if len(content) > 100:
# Extract key points
lines = content.split("\n")
key_lines = [l for l in lines if len(l) > 20][:2]
summaries.append(" ".join(key_lines))
else:
summaries.append(content)
return " | ".join(summaries)
# ============================================================
# Message Window (sliding window)
# ============================================================
class MessageWindow:
"""Sliding window for conversation history"""
def __init__(self, window_size: int = 20, max_age_minutes: int = 120):
self.window_size = window_size
self.max_age = timedelta(minutes=max_age_minutes)
self.messages: List[Dict[str, str]] = []
self.created_at = datetime.utcnow()
def add_message(self, role: str, content: str) -> None:
"""Add message to window"""
msg = {
"role": role,
"content": content,
"timestamp": datetime.utcnow().isoformat(),
}
self.messages.append(msg)
# Maintain window size
if len(self.messages) > self.window_size:
removed = self.messages.pop(0)
logger.debug(f"📤 Removed old message from window")
def get_messages(self, include_timestamps: bool = False) -> List[Dict[str, str]]:
"""Get messages in window"""
messages = self.messages
if not include_timestamps:
# Remove timestamps for API calls
messages = [
{k: v for k, v in m.items() if k != "timestamp"}
for m in messages
]
return messages
def is_expired(self) -> bool:
"""Check if window has expired"""
return datetime.utcnow() - self.created_at > self.max_age
def clear(self) -> None:
"""Clear window"""
self.messages = []
self.created_at = datetime.utcnow()
def get_stats(self) -> Dict[str, int]:
"""Get window statistics"""
return {
"message_count": len(self.messages),
"max_size": self.window_size,
"age_seconds": int((datetime.utcnow() - self.created_at).total_seconds()),
}
# ============================================================
# Global Context Manager
# ============================================================
_context_managers = {}
_message_windows = {}
def get_context_manager(model: str = "deepseek-chat") -> ContextManager:
"""Get or create context manager"""
if model not in _context_managers:
_context_managers[model] = ContextManager(model)
return _context_managers[model]
def get_message_window(user_id: str, create_if_missing: bool = True) -> Optional[MessageWindow]:
"""Get or create message window for user"""
if user_id not in _message_windows:
if create_if_missing:
_message_windows[user_id] = MessageWindow()
else:
return None
window = _message_windows[user_id]
# Check if expired
if window.is_expired():
logger.info(f"🗑️ Clearing expired window for user {user_id}")
window.clear()
return window
def cleanup_expired_windows() -> int:
"""Clean up expired message windows"""
expired = [
user_id for user_id, window in _message_windows.items()
if window.is_expired()
]
for user_id in expired:
del _message_windows[user_id]
if expired:
logger.info(f"🧹 Cleaned up {len(expired)} expired windows")
return len(expired)