""" Phase Transition Engine — Adaptive Computation via Physics. Instead of fixed-depth feed-forward layers, computation adapts via phase transitions: EXPLORATION (high temp) → CRYSTALLIZATION (cooling) → SOLID (stable) Simple problems crystallize quickly (few steps). Hard problems take longer to reach stability. This is NOT a halting mechanism — it's a physical process where the system naturally finds its ground state. """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from typing import Tuple, Dict, Optional from .particle import ParticleState class PhaseEngine(nn.Module): """Adaptive computation via simulated phase transitions. The system has three phases: 1. EXPLORATION (temperature = high) - Particles move freely - Many temporary bonds form and break - Broad search over meaning space - High entropy, high energy 2. CRYSTALLIZATION (temperature = cooling) - Strong patterns emerge - Weak bonds break, strong bonds solidify - Meaning begins to crystallize - Decreasing entropy, decreasing energy 3. SOLID (temperature = low) - Structure is stable - Output is determined - Low entropy, low energy If the system gets stuck (local minimum), it reheats and tries again. """ def __init__( self, d_semantic: int = 256, min_steps: int = 2, max_steps: int = 32, initial_temp: float = 1.0, cooling_rate: float = 0.9, reheat_threshold: float = 0.1, stability_threshold: float = 0.95, ): super().__init__() self.d_semantic = d_semantic self.min_steps = min_steps self.max_steps = max_steps self.initial_temp = initial_temp self.cooling_rate = cooling_rate self.reheat_threshold = reheat_threshold self.stability_threshold = stability_threshold # Order parameter network — measures how "crystallized" the system is self.order_net = nn.Sequential( nn.Linear(d_semantic, 128), nn.GELU(), nn.Linear(128, 1), nn.Sigmoid(), ) # Energy network — computes system energy (lower = more stable) self.energy_net = nn.Sequential( nn.Linear(d_semantic, 128), nn.GELU(), nn.Linear(128, 1), ) # Perturbation network — how to perturb when stuck self.perturbation_net = nn.Sequential( nn.Linear(d_semantic + 1, 128), # +1 for temperature nn.GELU(), nn.Linear(128, d_semantic), nn.Tanh(), ) # Refinement step — how particles update during crystallization self.refine_net = nn.Sequential( nn.Linear(d_semantic * 2, 256), nn.GELU(), nn.Linear(256, d_semantic), ) def forward( self, particles: ParticleState, flow_fn=None, bond_fn=None, ) -> Tuple[ParticleState, Dict]: """Run phase transition computation. The system starts hot (exploration), cools down (crystallization), and stabilizes (solid). If stuck, it reheats. Args: particles: Initial particle state flow_fn: Optional flow field function to call each step bond_fn: Optional bond system function to call each step Returns: Final particle state after reaching solid phase Diagnostics about the phase transition process """ batch, seq_len, d = particles.semantic.shape device = particles.semantic.device temperature = torch.ones(batch, device=device) * self.initial_temp diagnostics = { 'steps': [], 'temperatures': [], 'order_parameters': [], 'energies': [], 'phases': [], 'reheats': torch.zeros(batch, device=device), } current = particles prev_order = torch.zeros(batch, device=device) stuck_count = torch.zeros(batch, device=device) bonds = None for step in range(self.max_steps): # === Measure current state === # Order parameter: how crystallized is the system? pooled = current.semantic.mean(dim=1) # [batch, d] order = self.order_net(pooled).squeeze(-1) # [batch] # Energy: how stable is the system? energy = self.energy_net(pooled).squeeze(-1) # [batch] # === Determine phase === phase_name = [] for b in range(batch): if temperature[b] > 0.7: phase_name.append('exploration') elif temperature[b] > 0.3: phase_name.append('crystallization') else: phase_name.append('solid') # === Refinement step (temperature-dependent) === # At high temp: large perturbations, broad exploration # At low temp: small refinements, fine-tuning # Compute refinement based on current state and temperature temp_expanded = temperature.unsqueeze(-1).unsqueeze(-1).expand(-1, seq_len, 1) refine_input = torch.cat([current.semantic, temp_expanded], dim=-1) refinement = self.perturbation_net(refine_input) # Scale refinement by temperature refinement = refinement * temperature.unsqueeze(-1).unsqueeze(-1) # Apply refinement new_semantic = current.semantic + refinement * 0.1 # If flow function provided, run one flow step if flow_fn is not None: current_for_flow = ParticleState( semantic=new_semantic, position=current.position, charge=current.charge, mass=current.mass, spin=current.spin, amplitude=current.amplitude, phase=current.phase, memory_trace=current.memory_trace, ) current_for_flow, flow_diag = flow_fn(current_for_flow) new_semantic = current_for_flow.semantic # If bond function provided, run one bond step if bond_fn is not None: bonds, bond_output, bond_diag = bond_fn( ParticleState( semantic=new_semantic, position=current.position, charge=current.charge, mass=current.mass, spin=current.spin, amplitude=current.amplitude, phase=current.phase, memory_trace=current.memory_trace, ), bonds ) new_semantic = new_semantic + bond_output * 0.1 # === Update temperature === # Cool based on order parameter increase order_increase = order - prev_order cooling = self.cooling_rate * (1 - order_increase.abs()) temperature = temperature * cooling # Check for stuck state (local minimum) is_stuck = (order - prev_order).abs() < self.reheat_threshold stuck_count = stuck_count + is_stuck.float() # Reheat if stuck for too long needs_reheat = stuck_count > 3 temperature[needs_reheat] = self.initial_temp * 0.5 stuck_count[needs_reheat] = 0 diagnostics['reheats'] += needs_reheat.float() # Ensure minimum temperature temperature = torch.clamp(temperature, min=0.01) # === Update particle state === current = ParticleState( semantic=new_semantic, position=current.position, charge=current.charge, mass=current.mass, spin=current.spin, amplitude=current.amplitude, phase=current.phase, memory_trace=current.memory_trace, ) prev_order = order # Record diagnostics diagnostics['temperatures'].append(temperature.mean().item()) diagnostics['order_parameters'].append(order.mean().item()) diagnostics['energies'].append(energy.mean().item()) diagnostics['phases'].append(phase_name) diagnostics['steps'].append(step) # === Check stability === # If all batches are stable AND we've done minimum steps if step >= self.min_steps: is_stable = order > self.stability_threshold if is_stable.all(): break diagnostics['total_steps'] = len(diagnostics['steps']) diagnostics['final_temperature'] = temperature.mean().item() diagnostics['final_order'] = prev_order.mean().item() return current, diagnostics class PhaseTransitionLoss(nn.Module): """Loss function that encourages proper phase transition behavior. The system should: 1. Start with high entropy (exploration) 2. Progressively reduce entropy (crystallization) 3. Reach low energy state (solid) 4. Not get stuck in local minima """ def __init__(self): super().__init__() def forward(self, diagnostics: Dict) -> torch.Tensor: """Compute phase transition regularization loss.""" losses = [] # 1. Monotonic cooling — temperature should generally decrease temps = diagnostics['temperatures'] if len(temps) > 1: temp_diffs = torch.tensor(temps[1:]) - torch.tensor(temps[:-1]) monotonic_loss = F.relu(temp_diffs).mean() # penalize reheating losses.append(monotonic_loss) # 2. Order should increase — system should crystallize orders = diagnostics['order_parameters'] if len(orders) > 1: order_diffs = torch.tensor(orders[1:]) - torch.tensor(orders[:-1]) order_loss = F.relu(-order_diffs).mean() # penalize decreasing order losses.append(order_loss) # 3. Energy should decrease — system should find lower energy states energies = diagnostics['energies'] if len(energies) > 1: energy_diffs = torch.tensor(energies[1:]) - torch.tensor(energies[:-1]) energy_loss = F.relu(energy_diffs).mean() # penalize increasing energy losses.append(energy_loss) # 4. Penalize too many reheats reheat_loss = diagnostics['reheats'].mean() * 0.1 losses.append(reheat_loss) if losses: return sum(losses) / len(losses) return torch.tensor(0.0)