""" Wave Processor — Meaning Resolution via Interference. Instead of softmax attention, ambiguity is resolved through wave interference patterns. Constructive interference amplifies the "correct" interpretation; destructive interference suppresses alternatives. This is inspired by how holograms work — information is encoded in interference patterns, not in individual beams. """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from typing import Tuple, Optional from .particle import ParticleState from ..utils.math_utils import complex_wave, phase_coherence, resonance_detect class WaveProcessor(nn.Module): """Process information via wave interference. Key insight: when multiple particles are "in phase" (semantically aligned), their waves constructively interfere, amplifying their shared meaning. When "out of phase," they destructively interfere, suppressing noise. This replaces the softmax(Q·K^T)·V computation in Transformers. """ def __init__( self, d_semantic: int = 256, n_wave_heads: int = 8, d_wave: int = 32, ): super().__init__() self.d_semantic = d_semantic self.n_heads = n_wave_heads self.d_wave = d_wave # Wave generation — particles emit waves based on their state self.wave_generator = nn.Sequential( nn.Linear(d_semantic, n_wave_heads * d_wave * 2), # amplitude + phase nn.GELU(), ) # Wave reception — particles decode received waves self.wave_receiver = nn.Sequential( nn.Linear(n_wave_heads * d_wave, d_semantic), nn.GELU(), ) # Resonance detector — finds standing wave patterns self.resonance_net = nn.Sequential( nn.Linear(d_wave, 64), nn.GELU(), nn.Linear(64, 1), nn.Sigmoid(), ) # Interference decoder — extracts meaning from interference self.interference_decoder = nn.Sequential( nn.Linear(d_wave * 2, d_wave), nn.GELU(), nn.Linear(d_wave, d_wave), ) def forward(self, particles: ParticleState) -> Tuple[torch.Tensor, dict]: """Process particles via wave interference. Returns: output: Processed semantic representation [batch, seq, d_semantic] diagnostics: Wave processing diagnostics """ batch, seq_len, _ = particles.semantic.shape device = particles.semantic.device # === Step 1: Generate waves from particles === wave_params = self.wave_generator(particles.semantic) wave_params = wave_params.reshape(batch, seq_len, self.n_heads, self.d_wave, 2) amplitudes = torch.sigmoid(wave_params[..., 0]) # [batch, seq, heads, d_wave] phases = wave_params[..., 1] * np.pi # [batch, seq, heads, d_wave] # Modulate amplitude by particle's own amplitude amplitudes = amplitudes * particles.amplitude.unsqueeze(-1).unsqueeze(-1) # === Step 2: Create complex waves === waves = amplitudes * torch.exp(1j * phases) # [batch, seq, heads, d_wave] # === Step 3: Compute interference patterns === # For each head, compute the total wave field # This is the sum of all particle waves at each point # Total field at each position = sum of all other particles' waves # received at this position (with distance-based phase shift) # Simplified: interference = sum of waves weighted by phase alignment # Phase alignment between particles i and j = cos(phase_i - phase_j) phase_i = phases.unsqueeze(3) # [batch, seq, 1, heads, d_wave] phase_j = phases.unsqueeze(2) # [batch, 1, seq, heads, d_wave] phase_diff = phase_i - phase_j # [batch, seq, seq, heads, d_wave] # Interference factor: constructive (+1) or destructive (-1) interference = torch.cos(phase_diff) # [batch, seq, seq, heads, d_wave] # Amplitude-weighted interference amp_i = amplitudes.unsqueeze(3) amp_j = amplitudes.unsqueeze(2) weighted_interference = interference * amp_i * amp_j # Sum over source particles → received interference at each position received = weighted_interference.sum(dim=2) # [batch, seq, heads, d_wave] # === Step 4: Detect resonances === # Resonances are standing wave patterns — stable meanings resonance_scores = self.resonance_net(received) # [batch, seq, heads, d_wave, 1] resonance_scores = resonance_scores.squeeze(-1) # Weight received waves by resonance resonant_waves = received * resonance_scores.unsqueeze(-1) # === Step 5: Decode interference patterns === # Split into real and imaginary parts for processing received_real = resonant_waves.real received_imag = resonant_waves.imag # Decode meaning from interference decoded = self.interference_decoder( torch.cat([received_real, received_imag], dim=-1) ) # [batch, seq, heads, d_wave] # Flatten heads decoded = decoded.reshape(batch, seq_len, self.n_heads * self.d_wave) # Final projection output = self.wave_receiver(decoded) # === Diagnostics === coherence = phase_coherence(particles.phase) diagnostics = { 'coherence': coherence.mean().item(), 'mean_amplitude': amplitudes.mean().item(), 'resonance_strength': resonance_scores.mean().item(), } return output, diagnostics class WaveAttention(nn.Module): """Wave-based alternative to standard attention. Instead of: softmax(Q·K^T/√d)·V We compute: interference_pattern(waves)·V Key differences: 1. No quadratic matrix — waves propagate linearly 2. Phase-based selection — not softmax 3. Constructive/destructive interference — not weighted average """ def __init__(self, d_model: int = 256, n_heads: int = 8): super().__init__() self.d_model = d_model self.n_heads = n_heads self.d_head = d_model // n_heads # Wave generation (replaces Q, K projections) self.emit_proj = nn.Linear(d_model, n_heads * self.d_head * 2) # amp + phase # Value projection (same as Transformer — values are values) self.v_proj = nn.Linear(d_model, d_model) # Output projection self.out_proj = nn.Linear(d_model, d_model) # Learned wavelength (how far waves travel) self.wavelength = nn.Parameter(torch.ones(n_heads) * 2.0) def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Wave-based attention. Args: x: [batch, seq, d_model] mask: [batch, seq, seq] optional attention mask Returns: output: [batch, seq, d_model] """ batch, seq, _ = x.shape # Generate waves wave_params = self.emit_proj(x) wave_params = wave_params.reshape(batch, seq, self.n_heads, self.d_head, 2) amplitudes = torch.sigmoid(wave_params[..., 0]) phases = wave_params[..., 1] * np.pi # Values V = self.v_proj(x).reshape(batch, seq, self.n_heads, self.d_head) # Compute interference between all pairs # Phase difference p_i = phases.unsqueeze(2) # [batch, seq_i, 1, heads, d_head] p_j = phases.unsqueeze(1) # [batch, 1, seq_j, heads, d_head] # Distance-based phase shift (waves travel and phase-shift) positions = torch.arange(seq, device=x.device, dtype=torch.float32) dist = (positions.unsqueeze(0) - positions.unsqueeze(1)).abs() # [seq, seq] # Phase shift = 2π * distance / wavelength wavelength = self.wavelength.view(1, 1, 1, self.n_heads, 1) phase_shift = 2 * np.pi * dist.unsqueeze(-1).unsqueeze(-1) / wavelength # Total phase difference (emitter phase + travel phase - receiver phase) total_phase_diff = p_j + phase_shift - p_i # Interference factor interference = torch.cos(total_phase_diff) # [batch, seq, seq, heads, d_head] # Amplitude-weighted amp_i = amplitudes.unsqueeze(2) amp_j = amplitudes.unsqueeze(1) weights = (interference * amp_i * amp_j).sum(dim=-1) # [batch, seq, seq, heads] weights = weights / (self.d_head ** 0.5) if mask is not None: weights = weights.masked_fill(mask.unsqueeze(-1) == 0, float('-inf')) # Normalize (like softmax but based on wave physics) weights = F.softmax(weights, dim=2) # Apply to values V_expanded = V.transpose(1, 2) # [batch, heads, seq, d_head] weights_t = weights.permute(0, 3, 1, 2) # [batch, heads, seq_i, seq_j] output = torch.bmm(weights_t, V_expanded) # [batch, heads, seq, d_head] # Reshape and project output = output.transpose(1, 2).reshape(batch, seq, self.d_model) return self.out_proj(output)