File size: 12,658 Bytes
909b197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
"""
Paysage d'Énergie Apprenant

La fonction d'énergie évalue la cohérence d'un état du système.
Elle doit être :
- Locale : les mises à jour ne dépendent que des voisins
- Apprenante : s'ajuste avec l'expérience
- Discriminative : basse énergie = cohérent, haute = incohérent
"""

import numpy as np
from numba import njit, prange
from typing import Dict, List, Tuple, Optional, Set
import logging

logger = logging.getLogger(__name__)

VECTOR_SIZE = 4096


def compute_local_hamming_energy_fast(
    state: np.ndarray,
    neighbors: np.ndarray,
    weights: np.ndarray,
) -> float:
    """
    Énergie locale de Hamming : somme pondérée des distances aux voisins.
    Version numpy vectorisée rapide.
    """
    N = neighbors.shape[0]
    if N == 0:
        return 0.0
    # XOR vectorisé : state ^ neighbors[i]
    xor_result = state.astype(np.uint8) ^ neighbors  # (N, 4096)
    # Somme par ligne
    dists = np.sum(xor_result, axis=1)
    return float(np.dot(weights, dists)) / N


def compute_bit_flip_delta_fast(
    state: np.ndarray,
    neighbors: np.ndarray,
    weights: np.ndarray,
    biases: np.ndarray,
) -> np.ndarray:
    """
    Calcule le delta d'énergie pour chaque flip de bit possible.
    Version numpy vectorisée.
    """
    N = neighbors.shape[0]
    if N == 0:
        return biases.copy()
    
    # Pour chaque bit, calculer le changement si on le flippe
    # dist_actuel = sum(weights[i] * (state[j] ^ neighbors[i,j]))
    # dist_nouveau = sum(weights[i] * ((1-state[j]) ^ neighbors[i,j]))
    
    state_expanded = state[np.newaxis, :]  # (1, 4096)
    # XOR actuel
    xor_current = state_expanded ^ neighbors  # (N, 4096)
    # XOR après flip
    flipped = 1 - state_expanded
    xor_flipped = flipped ^ neighbors  # (N, 4096)
    
    # Delta = new_dist - old_dist pour chaque bit
    delta_xor = xor_flipped.astype(np.float64) - xor_current.astype(np.float64)  # (N, 4096)
    # Pondération
    weights_col = weights[:, np.newaxis]  # (N, 1)
    weighted_delta = delta_xor * weights_col  # (N, 4096)
    
    deltas = np.sum(weighted_delta, axis=0)  # (4096,)
    deltas += biases * (2.0 * (1 - state.astype(np.float64)) - 1.0)
    
    return deltas


