File size: 12,828 Bytes
96abbd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
"""
Memory Bank for storing and retrieving successful problem-solving cases
Uses LlamaIndex for RAG-based case retrieval
"""

import os
import json
from pathlib import Path
from typing import List, Dict, Optional
from llama_index.core import Document, VectorStoreIndex, StorageContext, load_index_from_storage
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import Settings

_PKG_DIR = Path(__file__).resolve().parent
_PROJECT_ROOT = _PKG_DIR.parent.parent
DEFAULT_MEMORY_DIR = str(_PROJECT_ROOT / "memory_storage")


class MemoryBank:
    """
    Memory Bank for storing successful problem-solving experiences
    
    Design inspired by Memento (https://arxiv.org/pdf/2508.16153):
    - Episodic memory: Store past successful trajectories
    - Case-based reasoning: Retrieve similar cases to guide current problem
    - Non-parametric: No gradient updates, just memory read/write
    """
    
    def __init__(self, memory_dir: str = DEFAULT_MEMORY_DIR, embedding_model: str = "BAAI/bge-small-en-v1.5"):
        """
        Initialize Memory Bank
        
        Args:
            memory_dir: Directory to store memory index and cases
            embedding_model: HuggingFace embedding model name or local path
        """
        self.memory_dir = memory_dir
        os.makedirs(memory_dir, exist_ok=True)
        
        self.cases_file = os.path.join(memory_dir, "cases.jsonl")
        self.index_dir = os.path.join(memory_dir, "index")
        
        # Configure embedding model with local caching
        # Set cache_folder to use llama_index's cache directory
        # Set trust_remote_code to False for security
        # If embedding_model is a local path, use it directly
        # Otherwise, try to use cached model to avoid network requests
        os.environ.setdefault("HF_HUB_OFFLINE", "0")  # Allow online access by default
        
        # Check if embedding_model is a local file path
        is_local_path = os.path.isabs(embedding_model) or (os.path.sep in embedding_model and os.path.exists(embedding_model))
        
        try:
            # If it's a local path, use it directly
            if is_local_path:
                print(f"📁 Using local embedding model from: {embedding_model}")
                Settings.embed_model = HuggingFaceEmbedding(
                    model_name=embedding_model,
                    cache_folder=os.path.expanduser("~/.cache/llama_index"),
                    trust_remote_code=False
                )
            else:
                # Try to load from cache first to avoid network requests
                # Set HF_HUB_OFFLINE=1 to force local-only mode
                print(f"🔍 Loading embedding model: {embedding_model}")
                print("   (If you want to avoid Hugging Face downloads, set HF_HUB_OFFLINE=1 or use a local model path)")
                Settings.embed_model = HuggingFaceEmbedding(
                    model_name=embedding_model,
                    cache_folder=os.path.expanduser("~/.cache/llama_index"),
                    trust_remote_code=False
                )
        except Exception as e:
            # If model loading fails, try to use cached model only
            print(f"⚠️  Warning: Failed to load embedding model '{embedding_model}': {e}")
            print("   Attempting to use cached model only (setting HF_HUB_OFFLINE=1)...")
            os.environ["HF_HUB_OFFLINE"] = "1"
            try:
                Settings.embed_model = HuggingFaceEmbedding(
                    model_name=embedding_model,
                    cache_folder=os.path.expanduser("~/.cache/llama_index"),
                    trust_remote_code=False
                )
                print("   ✅ Using cached model")
            except Exception as e2:
                print(f"❌ Error: Could not load embedding model: {e2}")
                print("   Please either:")
                print("   1. Download the model first: python -c \"from sentence_transformers import SentenceTransformer; SentenceTransformer('BAAI/bge-small-en-v1.5')\"")
                print("   2. Set HF_HUB_OFFLINE=1 and ensure the model is cached")
                print("   3. Use a local model path: --embedding_model /path/to/local/model")
                raise
        # Disable chunking to ensure one document = one node (no duplicates)
        Settings.chunk_size = 8192  # Large enough to never split
        Settings.chunk_overlap = 0
        
        # Load or create index
        self.index = self._load_or_create_index()
        self.case_count = self._count_cases()
        
        print(f"Memory Bank initialized with {self.case_count} cases")
    
    def _load_or_create_index(self):
        """Load existing index or create new one"""
        if os.path.exists(self.index_dir):
            try:
                storage_context = StorageContext.from_defaults(persist_dir=self.index_dir)
                index = load_index_from_storage(storage_context)
                print(f"Loaded existing memory index from {self.index_dir}")
                return index
            except:
                print("Failed to load index, creating new one")
        
        # Create new empty index
        documents = []
        index = VectorStoreIndex.from_documents(documents)
        os.makedirs(self.index_dir, exist_ok=True)
        index.storage_context.persist(persist_dir=self.index_dir)
        print(f"Created new memory index at {self.index_dir}")
        return index
    
    def _count_cases(self) -> int:
        """Count number of cases in memory"""
        if not os.path.exists(self.cases_file):
            return 0
        with open(self.cases_file, 'r') as f:
            return sum(1 for _ in f)
    
    def add_case(self, problem_id: int, problem_desc: str, solution_code: str, 
                 objective_value: float, is_correct: bool, metadata: Optional[Dict] = None):
        """
        Add a successful case to memory
        
        Args:
            problem_id: Problem ID
            problem_desc: Problem description
            solution_code: Solution code
            objective_value: Computed objective value
            is_correct: Whether the solution is correct
            metadata: Additional metadata (model, debate_rounds, etc.)
        """
        if not is_correct:
            # Only store successful cases
            return
        
        case = {
            'problem_id': problem_id,
            'description': problem_desc,
            'solution_code': solution_code,
            'objective_value': objective_value,
            'is_correct': is_correct,
            'metadata': metadata or {}
        }
        
        # Write to cases file
        with open(self.cases_file, 'a', encoding='utf-8') as f:
            f.write(json.dumps(case, ensure_ascii=False) + '\n')
        
        # Create document for indexing
        # Combine description and key solution insights for better retrieval
        doc_text = f"""Problem: {problem_desc}

Solution approach:
{solution_code[:500]}...

Key features:
- Problem ID: {problem_id}
- Objective value: {objective_value}
- Status: Correct
"""
        
        doc = Document(
            text=doc_text,
            metadata={
                'problem_id': problem_id,
                'objective_value': objective_value,
                **case['metadata']
            }
        )
        
        # Add to index
        self.index.insert(doc)
        self.index.storage_context.persist(persist_dir=self.index_dir)
        
        self.case_count += 1
        print(f"✅ Added case {problem_id} to memory (Total: {self.case_count})")
    
    def retrieve_similar_cases(self, query: str, top_k: int = 3, preferred_dataset: Optional[str] = None) -> List[Dict]:
        """
        Retrieve similar cases from memory using RAG based on semantic similarity
        
        Args:
            query: Query text (usually the problem description)
            top_k: Number of similar cases to retrieve (0 = no retrieval)
            preferred_dataset: Preferred dataset name to prioritize (optional)
        
        Returns:
            List of similar cases with scores, sorted by semantic similarity
        """
        if self.case_count == 0 or top_k <= 0:
            return []
        
        # Query the index - purely based on semantic similarity
        retriever = self.index.as_retriever(similarity_top_k=top_k * 2 if preferred_dataset else top_k)
        nodes = retriever.retrieve(query)
        
        # Load corresponding cases from cases.jsonl based on semantic similarity
        similar_cases = []
        seen_keys = set()  # Track which (problem_id, dataset) combinations we've added
        
        # If preferred_dataset is specified, prioritize those cases
        preferred_cases = []
        other_cases = []
        
        for node in nodes:
            problem_id = node.metadata.get('problem_id')
            score = node.score
            node_dataset = node.metadata.get('dataset', '')
            
            # Build key for deduplication
            case_key = (problem_id, node_dataset)
            if case_key in seen_keys:
                continue
            
            # Load the case - use dataset from node metadata to get the exact match
            case_data = None
            if node_dataset:
                # Try to load by problem_id and dataset (more precise)
                case_data = self._load_case_by_id_and_dataset(problem_id, node_dataset)
            
            if not case_data:
                # Fallback: try to load by problem_id only
                case_data = self._load_case_by_id(problem_id)
            
            if case_data:
                seen_keys.add(case_key)
                case_item = {
                    'case': case_data,
                    'score': score,
                    'text_preview': node.text[:200]
                }
                
                # Separate preferred dataset cases from others
                if preferred_dataset and node_dataset == preferred_dataset:
                    preferred_cases.append(case_item)
                else:
                    other_cases.append(case_item)
        
        # Combine: preferred cases first, then others, all sorted by similarity score
        similar_cases = preferred_cases + other_cases
        
        # Return top_k results
        return similar_cases[:top_k]
    
    def _load_case_by_id(self, problem_id: int) -> Optional[Dict]:
        """Load a specific case by problem ID (returns first match)"""
        if not os.path.exists(self.cases_file):
            return None
        
        with open(self.cases_file, 'r', encoding='utf-8') as f:
            for line in f:
                case = json.loads(line)
                if case['problem_id'] == problem_id:
                    return case
        return None
    
    def _load_case_by_id_and_dataset(self, problem_id: int, dataset: str) -> Optional[Dict]:
        """Load a specific case by problem ID and dataset"""
        if not os.path.exists(self.cases_file):
            return None
        
        with open(self.cases_file, 'r', encoding='utf-8') as f:
            for line in f:
                case = json.loads(line)
                if case['problem_id'] == problem_id:
                    case_dataset = case.get('metadata', {}).get('dataset', '')
                    if case_dataset == dataset:
                        return case
        return None
    
    def get_memory_stats(self) -> Dict:
        """Get memory bank statistics"""
        return {
            'total_cases': self.case_count,
            'memory_dir': self.memory_dir,
            'cases_file': self.cases_file,
            'index_dir': self.index_dir
        }
    
    def format_retrieved_cases_for_prompt(self, cases: List[Dict]) -> str:
        """
        Format retrieved cases for inclusion in LLM prompt
        
        Args:
            cases: List of retrieved cases
        
        Returns:
            Formatted string for prompt
        """
        if not cases:
            return ""
        
        prompt = "# Retrieved Similar Cases from Memory\n\n"
        prompt += "The following successful cases from previous problems might be relevant:\n\n"
        
        for i, item in enumerate(cases, 1):
            case = item['case']
            score = item['score']
            
            prompt += f"## Case {i} (Similarity: {score:.3f})\n"
            prompt += f"**Problem:** {case['description']}\n\n"
            prompt += f"**Solution approach:**\n```python\n{case['solution_code']}\n```\n\n"
            prompt += f"**Result:** Objective value = {case['objective_value']}, Status = Correct\n\n"
            prompt += "---\n\n"
        
        return prompt