flownet / core /phase_engine.py
Ashu9675's picture
Add FlowNet: Post-Transformer Architecture
d4fff7c
Raw
History Blame Contribute Delete
11.3 kB
"""
Phase Transition Engine — Adaptive Computation via Physics.
Instead of fixed-depth feed-forward layers, computation adapts via
phase transitions:
EXPLORATION (high temp) → CRYSTALLIZATION (cooling) → SOLID (stable)
Simple problems crystallize quickly (few steps).
Hard problems take longer to reach stability.
This is NOT a halting mechanism — it's a physical process where
the system naturally finds its ground state.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Dict, Optional
from .particle import ParticleState
class PhaseEngine(nn.Module):
"""Adaptive computation via simulated phase transitions.
The system has three phases:
1. EXPLORATION (temperature = high)
- Particles move freely
- Many temporary bonds form and break
- Broad search over meaning space
- High entropy, high energy
2. CRYSTALLIZATION (temperature = cooling)
- Strong patterns emerge
- Weak bonds break, strong bonds solidify
- Meaning begins to crystallize
- Decreasing entropy, decreasing energy
3. SOLID (temperature = low)
- Structure is stable
- Output is determined
- Low entropy, low energy
If the system gets stuck (local minimum), it reheats and tries again.
"""
def __init__(
self,
d_semantic: int = 256,
min_steps: int = 2,
max_steps: int = 32,
initial_temp: float = 1.0,
cooling_rate: float = 0.9,
reheat_threshold: float = 0.1,
stability_threshold: float = 0.95,
):
super().__init__()
self.d_semantic = d_semantic
self.min_steps = min_steps
self.max_steps = max_steps
self.initial_temp = initial_temp
self.cooling_rate = cooling_rate
self.reheat_threshold = reheat_threshold
self.stability_threshold = stability_threshold
# Order parameter network — measures how "crystallized" the system is
self.order_net = nn.Sequential(
nn.Linear(d_semantic, 128),
nn.GELU(),
nn.Linear(128, 1),
nn.Sigmoid(),
)
# Energy network — computes system energy (lower = more stable)
self.energy_net = nn.Sequential(
nn.Linear(d_semantic, 128),
nn.GELU(),
nn.Linear(128, 1),
)
# Perturbation network — how to perturb when stuck
self.perturbation_net = nn.Sequential(
nn.Linear(d_semantic + 1, 128), # +1 for temperature
nn.GELU(),
nn.Linear(128, d_semantic),
nn.Tanh(),
)
# Refinement step — how particles update during crystallization
self.refine_net = nn.Sequential(
nn.Linear(d_semantic * 2, 256),
nn.GELU(),
nn.Linear(256, d_semantic),
)
def forward(
self,
particles: ParticleState,
flow_fn=None,
bond_fn=None,
) -> Tuple[ParticleState, Dict]:
"""Run phase transition computation.
The system starts hot (exploration), cools down (crystallization),
and stabilizes (solid). If stuck, it reheats.
Args:
particles: Initial particle state
flow_fn: Optional flow field function to call each step
bond_fn: Optional bond system function to call each step
Returns:
Final particle state after reaching solid phase
Diagnostics about the phase transition process
"""
batch, seq_len, d = particles.semantic.shape
device = particles.semantic.device
temperature = torch.ones(batch, device=device) * self.initial_temp
diagnostics = {
'steps': [],
'temperatures': [],
'order_parameters': [],
'energies': [],
'phases': [],
'reheats': torch.zeros(batch, device=device),
}
current = particles
prev_order = torch.zeros(batch, device=device)
stuck_count = torch.zeros(batch, device=device)
bonds = None
for step in range(self.max_steps):
# === Measure current state ===
# Order parameter: how crystallized is the system?
pooled = current.semantic.mean(dim=1) # [batch, d]
order = self.order_net(pooled).squeeze(-1) # [batch]
# Energy: how stable is the system?
energy = self.energy_net(pooled).squeeze(-1) # [batch]
# === Determine phase ===
phase_name = []
for b in range(batch):
if temperature[b] > 0.7:
phase_name.append('exploration')
elif temperature[b] > 0.3:
phase_name.append('crystallization')
else:
phase_name.append('solid')
# === Refinement step (temperature-dependent) ===
# At high temp: large perturbations, broad exploration
# At low temp: small refinements, fine-tuning
# Compute refinement based on current state and temperature
temp_expanded = temperature.unsqueeze(-1).unsqueeze(-1).expand(-1, seq_len, 1)
refine_input = torch.cat([current.semantic, temp_expanded], dim=-1)
refinement = self.perturbation_net(refine_input)
# Scale refinement by temperature
refinement = refinement * temperature.unsqueeze(-1).unsqueeze(-1)
# Apply refinement
new_semantic = current.semantic + refinement * 0.1
# If flow function provided, run one flow step
if flow_fn is not None:
current_for_flow = ParticleState(
semantic=new_semantic,
position=current.position,
charge=current.charge,
mass=current.mass,
spin=current.spin,
amplitude=current.amplitude,
phase=current.phase,
memory_trace=current.memory_trace,
)
current_for_flow, flow_diag = flow_fn(current_for_flow)
new_semantic = current_for_flow.semantic
# If bond function provided, run one bond step
if bond_fn is not None:
bonds, bond_output, bond_diag = bond_fn(
ParticleState(
semantic=new_semantic,
position=current.position,
charge=current.charge,
mass=current.mass,
spin=current.spin,
amplitude=current.amplitude,
phase=current.phase,
memory_trace=current.memory_trace,
),
bonds
)
new_semantic = new_semantic + bond_output * 0.1
# === Update temperature ===
# Cool based on order parameter increase
order_increase = order - prev_order
cooling = self.cooling_rate * (1 - order_increase.abs())
temperature = temperature * cooling
# Check for stuck state (local minimum)
is_stuck = (order - prev_order).abs() < self.reheat_threshold
stuck_count = stuck_count + is_stuck.float()
# Reheat if stuck for too long
needs_reheat = stuck_count > 3
temperature[needs_reheat] = self.initial_temp * 0.5
stuck_count[needs_reheat] = 0
diagnostics['reheats'] += needs_reheat.float()
# Ensure minimum temperature
temperature = torch.clamp(temperature, min=0.01)
# === Update particle state ===
current = ParticleState(
semantic=new_semantic,
position=current.position,
charge=current.charge,
mass=current.mass,
spin=current.spin,
amplitude=current.amplitude,
phase=current.phase,
memory_trace=current.memory_trace,
)
prev_order = order
# Record diagnostics
diagnostics['temperatures'].append(temperature.mean().item())
diagnostics['order_parameters'].append(order.mean().item())
diagnostics['energies'].append(energy.mean().item())
diagnostics['phases'].append(phase_name)
diagnostics['steps'].append(step)
# === Check stability ===
# If all batches are stable AND we've done minimum steps
if step >= self.min_steps:
is_stable = order > self.stability_threshold
if is_stable.all():
break
diagnostics['total_steps'] = len(diagnostics['steps'])
diagnostics['final_temperature'] = temperature.mean().item()
diagnostics['final_order'] = prev_order.mean().item()
return current, diagnostics
class PhaseTransitionLoss(nn.Module):
"""Loss function that encourages proper phase transition behavior.
The system should:
1. Start with high entropy (exploration)
2. Progressively reduce entropy (crystallization)
3. Reach low energy state (solid)
4. Not get stuck in local minima
"""
def __init__(self):
super().__init__()
def forward(self, diagnostics: Dict) -> torch.Tensor:
"""Compute phase transition regularization loss."""
losses = []
# 1. Monotonic cooling — temperature should generally decrease
temps = diagnostics['temperatures']
if len(temps) > 1:
temp_diffs = torch.tensor(temps[1:]) - torch.tensor(temps[:-1])
monotonic_loss = F.relu(temp_diffs).mean() # penalize reheating
losses.append(monotonic_loss)
# 2. Order should increase — system should crystallize
orders = diagnostics['order_parameters']
if len(orders) > 1:
order_diffs = torch.tensor(orders[1:]) - torch.tensor(orders[:-1])
order_loss = F.relu(-order_diffs).mean() # penalize decreasing order
losses.append(order_loss)
# 3. Energy should decrease — system should find lower energy states
energies = diagnostics['energies']
if len(energies) > 1:
energy_diffs = torch.tensor(energies[1:]) - torch.tensor(energies[:-1])
energy_loss = F.relu(energy_diffs).mean() # penalize increasing energy
losses.append(energy_loss)
# 4. Penalize too many reheats
reheat_loss = diagnostics['reheats'].mean() * 0.1
losses.append(reheat_loss)
if losses:
return sum(losses) / len(losses)
return torch.tensor(0.0)