File size: 10,070 Bytes
4c9881b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
# ============================================================
# app/core/context_manager.py - Context Window Management
# ============================================================

import logging
import json
from typing import List, Dict, Optional, Tuple
from datetime import datetime, timedelta
import tiktoken

from app.core.error_handling import LojizError

logger = logging.getLogger(__name__)

# ============================================================
# Token Counter
# ============================================================

class TokenCounter:
    """Count tokens using tiktoken (OpenAI's tokenizer)"""
    
    def __init__(self, encoding_name: str = "cl100k_base"):
        try:
            self.encoding = tiktoken.get_encoding(encoding_name)
        except Exception as e:
            logger.warning(f"⚠️ Failed to load tiktoken: {e}, using fallback")
            self.encoding = None
    
    def count_tokens(self, text: str) -> int:
        """Count tokens in text"""
        if not self.encoding:
            # Fallback: rough estimate (4 chars ≈ 1 token)
            return len(text) // 4
        
        return len(self.encoding.encode(text))
    
    def count_messages_tokens(self, messages: List[Dict[str, str]]) -> int:
        """Count tokens in message list"""
        total = 0
        
        for msg in messages:
            # Add overhead per message (role + content markers)
            total += 4
            
            if msg.get("role"):
                total += self.count_tokens(msg["role"])
            if msg.get("content"):
                total += self.count_tokens(msg["content"])
        
        # Add overhead for message framing
        total += 2
        
        return total

# ============================================================
# Context Manager
# ============================================================

class ContextManager:
    """Manage context window to prevent overflow"""
    
    # Model limits (tokens)
    MODEL_LIMITS = {
        "deepseek-chat": 4096,
        "mistralai/mistral-7b-instruct": 8192,
        "xai-org/grok-beta": 8192,
        "meta-llama/llama-2-70b-chat": 4096,
    }
    
    # Reserve space for response
    RESPONSE_RESERVE = 600
    
    def __init__(self, model: str = "deepseek-chat"):
        self.model = model
        self.token_counter = TokenCounter()
        self.context_limit = self.MODEL_LIMITS.get(model, 4096)
        self.usable_limit = self.context_limit - self.RESPONSE_RESERVE
    
    def get_available_context(self, current_tokens: int) -> int:
        """Get available context space"""
        return max(0, self.usable_limit - current_tokens)
    
    def is_context_full(self, messages: List[Dict[str, str]]) -> bool:
        """Check if context is full"""
        tokens = self.token_counter.count_messages_tokens(messages)
        return tokens >= self.usable_limit
    
    async def manage_context(
        self,
        messages: List[Dict[str, str]],
        max_history_messages: int = 20,
    ) -> List[Dict[str, str]]:
        """
        Manage context by summarizing if needed
        
        Strategy:
        1. Keep system message
        2. Keep last message (current user input)
        3. Summarize older messages if needed
        """
        
        if not messages:
            return messages
        
        tokens = self.token_counter.count_messages_tokens(messages)
        
        if tokens <= self.usable_limit:
            logger.debug(
                f"✅ Context OK: {tokens}/{self.usable_limit} tokens, "
                f"{len(messages)} messages"
            )
            return messages
        
        logger.warning(
            f"⚠️ Context overflow: {tokens}/{self.usable_limit} tokens, "
            f"{len(messages)} messages"
        )
        
        # Keep system message + last message, summarize the rest
        system_msg = [m for m in messages if m.get("role") == "system"]
        user_msg = [m for m in messages if m.get("role") == "user"][-1:] if messages else []
        
        history = [
            m for m in messages 
            if m.get("role") not in ["system"] and m not in user_msg
        ]
        
        # Trim history to most recent max_history_messages
        if len(history) > max_history_messages:
            logger.info(f"📦 Trimming history from {len(history)} to {max_history_messages}")
            history = history[-max_history_messages:]
        
        # Rebuild messages
        managed_messages = system_msg + history + user_msg
        
        final_tokens = self.token_counter.count_messages_tokens(managed_messages)
        logger.info(
            f"📦 Context managed: {final_tokens}/{self.usable_limit} tokens, "
            f"{len(managed_messages)} messages"
        )
        
        return managed_messages
    
    async def summarize_conversation(
        self,
        messages: List[Dict[str, str]],
        summarizer_fn = None,
    ) -> str:
        """
        Summarize conversation history
        
        Args:
            messages: Message history
            summarizer_fn: Optional async function to summarize
        
        Returns:
            Summary of conversation
        """
        
        if not messages or len(messages) < 3:
            return ""
        
        # Extract conversation content (skip system message)
        conversation = [
            m for m in messages
            if m.get("role") != "system"
        ]
        
        conversation_text = "\n".join([
            f"{m.get('role', 'unknown').upper()}: {m.get('content', '')[:200]}"
            for m in conversation
        ])
        
        # If no custom summarizer, use basic extraction
        if not summarizer_fn:
            return self._basic_summary(conversation)
        
        # Use custom summarizer
        try:
            summary = await summarizer_fn(conversation_text)
            return summary
        except Exception as e:
            logger.warning(f"⚠️ Summarization failed: {e}, using basic summary")
            return self._basic_summary(conversation)
    
    def _basic_summary(self, messages: List[Dict[str, str]]) -> str:
        """Basic summary extraction"""
        
        summaries = []
        
        for msg in messages[-10:]:  # Last 10 messages
            content = msg.get("content", "")
            if len(content) > 100:
                # Extract key points
                lines = content.split("\n")
                key_lines = [l for l in lines if len(l) > 20][:2]
                summaries.append(" ".join(key_lines))
            else:
                summaries.append(content)
        
        return " | ".join(summaries)

