Spaces:
Running
Running
| # ============================================================ | |
| # app/core/context_manager.py - Context Window Management | |
| # ============================================================ | |
| import logging | |
| import json | |
| from typing import List, Dict, Optional, Tuple | |
| from datetime import datetime, timedelta | |
| import tiktoken | |
| from app.core.error_handling import LojizError | |
| logger = logging.getLogger(__name__) | |
| # ============================================================ | |
| # Token Counter | |
| # ============================================================ | |
| class TokenCounter: | |
| """Count tokens using tiktoken (OpenAI's tokenizer)""" | |
| def __init__(self, encoding_name: str = "cl100k_base"): | |
| try: | |
| self.encoding = tiktoken.get_encoding(encoding_name) | |
| except Exception as e: | |
| logger.warning(f"⚠️ Failed to load tiktoken: {e}, using fallback") | |
| self.encoding = None | |
| def count_tokens(self, text: str) -> int: | |
| """Count tokens in text""" | |
| if not self.encoding: | |
| # Fallback: rough estimate (4 chars ≈ 1 token) | |
| return len(text) // 4 | |
| return len(self.encoding.encode(text)) | |
| def count_messages_tokens(self, messages: List[Dict[str, str]]) -> int: | |
| """Count tokens in message list""" | |
| total = 0 | |
| for msg in messages: | |
| # Add overhead per message (role + content markers) | |
| total += 4 | |
| if msg.get("role"): | |
| total += self.count_tokens(msg["role"]) | |
| if msg.get("content"): | |
| total += self.count_tokens(msg["content"]) | |
| # Add overhead for message framing | |
| total += 2 | |
| return total | |
| # ============================================================ | |
| # Context Manager | |
| # ============================================================ | |
| class ContextManager: | |
| """Manage context window to prevent overflow""" | |
| # Model limits (tokens) | |
| MODEL_LIMITS = { | |
| "deepseek-chat": 4096, | |
| "mistralai/mistral-7b-instruct": 8192, | |
| "xai-org/grok-beta": 8192, | |
| "meta-llama/llama-2-70b-chat": 4096, | |
| } | |
| # Reserve space for response | |
| RESPONSE_RESERVE = 600 | |
| def __init__(self, model: str = "deepseek-chat"): | |
| self.model = model | |
| self.token_counter = TokenCounter() | |
| self.context_limit = self.MODEL_LIMITS.get(model, 4096) | |
| self.usable_limit = self.context_limit - self.RESPONSE_RESERVE | |
| def get_available_context(self, current_tokens: int) -> int: | |
| """Get available context space""" | |
| return max(0, self.usable_limit - current_tokens) | |
| def is_context_full(self, messages: List[Dict[str, str]]) -> bool: | |
| """Check if context is full""" | |
| tokens = self.token_counter.count_messages_tokens(messages) | |
| return tokens >= self.usable_limit | |
| async def manage_context( | |
| self, | |
| messages: List[Dict[str, str]], | |
| max_history_messages: int = 20, | |
| ) -> List[Dict[str, str]]: | |
| """ | |
| Manage context by summarizing if needed | |
| Strategy: | |
| 1. Keep system message | |
| 2. Keep last message (current user input) | |
| 3. Summarize older messages if needed | |
| """ | |
| if not messages: | |
| return messages | |
| tokens = self.token_counter.count_messages_tokens(messages) | |
| if tokens <= self.usable_limit: | |
| logger.debug( | |
| f"✅ Context OK: {tokens}/{self.usable_limit} tokens, " | |
| f"{len(messages)} messages" | |
| ) | |
| return messages | |
| logger.warning( | |
| f"⚠️ Context overflow: {tokens}/{self.usable_limit} tokens, " | |
| f"{len(messages)} messages" | |
| ) | |
| # Keep system message + last message, summarize the rest | |
| system_msg = [m for m in messages if m.get("role") == "system"] | |
| user_msg = [m for m in messages if m.get("role") == "user"][-1:] if messages else [] | |
| history = [ | |
| m for m in messages | |
| if m.get("role") not in ["system"] and m not in user_msg | |
| ] | |
| # Trim history to most recent max_history_messages | |
| if len(history) > max_history_messages: | |
| logger.info(f"📦 Trimming history from {len(history)} to {max_history_messages}") | |
| history = history[-max_history_messages:] | |
| # Rebuild messages | |
| managed_messages = system_msg + history + user_msg | |
| final_tokens = self.token_counter.count_messages_tokens(managed_messages) | |
| logger.info( | |
| f"📦 Context managed: {final_tokens}/{self.usable_limit} tokens, " | |
| f"{len(managed_messages)} messages" | |
| ) | |
| return managed_messages | |
| async def summarize_conversation( | |
| self, | |
| messages: List[Dict[str, str]], | |
| summarizer_fn = None, | |
| ) -> str: | |
| """ | |
| Summarize conversation history | |
| Args: | |
| messages: Message history | |
| summarizer_fn: Optional async function to summarize | |
| Returns: | |
| Summary of conversation | |
| """ | |
| if not messages or len(messages) < 3: | |
| return "" | |
| # Extract conversation content (skip system message) | |
| conversation = [ | |
| m for m in messages | |
| if m.get("role") != "system" | |
| ] | |
| conversation_text = "\n".join([ | |
| f"{m.get('role', 'unknown').upper()}: {m.get('content', '')[:200]}" | |
| for m in conversation | |
| ]) | |
| # If no custom summarizer, use basic extraction | |
| if not summarizer_fn: | |
| return self._basic_summary(conversation) | |
| # Use custom summarizer | |
| try: | |
| summary = await summarizer_fn(conversation_text) | |
| return summary | |
| except Exception as e: | |
| logger.warning(f"⚠️ Summarization failed: {e}, using basic summary") | |
| return self._basic_summary(conversation) | |
| def _basic_summary(self, messages: List[Dict[str, str]]) -> str: | |
| """Basic summary extraction""" | |
| summaries = [] | |
| for msg in messages[-10:]: # Last 10 messages | |
| content = msg.get("content", "") | |
| if len(content) > 100: | |
| # Extract key points | |
| lines = content.split("\n") | |
| key_lines = [l for l in lines if len(l) > 20][:2] | |
| summaries.append(" ".join(key_lines)) | |
| else: | |
| summaries.append(content) | |
| return " | ".join(summaries) | |
| # ============================================================ | |
| # Message Window (sliding window) | |
| # ============================================================ | |
| class MessageWindow: | |
| """Sliding window for conversation history""" | |
| def __init__(self, window_size: int = 20, max_age_minutes: int = 120): | |
| self.window_size = window_size | |
| self.max_age = timedelta(minutes=max_age_minutes) | |
| self.messages: List[Dict[str, str]] = [] | |
| self.created_at = datetime.utcnow() | |
| def add_message(self, role: str, content: str) -> None: | |
| """Add message to window""" | |
| msg = { | |
| "role": role, | |
| "content": content, | |
| "timestamp": datetime.utcnow().isoformat(), | |
| } | |
| self.messages.append(msg) | |
| # Maintain window size | |
| if len(self.messages) > self.window_size: | |
| removed = self.messages.pop(0) | |
| logger.debug(f"📤 Removed old message from window") | |
| def get_messages(self, include_timestamps: bool = False) -> List[Dict[str, str]]: | |
| """Get messages in window""" | |
| messages = self.messages | |
| if not include_timestamps: | |
| # Remove timestamps for API calls | |
| messages = [ | |
| {k: v for k, v in m.items() if k != "timestamp"} | |
| for m in messages | |
| ] | |
| return messages | |
| def is_expired(self) -> bool: | |
| """Check if window has expired""" | |
| return datetime.utcnow() - self.created_at > self.max_age | |
| def clear(self) -> None: | |
| """Clear window""" | |
| self.messages = [] | |
| self.created_at = datetime.utcnow() | |
| def get_stats(self) -> Dict[str, int]: | |
| """Get window statistics""" | |
| return { | |
| "message_count": len(self.messages), | |
| "max_size": self.window_size, | |
| "age_seconds": int((datetime.utcnow() - self.created_at).total_seconds()), | |
| } | |
| # ============================================================ | |
| # Global Context Manager | |
| # ============================================================ | |
| _context_managers = {} | |
| _message_windows = {} | |
| def get_context_manager(model: str = "deepseek-chat") -> ContextManager: | |
| """Get or create context manager""" | |
| if model not in _context_managers: | |
| _context_managers[model] = ContextManager(model) | |
| return _context_managers[model] | |
| def get_message_window(user_id: str, create_if_missing: bool = True) -> Optional[MessageWindow]: | |
| """Get or create message window for user""" | |
| if user_id not in _message_windows: | |
| if create_if_missing: | |
| _message_windows[user_id] = MessageWindow() | |
| else: | |
| return None | |
| window = _message_windows[user_id] | |
| # Check if expired | |
| if window.is_expired(): | |
| logger.info(f"🗑️ Clearing expired window for user {user_id}") | |
| window.clear() | |
| return window | |
| def cleanup_expired_windows() -> int: | |
| """Clean up expired message windows""" | |
| expired = [ | |
| user_id for user_id, window in _message_windows.items() | |
| if window.is_expired() | |
| ] | |
| for user_id in expired: | |
| del _message_windows[user_id] | |
| if expired: | |
| logger.info(f"🧹 Cleaned up {len(expired)} expired windows") | |
| return len(expired) |