| """ |
| task_memory.py β FAISS-based task memory for Phase 5 online learning. |
| |
| Stores: |
| - Task embeddings (512-dim float32) in FAISS IndexFlatL2 |
| - LoRA adapter weights on disk (one .pt file per task) |
| - Task metadata (type, input snippet, score, timestamp) in JSON |
| |
| Retrieval: top-k similar tasks by L2 distance β return adapters + similarity weights. |
| Weights: w_i = 1 / (dist_i + epsilon) β inverse distance weighting. |
| """ |
| import json, time, os |
| from pathlib import Path |
| from typing import List, Tuple, Optional |
|
|
| import torch |
| import faiss |
| import numpy as np |
|
|
| from lora import LoRAAdapter |
|
|
|
|
| class TaskMemory: |
| DIM = 512 |
| RANK = 4 |
| ALPHA = 32.0 |
| N_LAYERS = 8 |
|
|
| |
| |
| |
| |
| |
| |
| |
| DIST_THRESHOLD = 5.0 |
|
|
| |
| |
| |
| |
| |
| DEDUP_DIST_EPS = 1e-3 |
|
|
| SAVE_EVERY = 10 |
|
|
| def __init__(self, store_dir: str, top_k: int = 3): |
| self.store_dir = Path(store_dir) |
| self.store_dir.mkdir(parents=True, exist_ok=True) |
| self.top_k = top_k |
| self.index = faiss.IndexFlatL2(self.DIM) |
| self.metadata: List[dict] = [] |
| self._pending = 0 |
| self._load_existing() |
|
|
| |
| def _meta_path(self) -> Path: |
| return self.store_dir / 'metadata.json' |
|
|
| def _adapter_path(self, task_id: int) -> Path: |
| return self.store_dir / f'adapter_{task_id:06d}.pt' |
|
|
| def _index_path(self) -> Path: |
| return self.store_dir / 'faiss.index' |
|
|
| def _load_existing(self): |
| if self._meta_path().exists(): |
| self.metadata = json.loads(self._meta_path().read_text()) |
| if self._index_path().exists() and len(self.metadata) > 0: |
| self.index = faiss.read_index(str(self._index_path())) |
| if len(self.metadata) > 0: |
| print(f'[TaskMemory] Loaded {len(self.metadata)} tasks from {self.store_dir}') |
|
|
| def _save_index(self): |
| faiss.write_index(self.index, str(self._index_path())) |
| self._meta_path().write_text(json.dumps(self.metadata, indent=2)) |
|
|
| |
| def add(self, embedding: torch.Tensor, adapter: LoRAAdapter, |
| meta: dict) -> int: |
| """ |
| Store a task embedding + adapter. Returns task_id. |
| meta: arbitrary dict (task_type, input_snippet, score, etc.) |
| |
| Dedup: if an existing entry's embedding is within DEDUP_DIST_EPS |
| (squared L2) of `embedding` β i.e. the SAME recurring task β overwrite |
| that entry's adapter file + metadata in place instead of appending a |
| new one. The FAISS index is left untouched (the embedding is |
| unchanged for a recurring task). This bounds memory size at the |
| number of UNIQUE tasks ever seen, regardless of how many times each |
| one recurs. |
| """ |
| emb_np = embedding.float().cpu().numpy().reshape(1, self.DIM) |
| assert emb_np.shape == (1, self.DIM) |
|
|
| if self.index.ntotal > 0: |
| dists, ids = self.index.search(emb_np, 1) |
| if dists[0][0] < self.DEDUP_DIST_EPS: |
| task_id = int(ids[0][0]) |
| torch.save({'state_dict': adapter.state_dict(), 'task_id': task_id}, |
| self._adapter_path(task_id)) |
| self.metadata[task_id].update(meta) |
| self.metadata[task_id]['timestamp'] = time.time() |
| self._pending += 1 |
| if self._pending >= self.SAVE_EVERY: |
| self._save_index() |
| self._pending = 0 |
| return task_id |
|
|
| task_id = len(self.metadata) |
|
|
| |
| torch.save({'state_dict': adapter.state_dict(), 'task_id': task_id}, |
| self._adapter_path(task_id)) |
|
|
| |
| self.index.add(emb_np) |
|
|
| |
| self.metadata.append({ |
| 'task_id': task_id, |
| 'timestamp': time.time(), |
| **meta, |
| }) |
| self._pending += 1 |
| if self._pending >= self.SAVE_EVERY: |
| self._save_index() |
| self._pending = 0 |
| return task_id |
|
|
| def flush(self): |
| """Force-write FAISS index + metadata regardless of pending count.""" |
| if self._pending > 0: |
| self._save_index() |
| self._pending = 0 |
|
|
| def retrieve(self, query_emb: torch.Tensor) -> Tuple[List[LoRAAdapter], List[float]]: |
| """ |
| Find top-k similar tasks, load their adapters. |
| Returns (adapters, weights) β weights are inverse-distance normalised. |
| Returns ([], []) if memory is empty. |
| """ |
| n = self.index.ntotal |
| if n == 0: |
| return [], [] |
|
|
| k = min(self.top_k, n) |
| q = query_emb.float().cpu().numpy().reshape(1, self.DIM) |
| dists, ids = self.index.search(q, k) |
| dists = dists[0].tolist() |
| ids = ids[0].tolist() |
|
|
| adapters = [] |
| weights = [] |
| eps = 1e-6 |
| for dist, tid in zip(dists, ids): |
| if dist > self.DIST_THRESHOLD: |
| continue |
| path = self._adapter_path(tid) |
| if not path.exists(): |
| continue |
| ckpt = torch.load(path, map_location='cpu', weights_only=True) |
| adapter = LoRAAdapter(self.N_LAYERS, self.DIM, self.RANK, self.ALPHA) |
| adapter.load_state_dict(ckpt['state_dict']) |
| adapters.append(adapter) |
| weights.append(1.0 / (dist + eps)) |
|
|
| return adapters, weights |
|
|
| def retrieve_merged(self, query_emb: torch.Tensor) -> Optional[LoRAAdapter]: |
| """ |
| Retrieve top-k adapters and return a single weighted-merged adapter. |
| Returns None if memory is empty. |
| """ |
| adapters, weights = self.retrieve(query_emb) |
| if not adapters: |
| return None |
| if len(adapters) == 1: |
| return adapters[0] |
| return LoRAAdapter.merged(adapters, weights, |
| self.N_LAYERS, self.DIM, self.RANK, self.ALPHA) |
|
|
| def __len__(self) -> int: |
| return len(self.metadata) |
|
|
| def stats(self) -> dict: |
| if not self.metadata: |
| return {'n_tasks': 0} |
| scores = [m.get('score', 0) for m in self.metadata] |
| return { |
| 'n_tasks': len(self.metadata), |
| 'avg_score': sum(scores) / len(scores), |
| 'max_score': max(scores), |
| 'min_score': min(scores), |
| } |
|
|