| """ |
| Phase Transition Engine — Adaptive Computation via Physics. |
| |
| Instead of fixed-depth feed-forward layers, computation adapts via |
| phase transitions: |
| |
| EXPLORATION (high temp) → CRYSTALLIZATION (cooling) → SOLID (stable) |
| |
| Simple problems crystallize quickly (few steps). |
| Hard problems take longer to reach stability. |
| |
| This is NOT a halting mechanism — it's a physical process where |
| the system naturally finds its ground state. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from typing import Tuple, Dict, Optional |
|
|
| from .particle import ParticleState |
|
|
|
|
| class PhaseEngine(nn.Module): |
| """Adaptive computation via simulated phase transitions. |
| |
| The system has three phases: |
| |
| 1. EXPLORATION (temperature = high) |
| - Particles move freely |
| - Many temporary bonds form and break |
| - Broad search over meaning space |
| - High entropy, high energy |
| |
| 2. CRYSTALLIZATION (temperature = cooling) |
| - Strong patterns emerge |
| - Weak bonds break, strong bonds solidify |
| - Meaning begins to crystallize |
| - Decreasing entropy, decreasing energy |
| |
| 3. SOLID (temperature = low) |
| - Structure is stable |
| - Output is determined |
| - Low entropy, low energy |
| |
| If the system gets stuck (local minimum), it reheats and tries again. |
| """ |
| |
| def __init__( |
| self, |
| d_semantic: int = 256, |
| min_steps: int = 2, |
| max_steps: int = 32, |
| initial_temp: float = 1.0, |
| cooling_rate: float = 0.9, |
| reheat_threshold: float = 0.1, |
| stability_threshold: float = 0.95, |
| ): |
| super().__init__() |
| self.d_semantic = d_semantic |
| self.min_steps = min_steps |
| self.max_steps = max_steps |
| self.initial_temp = initial_temp |
| self.cooling_rate = cooling_rate |
| self.reheat_threshold = reheat_threshold |
| self.stability_threshold = stability_threshold |
| |
| |
| self.order_net = nn.Sequential( |
| nn.Linear(d_semantic, 128), |
| nn.GELU(), |
| nn.Linear(128, 1), |
| nn.Sigmoid(), |
| ) |
| |
| |
| self.energy_net = nn.Sequential( |
| nn.Linear(d_semantic, 128), |
| nn.GELU(), |
| nn.Linear(128, 1), |
| ) |
| |
| |
| self.perturbation_net = nn.Sequential( |
| nn.Linear(d_semantic + 1, 128), |
| nn.GELU(), |
| nn.Linear(128, d_semantic), |
| nn.Tanh(), |
| ) |
| |
| |
| self.refine_net = nn.Sequential( |
| nn.Linear(d_semantic * 2, 256), |
| nn.GELU(), |
| nn.Linear(256, d_semantic), |
| ) |
| |
| def forward( |
| self, |
| particles: ParticleState, |
| flow_fn=None, |
| bond_fn=None, |
| ) -> Tuple[ParticleState, Dict]: |
| """Run phase transition computation. |
| |
| The system starts hot (exploration), cools down (crystallization), |
| and stabilizes (solid). If stuck, it reheats. |
| |
| Args: |
| particles: Initial particle state |
| flow_fn: Optional flow field function to call each step |
| bond_fn: Optional bond system function to call each step |
| |
| Returns: |
| Final particle state after reaching solid phase |
| Diagnostics about the phase transition process |
| """ |
| batch, seq_len, d = particles.semantic.shape |
| device = particles.semantic.device |
| |
| temperature = torch.ones(batch, device=device) * self.initial_temp |
| |
| diagnostics = { |
| 'steps': [], |
| 'temperatures': [], |
| 'order_parameters': [], |
| 'energies': [], |
| 'phases': [], |
| 'reheats': torch.zeros(batch, device=device), |
| } |
| |
| current = particles |
| prev_order = torch.zeros(batch, device=device) |
| stuck_count = torch.zeros(batch, device=device) |
| |
| bonds = None |
| |
| for step in range(self.max_steps): |
| |
| |
| pooled = current.semantic.mean(dim=1) |
| order = self.order_net(pooled).squeeze(-1) |
| |
| |
| energy = self.energy_net(pooled).squeeze(-1) |
| |
| |
| phase_name = [] |
| for b in range(batch): |
| if temperature[b] > 0.7: |
| phase_name.append('exploration') |
| elif temperature[b] > 0.3: |
| phase_name.append('crystallization') |
| else: |
| phase_name.append('solid') |
| |
| |
| |
| |
| |
| |
| temp_expanded = temperature.unsqueeze(-1).unsqueeze(-1).expand(-1, seq_len, 1) |
| refine_input = torch.cat([current.semantic, temp_expanded], dim=-1) |
| refinement = self.perturbation_net(refine_input) |
| |
| |
| refinement = refinement * temperature.unsqueeze(-1).unsqueeze(-1) |
| |
| |
| new_semantic = current.semantic + refinement * 0.1 |
| |
| |
| if flow_fn is not None: |
| current_for_flow = ParticleState( |
| semantic=new_semantic, |
| position=current.position, |
| charge=current.charge, |
| mass=current.mass, |
| spin=current.spin, |
| amplitude=current.amplitude, |
| phase=current.phase, |
| memory_trace=current.memory_trace, |
| ) |
| current_for_flow, flow_diag = flow_fn(current_for_flow) |
| new_semantic = current_for_flow.semantic |
| |
| |
| if bond_fn is not None: |
| bonds, bond_output, bond_diag = bond_fn( |
| ParticleState( |
| semantic=new_semantic, |
| position=current.position, |
| charge=current.charge, |
| mass=current.mass, |
| spin=current.spin, |
| amplitude=current.amplitude, |
| phase=current.phase, |
| memory_trace=current.memory_trace, |
| ), |
| bonds |
| ) |
| new_semantic = new_semantic + bond_output * 0.1 |
| |
| |
| |
| order_increase = order - prev_order |
| cooling = self.cooling_rate * (1 - order_increase.abs()) |
| temperature = temperature * cooling |
| |
| |
| is_stuck = (order - prev_order).abs() < self.reheat_threshold |
| stuck_count = stuck_count + is_stuck.float() |
| |
| |
| needs_reheat = stuck_count > 3 |
| temperature[needs_reheat] = self.initial_temp * 0.5 |
| stuck_count[needs_reheat] = 0 |
| diagnostics['reheats'] += needs_reheat.float() |
| |
| |
| temperature = torch.clamp(temperature, min=0.01) |
| |
| |
| current = ParticleState( |
| semantic=new_semantic, |
| position=current.position, |
| charge=current.charge, |
| mass=current.mass, |
| spin=current.spin, |
| amplitude=current.amplitude, |
| phase=current.phase, |
| memory_trace=current.memory_trace, |
| ) |
| |
| prev_order = order |
| |
| |
| diagnostics['temperatures'].append(temperature.mean().item()) |
| diagnostics['order_parameters'].append(order.mean().item()) |
| diagnostics['energies'].append(energy.mean().item()) |
| diagnostics['phases'].append(phase_name) |
| diagnostics['steps'].append(step) |
| |
| |
| |
| if step >= self.min_steps: |
| is_stable = order > self.stability_threshold |
| if is_stable.all(): |
| break |
| |
| diagnostics['total_steps'] = len(diagnostics['steps']) |
| diagnostics['final_temperature'] = temperature.mean().item() |
| diagnostics['final_order'] = prev_order.mean().item() |
| |
| return current, diagnostics |
|
|
|
|
| class PhaseTransitionLoss(nn.Module): |
| """Loss function that encourages proper phase transition behavior. |
| |
| The system should: |
| 1. Start with high entropy (exploration) |
| 2. Progressively reduce entropy (crystallization) |
| 3. Reach low energy state (solid) |
| 4. Not get stuck in local minima |
| """ |
| |
| def __init__(self): |
| super().__init__() |
| |
| def forward(self, diagnostics: Dict) -> torch.Tensor: |
| """Compute phase transition regularization loss.""" |
| losses = [] |
| |
| |
| temps = diagnostics['temperatures'] |
| if len(temps) > 1: |
| temp_diffs = torch.tensor(temps[1:]) - torch.tensor(temps[:-1]) |
| monotonic_loss = F.relu(temp_diffs).mean() |
| losses.append(monotonic_loss) |
| |
| |
| orders = diagnostics['order_parameters'] |
| if len(orders) > 1: |
| order_diffs = torch.tensor(orders[1:]) - torch.tensor(orders[:-1]) |
| order_loss = F.relu(-order_diffs).mean() |
| losses.append(order_loss) |
| |
| |
| energies = diagnostics['energies'] |
| if len(energies) > 1: |
| energy_diffs = torch.tensor(energies[1:]) - torch.tensor(energies[:-1]) |
| energy_loss = F.relu(energy_diffs).mean() |
| losses.append(energy_loss) |
| |
| |
| reheat_loss = diagnostics['reheats'].mean() * 0.1 |
| losses.append(reheat_loss) |
| |
| if losses: |
| return sum(losses) / len(losses) |
| return torch.tensor(0.0) |
|
|