File size: 17,804 Bytes
ab07cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
"""
ERA5 MCP Memory System
======================

Session-based memory with smart compression for conversation history.
Dataset cache persists across sessions, but conversations are fresh each session.
"""

from __future__ import annotations

import json
import logging
import os
import tiktoken
from dataclasses import asdict, dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional

from eurus.config import get_memory_dir, CONFIG

logger = logging.getLogger(__name__)


# ============================================================================
# CONFIGURATION
# ============================================================================

# Token limits for smart memory management
MAX_CONTEXT_TOKENS = 8000  # Max tokens to keep in active memory
COMPRESSION_THRESHOLD = 6000  # Start compressing when we hit this
SUMMARY_TARGET_TOKENS = 500  # Target tokens for compressed summary


# ============================================================================
# DATA STRUCTURES
# ============================================================================

@dataclass
class DatasetRecord:
    """Record of a downloaded dataset."""

    path: str
    variable: str
    query_type: str
    start_date: str
    end_date: str
    lat_bounds: tuple[float, float]
    lon_bounds: tuple[float, float]
    file_size_bytes: int
    download_timestamp: str
    shape: Optional[tuple[int, ...]] = None

    def to_dict(self) -> dict:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: dict) -> "DatasetRecord":
        if isinstance(data.get("lat_bounds"), list):
            data["lat_bounds"] = tuple(data["lat_bounds"])
        if isinstance(data.get("lon_bounds"), list):
            data["lon_bounds"] = tuple(data["lon_bounds"])
        if isinstance(data.get("shape"), list):
            data["shape"] = tuple(data["shape"])
        return cls(**data)


@dataclass
class Message:
    """A conversation message."""

    role: str
    content: str
    timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
    is_compressed: bool = False  # Flag for compressed summary messages

    def to_dict(self) -> dict:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: dict) -> "Message":
        valid_keys = {'role', 'content', 'timestamp', 'is_compressed'}
        filtered = {k: v for k, v in data.items() if k in valid_keys}
        return cls(**filtered)

    def to_langchain(self) -> dict:
        """Convert to LangChain message format."""
        return {"role": self.role, "content": self.content}


@dataclass
class AnalysisRecord:
    """Record of an analysis performed."""

    description: str
    code: str
    output: str
    timestamp: str
    datasets_used: List[str] = field(default_factory=list)
    plots_generated: List[str] = field(default_factory=list)

    def to_dict(self) -> dict:
        return asdict(self)

    @classmethod
    def from_dict(cls, data: dict) -> "AnalysisRecord":
        return cls(**data)


# ============================================================================
# TOKEN COUNTER
# ============================================================================

class TokenCounter:
    """Efficient token counting using tiktoken."""
    
    _encoder = None
    
    @classmethod
    def get_encoder(cls):
        if cls._encoder is None:
            try:
                cls._encoder = tiktoken.encoding_for_model("gpt-4")
            except Exception:
                cls._encoder = tiktoken.get_encoding("cl100k_base")
        return cls._encoder
    
    @classmethod
    def count(cls, text: str) -> int:
        """Count tokens in text."""
        try:
            return len(cls.get_encoder().encode(text))
        except Exception:
            # Fallback: rough estimate
            return len(text) // 4


# ============================================================================
# SMART CONVERSATION MEMORY
# ============================================================================

class SmartConversationMemory:
    """
    Session-based conversation memory with smart compression.
    
    Features:
    - Fresh start each session (no persistent history)
    - Automatic compression when context gets too long
    - Preserves recent messages in full, compresses older ones
    - Token-aware memory management
    """
    
    def __init__(self):
        self.messages: List[Message] = []
        self.compressed_summary: Optional[str] = None
        self._token_count = 0
        logger.info("SmartConversationMemory initialized (fresh session)")
    
    def add_message(self, role: str, content: str) -> Message:
        """Add a message and check if compression is needed."""
        msg = Message(role=role, content=content)
        self.messages.append(msg)
        
        # Update token count
        self._token_count += TokenCounter.count(content)
        
        # Check if we need to compress
        if self._token_count > COMPRESSION_THRESHOLD:
            self._compress_history()
        
        return msg
    
    def _compress_history(self) -> None:
        """Compress older messages into a summary."""
        if len(self.messages) < 6:
            return  # Not enough messages to compress
        
        # Keep the last 4 messages in full
        keep_count = 4
        to_compress = self.messages[:-keep_count]
        to_keep = self.messages[-keep_count:]
        
        if not to_compress:
            return
        
        # Create a concise summary of compressed messages
        summary_parts = []
        for msg in to_compress:
            role = msg.role.upper()
            # Truncate long content for summary
            content = msg.content[:200] + "..." if len(msg.content) > 200 else msg.content
            summary_parts.append(f"[{role}]: {content}")
        
        summary = "[Previous conversation summary]\n" + "\n".join(summary_parts)
        
        # Truncate summary to target token size
        while TokenCounter.count(summary) > SUMMARY_TARGET_TOKENS and summary:
            # Trim from the oldest messages in the summary
            lines = summary.split('\n')
            if len(lines) <= 2:
                break
            summary = lines[0] + '\n' + '\n'.join(lines[2:])

        summary_msg = Message(
            role="system",
            content=summary,
            is_compressed=True
        )
        
        self.messages = [summary_msg] + to_keep
        
        # Recalculate token count
        self._token_count = sum(
            TokenCounter.count(m.content) for m in self.messages
        )
        
        logger.info(f"Compressed {len(to_compress)} messages. Current tokens: {self._token_count}")
    
    def get_messages(self, n_messages: Optional[int] = None) -> List[Message]:
        """Get conversation messages."""
        if n_messages is None:
            return list(self.messages)
        return list(self.messages)[-n_messages:]
    
    def get_langchain_messages(self, n_messages: Optional[int] = None) -> List[dict]:
        """Get messages in LangChain format."""
        messages = self.get_messages(n_messages)
        return [m.to_langchain() for m in messages]
    
    def clear(self) -> None:
        """Clear all messages."""
        self.messages.clear()
        self.compressed_summary = None
        self._token_count = 0
        logger.info("Conversation memory cleared")
    
    def get_token_count(self) -> int:
        """Get current token count."""
        return self._token_count


