flownet / core /wave_processor.py
Ashu9675's picture
Add FlowNet: Post-Transformer Architecture
d4fff7c
Raw
History Blame Contribute Delete
9.55 kB
"""
Wave Processor — Meaning Resolution via Interference.
Instead of softmax attention, ambiguity is resolved through wave
interference patterns. Constructive interference amplifies the
"correct" interpretation; destructive interference suppresses
alternatives.
This is inspired by how holograms work — information is encoded
in interference patterns, not in individual beams.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional
from .particle import ParticleState
from ..utils.math_utils import complex_wave, phase_coherence, resonance_detect
class WaveProcessor(nn.Module):
"""Process information via wave interference.
Key insight: when multiple particles are "in phase" (semantically
aligned), their waves constructively interfere, amplifying their
shared meaning. When "out of phase," they destructively interfere,
suppressing noise.
This replaces the softmax(Q·K^T)·V computation in Transformers.
"""
def __init__(
self,
d_semantic: int = 256,
n_wave_heads: int = 8,
d_wave: int = 32,
):
super().__init__()
self.d_semantic = d_semantic
self.n_heads = n_wave_heads
self.d_wave = d_wave
# Wave generation — particles emit waves based on their state
self.wave_generator = nn.Sequential(
nn.Linear(d_semantic, n_wave_heads * d_wave * 2), # amplitude + phase
nn.GELU(),
)
# Wave reception — particles decode received waves
self.wave_receiver = nn.Sequential(
nn.Linear(n_wave_heads * d_wave, d_semantic),
nn.GELU(),
)
# Resonance detector — finds standing wave patterns
self.resonance_net = nn.Sequential(
nn.Linear(d_wave, 64),
nn.GELU(),
nn.Linear(64, 1),
nn.Sigmoid(),
)
# Interference decoder — extracts meaning from interference
self.interference_decoder = nn.Sequential(
nn.Linear(d_wave * 2, d_wave),
nn.GELU(),
nn.Linear(d_wave, d_wave),
)
def forward(self, particles: ParticleState) -> Tuple[torch.Tensor, dict]:
"""Process particles via wave interference.
Returns:
output: Processed semantic representation [batch, seq, d_semantic]
diagnostics: Wave processing diagnostics
"""
batch, seq_len, _ = particles.semantic.shape
device = particles.semantic.device
# === Step 1: Generate waves from particles ===
wave_params = self.wave_generator(particles.semantic)
wave_params = wave_params.reshape(batch, seq_len, self.n_heads, self.d_wave, 2)
amplitudes = torch.sigmoid(wave_params[..., 0]) # [batch, seq, heads, d_wave]
phases = wave_params[..., 1] * np.pi # [batch, seq, heads, d_wave]
# Modulate amplitude by particle's own amplitude
amplitudes = amplitudes * particles.amplitude.unsqueeze(-1).unsqueeze(-1)
# === Step 2: Create complex waves ===
waves = amplitudes * torch.exp(1j * phases) # [batch, seq, heads, d_wave]
# === Step 3: Compute interference patterns ===
# For each head, compute the total wave field
# This is the sum of all particle waves at each point
# Total field at each position = sum of all other particles' waves
# received at this position (with distance-based phase shift)
# Simplified: interference = sum of waves weighted by phase alignment
# Phase alignment between particles i and j = cos(phase_i - phase_j)
phase_i = phases.unsqueeze(3) # [batch, seq, 1, heads, d_wave]
phase_j = phases.unsqueeze(2) # [batch, 1, seq, heads, d_wave]
phase_diff = phase_i - phase_j # [batch, seq, seq, heads, d_wave]
# Interference factor: constructive (+1) or destructive (-1)
interference = torch.cos(phase_diff) # [batch, seq, seq, heads, d_wave]
# Amplitude-weighted interference
amp_i = amplitudes.unsqueeze(3)
amp_j = amplitudes.unsqueeze(2)
weighted_interference = interference * amp_i * amp_j
# Sum over source particles → received interference at each position
received = weighted_interference.sum(dim=2) # [batch, seq, heads, d_wave]
# === Step 4: Detect resonances ===
# Resonances are standing wave patterns — stable meanings
resonance_scores = self.resonance_net(received) # [batch, seq, heads, d_wave, 1]
resonance_scores = resonance_scores.squeeze(-1)
# Weight received waves by resonance
resonant_waves = received * resonance_scores.unsqueeze(-1)
# === Step 5: Decode interference patterns ===
# Split into real and imaginary parts for processing
received_real = resonant_waves.real
received_imag = resonant_waves.imag
# Decode meaning from interference
decoded = self.interference_decoder(
torch.cat([received_real, received_imag], dim=-1)
) # [batch, seq, heads, d_wave]
# Flatten heads
decoded = decoded.reshape(batch, seq_len, self.n_heads * self.d_wave)
# Final projection
output = self.wave_receiver(decoded)
# === Diagnostics ===
coherence = phase_coherence(particles.phase)
diagnostics = {
'coherence': coherence.mean().item(),
'mean_amplitude': amplitudes.mean().item(),
'resonance_strength': resonance_scores.mean().item(),
}
return output, diagnostics
class WaveAttention(nn.Module):
"""Wave-based alternative to standard attention.
Instead of: softmax(Q·K^T/√d)·V
We compute: interference_pattern(waves)·V
Key differences:
1. No quadratic matrix — waves propagate linearly
2. Phase-based selection — not softmax
3. Constructive/destructive interference — not weighted average
"""
def __init__(self, d_model: int = 256, n_heads: int = 8):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
# Wave generation (replaces Q, K projections)
self.emit_proj = nn.Linear(d_model, n_heads * self.d_head * 2) # amp + phase
# Value projection (same as Transformer — values are values)
self.v_proj = nn.Linear(d_model, d_model)
# Output projection
self.out_proj = nn.Linear(d_model, d_model)
# Learned wavelength (how far waves travel)
self.wavelength = nn.Parameter(torch.ones(n_heads) * 2.0)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Wave-based attention.
Args:
x: [batch, seq, d_model]
mask: [batch, seq, seq] optional attention mask
Returns:
output: [batch, seq, d_model]
"""
batch, seq, _ = x.shape
# Generate waves
wave_params = self.emit_proj(x)
wave_params = wave_params.reshape(batch, seq, self.n_heads, self.d_head, 2)
amplitudes = torch.sigmoid(wave_params[..., 0])
phases = wave_params[..., 1] * np.pi
# Values
V = self.v_proj(x).reshape(batch, seq, self.n_heads, self.d_head)
# Compute interference between all pairs
# Phase difference
p_i = phases.unsqueeze(2) # [batch, seq_i, 1, heads, d_head]
p_j = phases.unsqueeze(1) # [batch, 1, seq_j, heads, d_head]
# Distance-based phase shift (waves travel and phase-shift)
positions = torch.arange(seq, device=x.device, dtype=torch.float32)
dist = (positions.unsqueeze(0) - positions.unsqueeze(1)).abs() # [seq, seq]
# Phase shift = 2π * distance / wavelength
wavelength = self.wavelength.view(1, 1, 1, self.n_heads, 1)
phase_shift = 2 * np.pi * dist.unsqueeze(-1).unsqueeze(-1) / wavelength
# Total phase difference (emitter phase + travel phase - receiver phase)
total_phase_diff = p_j + phase_shift - p_i
# Interference factor
interference = torch.cos(total_phase_diff) # [batch, seq, seq, heads, d_head]
# Amplitude-weighted
amp_i = amplitudes.unsqueeze(2)
amp_j = amplitudes.unsqueeze(1)
weights = (interference * amp_i * amp_j).sum(dim=-1) # [batch, seq, seq, heads]
weights = weights / (self.d_head ** 0.5)
if mask is not None:
weights = weights.masked_fill(mask.unsqueeze(-1) == 0, float('-inf'))
# Normalize (like softmax but based on wave physics)
weights = F.softmax(weights, dim=2)
# Apply to values
V_expanded = V.transpose(1, 2) # [batch, heads, seq, d_head]
weights_t = weights.permute(0, 3, 1, 2) # [batch, heads, seq_i, seq_j]
output = torch.bmm(weights_t, V_expanded) # [batch, heads, seq, d_head]
# Reshape and project
output = output.transpose(1, 2).reshape(batch, seq, self.d_model)
return self.out_proj(output)