# ============================================================
# Message Window (sliding window)
# ============================================================

class MessageWindow:
    """Sliding window for conversation history"""
    
    def __init__(self, window_size: int = 20, max_age_minutes: int = 120):
        self.window_size = window_size
        self.max_age = timedelta(minutes=max_age_minutes)
        self.messages: List[Dict[str, str]] = []
        self.created_at = datetime.utcnow()
    
    def add_message(self, role: str, content: str) -> None:
        """Add message to window"""
        msg = {
            "role": role,
            "content": content,
            "timestamp": datetime.utcnow().isoformat(),
        }
        self.messages.append(msg)
        
        # Maintain window size
        if len(self.messages) > self.window_size:
            removed = self.messages.pop(0)
            logger.debug(f"📤 Removed old message from window")
    
    def get_messages(self, include_timestamps: bool = False) -> List[Dict[str, str]]:
        """Get messages in window"""
        messages = self.messages
        
        if not include_timestamps:
            # Remove timestamps for API calls
            messages = [
                {k: v for k, v in m.items() if k != "timestamp"}
                for m in messages
            ]
        
        return messages
    
    def is_expired(self) -> bool:
        """Check if window has expired"""
        return datetime.utcnow() - self.created_at > self.max_age
    
    def clear(self) -> None:
        """Clear window"""
        self.messages = []
        self.created_at = datetime.utcnow()
    
    def get_stats(self) -> Dict[str, int]:
        """Get window statistics"""
        return {
            "message_count": len(self.messages),
            "max_size": self.window_size,
            "age_seconds": int((datetime.utcnow() - self.created_at).total_seconds()),
        }

# ============================================================
# Global Context Manager
# ============================================================

_context_managers = {}
_message_windows = {}

def get_context_manager(model: str = "deepseek-chat") -> ContextManager:
    """Get or create context manager"""
    if model not in _context_managers:
        _context_managers[model] = ContextManager(model)
    return _context_managers[model]

def get_message_window(user_id: str, create_if_missing: bool = True) -> Optional[MessageWindow]:
    """Get or create message window for user"""
    if user_id not in _message_windows:
        if create_if_missing:
            _message_windows[user_id] = MessageWindow()
        else:
            return None
    
    window = _message_windows[user_id]
    
    # Check if expired
    if window.is_expired():
        logger.info(f"🗑️  Clearing expired window for user {user_id}")
        window.clear()
    
    return window

def cleanup_expired_windows() -> int:
    """Clean up expired message windows"""
    expired = [
        user_id for user_id, window in _message_windows.items()
        if window.is_expired()
    ]
    
    for user_id in expired:
        del _message_windows[user_id]
    
    if expired:
        logger.info(f"🧹 Cleaned up {len(expired)} expired windows")
    
    return len(expired)