""" Kakeya-Inspired Orientation Image Encoder ========================================= The Kakeya conjecture (proven by Wang & Zahl, 2025) concerns the minimum size of sets containing unit line segments in every direction. The key insight applied here is that directional information across all orientations can be captured efficiently with sparse, anisotropic sampling. This encoder uses the shearlet transform as the directional decomposition front-end, capturing multi-scale, multi-orientation information. The shearlet coefficients are then processed by a lightweight neural network to produce a compact latent representation that preserves orientation information. Architecture: 1. Directional Decomposition: Fast Finite Shearlet Transform (FFST) - Anisotropic dilation A_a = diag(a, sqrt(a)) - Shearing S_s = [[1, s], [0, 1]] - Coefficients indexed by (scale, shear/orientation, position) 2. Orientation Encoder: CNN on shearlet coefficient tensor - Processes (scale, orientation) as channels - Learns compact latent representation 3. Decoder: Inverse shearlet transform + optional refinement network Inspired by: - ShearletX (Kolek et al., 2023) - shearlet domain explanations - SAD (Iinbor et al., 2025) - anisotropic spatial representation - Kakeya Conjecture (Wang & Zahl, 2025) - geometric measure theory """ import sys sys.path.insert(0, '/app/PyShearlets') import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from typing import Tuple, Optional, List from FFST import shearletTransformSpect, inverseShearletTransformSpect, scalesShearsAndSpectra class ShearletTransform: """Wrapper for PyShearlets FFST to work with PyTorch tensors.""" def __init__(self, num_scales: int = 4, real_coefficients: bool = True): self.num_scales = num_scales self.real_coefficients = real_coefficients self._psi_cache = {} def _get_psi(self, shape: Tuple[int, int]) -> np.ndarray: """Cache shearlet spectra for given shape.""" key = (shape, self.num_scales, self.real_coefficients) if key not in self._psi_cache: self._psi_cache[key] = scalesShearsAndSpectra( shape, numOfScales=self.num_scales, realCoefficients=self.real_coefficients, ) return self._psi_cache[key] def transform(self, image: torch.Tensor) -> torch.Tensor: """ Apply shearlet transform to batch of images. Args: image: (B, C, H, W) or (B, 1, H, W) or (B, H, W) Returns: coeffs: (B, num_shearlets, H, W) tensor """ # Handle different input shapes if image.dim() == 3: image = image.unsqueeze(1) # (B, 1, H, W) B, C, H, W = image.shape # For multi-channel images, average to grayscale for shearlet transform if C > 1: img_gray = image.mean(dim=1) # (B, H, W) else: img_gray = image.squeeze(1) # (B, H, W) # Ensure square and odd-sized for best results with FFST target_size = max(H, W) if target_size % 2 == 0: target_size += 1 # Pad if needed pad_h = target_size - H pad_w = target_size - W if pad_h > 0 or pad_w > 0: img_padded = F.pad(img_gray, (0, pad_w, 0, pad_h), mode='reflect') else: img_padded = img_gray # Get Psi for this shape psi = self._get_psi((target_size, target_size)) # Process each image in batch coeffs_list = [] for b in range(B): img_np = img_padded[b].cpu().numpy() coeffs, _ = shearletTransformSpect( img_np, Psi=psi, realCoefficients=self.real_coefficients, ) # coeffs shape: (H, W, num_shearlets) coeffs_tensor = torch.from_numpy(coeffs).float() # Crop back to original size if padded if pad_h > 0 or pad_w > 0: coeffs_tensor = coeffs_tensor[:H, :W, :] # Transpose to (num_shearlets, H, W) coeffs_tensor = coeffs_tensor.permute(2, 0, 1) coeffs_list.append(coeffs_tensor) return torch.stack(coeffs_list, dim=0) # (B, num_shearlets, H, W) def inverse(self, coeffs: torch.Tensor, original_shape: Tuple[int, int]) -> torch.Tensor: """ Apply inverse shearlet transform. Args: coeffs: (B, num_shearlets, H, W) tensor original_shape: (H, W) of output image Returns: image: (B, H, W) tensor """ B, _, H, W = coeffs.shape target_size = max(H, W) if target_size % 2 == 0: target_size += 1 psi = self._get_psi((target_size, target_size)) images = [] for b in range(B): # Permute to (H, W, num_shearlets) c_np = coeffs[b].permute(1, 2, 0).cpu().numpy() # Pad if needed to match psi shape if H < target_size or W < target_size: padded = np.zeros((target_size, target_size, c_np.shape[2])) padded[:H, :W, :] = c_np c_np = padded img = inverseShearletTransformSpect(c_np, Psi=psi) img = torch.from_numpy(img).float() # Crop back img = img[:original_shape[0], :original_shape[1]] images.append(img) return torch.stack(images, dim=0) # (B, H, W) @property def num_shearlets(self) -> int: """Number of shearlet coefficients per pixel.""" # Formula: 1 (lowpass) + sum(2^(j+2) for j in range(num_scales)) # = 1 + 4 + 8 + ... + 2^(num_scales+1) # = 1 + 4*(2^num_scales - 1) return 1 + 4 * (2**self.num_scales - 1) class KakeyaOrientationEncoder(nn.Module): """ Kakeya-inspired orientation encoder using shearlet transform. The encoder maps an image to a compact latent representation that preserves directional/edge information via shearlet coefficients. """ def __init__( self, input_channels: int = 1, latent_dim: int = 256, num_scales: int = 3, base_channels: int = 32, ): super().__init__() self.input_channels = input_channels self.latent_dim = latent_dim self.num_scales = num_scales self.base_channels = base_channels # Shearlet transform (non-differentiable, used as feature extractor) self.shearlet = ShearletTransform(num_scales=num_scales) # Number of shearlet channels n_sh = self.shearlet.num_shearlets # Encoder: process shearlet coefficients # Input: (B, n_sh, H, W) self.enc_conv1 = nn.Conv2d(n_sh, base_channels * 2, 3, padding=1) self.enc_bn1 = nn.BatchNorm2d(base_channels * 2) self.enc_conv2 = nn.Conv2d(base_channels * 2, base_channels * 4, 3, stride=2, padding=1) self.enc_bn2 = nn.BatchNorm2d(base_channels * 4) self.enc_conv3 = nn.Conv2d(base_channels * 4, base_channels * 8, 3, stride=2, padding=1) self.enc_bn3 = nn.BatchNorm2d(base_channels * 8) self.enc_conv4 = nn.Conv2d(base_channels * 8, base_channels * 8, 3, stride=2, padding=1) self.enc_bn4 = nn.BatchNorm2d(base_channels * 8) # Adaptive pooling to fixed size self.global_pool = nn.AdaptiveAvgPool2d(1) # Latent projection self.fc_mu = nn.Linear(base_channels * 8, latent_dim) self.fc_logvar = nn.Linear(base_channels * 8, latent_dim) # Orientation-specific branch # Aggregate shearlet energy per orientation direction self.orientation_fc = nn.Sequential( nn.Linear(n_sh, 64), nn.ReLU(), nn.Linear(64, 32), ) self.orientation_proj = nn.Linear(32, latent_dim) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) def extract_orientation_features(self, shearlet_coeffs: torch.Tensor) -> torch.Tensor: """ Extract orientation histogram from shearlet coefficients. Args: shearlet_coeffs: (B, n_sh, H, W) Returns: orientation_features: (B, latent_dim) - directional energy profile """ B, n_sh, H, W = shearlet_coeffs.shape # Compute energy per shearlet channel (average over spatial dimensions) energy = shearlet_coeffs.pow(2).mean(dim=[2, 3]) # (B, n_sh) # Normalize energy = energy / (energy.sum(dim=1, keepdim=True) + 1e-8) # Process through orientation MLP orient_feat = self.orientation_fc(energy) # (B, 32) orient_feat = self.orientation_proj(orient_feat) # (B, latent_dim) return orient_feat def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Encode image to latent representation. Args: x: (B, C, H, W) input image Returns: mu: (B, latent_dim) mean of latent distribution logvar: (B, latent_dim) log variance of latent distribution orientation_features: (B, latent_dim) orientation-preserving features """ # Shearlet transform (non-differentiable preprocessing) with torch.no_grad(): shearlet_coeffs = self.shearlet.transform(x) # Spatial encoding path h = F.relu(self.enc_bn1(self.enc_conv1(shearlet_coeffs))) h = F.relu(self.enc_bn2(self.enc_conv2(h))) h = F.relu(self.enc_bn3(self.enc_conv3(h))) h = F.relu(self.enc_bn4(self.enc_conv4(h))) h = self.global_pool(h).view(h.size(0), -1) # (B, base_channels*8) mu = self.fc_mu(h) logvar = self.fc_logvar(h) # Orientation features orientation_features = self.extract_orientation_features(shearlet_coeffs) return mu, logvar, orientation_features def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: """Reparameterization trick for VAE sampling.""" std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def encode(self, x: torch.Tensor) -> torch.Tensor: """Get deterministic latent representation.""" mu, logvar, orient = self.forward(x) # Combine spatial and orientation features z = self.reparameterize(mu, logvar) + 0.1 * orient return z class KakeyaOrientationDecoder(nn.Module): """Decoder that reconstructs image from latent representation.""" def __init__( self, latent_dim: int = 256, output_channels: int = 1, base_channels: int = 32, output_size: int = 128, ): super().__init__() self.latent_dim = latent_dim self.output_size = output_size # Calculate bottleneck size based on output size # After 3 stride-2 convs in encoder: output_size // 8 self.bottleneck_size = output_size // 8 if self.bottleneck_size < 4: self.bottleneck_size = 4 self.fc = nn.Linear(latent_dim, base_channels * 8 * self.bottleneck_size * self.bottleneck_size) # Transposed convolutions to upsample self.dec_conv1 = nn.ConvTranspose2d(base_channels * 8, base_channels * 8, 4, stride=2, padding=1) self.dec_bn1 = nn.BatchNorm2d(base_channels * 8) self.dec_conv2 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, 4, stride=2, padding=1) self.dec_bn2 = nn.BatchNorm2d(base_channels * 4) self.dec_conv3 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 4, stride=2, padding=1) self.dec_bn3 = nn.BatchNorm2d(base_channels * 2) self.dec_conv4 = nn.Conv2d(base_channels * 2, output_channels, 3, padding=1) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, z: torch.Tensor) -> torch.Tensor: """ Decode latent vector to image. Args: z: (B, latent_dim) Returns: image: (B, output_channels, output_size, output_size) """ h = self.fc(z) h = h.view(h.size(0), -1, self.bottleneck_size, self.bottleneck_size) h = F.relu(self.dec_bn1(self.dec_conv1(h))) h = F.relu(self.dec_bn2(self.dec_conv2(h))) h = F.relu(self.dec_bn3(self.dec_conv3(h))) # Resize to exact output size if needed if h.shape[2] != self.output_size or h.shape[3] != self.output_size: h = F.interpolate(h, size=(self.output_size, self.output_size), mode='bilinear', align_corners=False) x_recon = torch.sigmoid(self.dec_conv4(h)) return x_recon class KakeyaAutoencoder(nn.Module): """Complete Kakeya-inspired autoencoder for orientation-preserving encoding.""" def __init__( self, input_channels: int = 1, latent_dim: int = 256, num_scales: int = 3, base_channels: int = 32, output_size: int = 128, ): super().__init__() self.encoder = KakeyaOrientationEncoder( input_channels=input_channels, latent_dim=latent_dim, num_scales=num_scales, base_channels=base_channels, ) self.decoder = KakeyaOrientationDecoder( latent_dim=latent_dim, output_channels=input_channels, base_channels=base_channels, output_size=output_size, ) self.latent_dim = latent_dim def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Full forward pass. Returns: x_recon: reconstructed image mu: latent mean logvar: latent log variance orient_features: orientation features """ mu, logvar, orient_features = self.encoder(x) z = self.encoder.reparameterize(mu, logvar) + 0.1 * orient_features x_recon = self.decoder(z) return x_recon, mu, logvar, orient_features def encode(self, x: torch.Tensor) -> torch.Tensor: """Get latent representation.""" return self.encoder.encode(x) def decode(self, z: torch.Tensor) -> torch.Tensor: """Reconstruct from latent.""" return self.decoder(z) def kakeya_loss( x_recon: torch.Tensor, x_target: torch.Tensor, mu: torch.Tensor, logvar: torch.Tensor, orient_features: torch.Tensor, shearlet_coeffs_target: Optional[torch.Tensor] = None, recon_weight: float = 1.0, kl_weight: float = 0.001, orient_weight: float = 0.1, shearlet_weight: float = 0.5, ) -> Tuple[torch.Tensor, dict]: """ Combined loss for Kakeya autoencoder. Components: 1. Reconstruction loss (MSE) 2. KL divergence (VAE regularization) 3. Orientation preservation loss (encourage latent to capture directions) 4. Shearlet coefficient matching (optional, for orientation fidelity) """ # Reconstruction loss recon_loss = F.mse_loss(x_recon, x_target, reduction='mean') # KL divergence kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean() # Orientation preservation: encourage orientation features to be non-zero # and well-distributed (information content) orient_entropy = -torch.sum( F.softmax(orient_features, dim=1) * F.log_softmax(orient_features, dim=1), dim=1 ).mean() orient_loss = -orient_entropy # Maximize entropy = spread information # Shearlet coefficient matching (if provided) shearlet_loss = torch.tensor(0.0, device=x_recon.device) if shearlet_coeffs_target is not None: # Compute shearlet coeffs of reconstruction from FFST import shearletTransformSpect # This is expensive, only do it occasionally # Instead: use gradient-based approximation via edge detection pass total_loss = ( recon_weight * recon_loss + kl_weight * kl_loss + orient_weight * orient_loss + shearlet_weight * shearlet_loss ) metrics = { 'recon_loss': recon_loss.item(), 'kl_loss': kl_loss.item(), 'orient_loss': orient_loss.item(), 'total_loss': total_loss.item(), } return total_loss, metrics if __name__ == "__main__": # Quick test model = KakeyaAutoencoder(input_channels=1, latent_dim=256, num_scales=3, output_size=128) # Test with random input x = torch.randn(2, 1, 128, 128) x_recon, mu, logvar, orient = model(x) print(f"Input shape: {x.shape}") print(f"Reconstructed shape: {x_recon.shape}") print(f"Latent mu shape: {mu.shape}") print(f"Orientation features shape: {orient.shape}") # Test encode/decode z = model.encode(x) print(f"Encoded latent shape: {z.shape}") x_decoded = model.decode(z) print(f"Decoded shape: {x_decoded.shape}") print("\nModel test passed!")