# ============================================================================
# MEMORY MANAGER
# ============================================================================

class MemoryManager:
    """
    Manages memory for ERA5 MCP.

    Features:
    - Dataset cache registry (persists across sessions)
    - Session-based conversation history (fresh each restart)
    - Smart compression for long conversations
    - NO persistent conversation history to avoid stale context
    """

    def __init__(self, memory_dir: Optional[Path] = None, persist_conversations: bool = False):
        self.memory_dir = memory_dir or get_memory_dir()
        self.memory_dir.mkdir(parents=True, exist_ok=True)
        self.persist_conversations = persist_conversations

        # File paths (only datasets persist)
        self.datasets_file = self.memory_dir / "datasets.json"
        self.analyses_file = self.memory_dir / "analyses.json"

        # In-memory storage
        self.datasets: Dict[str, DatasetRecord] = {}
        self.analyses: List[AnalysisRecord] = []
        
        # Session-based conversation memory (FRESH each time!)
        self.conversation_memory = SmartConversationMemory()

        # Load persistent data (only datasets)
        self._load_datasets()
        self._load_analyses()

        logger.info(
            f"MemoryManager initialized: {len(self.datasets)} datasets, "
            f"FRESH conversation (session-based)"
        )

    # ========================================================================
    # PERSISTENCE (Datasets only)
    # ========================================================================

    def _load_datasets(self) -> None:
        """Load dataset registry from disk."""
        if self.datasets_file.exists():
            try:
                with open(self.datasets_file, "r") as f:
                    data = json.load(f)
                    for path, record_data in data.items():
                        self.datasets[path] = DatasetRecord.from_dict(record_data)
            except Exception as e:
                logger.warning(f"Failed to load datasets: {e}")

    def _save_datasets(self) -> None:
        """Save dataset registry to disk."""
        try:
            with open(self.datasets_file, "w") as f:
                json.dump({p: r.to_dict() for p, r in self.datasets.items()}, f, indent=2)
        except Exception as e:
            logger.error(f"Failed to save datasets: {e}")

    def _load_analyses(self) -> None:
        """Load analysis history from disk."""
        if self.analyses_file.exists():
            try:
                with open(self.analyses_file, "r") as f:
                    data = json.load(f)
                    self.analyses = [AnalysisRecord.from_dict(r) for r in data[-20:]]  # Keep last 20
            except Exception as e:
                logger.warning(f"Failed to load analyses: {e}")

    def _save_analyses(self) -> None:
        """Save analysis history to disk."""
        try:
            with open(self.analyses_file, "w") as f:
                json.dump([a.to_dict() for a in self.analyses[-20:]], f, indent=2)
        except Exception as e:
            logger.error(f"Failed to save analyses: {e}")

    # ========================================================================
    # DATASET MANAGEMENT
    # ========================================================================

    def register_dataset(
        self,
        path: str,
        variable: str,
        query_type: str,
        start_date: str,
        end_date: str,
        lat_bounds: tuple[float, float],
        lon_bounds: tuple[float, float],
        file_size_bytes: int = 0,
        shape: Optional[tuple[int, ...]] = None,
    ) -> DatasetRecord:
        """Register a downloaded dataset."""
        record = DatasetRecord(
            path=path,
            variable=variable,
            query_type=query_type,
            start_date=start_date,
            end_date=end_date,
            lat_bounds=lat_bounds,
            lon_bounds=lon_bounds,
            file_size_bytes=file_size_bytes,
            download_timestamp=datetime.now().isoformat(),
            shape=shape,
        )
        self.datasets[path] = record
        self._save_datasets()
        logger.info(f"Registered dataset: {path}")
        return record

    def get_dataset(self, path: str) -> Optional[DatasetRecord]:
        """Get dataset record by path."""
        return self.datasets.get(path)

    def list_datasets(self) -> str:
        """Return formatted list of cached datasets."""
        if not self.datasets:
            return "No datasets in cache."

        lines = ["Cached Datasets:", "=" * 70]
        for path, record in self.datasets.items():
            if os.path.exists(path):
                size_str = self._format_size(record.file_size_bytes)
                lines.append(
                    f"  {record.variable:5} | {record.start_date} to {record.end_date} | "
                    f"{record.query_type:8} | {size_str:>10}"
                )
                lines.append(f"        Path: {path}")
            else:
                lines.append(f"  [MISSING] {path}")

        return "\n".join(lines)

    def cleanup_missing_datasets(self) -> int:
        """Remove records for datasets that no longer exist."""
        missing = [p for p in self.datasets if not os.path.exists(p)]
        for path in missing:
            del self.datasets[path]
            logger.info(f"Removed missing dataset: {path}")
        if missing:
            self._save_datasets()
        return len(missing)

    # ========================================================================
    # CONVERSATION MANAGEMENT (Session-based)
    # ========================================================================

    def add_message(self, role: str, content: str) -> Message:
        """Add a message to conversation history."""
        return self.conversation_memory.add_message(role, content)

    def get_conversation_history(self, n_messages: Optional[int] = None) -> List[Message]:
        """Get recent conversation history."""
        return self.conversation_memory.get_messages(n_messages)

    def clear_conversation(self) -> None:
        """Clear conversation history."""
        self.conversation_memory.clear()
        logger.info("Conversation history cleared")

    def get_langchain_messages(self, n_messages: Optional[int] = None) -> List[dict]:
        """Get messages in LangChain format."""
        return self.conversation_memory.get_langchain_messages(n_messages)

    # Legacy property for compatibility
    @property
    def conversations(self) -> List[Message]:
        return self.conversation_memory.messages

    # ========================================================================
    # ANALYSIS TRACKING
    # ========================================================================

    def record_analysis(
        self,
        description: str,
        code: str,
        output: str,
        datasets_used: Optional[List[str]] = None,
        plots_generated: Optional[List[str]] = None,
    ) -> AnalysisRecord:
        """Record an analysis for history."""
        record = AnalysisRecord(
            description=description,
            code=code,
            output=output[:2000],  # Truncate long output
            timestamp=datetime.now().isoformat(),
            datasets_used=datasets_used or [],
            plots_generated=plots_generated or [],
        )
        self.analyses.append(record)
        self._save_analyses()
        return record

    def get_recent_analyses(self, n: int = 10) -> List[AnalysisRecord]:
        """Get recent analyses."""
        return self.analyses[-n:]

    # ========================================================================
    # CONTEXT SUMMARY
    # ========================================================================

    def get_context_summary(self) -> str:
        """Get a summary of current context for the agent."""
        lines = []

        # Token usage
        tokens = self.conversation_memory.get_token_count()
        if tokens > 0:
            lines.append(f"Session tokens: {tokens}/{MAX_CONTEXT_TOKENS}")

        # Recent conversation (brief)
        recent = self.get_conversation_history(3)
        if recent:
            lines.append("\nRecent in this session:")
            for msg in recent:
                preview = msg.content[:80] + "..." if len(msg.content) > 80 else msg.content
                lines.append(f"  [{msg.role}]: {preview}")

        # Available datasets
        valid_datasets = {p: r for p, r in self.datasets.items() if os.path.exists(p)}
        if valid_datasets:
            lines.append(f"\nCached Datasets ({len(valid_datasets)}):")
            for path, record in list(valid_datasets.items())[:5]:
                lines.append(f"  - {record.variable}: {record.start_date} to {record.end_date}")

        return "\n".join(lines) if lines else "Fresh session - no context yet."

    # ========================================================================
    # UTILITIES
    # ========================================================================

    @staticmethod
    def _format_size(size_bytes: int) -> str:
        """Format file size in human-readable format."""
        for unit in ["B", "KB", "MB", "GB"]:
            if size_bytes < 1024:
                return f"{size_bytes:.1f} {unit}"
            size_bytes /= 1024
        return f"{size_bytes:.1f} TB"


# ============================================================================
# GLOBAL INSTANCE
# ============================================================================

_memory_instance: Optional[MemoryManager] = None


def get_memory() -> MemoryManager:
    """Get the global memory manager instance."""
    global _memory_instance
    if _memory_instance is None:
        _memory_instance = MemoryManager()
    return _memory_instance


def reset_memory() -> None:
    """Reset the global memory instance (new session)."""
    global _memory_instance
    _memory_instance = None
    logger.info("Memory reset - next get_memory() will create fresh session")