File size: 12,658 Bytes
909b197 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 | """
Paysage d'Énergie Apprenant
La fonction d'énergie évalue la cohérence d'un état du système.
Elle doit être :
- Locale : les mises à jour ne dépendent que des voisins
- Apprenante : s'ajuste avec l'expérience
- Discriminative : basse énergie = cohérent, haute = incohérent
"""
import numpy as np
from numba import njit, prange
from typing import Dict, List, Tuple, Optional, Set
import logging
logger = logging.getLogger(__name__)
VECTOR_SIZE = 4096
def compute_local_hamming_energy_fast(
state: np.ndarray,
neighbors: np.ndarray,
weights: np.ndarray,
) -> float:
"""
Énergie locale de Hamming : somme pondérée des distances aux voisins.
Version numpy vectorisée rapide.
"""
N = neighbors.shape[0]
if N == 0:
return 0.0
# XOR vectorisé : state ^ neighbors[i]
xor_result = state.astype(np.uint8) ^ neighbors # (N, 4096)
# Somme par ligne
dists = np.sum(xor_result, axis=1)
return float(np.dot(weights, dists)) / N
def compute_bit_flip_delta_fast(
state: np.ndarray,
neighbors: np.ndarray,
weights: np.ndarray,
biases: np.ndarray,
) -> np.ndarray:
"""
Calcule le delta d'énergie pour chaque flip de bit possible.
Version numpy vectorisée.
"""
N = neighbors.shape[0]
if N == 0:
return biases.copy()
# Pour chaque bit, calculer le changement si on le flippe
# dist_actuel = sum(weights[i] * (state[j] ^ neighbors[i,j]))
# dist_nouveau = sum(weights[i] * ((1-state[j]) ^ neighbors[i,j]))
state_expanded = state[np.newaxis, :] # (1, 4096)
# XOR actuel
xor_current = state_expanded ^ neighbors # (N, 4096)
# XOR après flip
flipped = 1 - state_expanded
xor_flipped = flipped ^ neighbors # (N, 4096)
# Delta = new_dist - old_dist pour chaque bit
delta_xor = xor_flipped.astype(np.float64) - xor_current.astype(np.float64) # (N, 4096)
# Pondération
weights_col = weights[:, np.newaxis] # (N, 1)
weighted_delta = delta_xor * weights_col # (N, 4096)
deltas = np.sum(weighted_delta, axis=0) # (4096,)
deltas += biases * (2.0 * (1 - state.astype(np.float64)) - 1.0)
return deltas
class EnergyLandscape:
"""
Paysage d'énergie apprenant avec mises à jour locales.
Components:
- hamming_weight: poids de la distance de Hamming
- association_weight: poids des associations Hebbian
- structure_weight: poids de la cohérence structurelle (binding)
- biases: biais par bit (apprenant)
- local_weights: poids par voisin (apprenant)
Apprentissage:
- Renforcement des associations dans les états de basse énergie
- Affaiblissement dans les états instables
- Ajustement des biais locaux
"""
def __init__(
self,
hamming_weight: float = 1.0,
association_weight: float = 0.5,
structure_weight: float = 0.3,
bias_learning_rate: float = 0.01,
association_decay: float = 0.99,
energy_threshold_low: float = 500.0,
energy_threshold_high: float = 1500.0,
):
self.hamming_weight = hamming_weight
self.association_weight = association_weight
self.structure_weight = structure_weight
self.bias_lr = bias_learning_rate
self.assoc_decay = association_decay
self.low_threshold = energy_threshold_low
self.high_threshold = energy_threshold_high
# Biases par bit (apprenant)
self.biases = np.zeros(VECTOR_SIZE, dtype=np.float64)
# Associations apprises : (id1, id2) -> strength
self.associations: Dict[Tuple[int, int], float] = {}
# Poids structurels pour binding
self.structure_weights: Dict[Tuple[int, int], float] = {}
# Historique pour apprentissage
self.energy_history: List[float] = []
self.min_energy_history: List[float] = []
self.stable_states: List[np.ndarray] = []
# Stats
self.n_updates = 0
self.total_energy = 0.0
def compute_energy(
self,
state: np.ndarray,
neighbor_vectors: np.ndarray,
neighbor_ids: List[int],
neighbor_metadata: Optional[List] = None,
) -> float:
"""
Calcule l'énergie totale d'un état.
Args:
state: vecteur courant (4096,) uint8
neighbor_vectors: (N, 4096) voisins actifs
neighbor_ids: IDs des voisins
neighbor_metadata: métadonnées optionnelles
Returns:
énergie scalaire (plus bas = plus cohérent)
"""
N = len(neighbor_ids) if neighbor_ids else 0
if N == 0:
return float(np.sum(self.biases * (2.0 * state - 1.0)))
# Poids des voisins (apprenant)
weights = np.array([
self.associations.get(
tuple(sorted((-1, nid))), 0.5 # Default weight
) + 0.5
for nid in neighbor_ids
], dtype=np.float64)
# 1. Énergie de Hamming (vectorisée rapide)
hamming_energy = compute_local_hamming_energy_fast(
state, neighbor_vectors, weights
)
hamming_energy *= self.hamming_weight
# 2. Énergie d'association (Hebbian-like)
association_energy = 0.0
if len(neighbor_ids) > 1:
for i in range(len(neighbor_ids)):
for j in range(i+1, len(neighbor_ids)):
pair = tuple(sorted((neighbor_ids[i], neighbor_ids[j])))
strength = self.associations.get(pair, 0.0)
# Distance entre voisins : proches = faible énergie
dist = np.sum(neighbor_vectors[i] != neighbor_vectors[j])
association_energy += strength * dist
association_energy *= self.association_weight
# 3. Énergie de structure (binding)
structure_energy = 0.0
if len(neighbor_ids) > 1:
for i in range(len(neighbor_ids)):
for j in range(i+1, len(neighbor_ids)):
pair = tuple(sorted((neighbor_ids[i], neighbor_ids[j])))
struct_weight = self.structure_weights.get(pair, 0.0)
# Corrélation comme proxy de structure
corr = np.mean(neighbor_vectors[i] == neighbor_vectors[j])
structure_energy += struct_weight * (1.0 - corr)
structure_energy *= self.structure_weight
# 4. Biais local
bias_energy = float(np.sum(self.biases * (2.0 * state.astype(np.float64) - 1.0)))
total = hamming_energy + association_energy + structure_energy + bias_energy
self.energy_history.append(total)
if len(self.energy_history) > 1000:
self.energy_history = self.energy_history[-1000:]
return total
def get_bit_flip_deltas(
self,
state: np.ndarray,
neighbor_vectors: np.ndarray,
neighbor_ids: List[int],
) -> np.ndarray:
"""
Calcule le delta d'énergie pour chaque bit flip possible.
Utilisé par l'inférence pour trouver les meilleurs flips.
Returns:
deltas: (4096,) négatif = flip réduit l'énergie
"""
N = len(neighbor_ids)
if N == 0:
# Juste les biais
return self.biases.copy()
weights = np.array([
self.associations.get(tuple(sorted((-1, nid))), 0.5) + 0.5
for nid in neighbor_ids
], dtype=np.float64)
return compute_bit_flip_delta_fast(state, neighbor_vectors, weights, self.biases)
def update_from_state(
self,
state: np.ndarray,
neighbor_ids: List[int],
energy: float,
is_stable: bool = False,
):
"""
Met à jour le paysage d'énergie à partir d'un état observé.
Appelé après chaque itération d'inférence.
Règles:
- Basse énergie stable : renforce les associations
- Haute énergie instable : affaiblit ou réorganise
"""
self.n_updates += 1
self.total_energy += energy
if len(neighbor_ids) < 2:
return
# Normalise l'énergie relative
if len(self.energy_history) > 20:
recent_mean = np.mean(self.energy_history[-20:])
recent_std = np.std(self.energy_history[-20:]) + 1e-8
normalized_energy = (energy - recent_mean) / recent_std
else:
normalized_energy = 0.0
# Détermine si l'état est cohérent ou incohérent
is_coherent = normalized_energy < -0.5 or (energy < self.low_threshold and is_stable)
is_incoherent = normalized_energy > 0.5 or energy > self.high_threshold
# Met à jour les associations
for i in range(len(neighbor_ids)):
for j in range(i+1, len(neighbor_ids)):
pair = tuple(sorted((neighbor_ids[i], neighbor_ids[j])))
if is_coherent:
# Renforce l'association (Hebbian-like)
current = self.associations.get(pair, 0.0)
self.associations[pair] = min(1.0, current + self.bias_lr * (1.0 - current))
elif is_incoherent:
# Affaiblit l'association (anti-Hebbian)
current = self.associations.get(pair, 0.0)
self.associations[pair] = max(-0.5, current - self.bias_lr * 2.0)
# Met à jour les biais locaux
state_float = state.astype(np.float64)
if is_coherent:
# Renforce les bits qui sont cohérents avec le voisinage
self.biases += self.bias_lr * (2.0 * state_float - 1.0) * 0.1
elif is_incoherent:
# Affaiblit les bits qui sont incohérents
self.biases -= self.bias_lr * (2.0 * state_float - 1.0) * 0.2
# Normalise les biais pour éviter la dérive
self.biases = np.clip(self.biases, -2.0, 2.0)
# Si stable et basse énergie, mémorise l'état
if is_stable and is_coherent:
self.stable_states.append(state.copy())
if len(self.stable_states) > 100:
self.stable_states = self.stable_states[-100:]
# Déduit les associations avec le temps
if self.n_updates % 100 == 0:
self._decay_associations()
def update_structure_weights(
self,
bound_state: np.ndarray,
component_ids: List[int],
coherence: float,
):
"""
Met à jour les poids structurels pour le binding.
"""
for i in range(len(component_ids)):
for j in range(i+1, len(component_ids)):
pair = tuple(sorted((component_ids[i], component_ids[j])))
current = self.structure_weights.get(pair, 0.0)
# Mise à jour basée sur la cohérence
if coherence > 0.7:
self.structure_weights[pair] = min(1.0, current + self.bias_lr)
elif coherence < 0.3:
self.structure_weights[pair] = max(-0.5, current - self.bias_lr * 1.5)
def _decay_associations(self):
"""Décroissance périodique des associations peu utilisées."""
to_remove = []
for pair, strength in self.associations.items():
decayed = strength * self.assoc_decay
if abs(decayed) < 0.01:
to_remove.append(pair)
else:
self.associations[pair] = decayed
for pair in to_remove:
del self.associations[pair]
def get_association_strength(self, id1: int, id2: int) -> float:
"""Retourne la force d'association entre deux vecteurs."""
pair = tuple(sorted((id1, id2)))
return self.associations.get(pair, 0.0)
def get_stats(self) -> Dict:
n_assoc = len(self.associations)
n_struct = len(self.structure_weights)
recent_energy = np.mean(self.energy_history[-100:]) if self.energy_history else 0.0
return {
'n_associations': n_assoc,
'n_structure_weights': n_struct,
'mean_bias': float(np.mean(np.abs(self.biases))),
'recent_mean_energy': float(recent_energy),
'n_updates': self.n_updates,
'n_stable_states': len(self.stable_states),
} |