File size: 1,941 Bytes
63dd1f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import numpy as np
import pickle
from pathlib import Path
from typing import List, Tuple, Dict, Optional

class HelixMemory:
    def __init__(self, storage_path: Path):
        self.storage_path = storage_path
        self._load()

    def _load(self) -> None:
        if self.storage_path.exists():
            with self.storage_path.open("rb") as f:
                self.entries: List[Tuple[int, np.ndarray, int, dict]] = pickle.load(f)
        else:
            self.entries = []

    def _save(self) -> None:
        self.storage_path.parent.mkdir(parents=True, exist_ok=True)
        with self.storage_path.open("wb") as f:
            pickle.dump(self.entries, f)

    def add(self, hv: np.ndarray, meta: Optional[dict] = None) -> None:
        meta = meta or {}
        for i, (code, proto, cnt, old_meta) in enumerate(self.entries):
            sim = np.mean(hv == proto)
            if sim > 0.85:
                merged = {**old_meta, **meta}
                self.entries[i] = (code, proto, cnt + 1, merged)
                self._save()
                return

        new_code = len(self.entries)
        self.entries.append((new_code, hv.copy(), 1, meta))
        self._save()

    def retrieve(self, hv: np.ndarray, top_k: int = 3) -> List[Tuple[np.ndarray, dict]]:
        sims = [(np.mean(hv == proto), proto, meta) for _, proto, _, meta in self.entries]
        sims.sort(key=lambda x: x[0], reverse=True)
        return [(proto, meta) for _, proto, meta in sims[:top_k]]

    def reconstruct(self, code_id: int) -> np.ndarray:
        for cid, proto, _, _ in self.entries:
            if cid == code_id:
                return proto.copy()
        raise KeyError(f"Helix code {code_id} not found")

    def most_uncertain(self) -> Tuple[int, np.ndarray]:
        if not self.entries:
            raise RuntimeError("HelixMemory empty")
        entry = min(self.entries, key=lambda e: e[2])
        return entry[0], entry[1]