File size: 7,770 Bytes
4700286
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""
task_memory.py β€” FAISS-based task memory for Phase 5 online learning.

Stores:
  - Task embeddings (512-dim float32) in FAISS IndexFlatL2
  - LoRA adapter weights on disk (one .pt file per task)
  - Task metadata (type, input snippet, score, timestamp) in JSON

Retrieval: top-k similar tasks by L2 distance β†’ return adapters + similarity weights.
Weights: w_i = 1 / (dist_i + epsilon) β€” inverse distance weighting.
"""
import json, time, os
from pathlib import Path
from typing import List, Tuple, Optional

import torch
import faiss
import numpy as np

from lora import LoRAAdapter


class TaskMemory:
    DIM       = 512
    RANK      = 4
    ALPHA     = 32.0  # MUST match online_learner.LORA_ALPHA β€” train/inference scale must agree
    N_LAYERS  = 8

    # IndexFlatL2 returns SQUARED L2 distance. Empirically measured on the 50
    # benchmark task embeddings (phase4_latest.pt, no adapter):
    #   same-task repeat (unperturbed embed, adapter=None both times): 0.0 (squared)
    #   nearest *different* benchmark task: >=70 (squared, L2 norm >=8.4)
    # Threshold gates out irrelevant adapters (huge margin on both sides).
    # Keys are stored UNPERTURBED (embed_task(adapter=None)), matching the
    # query-time embedding exactly β€” see online_learner.py.
    DIST_THRESHOLD = 5.0

    # A recurring task (same canonical prompt) re-embeds to a BIT-IDENTICAL
    # vector (deterministic forward pass, eval mode, no dropout) -> dist ==
    # 0.0 exactly. Different tasks measure >=70. 1e-3 sits in the enormous
    # gap between the two, so this can never misfire on a genuinely
    # different (but nearby) task.
    DEDUP_DIST_EPS = 1e-3

    SAVE_EVERY = 10   # flush FAISS index + metadata every N adds

    def __init__(self, store_dir: str, top_k: int = 3):
        self.store_dir    = Path(store_dir)
        self.store_dir.mkdir(parents=True, exist_ok=True)
        self.top_k        = top_k
        self.index        = faiss.IndexFlatL2(self.DIM)
        self.metadata: List[dict] = []
        self._pending     = 0   # adds since last _save_index
        self._load_existing()

    # ── Persistence ──────────────────────────────────────────────────────────
    def _meta_path(self) -> Path:
        return self.store_dir / 'metadata.json'

    def _adapter_path(self, task_id: int) -> Path:
        return self.store_dir / f'adapter_{task_id:06d}.pt'

    def _index_path(self) -> Path:
        return self.store_dir / 'faiss.index'

    def _load_existing(self):
        if self._meta_path().exists():
            self.metadata = json.loads(self._meta_path().read_text())
        if self._index_path().exists() and len(self.metadata) > 0:
            self.index = faiss.read_index(str(self._index_path()))
        if len(self.metadata) > 0:
            print(f'[TaskMemory] Loaded {len(self.metadata)} tasks from {self.store_dir}')

    def _save_index(self):
        faiss.write_index(self.index, str(self._index_path()))
        self._meta_path().write_text(json.dumps(self.metadata, indent=2))

    # ── Core operations ──────────────────────────────────────────────────────
    def add(self, embedding: torch.Tensor, adapter: LoRAAdapter,
            meta: dict) -> int:
        """
        Store a task embedding + adapter. Returns task_id.
        meta: arbitrary dict (task_type, input_snippet, score, etc.)

        Dedup: if an existing entry's embedding is within DEDUP_DIST_EPS
        (squared L2) of `embedding` β€” i.e. the SAME recurring task β€” overwrite
        that entry's adapter file + metadata in place instead of appending a
        new one. The FAISS index is left untouched (the embedding is
        unchanged for a recurring task). This bounds memory size at the
        number of UNIQUE tasks ever seen, regardless of how many times each
        one recurs.
        """
        emb_np = embedding.float().cpu().numpy().reshape(1, self.DIM)
        assert emb_np.shape == (1, self.DIM)

        if self.index.ntotal > 0:
            dists, ids = self.index.search(emb_np, 1)
            if dists[0][0] < self.DEDUP_DIST_EPS:
                task_id = int(ids[0][0])
                torch.save({'state_dict': adapter.state_dict(), 'task_id': task_id},
                           self._adapter_path(task_id))
                self.metadata[task_id].update(meta)
                self.metadata[task_id]['timestamp'] = time.time()
                self._pending += 1
                if self._pending >= self.SAVE_EVERY:
                    self._save_index()
                    self._pending = 0
                return task_id

        task_id = len(self.metadata)

        # Save adapter
        torch.save({'state_dict': adapter.state_dict(), 'task_id': task_id},
                   self._adapter_path(task_id))

        # Add to FAISS
        self.index.add(emb_np)

        # Store metadata
        self.metadata.append({
            'task_id':    task_id,
            'timestamp':  time.time(),
            **meta,
        })
        self._pending += 1
        if self._pending >= self.SAVE_EVERY:
            self._save_index()
            self._pending = 0
        return task_id

    def flush(self):
        """Force-write FAISS index + metadata regardless of pending count."""
        if self._pending > 0:
            self._save_index()
            self._pending = 0

    def retrieve(self, query_emb: torch.Tensor) -> Tuple[List[LoRAAdapter], List[float]]:
        """
        Find top-k similar tasks, load their adapters.
        Returns (adapters, weights) β€” weights are inverse-distance normalised.
        Returns ([], []) if memory is empty.
        """
        n = self.index.ntotal
        if n == 0:
            return [], []

        k = min(self.top_k, n)
        q = query_emb.float().cpu().numpy().reshape(1, self.DIM)
        dists, ids = self.index.search(q, k)   # [1, k]
        dists = dists[0].tolist()
        ids   = ids[0].tolist()

        adapters = []
        weights  = []
        eps = 1e-6
        for dist, tid in zip(dists, ids):
            if dist > self.DIST_THRESHOLD:
                continue   # too dissimilar β€” irrelevant adapter would inject noise
            path = self._adapter_path(tid)
            if not path.exists():
                continue
            ckpt    = torch.load(path, map_location='cpu', weights_only=True)
            adapter = LoRAAdapter(self.N_LAYERS, self.DIM, self.RANK, self.ALPHA)
            adapter.load_state_dict(ckpt['state_dict'])
            adapters.append(adapter)
            weights.append(1.0 / (dist + eps))

        return adapters, weights

    def retrieve_merged(self, query_emb: torch.Tensor) -> Optional[LoRAAdapter]:
        """
        Retrieve top-k adapters and return a single weighted-merged adapter.
        Returns None if memory is empty.
        """
        adapters, weights = self.retrieve(query_emb)
        if not adapters:
            return None
        if len(adapters) == 1:
            return adapters[0]
        return LoRAAdapter.merged(adapters, weights,
                                  self.N_LAYERS, self.DIM, self.RANK, self.ALPHA)

    def __len__(self) -> int:
        return len(self.metadata)

    def stats(self) -> dict:
        if not self.metadata:
            return {'n_tasks': 0}
        scores = [m.get('score', 0) for m in self.metadata]
        return {
            'n_tasks':   len(self.metadata),
            'avg_score': sum(scores) / len(scores),
            'max_score': max(scores),
            'min_score': min(scores),
        }