File size: 16,555 Bytes
22007d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
"""
Moteur d'Inférence avec Apprentissage en Ligne

L'inférence minimise l'énergie par descente stochastique locale.
À chaque itération :
1. Calcule les voisins via le routeur
2. Évalue l'énergie du paysage
3. Sélectionne les flips de bits qui réduisent l'énergie
4. Met à jour les associations (apprentissage en ligne)
5. Détecte motifs pour abstraction

La minimisation est un processus de Monte Carlo / Hopfield-like
mais avec mémoire adaptative et apprentissage continu.
"""

import numpy as np
from numba import njit, prange
from typing import List, Tuple, Dict, Optional, Callable
import logging
import time

logger = logging.getLogger(__name__)

VECTOR_SIZE = 4096


@njit(cache=True)
def random_flip_batch(state: np.ndarray, n_flips: int, rng_seed: int) -> np.ndarray:
    """Flip aléatoire de n_flips bits."""
    np.random.seed(rng_seed)
    new_state = state.copy()
    flip_indices = np.random.choice(VECTOR_SIZE, size=n_flips, replace=False)
    for idx in flip_indices:
        new_state[idx] = 1 - new_state[idx]
    return new_state


@njit(cache=True)
def hamming_distance(a: np.ndarray, b: np.ndarray) -> int:
    """Distance de Hamming entre deux vecteurs binaires."""
    dist = 0
    for i in range(len(a)):
        dist += a[i] ^ b[i]
    return dist


class InferenceResult:
    """Résultat d'une inférence complète."""
    
    def __init__(self):
        self.initial_state: Optional[np.ndarray] = None
        self.final_state: Optional[np.ndarray] = None
        self.energy_trajectory: List[float] = []
        self.neighbor_trajectory: List[List[Tuple[int, float]]] = []
        self.n_iterations = 0
        self.converged = False
        self.creation_events: List[Dict] = []
        self.learning_events: List[Dict] = []
        self.execution_time_ms = 0.0


