| """ |
| Wave Processor — Meaning Resolution via Interference. |
| |
| Instead of softmax attention, ambiguity is resolved through wave |
| interference patterns. Constructive interference amplifies the |
| "correct" interpretation; destructive interference suppresses |
| alternatives. |
| |
| This is inspired by how holograms work — information is encoded |
| in interference patterns, not in individual beams. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from typing import Tuple, Optional |
|
|
| from .particle import ParticleState |
| from ..utils.math_utils import complex_wave, phase_coherence, resonance_detect |
|
|
|
|
| class WaveProcessor(nn.Module): |
| """Process information via wave interference. |
| |
| Key insight: when multiple particles are "in phase" (semantically |
| aligned), their waves constructively interfere, amplifying their |
| shared meaning. When "out of phase," they destructively interfere, |
| suppressing noise. |
| |
| This replaces the softmax(Q·K^T)·V computation in Transformers. |
| """ |
| |
| def __init__( |
| self, |
| d_semantic: int = 256, |
| n_wave_heads: int = 8, |
| d_wave: int = 32, |
| ): |
| super().__init__() |
| self.d_semantic = d_semantic |
| self.n_heads = n_wave_heads |
| self.d_wave = d_wave |
| |
| |
| self.wave_generator = nn.Sequential( |
| nn.Linear(d_semantic, n_wave_heads * d_wave * 2), |
| nn.GELU(), |
| ) |
| |
| |
| self.wave_receiver = nn.Sequential( |
| nn.Linear(n_wave_heads * d_wave, d_semantic), |
| nn.GELU(), |
| ) |
| |
| |
| self.resonance_net = nn.Sequential( |
| nn.Linear(d_wave, 64), |
| nn.GELU(), |
| nn.Linear(64, 1), |
| nn.Sigmoid(), |
| ) |
| |
| |
| self.interference_decoder = nn.Sequential( |
| nn.Linear(d_wave * 2, d_wave), |
| nn.GELU(), |
| nn.Linear(d_wave, d_wave), |
| ) |
| |
| def forward(self, particles: ParticleState) -> Tuple[torch.Tensor, dict]: |
| """Process particles via wave interference. |
| |
| Returns: |
| output: Processed semantic representation [batch, seq, d_semantic] |
| diagnostics: Wave processing diagnostics |
| """ |
| batch, seq_len, _ = particles.semantic.shape |
| device = particles.semantic.device |
| |
| |
| wave_params = self.wave_generator(particles.semantic) |
| wave_params = wave_params.reshape(batch, seq_len, self.n_heads, self.d_wave, 2) |
| |
| amplitudes = torch.sigmoid(wave_params[..., 0]) |
| phases = wave_params[..., 1] * np.pi |
| |
| |
| amplitudes = amplitudes * particles.amplitude.unsqueeze(-1).unsqueeze(-1) |
| |
| |
| waves = amplitudes * torch.exp(1j * phases) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| phase_i = phases.unsqueeze(3) |
| phase_j = phases.unsqueeze(2) |
| phase_diff = phase_i - phase_j |
| |
| |
| interference = torch.cos(phase_diff) |
| |
| |
| amp_i = amplitudes.unsqueeze(3) |
| amp_j = amplitudes.unsqueeze(2) |
| weighted_interference = interference * amp_i * amp_j |
| |
| |
| received = weighted_interference.sum(dim=2) |
| |
| |
| |
| resonance_scores = self.resonance_net(received) |
| resonance_scores = resonance_scores.squeeze(-1) |
| |
| |
| resonant_waves = received * resonance_scores.unsqueeze(-1) |
| |
| |
| |
| received_real = resonant_waves.real |
| received_imag = resonant_waves.imag |
| |
| |
| decoded = self.interference_decoder( |
| torch.cat([received_real, received_imag], dim=-1) |
| ) |
| |
| |
| decoded = decoded.reshape(batch, seq_len, self.n_heads * self.d_wave) |
| |
| |
| output = self.wave_receiver(decoded) |
| |
| |
| coherence = phase_coherence(particles.phase) |
| |
| diagnostics = { |
| 'coherence': coherence.mean().item(), |
| 'mean_amplitude': amplitudes.mean().item(), |
| 'resonance_strength': resonance_scores.mean().item(), |
| } |
| |
| return output, diagnostics |
|
|
|
|
| class WaveAttention(nn.Module): |
| """Wave-based alternative to standard attention. |
| |
| Instead of: softmax(Q·K^T/√d)·V |
| We compute: interference_pattern(waves)·V |
| |
| Key differences: |
| 1. No quadratic matrix — waves propagate linearly |
| 2. Phase-based selection — not softmax |
| 3. Constructive/destructive interference — not weighted average |
| """ |
| |
| def __init__(self, d_model: int = 256, n_heads: int = 8): |
| super().__init__() |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.d_head = d_model // n_heads |
| |
| |
| self.emit_proj = nn.Linear(d_model, n_heads * self.d_head * 2) |
| |
| |
| self.v_proj = nn.Linear(d_model, d_model) |
| |
| |
| self.out_proj = nn.Linear(d_model, d_model) |
| |
| |
| self.wavelength = nn.Parameter(torch.ones(n_heads) * 2.0) |
| |
| def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| """Wave-based attention. |
| |
| Args: |
| x: [batch, seq, d_model] |
| mask: [batch, seq, seq] optional attention mask |
| |
| Returns: |
| output: [batch, seq, d_model] |
| """ |
| batch, seq, _ = x.shape |
| |
| |
| wave_params = self.emit_proj(x) |
| wave_params = wave_params.reshape(batch, seq, self.n_heads, self.d_head, 2) |
| |
| amplitudes = torch.sigmoid(wave_params[..., 0]) |
| phases = wave_params[..., 1] * np.pi |
| |
| |
| V = self.v_proj(x).reshape(batch, seq, self.n_heads, self.d_head) |
| |
| |
| |
| p_i = phases.unsqueeze(2) |
| p_j = phases.unsqueeze(1) |
| |
| |
| positions = torch.arange(seq, device=x.device, dtype=torch.float32) |
| dist = (positions.unsqueeze(0) - positions.unsqueeze(1)).abs() |
| |
| |
| wavelength = self.wavelength.view(1, 1, 1, self.n_heads, 1) |
| phase_shift = 2 * np.pi * dist.unsqueeze(-1).unsqueeze(-1) / wavelength |
| |
| |
| total_phase_diff = p_j + phase_shift - p_i |
| |
| |
| interference = torch.cos(total_phase_diff) |
| |
| |
| amp_i = amplitudes.unsqueeze(2) |
| amp_j = amplitudes.unsqueeze(1) |
| weights = (interference * amp_i * amp_j).sum(dim=-1) |
| weights = weights / (self.d_head ** 0.5) |
| |
| if mask is not None: |
| weights = weights.masked_fill(mask.unsqueeze(-1) == 0, float('-inf')) |
| |
| |
| weights = F.softmax(weights, dim=2) |
| |
| |
| V_expanded = V.transpose(1, 2) |
| weights_t = weights.permute(0, 3, 1, 2) |
| output = torch.bmm(weights_t, V_expanded) |
| |
| |
| output = output.transpose(1, 2).reshape(batch, seq, self.d_model) |
| return self.out_proj(output) |
|
|