| """ |
| Memory Bank for storing and retrieving successful problem-solving cases |
| Uses LlamaIndex for RAG-based case retrieval |
| """ |
|
|
| import os |
| import json |
| from pathlib import Path |
| from typing import List, Dict, Optional |
| from llama_index.core import Document, VectorStoreIndex, StorageContext, load_index_from_storage |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding |
| from llama_index.core import Settings |
|
|
| _PKG_DIR = Path(__file__).resolve().parent |
| _PROJECT_ROOT = _PKG_DIR.parent.parent |
| DEFAULT_MEMORY_DIR = str(_PROJECT_ROOT / "memory_storage") |
|
|
|
|
| class MemoryBank: |
| """ |
| Memory Bank for storing successful problem-solving experiences |
| |
| Design inspired by Memento (https://arxiv.org/pdf/2508.16153): |
| - Episodic memory: Store past successful trajectories |
| - Case-based reasoning: Retrieve similar cases to guide current problem |
| - Non-parametric: No gradient updates, just memory read/write |
| """ |
| |
| def __init__(self, memory_dir: str = DEFAULT_MEMORY_DIR, embedding_model: str = "BAAI/bge-small-en-v1.5"): |
| """ |
| Initialize Memory Bank |
| |
| Args: |
| memory_dir: Directory to store memory index and cases |
| embedding_model: HuggingFace embedding model name or local path |
| """ |
| self.memory_dir = memory_dir |
| os.makedirs(memory_dir, exist_ok=True) |
| |
| self.cases_file = os.path.join(memory_dir, "cases.jsonl") |
| self.index_dir = os.path.join(memory_dir, "index") |
| |
| |
| |
| |
| |
| |
| os.environ.setdefault("HF_HUB_OFFLINE", "0") |
| |
| |
| is_local_path = os.path.isabs(embedding_model) or (os.path.sep in embedding_model and os.path.exists(embedding_model)) |
| |
| try: |
| |
| if is_local_path: |
| print(f"📁 Using local embedding model from: {embedding_model}") |
| Settings.embed_model = HuggingFaceEmbedding( |
| model_name=embedding_model, |
| cache_folder=os.path.expanduser("~/.cache/llama_index"), |
| trust_remote_code=False |
| ) |
| else: |
| |
| |
| print(f"🔍 Loading embedding model: {embedding_model}") |
| print(" (If you want to avoid Hugging Face downloads, set HF_HUB_OFFLINE=1 or use a local model path)") |
| Settings.embed_model = HuggingFaceEmbedding( |
| model_name=embedding_model, |
| cache_folder=os.path.expanduser("~/.cache/llama_index"), |
| trust_remote_code=False |
| ) |
| except Exception as e: |
| |
| print(f"⚠️ Warning: Failed to load embedding model '{embedding_model}': {e}") |
| print(" Attempting to use cached model only (setting HF_HUB_OFFLINE=1)...") |
| os.environ["HF_HUB_OFFLINE"] = "1" |
| try: |
| Settings.embed_model = HuggingFaceEmbedding( |
| model_name=embedding_model, |
| cache_folder=os.path.expanduser("~/.cache/llama_index"), |
| trust_remote_code=False |
| ) |
| print(" ✅ Using cached model") |
| except Exception as e2: |
| print(f"❌ Error: Could not load embedding model: {e2}") |
| print(" Please either:") |
| print(" 1. Download the model first: python -c \"from sentence_transformers import SentenceTransformer; SentenceTransformer('BAAI/bge-small-en-v1.5')\"") |
| print(" 2. Set HF_HUB_OFFLINE=1 and ensure the model is cached") |
| print(" 3. Use a local model path: --embedding_model /path/to/local/model") |
| raise |
| |
| Settings.chunk_size = 8192 |
| Settings.chunk_overlap = 0 |
| |
| |
| self.index = self._load_or_create_index() |
| self.case_count = self._count_cases() |
| |
| print(f"Memory Bank initialized with {self.case_count} cases") |
| |
| def _load_or_create_index(self): |
| """Load existing index or create new one""" |
| if os.path.exists(self.index_dir): |
| try: |
| storage_context = StorageContext.from_defaults(persist_dir=self.index_dir) |
| index = load_index_from_storage(storage_context) |
| print(f"Loaded existing memory index from {self.index_dir}") |
| return index |
| except: |
| print("Failed to load index, creating new one") |
| |
| |
| documents = [] |
| index = VectorStoreIndex.from_documents(documents) |
| os.makedirs(self.index_dir, exist_ok=True) |
| index.storage_context.persist(persist_dir=self.index_dir) |
| print(f"Created new memory index at {self.index_dir}") |
| return index |
| |
| def _count_cases(self) -> int: |
| """Count number of cases in memory""" |
| if not os.path.exists(self.cases_file): |
| return 0 |
| with open(self.cases_file, 'r') as f: |
| return sum(1 for _ in f) |
| |
| def add_case(self, problem_id: int, problem_desc: str, solution_code: str, |
| objective_value: float, is_correct: bool, metadata: Optional[Dict] = None): |
| """ |
| Add a successful case to memory |
| |
| Args: |
| problem_id: Problem ID |
| problem_desc: Problem description |
| solution_code: Solution code |
| objective_value: Computed objective value |
| is_correct: Whether the solution is correct |
| metadata: Additional metadata (model, debate_rounds, etc.) |
| """ |
| if not is_correct: |
| |
| return |
| |
| case = { |
| 'problem_id': problem_id, |
| 'description': problem_desc, |
| 'solution_code': solution_code, |
| 'objective_value': objective_value, |
| 'is_correct': is_correct, |
| 'metadata': metadata or {} |
| } |
| |
| |
| with open(self.cases_file, 'a', encoding='utf-8') as f: |
| f.write(json.dumps(case, ensure_ascii=False) + '\n') |
| |
| |
| |
| doc_text = f"""Problem: {problem_desc} |
| |
| Solution approach: |
| {solution_code[:500]}... |
| |
| Key features: |
| - Problem ID: {problem_id} |
| - Objective value: {objective_value} |
| - Status: Correct |
| """ |
| |
| doc = Document( |
| text=doc_text, |
| metadata={ |
| 'problem_id': problem_id, |
| 'objective_value': objective_value, |
| **case['metadata'] |
| } |
| ) |
| |
| |
| self.index.insert(doc) |
| self.index.storage_context.persist(persist_dir=self.index_dir) |
| |
| self.case_count += 1 |
| print(f"✅ Added case {problem_id} to memory (Total: {self.case_count})") |
| |
| def retrieve_similar_cases(self, query: str, top_k: int = 3, preferred_dataset: Optional[str] = None) -> List[Dict]: |
| """ |
| Retrieve similar cases from memory using RAG based on semantic similarity |
| |
| Args: |
| query: Query text (usually the problem description) |
| top_k: Number of similar cases to retrieve (0 = no retrieval) |
| preferred_dataset: Preferred dataset name to prioritize (optional) |
| |
| Returns: |
| List of similar cases with scores, sorted by semantic similarity |
| """ |
| if self.case_count == 0 or top_k <= 0: |
| return [] |
| |
| |
| retriever = self.index.as_retriever(similarity_top_k=top_k * 2 if preferred_dataset else top_k) |
| nodes = retriever.retrieve(query) |
| |
| |
| similar_cases = [] |
| seen_keys = set() |
| |
| |
| preferred_cases = [] |
| other_cases = [] |
| |
| for node in nodes: |
| problem_id = node.metadata.get('problem_id') |
| score = node.score |
| node_dataset = node.metadata.get('dataset', '') |
| |
| |
| case_key = (problem_id, node_dataset) |
| if case_key in seen_keys: |
| continue |
| |
| |
| case_data = None |
| if node_dataset: |
| |
| case_data = self._load_case_by_id_and_dataset(problem_id, node_dataset) |
| |
| if not case_data: |
| |
| case_data = self._load_case_by_id(problem_id) |
| |
| if case_data: |
| seen_keys.add(case_key) |
| case_item = { |
| 'case': case_data, |
| 'score': score, |
| 'text_preview': node.text[:200] |
| } |
| |
| |
| if preferred_dataset and node_dataset == preferred_dataset: |
| preferred_cases.append(case_item) |
| else: |
| other_cases.append(case_item) |
| |
| |
| similar_cases = preferred_cases + other_cases |
| |
| |
| return similar_cases[:top_k] |
| |
| def _load_case_by_id(self, problem_id: int) -> Optional[Dict]: |
| """Load a specific case by problem ID (returns first match)""" |
| if not os.path.exists(self.cases_file): |
| return None |
| |
| with open(self.cases_file, 'r', encoding='utf-8') as f: |
| for line in f: |
| case = json.loads(line) |
| if case['problem_id'] == problem_id: |
| return case |
| return None |
| |
| def _load_case_by_id_and_dataset(self, problem_id: int, dataset: str) -> Optional[Dict]: |
| """Load a specific case by problem ID and dataset""" |
| if not os.path.exists(self.cases_file): |
| return None |
| |
| with open(self.cases_file, 'r', encoding='utf-8') as f: |
| for line in f: |
| case = json.loads(line) |
| if case['problem_id'] == problem_id: |
| case_dataset = case.get('metadata', {}).get('dataset', '') |
| if case_dataset == dataset: |
| return case |
| return None |
| |
| def get_memory_stats(self) -> Dict: |
| """Get memory bank statistics""" |
| return { |
| 'total_cases': self.case_count, |
| 'memory_dir': self.memory_dir, |
| 'cases_file': self.cases_file, |
| 'index_dir': self.index_dir |
| } |
| |
| def format_retrieved_cases_for_prompt(self, cases: List[Dict]) -> str: |
| """ |
| Format retrieved cases for inclusion in LLM prompt |
| |
| Args: |
| cases: List of retrieved cases |
| |
| Returns: |
| Formatted string for prompt |
| """ |
| if not cases: |
| return "" |
| |
| prompt = "# Retrieved Similar Cases from Memory\n\n" |
| prompt += "The following successful cases from previous problems might be relevant:\n\n" |
| |
| for i, item in enumerate(cases, 1): |
| case = item['case'] |
| score = item['score'] |
| |
| prompt += f"## Case {i} (Similarity: {score:.3f})\n" |
| prompt += f"**Problem:** {case['description']}\n\n" |
| prompt += f"**Solution approach:**\n```python\n{case['solution_code']}\n```\n\n" |
| prompt += f"**Result:** Objective value = {case['objective_value']}, Status = Correct\n\n" |
| prompt += "---\n\n" |
| |
| return prompt |
|
|
|
|
|
|