class InferenceEngine:
    """
    Moteur d'inférence par minimisation d'énergie avec apprentissage en ligne.
    
    Paramètres clés:
    - temperature: contrôle le bruit dans la descente (plus haut = plus exploratoire)
    - max_iterations: nombre max d'itérations de minimisation
    - energy_tolerance: seuil de convergence
    - learning_rate: vitesse d'apprentissage pendant l'inférence
    """
    
    def __init__(
        self,
        temperature: float = 0.5,
        cooling_rate: float = 0.995,
        max_iterations: int = 100,
        energy_tolerance: float = 1.0,
        learning_rate: float = 0.01,
        online_learning: bool = True,
        pattern_detection_interval: int = 10,
        convergence_window: int = 5,
        early_stop_threshold: float = 0.001,
    ):
        self.temperature = temperature
        self.cooling_rate = cooling_rate
        self.max_iterations = max_iterations
        self.energy_tolerance = energy_tolerance
        self.learning_rate = learning_rate
        self.online_learning = online_learning
        self.pattern_detection_interval = pattern_detection_interval
        self.convergence_window = convergence_window
        self.early_stop_threshold = early_stop_threshold
        
        # Stats
        self.total_inferences = 0
        self.total_iterations = 0
        self.total_converged = 0
        self.avg_inference_time_ms = 0.0
    
    def infer(
        self,
        initial_state: np.ndarray,
        memory_table,
        router,
        energy_landscape,
        binder,
        k_neighbors: int = 10,
        external_callback: Optional[Callable] = None,
    ) -> InferenceResult:
        """
        Inférence complète avec minimisation d'énergie et apprentissage en ligne.
        
        Args:
            initial_state: état initial du système (4096 bits)
            memory_table: SparseAddressTable
            router: HammingRouter
            energy_landscape: EnergyLandscape
            binder: CircularBinder
            k_neighbors: nombre de voisins à considérer
            external_callback: fonction optionnelle appelée à chaque itération
        
        Returns:
            InferenceResult avec trajectoire et événements d'apprentissage
        """
        import time
        t0 = time.time()
        
        result = InferenceResult()
        result.initial_state = initial_state.copy()
        
        current_state = initial_state.copy()
        temperature = self.temperature
        
        prev_energy = float('inf')
        energy_window = []
        
        # Trajectoire des états pour détection de motifs
        state_trajectory = [current_state.copy()]
        
        for iteration in range(self.max_iterations):
            # 1. Route vers les voisins les plus proches
            neighbors_info = router.route(
                current_state,
                k=k_neighbors,
                use_cache=True
            )
            
            if len(neighbors_info) == 0:
                # Pas de voisins : état nouveau, potentiellement créer
                break
            
            neighbor_indices = [idx for idx, _ in neighbors_info]
            neighbor_distances = [dist for _, dist in neighbors_info]
            
            # Récupère les vecteurs voisins depuis la mémoire
            neighbor_vectors = np.array([
                memory_table.vectors[idx]
                for idx in neighbor_indices
                if memory_table.active_mask[idx]
            ], dtype=np.uint8)
            
            if len(neighbor_vectors) == 0:
                break
            
            neighbor_ids = [
                memory_table.metadata[idx].id
                for idx in neighbor_indices
                if memory_table.active_mask[idx]
            ]
            
            # 2. Calcule l'énergie actuelle
            energy = energy_landscape.compute_energy(
                current_state,
                neighbor_vectors,
                neighbor_ids,
            )
            
            result.energy_trajectory.append(energy)
            result.neighbor_trajectory.append(neighbors_info)
            
            # 3. Détermine les flips optimaux
            deltas = energy_landscape.get_bit_flip_deltas(
                current_state,
                neighbor_vectors,
                neighbor_ids,
            )
            
            # Sélectionne les flips qui réduisent l'énergie
            # avec bruit thermique pour exploration
            flip_probs = np.exp(-deltas / max(temperature, 0.01))
            flip_probs = flip_probs / np.sum(flip_probs)
            
            # Choix déterministe + stochastique
            n_candidates = max(1, int(VECTOR_SIZE * 0.005))  # ~20 bits
            top_candidates = np.argsort(-flip_probs)[:n_candidates * 2]
            
            # Favorise les flips qui réduisent l'énergie
            beneficial = deltas[top_candidates] < 0
            if np.any(beneficial):
                # Fait tous les flips bénéfiques avec probabilité selon température
                selected = top_candidates[
                    np.random.random(len(top_candidates)) < flip_probs[top_candidates]
                ]
            else:
                # Échappement local : flip aléatoire contrôlé
                selected = np.random.choice(
                    VECTOR_SIZE,
                    size=max(1, int(n_candidates * temperature)),
                    replace=False,
                    p=flip_probs
                )
            
            if len(selected) > 0:
                new_state = current_state.copy()
                new_state[selected] = 1 - new_state[selected]
                
                # Calcule la nouvelle énergie
                new_energy = energy_landscape.compute_energy(
                    new_state,
                    neighbor_vectors,
                    neighbor_ids,
                )
                
                # Acceptation Metropolis-Hastings
                delta_e = new_energy - energy
                if delta_e < 0 or np.random.random() < np.exp(-delta_e / max(temperature, 0.01)):
                    current_state = new_state
                    energy = new_energy
            
            state_trajectory.append(current_state.copy())
            
            # 4. Apprentissage en ligne
            if self.online_learning:
                learning_events = self._online_learning_step(
                    current_state,
                    neighbor_vectors,
                    neighbor_ids,
                    neighbor_indices,
                    energy,
                    iteration,
                    memory_table,
                    energy_landscape,
                )
                result.learning_events.extend(learning_events)
            
            # 5. Détection périodique de motifs pour abstraction
            if iteration > 0 and iteration % self.pattern_detection_interval == 0:
                patterns = memory_table.detect_frequent_patterns(
                    [st for st in state_trajectory[-self.pattern_detection_interval:]],
                    min_frequency=3
                )
                
                for pattern in patterns:
                    # Crée une abstraction si le pattern est fréquent
                    new_id = memory_table.create_vector(
                        context=pattern,
                        abstraction_level=1,
                    )
                    result.creation_events.append({
                        'type': 'abstraction',
                        'id': new_id,
                        'iteration': iteration,
                    })
            
            # 6. Callback externe
            if external_callback:
                external_callback({
                    'iteration': iteration,
                    'energy': energy,
                    'state': current_state,
                    'neighbors': neighbors_info,
                })
            
            # 7. Vérification convergence
            energy_window.append(energy)
            if len(energy_window) > self.convergence_window:
                energy_window.pop(0)
            
            if len(energy_window) >= self.convergence_window:
                energy_std = np.std(energy_window)
                energy_mean = np.mean(energy_window)
                if energy_std / max(abs(energy_mean), 1.0) < self.early_stop_threshold:
                    result.converged = True
                    break
            
            # Refroidissement
            temperature *= self.cooling_rate
            prev_energy = energy
        
        # Inférence terminée
        result.final_state = current_state.copy()
        result.n_iterations = iteration + 1
        
        # Apprentissage post-inférence : renforce les associations
        # si l'inférence a convergé vers un état stable
        if self.online_learning and result.converged:
            self._post_inference_learning(
                result,
                memory_table,
                energy_landscape,
                router,
            )
        
        t1 = time.time()
        result.execution_time_ms = (t1 - t0) * 1000
        
        # Stats
        self.total_inferences += 1
        self.total_iterations += result.n_iterations
        if result.converged:
            self.total_converged += 1
        self.avg_inference_time_ms = (
            self.avg_inference_time_ms * (self.total_inferences - 1) + result.execution_time_ms
        ) / self.total_inferences
        
        return result
    
    def _online_learning_step(
        self,
        state: np.ndarray,
        neighbor_vectors: np.ndarray,
        neighbor_ids: List[int],
        neighbor_indices: List[int],
        energy: float,
        iteration: int,
        memory_table,
        energy_landscape,
    ) -> List[Dict]:
        """
        Effectue un pas d'apprentissage pendant l'inférence.
        Mises à jour locales uniquement.
        
        Returns:
            Liste d'événements d'apprentissage
        """
        events = []
        
        # Met à jour les métadonnées des voisins
        for idx in neighbor_indices:
            if memory_table.active_mask[idx]:
                meta = memory_table.metadata[idx]
                meta.record_access(memory_table.time_step, energy)
        
        # Met à jour le paysage d'énergie
        is_stable = iteration > 5 and len(energy_landscape.energy_history) > 10
        energy_landscape.update_from_state(
            state,
            neighbor_ids,
            energy,
            is_stable=is_stable,
        )
        
        # Met à jour les coactivations entre voisins
        for i, idx1 in enumerate(neighbor_indices):
            for idx2 in neighbor_indices[i+1:]:
                if memory_table.active_mask[idx1] and memory_table.active_mask[idx2]:
                    id1 = memory_table.metadata[idx1].id
                    id2 = memory_table.metadata[idx2].id
                    # Renforce l'association si coactivation fréquente
                    strength = 1.0 / (1.0 + energy / 1000.0)
                    memory_table.metadata[idx1].update_coactivation(id2, strength)
                    memory_table.metadata[idx2].update_coactivation(id1, strength)
        
        # Crée un nouveau vecteur si l'état est suffisamment différent
        # de tous les voisins (configuration récurrente ou nouvelle)
        if iteration > 3:
            min_neighbor_dist = min([
                float(np.sum(state != memory_table.vectors[idx]))
                for idx in neighbor_indices
                if memory_table.active_mask[idx]
            ]) if neighbor_indices else float('inf')
            
            if min_neighbor_dist > memory_table.creation_threshold:
                # Nouvelle configuration intéressante
                new_id = memory_table.create_vector(context=state)
                events.append({
                    'type': 'creation',
                    'id': new_id,
                    'reason': 'novel_pattern',
                    'distance': float(min_neighbor_dist),
                    'iteration': iteration,
                })
        
        return events
    
    def _post_inference_learning(
        self,
        result: InferenceResult,
        memory_table,
        energy_landscape,
        router,
    ):
        """
        Apprentissage après convergence.
        Renforce les associations dans la trajectoire de basse énergie.
        """
        if len(result.neighbor_trajectory) < 3:
            return
        
        # Identifie la phase de basse énergie
        energies = np.array(result.energy_trajectory)
        min_energy_idx = int(np.argmin(energies))
        
        # Les voisins à ce point sont "la réponse"
        if min_energy_idx < len(result.neighbor_trajectory):
            stable_neighbors = result.neighbor_trajectory[min_energy_idx]
            stable_ids = [nid for nid, _ in stable_neighbors]
            
            # Renforce les associations entre voisins stables
            for i, id1 in enumerate(stable_ids):
                for id2 in stable_ids[i+1:]:
                    pair = tuple(sorted((id1, id2)))
                    current = energy_landscape.associations.get(pair, 0.0)
                    energy_landscape.associations[pair] = min(
                        1.0,
                        current + self.learning_rate * 2.0
                    )
            
            # Met à jour le routeur avec les nouvelles routes apprises
            final_state = result.final_state
            final_packed = router.pack_bits_to_uint64(final_state) if hasattr(router, 'pack_bits_to_uint64') else None
            if final_packed is not None:
                ph = router._pattern_hash(final_packed) if hasattr(router, '_pattern_hash') else None
                if ph is not None:
                    router.route_cache[ph] = [
                        (nid, 1.0 / (1.0 + dist))
                        for nid, dist in stable_neighbors
                    ]
    
    def get_stats(self) -> Dict:
        convergence_rate = (
            self.total_converged / self.total_inferences
            if self.total_inferences > 0 else 0.0
        )
        avg_iterations = (
            self.total_iterations / self.total_inferences
            if self.total_inferences > 0 else 0.0
        )
        
        return {
            'total_inferences': self.total_inferences,
            'convergence_rate': convergence_rate,
            'avg_iterations': avg_iterations,
            'avg_inference_time_ms': self.avg_inference_time_ms,
        }