morpho-logic-engine / mle /energy.py
Harry00's picture
Upload mle/energy.py
909b197 verified
"""
Paysage d'Énergie Apprenant
La fonction d'énergie évalue la cohérence d'un état du système.
Elle doit être :
- Locale : les mises à jour ne dépendent que des voisins
- Apprenante : s'ajuste avec l'expérience
- Discriminative : basse énergie = cohérent, haute = incohérent
"""
import numpy as np
from numba import njit, prange
from typing import Dict, List, Tuple, Optional, Set
import logging
logger = logging.getLogger(__name__)
VECTOR_SIZE = 4096
def compute_local_hamming_energy_fast(
state: np.ndarray,
neighbors: np.ndarray,
weights: np.ndarray,
) -> float:
"""
Énergie locale de Hamming : somme pondérée des distances aux voisins.
Version numpy vectorisée rapide.
"""
N = neighbors.shape[0]
if N == 0:
return 0.0
# XOR vectorisé : state ^ neighbors[i]
xor_result = state.astype(np.uint8) ^ neighbors # (N, 4096)
# Somme par ligne
dists = np.sum(xor_result, axis=1)
return float(np.dot(weights, dists)) / N
def compute_bit_flip_delta_fast(
state: np.ndarray,
neighbors: np.ndarray,
weights: np.ndarray,
biases: np.ndarray,
) -> np.ndarray:
"""
Calcule le delta d'énergie pour chaque flip de bit possible.
Version numpy vectorisée.
"""
N = neighbors.shape[0]
if N == 0:
return biases.copy()
# Pour chaque bit, calculer le changement si on le flippe
# dist_actuel = sum(weights[i] * (state[j] ^ neighbors[i,j]))
# dist_nouveau = sum(weights[i] * ((1-state[j]) ^ neighbors[i,j]))
state_expanded = state[np.newaxis, :] # (1, 4096)
# XOR actuel
xor_current = state_expanded ^ neighbors # (N, 4096)
# XOR après flip
flipped = 1 - state_expanded
xor_flipped = flipped ^ neighbors # (N, 4096)
# Delta = new_dist - old_dist pour chaque bit
delta_xor = xor_flipped.astype(np.float64) - xor_current.astype(np.float64) # (N, 4096)
# Pondération
weights_col = weights[:, np.newaxis] # (N, 1)
weighted_delta = delta_xor * weights_col # (N, 4096)
deltas = np.sum(weighted_delta, axis=0) # (4096,)
deltas += biases * (2.0 * (1 - state.astype(np.float64)) - 1.0)
return deltas
class EnergyLandscape:
"""
Paysage d'énergie apprenant avec mises à jour locales.
Components:
- hamming_weight: poids de la distance de Hamming
- association_weight: poids des associations Hebbian
- structure_weight: poids de la cohérence structurelle (binding)
- biases: biais par bit (apprenant)
- local_weights: poids par voisin (apprenant)
Apprentissage:
- Renforcement des associations dans les états de basse énergie
- Affaiblissement dans les états instables
- Ajustement des biais locaux
"""
def __init__(
self,
hamming_weight: float = 1.0,
association_weight: float = 0.5,
structure_weight: float = 0.3,
bias_learning_rate: float = 0.01,
association_decay: float = 0.99,
energy_threshold_low: float = 500.0,
energy_threshold_high: float = 1500.0,
):
self.hamming_weight = hamming_weight
self.association_weight = association_weight
self.structure_weight = structure_weight
self.bias_lr = bias_learning_rate
self.assoc_decay = association_decay
self.low_threshold = energy_threshold_low
self.high_threshold = energy_threshold_high
# Biases par bit (apprenant)
self.biases = np.zeros(VECTOR_SIZE, dtype=np.float64)
# Associations apprises : (id1, id2) -> strength
self.associations: Dict[Tuple[int, int], float] = {}
# Poids structurels pour binding
self.structure_weights: Dict[Tuple[int, int], float] = {}
# Historique pour apprentissage
self.energy_history: List[float] = []
self.min_energy_history: List[float] = []
self.stable_states: List[np.ndarray] = []
# Stats
self.n_updates = 0
self.total_energy = 0.0
def compute_energy(
self,
state: np.ndarray,
neighbor_vectors: np.ndarray,
neighbor_ids: List[int],
neighbor_metadata: Optional[List] = None,
) -> float:
"""
Calcule l'énergie totale d'un état.
Args:
state: vecteur courant (4096,) uint8
neighbor_vectors: (N, 4096) voisins actifs
neighbor_ids: IDs des voisins
neighbor_metadata: métadonnées optionnelles
Returns:
énergie scalaire (plus bas = plus cohérent)
"""
N = len(neighbor_ids) if neighbor_ids else 0
if N == 0:
return float(np.sum(self.biases * (2.0 * state - 1.0)))
# Poids des voisins (apprenant)
weights = np.array([
self.associations.get(
tuple(sorted((-1, nid))), 0.5 # Default weight
) + 0.5
for nid in neighbor_ids
], dtype=np.float64)
# 1. Énergie de Hamming (vectorisée rapide)
hamming_energy = compute_local_hamming_energy_fast(
state, neighbor_vectors, weights
)
hamming_energy *= self.hamming_weight
# 2. Énergie d'association (Hebbian-like)
association_energy = 0.0
if len(neighbor_ids) > 1:
for i in range(len(neighbor_ids)):
for j in range(i+1, len(neighbor_ids)):
pair = tuple(sorted((neighbor_ids[i], neighbor_ids[j])))
strength = self.associations.get(pair, 0.0)
# Distance entre voisins : proches = faible énergie
dist = np.sum(neighbor_vectors[i] != neighbor_vectors[j])
association_energy += strength * dist
association_energy *= self.association_weight
# 3. Énergie de structure (binding)
structure_energy = 0.0
if len(neighbor_ids) > 1:
for i in range(len(neighbor_ids)):
for j in range(i+1, len(neighbor_ids)):
pair = tuple(sorted((neighbor_ids[i], neighbor_ids[j])))
struct_weight = self.structure_weights.get(pair, 0.0)
# Corrélation comme proxy de structure
corr = np.mean(neighbor_vectors[i] == neighbor_vectors[j])
structure_energy += struct_weight * (1.0 - corr)
structure_energy *= self.structure_weight
# 4. Biais local
bias_energy = float(np.sum(self.biases * (2.0 * state.astype(np.float64) - 1.0)))
total = hamming_energy + association_energy + structure_energy + bias_energy
self.energy_history.append(total)
if len(self.energy_history) > 1000:
self.energy_history = self.energy_history[-1000:]
return total
def get_bit_flip_deltas(
self,
state: np.ndarray,
neighbor_vectors: np.ndarray,
neighbor_ids: List[int],
) -> np.ndarray:
"""
Calcule le delta d'énergie pour chaque bit flip possible.
Utilisé par l'inférence pour trouver les meilleurs flips.
Returns:
deltas: (4096,) négatif = flip réduit l'énergie
"""
N = len(neighbor_ids)
if N == 0:
# Juste les biais
return self.biases.copy()
weights = np.array([
self.associations.get(tuple(sorted((-1, nid))), 0.5) + 0.5
for nid in neighbor_ids
], dtype=np.float64)
return compute_bit_flip_delta_fast(state, neighbor_vectors, weights, self.biases)
def update_from_state(
self,
state: np.ndarray,
neighbor_ids: List[int],
energy: float,
is_stable: bool = False,
):
"""
Met à jour le paysage d'énergie à partir d'un état observé.
Appelé après chaque itération d'inférence.
Règles:
- Basse énergie stable : renforce les associations
- Haute énergie instable : affaiblit ou réorganise
"""
self.n_updates += 1
self.total_energy += energy
if len(neighbor_ids) < 2:
return
# Normalise l'énergie relative
if len(self.energy_history) > 20:
recent_mean = np.mean(self.energy_history[-20:])
recent_std = np.std(self.energy_history[-20:]) + 1e-8
normalized_energy = (energy - recent_mean) / recent_std
else:
normalized_energy = 0.0
# Détermine si l'état est cohérent ou incohérent
is_coherent = normalized_energy < -0.5 or (energy < self.low_threshold and is_stable)
is_incoherent = normalized_energy > 0.5 or energy > self.high_threshold
# Met à jour les associations
for i in range(len(neighbor_ids)):
for j in range(i+1, len(neighbor_ids)):
pair = tuple(sorted((neighbor_ids[i], neighbor_ids[j])))
if is_coherent:
# Renforce l'association (Hebbian-like)
current = self.associations.get(pair, 0.0)
self.associations[pair] = min(1.0, current + self.bias_lr * (1.0 - current))
elif is_incoherent:
# Affaiblit l'association (anti-Hebbian)
current = self.associations.get(pair, 0.0)
self.associations[pair] = max(-0.5, current - self.bias_lr * 2.0)
# Met à jour les biais locaux
state_float = state.astype(np.float64)
if is_coherent:
# Renforce les bits qui sont cohérents avec le voisinage
self.biases += self.bias_lr * (2.0 * state_float - 1.0) * 0.1
elif is_incoherent:
# Affaiblit les bits qui sont incohérents
self.biases -= self.bias_lr * (2.0 * state_float - 1.0) * 0.2
# Normalise les biais pour éviter la dérive
self.biases = np.clip(self.biases, -2.0, 2.0)
# Si stable et basse énergie, mémorise l'état
if is_stable and is_coherent:
self.stable_states.append(state.copy())
if len(self.stable_states) > 100:
self.stable_states = self.stable_states[-100:]
# Déduit les associations avec le temps
if self.n_updates % 100 == 0:
self._decay_associations()
def update_structure_weights(
self,
bound_state: np.ndarray,
component_ids: List[int],
coherence: float,
):
"""
Met à jour les poids structurels pour le binding.
"""
for i in range(len(component_ids)):
for j in range(i+1, len(component_ids)):
pair = tuple(sorted((component_ids[i], component_ids[j])))
current = self.structure_weights.get(pair, 0.0)
# Mise à jour basée sur la cohérence
if coherence > 0.7:
self.structure_weights[pair] = min(1.0, current + self.bias_lr)
elif coherence < 0.3:
self.structure_weights[pair] = max(-0.5, current - self.bias_lr * 1.5)
def _decay_associations(self):
"""Décroissance périodique des associations peu utilisées."""
to_remove = []
for pair, strength in self.associations.items():
decayed = strength * self.assoc_decay
if abs(decayed) < 0.01:
to_remove.append(pair)
else:
self.associations[pair] = decayed
for pair in to_remove:
del self.associations[pair]
def get_association_strength(self, id1: int, id2: int) -> float:
"""Retourne la force d'association entre deux vecteurs."""
pair = tuple(sorted((id1, id2)))
return self.associations.get(pair, 0.0)
def get_stats(self) -> Dict:
n_assoc = len(self.associations)
n_struct = len(self.structure_weights)
recent_energy = np.mean(self.energy_history[-100:]) if self.energy_history else 0.0
return {
'n_associations': n_assoc,
'n_structure_weights': n_struct,
'mean_bias': float(np.mean(np.abs(self.biases))),
'recent_mean_energy': float(recent_energy),
'n_updates': self.n_updates,
'n_stable_states': len(self.stable_states),
}