class EnergyLandscape:
    """
    Paysage d'énergie apprenant avec mises à jour locales.
    
    Components:
    - hamming_weight: poids de la distance de Hamming
    - association_weight: poids des associations Hebbian
    - structure_weight: poids de la cohérence structurelle (binding)
    - biases: biais par bit (apprenant)
    - local_weights: poids par voisin (apprenant)
    
    Apprentissage:
    - Renforcement des associations dans les états de basse énergie
    - Affaiblissement dans les états instables
    - Ajustement des biais locaux
    """
    
    def __init__(
        self,
        hamming_weight: float = 1.0,
        association_weight: float = 0.5,
        structure_weight: float = 0.3,
        bias_learning_rate: float = 0.01,
        association_decay: float = 0.99,
        energy_threshold_low: float = 500.0,
        energy_threshold_high: float = 1500.0,
    ):
        self.hamming_weight = hamming_weight
        self.association_weight = association_weight
        self.structure_weight = structure_weight
        self.bias_lr = bias_learning_rate
        self.assoc_decay = association_decay
        self.low_threshold = energy_threshold_low
        self.high_threshold = energy_threshold_high
        
        # Biases par bit (apprenant)
        self.biases = np.zeros(VECTOR_SIZE, dtype=np.float64)
        
        # Associations apprises : (id1, id2) -> strength
        self.associations: Dict[Tuple[int, int], float] = {}
        
        # Poids structurels pour binding
        self.structure_weights: Dict[Tuple[int, int], float] = {}
        
        # Historique pour apprentissage
        self.energy_history: List[float] = []
        self.min_energy_history: List[float] = []
        self.stable_states: List[np.ndarray] = []
        
        # Stats
        self.n_updates = 0
        self.total_energy = 0.0
    
    def compute_energy(
        self,
        state: np.ndarray,
        neighbor_vectors: np.ndarray,
        neighbor_ids: List[int],
        neighbor_metadata: Optional[List] = None,
    ) -> float:
        """
        Calcule l'énergie totale d'un état.
        
        Args:
            state: vecteur courant (4096,) uint8
            neighbor_vectors: (N, 4096) voisins actifs
            neighbor_ids: IDs des voisins
            neighbor_metadata: métadonnées optionnelles
        
        Returns:
            énergie scalaire (plus bas = plus cohérent)
        """
        N = len(neighbor_ids) if neighbor_ids else 0
        if N == 0:
            return float(np.sum(self.biases * (2.0 * state - 1.0)))
        
        # Poids des voisins (apprenant)
        weights = np.array([
            self.associations.get(
                tuple(sorted((-1, nid))), 0.5  # Default weight
            ) + 0.5
            for nid in neighbor_ids
        ], dtype=np.float64)
        
        # 1. Énergie de Hamming (vectorisée rapide)
        hamming_energy = compute_local_hamming_energy_fast(
            state, neighbor_vectors, weights
        )
        hamming_energy *= self.hamming_weight
        
        # 2. Énergie d'association (Hebbian-like)
        association_energy = 0.0
        if len(neighbor_ids) > 1:
            for i in range(len(neighbor_ids)):
                for j in range(i+1, len(neighbor_ids)):
                    pair = tuple(sorted((neighbor_ids[i], neighbor_ids[j])))
                    strength = self.associations.get(pair, 0.0)
                    # Distance entre voisins : proches = faible énergie
                    dist = np.sum(neighbor_vectors[i] != neighbor_vectors[j])
                    association_energy += strength * dist
        
        association_energy *= self.association_weight
        
        # 3. Énergie de structure (binding)
        structure_energy = 0.0
        if len(neighbor_ids) > 1:
            for i in range(len(neighbor_ids)):
                for j in range(i+1, len(neighbor_ids)):
                    pair = tuple(sorted((neighbor_ids[i], neighbor_ids[j])))
                    struct_weight = self.structure_weights.get(pair, 0.0)
                    # Corrélation comme proxy de structure
                    corr = np.mean(neighbor_vectors[i] == neighbor_vectors[j])
                    structure_energy += struct_weight * (1.0 - corr)
        
        structure_energy *= self.structure_weight
        
        # 4. Biais local
        bias_energy = float(np.sum(self.biases * (2.0 * state.astype(np.float64) - 1.0)))
        
        total = hamming_energy + association_energy + structure_energy + bias_energy
        
        self.energy_history.append(total)
        if len(self.energy_history) > 1000:
            self.energy_history = self.energy_history[-1000:]
        
        return total
    
    def get_bit_flip_deltas(
        self,
        state: np.ndarray,
        neighbor_vectors: np.ndarray,
        neighbor_ids: List[int],
    ) -> np.ndarray:
        """
        Calcule le delta d'énergie pour chaque bit flip possible.
        Utilisé par l'inférence pour trouver les meilleurs flips.
        
        Returns:
            deltas: (4096,) négatif = flip réduit l'énergie
        """
        N = len(neighbor_ids)
        if N == 0:
            # Juste les biais
            return self.biases.copy()
        
        weights = np.array([
            self.associations.get(tuple(sorted((-1, nid))), 0.5) + 0.5
            for nid in neighbor_ids
        ], dtype=np.float64)
        
        return compute_bit_flip_delta_fast(state, neighbor_vectors, weights, self.biases)
    
    def update_from_state(
        self,
        state: np.ndarray,
        neighbor_ids: List[int],
        energy: float,
        is_stable: bool = False,
    ):
        """
        Met à jour le paysage d'énergie à partir d'un état observé.
        Appelé après chaque itération d'inférence.
        
        Règles:
        - Basse énergie stable : renforce les associations
        - Haute énergie instable : affaiblit ou réorganise
        """
        self.n_updates += 1
        self.total_energy += energy
        
        if len(neighbor_ids) < 2:
            return
        
        # Normalise l'énergie relative
        if len(self.energy_history) > 20:
            recent_mean = np.mean(self.energy_history[-20:])
            recent_std = np.std(self.energy_history[-20:]) + 1e-8
            normalized_energy = (energy - recent_mean) / recent_std
        else:
            normalized_energy = 0.0
        
        # Détermine si l'état est cohérent ou incohérent
        is_coherent = normalized_energy < -0.5 or (energy < self.low_threshold and is_stable)
        is_incoherent = normalized_energy > 0.5 or energy > self.high_threshold
        
        # Met à jour les associations
        for i in range(len(neighbor_ids)):
            for j in range(i+1, len(neighbor_ids)):
                pair = tuple(sorted((neighbor_ids[i], neighbor_ids[j])))
                
                if is_coherent:
                    # Renforce l'association (Hebbian-like)
                    current = self.associations.get(pair, 0.0)
                    self.associations[pair] = min(1.0, current + self.bias_lr * (1.0 - current))
                elif is_incoherent:
                    # Affaiblit l'association (anti-Hebbian)
                    current = self.associations.get(pair, 0.0)
                    self.associations[pair] = max(-0.5, current - self.bias_lr * 2.0)
        
        # Met à jour les biais locaux
        state_float = state.astype(np.float64)
        if is_coherent:
            # Renforce les bits qui sont cohérents avec le voisinage
            self.biases += self.bias_lr * (2.0 * state_float - 1.0) * 0.1
        elif is_incoherent:
            # Affaiblit les bits qui sont incohérents
            self.biases -= self.bias_lr * (2.0 * state_float - 1.0) * 0.2
        
        # Normalise les biais pour éviter la dérive
        self.biases = np.clip(self.biases, -2.0, 2.0)
        
        # Si stable et basse énergie, mémorise l'état
        if is_stable and is_coherent:
            self.stable_states.append(state.copy())
            if len(self.stable_states) > 100:
                self.stable_states = self.stable_states[-100:]
        
        # Déduit les associations avec le temps
        if self.n_updates % 100 == 0:
            self._decay_associations()
    
    def update_structure_weights(
        self,
        bound_state: np.ndarray,
        component_ids: List[int],
        coherence: float,
    ):
        """
        Met à jour les poids structurels pour le binding.
        """
        for i in range(len(component_ids)):
            for j in range(i+1, len(component_ids)):
                pair = tuple(sorted((component_ids[i], component_ids[j])))
                current = self.structure_weights.get(pair, 0.0)
                # Mise à jour basée sur la cohérence
                if coherence > 0.7:
                    self.structure_weights[pair] = min(1.0, current + self.bias_lr)
                elif coherence < 0.3:
                    self.structure_weights[pair] = max(-0.5, current - self.bias_lr * 1.5)
    
    def _decay_associations(self):
        """Décroissance périodique des associations peu utilisées."""
        to_remove = []
        for pair, strength in self.associations.items():
            decayed = strength * self.assoc_decay
            if abs(decayed) < 0.01:
                to_remove.append(pair)
            else:
                self.associations[pair] = decayed
        
        for pair in to_remove:
            del self.associations[pair]
    
    def get_association_strength(self, id1: int, id2: int) -> float:
        """Retourne la force d'association entre deux vecteurs."""
        pair = tuple(sorted((id1, id2)))
        return self.associations.get(pair, 0.0)
    
    def get_stats(self) -> Dict:
        n_assoc = len(self.associations)
        n_struct = len(self.structure_weights)
        recent_energy = np.mean(self.energy_history[-100:]) if self.energy_history else 0.0
        
        return {
            'n_associations': n_assoc,
            'n_structure_weights': n_struct,
            'mean_bias': float(np.mean(np.abs(self.biases))),
            'recent_mean_energy': float(recent_energy),
            'n_updates': self.n_updates,
            'n_stable_states': len(self.stable_states),
        }