kakeya-orientation-encoder / kakeya_orientation_encoder.py
DJLougen's picture
Upload kakeya_orientation_encoder.py
7da6a5f verified
"""
Kakeya-Inspired Orientation Image Encoder
=========================================
The Kakeya conjecture (proven by Wang & Zahl, 2025) concerns the minimum
size of sets containing unit line segments in every direction. The key
insight applied here is that directional information across all orientations
can be captured efficiently with sparse, anisotropic sampling.
This encoder uses the shearlet transform as the directional decomposition
front-end, capturing multi-scale, multi-orientation information. The shearlet
coefficients are then processed by a lightweight neural network to produce a
compact latent representation that preserves orientation information.
Architecture:
1. Directional Decomposition: Fast Finite Shearlet Transform (FFST)
- Anisotropic dilation A_a = diag(a, sqrt(a))
- Shearing S_s = [[1, s], [0, 1]]
- Coefficients indexed by (scale, shear/orientation, position)
2. Orientation Encoder: CNN on shearlet coefficient tensor
- Processes (scale, orientation) as channels
- Learns compact latent representation
3. Decoder: Inverse shearlet transform + optional refinement network
Inspired by:
- ShearletX (Kolek et al., 2023) - shearlet domain explanations
- SAD (Iinbor et al., 2025) - anisotropic spatial representation
- Kakeya Conjecture (Wang & Zahl, 2025) - geometric measure theory
"""
import sys
sys.path.insert(0, '/app/PyShearlets')
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional, List
from FFST import shearletTransformSpect, inverseShearletTransformSpect, scalesShearsAndSpectra
class ShearletTransform:
"""Wrapper for PyShearlets FFST to work with PyTorch tensors."""
def __init__(self, num_scales: int = 4, real_coefficients: bool = True):
self.num_scales = num_scales
self.real_coefficients = real_coefficients
self._psi_cache = {}
def _get_psi(self, shape: Tuple[int, int]) -> np.ndarray:
"""Cache shearlet spectra for given shape."""
key = (shape, self.num_scales, self.real_coefficients)
if key not in self._psi_cache:
self._psi_cache[key] = scalesShearsAndSpectra(
shape,
numOfScales=self.num_scales,
realCoefficients=self.real_coefficients,
)
return self._psi_cache[key]
def transform(self, image: torch.Tensor) -> torch.Tensor:
"""
Apply shearlet transform to batch of images.
Args:
image: (B, C, H, W) or (B, 1, H, W) or (B, H, W)
Returns:
coeffs: (B, num_shearlets, H, W) tensor
"""
# Handle different input shapes
if image.dim() == 3:
image = image.unsqueeze(1) # (B, 1, H, W)
B, C, H, W = image.shape
# For multi-channel images, average to grayscale for shearlet transform
if C > 1:
img_gray = image.mean(dim=1) # (B, H, W)
else:
img_gray = image.squeeze(1) # (B, H, W)
# Ensure square and odd-sized for best results with FFST
target_size = max(H, W)
if target_size % 2 == 0:
target_size += 1
# Pad if needed
pad_h = target_size - H
pad_w = target_size - W
if pad_h > 0 or pad_w > 0:
img_padded = F.pad(img_gray, (0, pad_w, 0, pad_h), mode='reflect')
else:
img_padded = img_gray
# Get Psi for this shape
psi = self._get_psi((target_size, target_size))
# Process each image in batch
coeffs_list = []
for b in range(B):
img_np = img_padded[b].cpu().numpy()
coeffs, _ = shearletTransformSpect(
img_np,
Psi=psi,
realCoefficients=self.real_coefficients,
)
# coeffs shape: (H, W, num_shearlets)
coeffs_tensor = torch.from_numpy(coeffs).float()
# Crop back to original size if padded
if pad_h > 0 or pad_w > 0:
coeffs_tensor = coeffs_tensor[:H, :W, :]
# Transpose to (num_shearlets, H, W)
coeffs_tensor = coeffs_tensor.permute(2, 0, 1)
coeffs_list.append(coeffs_tensor)
return torch.stack(coeffs_list, dim=0) # (B, num_shearlets, H, W)
def inverse(self, coeffs: torch.Tensor, original_shape: Tuple[int, int]) -> torch.Tensor:
"""
Apply inverse shearlet transform.
Args:
coeffs: (B, num_shearlets, H, W) tensor
original_shape: (H, W) of output image
Returns:
image: (B, H, W) tensor
"""
B, _, H, W = coeffs.shape
target_size = max(H, W)
if target_size % 2 == 0:
target_size += 1
psi = self._get_psi((target_size, target_size))
images = []
for b in range(B):
# Permute to (H, W, num_shearlets)
c_np = coeffs[b].permute(1, 2, 0).cpu().numpy()
# Pad if needed to match psi shape
if H < target_size or W < target_size:
padded = np.zeros((target_size, target_size, c_np.shape[2]))
padded[:H, :W, :] = c_np
c_np = padded
img = inverseShearletTransformSpect(c_np, Psi=psi)
img = torch.from_numpy(img).float()
# Crop back
img = img[:original_shape[0], :original_shape[1]]
images.append(img)
return torch.stack(images, dim=0) # (B, H, W)
@property
def num_shearlets(self) -> int:
"""Number of shearlet coefficients per pixel."""
# Formula: 1 (lowpass) + sum(2^(j+2) for j in range(num_scales))
# = 1 + 4 + 8 + ... + 2^(num_scales+1)
# = 1 + 4*(2^num_scales - 1)
return 1 + 4 * (2**self.num_scales - 1)
class KakeyaOrientationEncoder(nn.Module):
"""
Kakeya-inspired orientation encoder using shearlet transform.
The encoder maps an image to a compact latent representation that
preserves directional/edge information via shearlet coefficients.
"""
def __init__(
self,
input_channels: int = 1,
latent_dim: int = 256,
num_scales: int = 3,
base_channels: int = 32,
):
super().__init__()
self.input_channels = input_channels
self.latent_dim = latent_dim
self.num_scales = num_scales
self.base_channels = base_channels
# Shearlet transform (non-differentiable, used as feature extractor)
self.shearlet = ShearletTransform(num_scales=num_scales)
# Number of shearlet channels
n_sh = self.shearlet.num_shearlets
# Encoder: process shearlet coefficients
# Input: (B, n_sh, H, W)
self.enc_conv1 = nn.Conv2d(n_sh, base_channels * 2, 3, padding=1)
self.enc_bn1 = nn.BatchNorm2d(base_channels * 2)
self.enc_conv2 = nn.Conv2d(base_channels * 2, base_channels * 4, 3, stride=2, padding=1)
self.enc_bn2 = nn.BatchNorm2d(base_channels * 4)
self.enc_conv3 = nn.Conv2d(base_channels * 4, base_channels * 8, 3, stride=2, padding=1)
self.enc_bn3 = nn.BatchNorm2d(base_channels * 8)
self.enc_conv4 = nn.Conv2d(base_channels * 8, base_channels * 8, 3, stride=2, padding=1)
self.enc_bn4 = nn.BatchNorm2d(base_channels * 8)
# Adaptive pooling to fixed size
self.global_pool = nn.AdaptiveAvgPool2d(1)
# Latent projection
self.fc_mu = nn.Linear(base_channels * 8, latent_dim)
self.fc_logvar = nn.Linear(base_channels * 8, latent_dim)
# Orientation-specific branch
# Aggregate shearlet energy per orientation direction
self.orientation_fc = nn.Sequential(
nn.Linear(n_sh, 64),
nn.ReLU(),
nn.Linear(64, 32),
)
self.orientation_proj = nn.Linear(32, latent_dim)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def extract_orientation_features(self, shearlet_coeffs: torch.Tensor) -> torch.Tensor:
"""
Extract orientation histogram from shearlet coefficients.
Args:
shearlet_coeffs: (B, n_sh, H, W)
Returns:
orientation_features: (B, latent_dim) - directional energy profile
"""
B, n_sh, H, W = shearlet_coeffs.shape
# Compute energy per shearlet channel (average over spatial dimensions)
energy = shearlet_coeffs.pow(2).mean(dim=[2, 3]) # (B, n_sh)
# Normalize
energy = energy / (energy.sum(dim=1, keepdim=True) + 1e-8)
# Process through orientation MLP
orient_feat = self.orientation_fc(energy) # (B, 32)
orient_feat = self.orientation_proj(orient_feat) # (B, latent_dim)
return orient_feat
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Encode image to latent representation.
Args:
x: (B, C, H, W) input image
Returns:
mu: (B, latent_dim) mean of latent distribution
logvar: (B, latent_dim) log variance of latent distribution
orientation_features: (B, latent_dim) orientation-preserving features
"""
# Shearlet transform (non-differentiable preprocessing)
with torch.no_grad():
shearlet_coeffs = self.shearlet.transform(x)
# Spatial encoding path
h = F.relu(self.enc_bn1(self.enc_conv1(shearlet_coeffs)))
h = F.relu(self.enc_bn2(self.enc_conv2(h)))
h = F.relu(self.enc_bn3(self.enc_conv3(h)))
h = F.relu(self.enc_bn4(self.enc_conv4(h)))
h = self.global_pool(h).view(h.size(0), -1) # (B, base_channels*8)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
# Orientation features
orientation_features = self.extract_orientation_features(shearlet_coeffs)
return mu, logvar, orientation_features
def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
"""Reparameterization trick for VAE sampling."""
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Get deterministic latent representation."""
mu, logvar, orient = self.forward(x)
# Combine spatial and orientation features
z = self.reparameterize(mu, logvar) + 0.1 * orient
return z
class KakeyaOrientationDecoder(nn.Module):
"""Decoder that reconstructs image from latent representation."""
def __init__(
self,
latent_dim: int = 256,
output_channels: int = 1,
base_channels: int = 32,
output_size: int = 128,
):
super().__init__()
self.latent_dim = latent_dim
self.output_size = output_size
# Calculate bottleneck size based on output size
# After 3 stride-2 convs in encoder: output_size // 8
self.bottleneck_size = output_size // 8
if self.bottleneck_size < 4:
self.bottleneck_size = 4
self.fc = nn.Linear(latent_dim, base_channels * 8 * self.bottleneck_size * self.bottleneck_size)
# Transposed convolutions to upsample
self.dec_conv1 = nn.ConvTranspose2d(base_channels * 8, base_channels * 8, 4, stride=2, padding=1)
self.dec_bn1 = nn.BatchNorm2d(base_channels * 8)
self.dec_conv2 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, 4, stride=2, padding=1)
self.dec_bn2 = nn.BatchNorm2d(base_channels * 4)
self.dec_conv3 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 4, stride=2, padding=1)
self.dec_bn3 = nn.BatchNorm2d(base_channels * 2)
self.dec_conv4 = nn.Conv2d(base_channels * 2, output_channels, 3, padding=1)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, z: torch.Tensor) -> torch.Tensor:
"""
Decode latent vector to image.
Args:
z: (B, latent_dim)
Returns:
image: (B, output_channels, output_size, output_size)
"""
h = self.fc(z)
h = h.view(h.size(0), -1, self.bottleneck_size, self.bottleneck_size)
h = F.relu(self.dec_bn1(self.dec_conv1(h)))
h = F.relu(self.dec_bn2(self.dec_conv2(h)))
h = F.relu(self.dec_bn3(self.dec_conv3(h)))
# Resize to exact output size if needed
if h.shape[2] != self.output_size or h.shape[3] != self.output_size:
h = F.interpolate(h, size=(self.output_size, self.output_size), mode='bilinear', align_corners=False)
x_recon = torch.sigmoid(self.dec_conv4(h))
return x_recon
class KakeyaAutoencoder(nn.Module):
"""Complete Kakeya-inspired autoencoder for orientation-preserving encoding."""
def __init__(
self,
input_channels: int = 1,
latent_dim: int = 256,
num_scales: int = 3,
base_channels: int = 32,
output_size: int = 128,
):
super().__init__()
self.encoder = KakeyaOrientationEncoder(
input_channels=input_channels,
latent_dim=latent_dim,
num_scales=num_scales,
base_channels=base_channels,
)
self.decoder = KakeyaOrientationDecoder(
latent_dim=latent_dim,
output_channels=input_channels,
base_channels=base_channels,
output_size=output_size,
)
self.latent_dim = latent_dim
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Full forward pass.
Returns:
x_recon: reconstructed image
mu: latent mean
logvar: latent log variance
orient_features: orientation features
"""
mu, logvar, orient_features = self.encoder(x)
z = self.encoder.reparameterize(mu, logvar) + 0.1 * orient_features
x_recon = self.decoder(z)
return x_recon, mu, logvar, orient_features
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Get latent representation."""
return self.encoder.encode(x)
def decode(self, z: torch.Tensor) -> torch.Tensor:
"""Reconstruct from latent."""
return self.decoder(z)
def kakeya_loss(
x_recon: torch.Tensor,
x_target: torch.Tensor,
mu: torch.Tensor,
logvar: torch.Tensor,
orient_features: torch.Tensor,
shearlet_coeffs_target: Optional[torch.Tensor] = None,
recon_weight: float = 1.0,
kl_weight: float = 0.001,
orient_weight: float = 0.1,
shearlet_weight: float = 0.5,
) -> Tuple[torch.Tensor, dict]:
"""
Combined loss for Kakeya autoencoder.
Components:
1. Reconstruction loss (MSE)
2. KL divergence (VAE regularization)
3. Orientation preservation loss (encourage latent to capture directions)
4. Shearlet coefficient matching (optional, for orientation fidelity)
"""
# Reconstruction loss
recon_loss = F.mse_loss(x_recon, x_target, reduction='mean')
# KL divergence
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()
# Orientation preservation: encourage orientation features to be non-zero
# and well-distributed (information content)
orient_entropy = -torch.sum(
F.softmax(orient_features, dim=1) * F.log_softmax(orient_features, dim=1),
dim=1
).mean()
orient_loss = -orient_entropy # Maximize entropy = spread information
# Shearlet coefficient matching (if provided)
shearlet_loss = torch.tensor(0.0, device=x_recon.device)
if shearlet_coeffs_target is not None:
# Compute shearlet coeffs of reconstruction
from FFST import shearletTransformSpect
# This is expensive, only do it occasionally
# Instead: use gradient-based approximation via edge detection
pass
total_loss = (
recon_weight * recon_loss +
kl_weight * kl_loss +
orient_weight * orient_loss +
shearlet_weight * shearlet_loss
)
metrics = {
'recon_loss': recon_loss.item(),
'kl_loss': kl_loss.item(),
'orient_loss': orient_loss.item(),
'total_loss': total_loss.item(),
}
return total_loss, metrics
if __name__ == "__main__":
# Quick test
model = KakeyaAutoencoder(input_channels=1, latent_dim=256, num_scales=3, output_size=128)
# Test with random input
x = torch.randn(2, 1, 128, 128)
x_recon, mu, logvar, orient = model(x)
print(f"Input shape: {x.shape}")
print(f"Reconstructed shape: {x_recon.shape}")
print(f"Latent mu shape: {mu.shape}")
print(f"Orientation features shape: {orient.shape}")
# Test encode/decode
z = model.encode(x)
print(f"Encoded latent shape: {z.shape}")
x_decoded = model.decode(z)
print(f"Decoded shape: {x_decoded.shape}")
print("\nModel test passed!")