| """ |
| Kakeya-Inspired Orientation Image Encoder |
| ========================================= |
| |
| The Kakeya conjecture (proven by Wang & Zahl, 2025) concerns the minimum |
| size of sets containing unit line segments in every direction. The key |
| insight applied here is that directional information across all orientations |
| can be captured efficiently with sparse, anisotropic sampling. |
| |
| This encoder uses the shearlet transform as the directional decomposition |
| front-end, capturing multi-scale, multi-orientation information. The shearlet |
| coefficients are then processed by a lightweight neural network to produce a |
| compact latent representation that preserves orientation information. |
| |
| Architecture: |
| 1. Directional Decomposition: Fast Finite Shearlet Transform (FFST) |
| - Anisotropic dilation A_a = diag(a, sqrt(a)) |
| - Shearing S_s = [[1, s], [0, 1]] |
| - Coefficients indexed by (scale, shear/orientation, position) |
| |
| 2. Orientation Encoder: CNN on shearlet coefficient tensor |
| - Processes (scale, orientation) as channels |
| - Learns compact latent representation |
| |
| 3. Decoder: Inverse shearlet transform + optional refinement network |
| |
| Inspired by: |
| - ShearletX (Kolek et al., 2023) - shearlet domain explanations |
| - SAD (Iinbor et al., 2025) - anisotropic spatial representation |
| - Kakeya Conjecture (Wang & Zahl, 2025) - geometric measure theory |
| """ |
|
|
| import sys |
| sys.path.insert(0, '/app/PyShearlets') |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from typing import Tuple, Optional, List |
| from FFST import shearletTransformSpect, inverseShearletTransformSpect, scalesShearsAndSpectra |
|
|
|
|
| class ShearletTransform: |
| """Wrapper for PyShearlets FFST to work with PyTorch tensors.""" |
| |
| def __init__(self, num_scales: int = 4, real_coefficients: bool = True): |
| self.num_scales = num_scales |
| self.real_coefficients = real_coefficients |
| self._psi_cache = {} |
| |
| def _get_psi(self, shape: Tuple[int, int]) -> np.ndarray: |
| """Cache shearlet spectra for given shape.""" |
| key = (shape, self.num_scales, self.real_coefficients) |
| if key not in self._psi_cache: |
| self._psi_cache[key] = scalesShearsAndSpectra( |
| shape, |
| numOfScales=self.num_scales, |
| realCoefficients=self.real_coefficients, |
| ) |
| return self._psi_cache[key] |
| |
| def transform(self, image: torch.Tensor) -> torch.Tensor: |
| """ |
| Apply shearlet transform to batch of images. |
| |
| Args: |
| image: (B, C, H, W) or (B, 1, H, W) or (B, H, W) |
| |
| Returns: |
| coeffs: (B, num_shearlets, H, W) tensor |
| """ |
| |
| if image.dim() == 3: |
| image = image.unsqueeze(1) |
| |
| B, C, H, W = image.shape |
| |
| |
| if C > 1: |
| img_gray = image.mean(dim=1) |
| else: |
| img_gray = image.squeeze(1) |
| |
| |
| target_size = max(H, W) |
| if target_size % 2 == 0: |
| target_size += 1 |
| |
| |
| pad_h = target_size - H |
| pad_w = target_size - W |
| if pad_h > 0 or pad_w > 0: |
| img_padded = F.pad(img_gray, (0, pad_w, 0, pad_h), mode='reflect') |
| else: |
| img_padded = img_gray |
| |
| |
| psi = self._get_psi((target_size, target_size)) |
| |
| |
| coeffs_list = [] |
| for b in range(B): |
| img_np = img_padded[b].cpu().numpy() |
| coeffs, _ = shearletTransformSpect( |
| img_np, |
| Psi=psi, |
| realCoefficients=self.real_coefficients, |
| ) |
| |
| coeffs_tensor = torch.from_numpy(coeffs).float() |
| |
| |
| if pad_h > 0 or pad_w > 0: |
| coeffs_tensor = coeffs_tensor[:H, :W, :] |
| |
| |
| coeffs_tensor = coeffs_tensor.permute(2, 0, 1) |
| coeffs_list.append(coeffs_tensor) |
| |
| return torch.stack(coeffs_list, dim=0) |
| |
| def inverse(self, coeffs: torch.Tensor, original_shape: Tuple[int, int]) -> torch.Tensor: |
| """ |
| Apply inverse shearlet transform. |
| |
| Args: |
| coeffs: (B, num_shearlets, H, W) tensor |
| original_shape: (H, W) of output image |
| |
| Returns: |
| image: (B, H, W) tensor |
| """ |
| B, _, H, W = coeffs.shape |
| |
| target_size = max(H, W) |
| if target_size % 2 == 0: |
| target_size += 1 |
| |
| psi = self._get_psi((target_size, target_size)) |
| |
| images = [] |
| for b in range(B): |
| |
| c_np = coeffs[b].permute(1, 2, 0).cpu().numpy() |
| |
| |
| if H < target_size or W < target_size: |
| padded = np.zeros((target_size, target_size, c_np.shape[2])) |
| padded[:H, :W, :] = c_np |
| c_np = padded |
| |
| img = inverseShearletTransformSpect(c_np, Psi=psi) |
| img = torch.from_numpy(img).float() |
| |
| |
| img = img[:original_shape[0], :original_shape[1]] |
| images.append(img) |
| |
| return torch.stack(images, dim=0) |
| |
| @property |
| def num_shearlets(self) -> int: |
| """Number of shearlet coefficients per pixel.""" |
| |
| |
| |
| return 1 + 4 * (2**self.num_scales - 1) |
|
|
|
|
| class KakeyaOrientationEncoder(nn.Module): |
| """ |
| Kakeya-inspired orientation encoder using shearlet transform. |
| |
| The encoder maps an image to a compact latent representation that |
| preserves directional/edge information via shearlet coefficients. |
| """ |
| |
| def __init__( |
| self, |
| input_channels: int = 1, |
| latent_dim: int = 256, |
| num_scales: int = 3, |
| base_channels: int = 32, |
| ): |
| super().__init__() |
| |
| self.input_channels = input_channels |
| self.latent_dim = latent_dim |
| self.num_scales = num_scales |
| self.base_channels = base_channels |
| |
| |
| self.shearlet = ShearletTransform(num_scales=num_scales) |
| |
| |
| n_sh = self.shearlet.num_shearlets |
| |
| |
| |
| self.enc_conv1 = nn.Conv2d(n_sh, base_channels * 2, 3, padding=1) |
| self.enc_bn1 = nn.BatchNorm2d(base_channels * 2) |
| |
| self.enc_conv2 = nn.Conv2d(base_channels * 2, base_channels * 4, 3, stride=2, padding=1) |
| self.enc_bn2 = nn.BatchNorm2d(base_channels * 4) |
| |
| self.enc_conv3 = nn.Conv2d(base_channels * 4, base_channels * 8, 3, stride=2, padding=1) |
| self.enc_bn3 = nn.BatchNorm2d(base_channels * 8) |
| |
| self.enc_conv4 = nn.Conv2d(base_channels * 8, base_channels * 8, 3, stride=2, padding=1) |
| self.enc_bn4 = nn.BatchNorm2d(base_channels * 8) |
| |
| |
| self.global_pool = nn.AdaptiveAvgPool2d(1) |
| |
| |
| self.fc_mu = nn.Linear(base_channels * 8, latent_dim) |
| self.fc_logvar = nn.Linear(base_channels * 8, latent_dim) |
| |
| |
| |
| self.orientation_fc = nn.Sequential( |
| nn.Linear(n_sh, 64), |
| nn.ReLU(), |
| nn.Linear(64, 32), |
| ) |
| self.orientation_proj = nn.Linear(32, latent_dim) |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Conv2d): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.BatchNorm2d): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, 0, 0.01) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| |
| def extract_orientation_features(self, shearlet_coeffs: torch.Tensor) -> torch.Tensor: |
| """ |
| Extract orientation histogram from shearlet coefficients. |
| |
| Args: |
| shearlet_coeffs: (B, n_sh, H, W) |
| |
| Returns: |
| orientation_features: (B, latent_dim) - directional energy profile |
| """ |
| B, n_sh, H, W = shearlet_coeffs.shape |
| |
| |
| energy = shearlet_coeffs.pow(2).mean(dim=[2, 3]) |
| |
| |
| energy = energy / (energy.sum(dim=1, keepdim=True) + 1e-8) |
| |
| |
| orient_feat = self.orientation_fc(energy) |
| orient_feat = self.orientation_proj(orient_feat) |
| |
| return orient_feat |
| |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Encode image to latent representation. |
| |
| Args: |
| x: (B, C, H, W) input image |
| |
| Returns: |
| mu: (B, latent_dim) mean of latent distribution |
| logvar: (B, latent_dim) log variance of latent distribution |
| orientation_features: (B, latent_dim) orientation-preserving features |
| """ |
| |
| with torch.no_grad(): |
| shearlet_coeffs = self.shearlet.transform(x) |
| |
| |
| h = F.relu(self.enc_bn1(self.enc_conv1(shearlet_coeffs))) |
| h = F.relu(self.enc_bn2(self.enc_conv2(h))) |
| h = F.relu(self.enc_bn3(self.enc_conv3(h))) |
| h = F.relu(self.enc_bn4(self.enc_conv4(h))) |
| |
| h = self.global_pool(h).view(h.size(0), -1) |
| |
| mu = self.fc_mu(h) |
| logvar = self.fc_logvar(h) |
| |
| |
| orientation_features = self.extract_orientation_features(shearlet_coeffs) |
| |
| return mu, logvar, orientation_features |
| |
| def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: |
| """Reparameterization trick for VAE sampling.""" |
| std = torch.exp(0.5 * logvar) |
| eps = torch.randn_like(std) |
| return mu + eps * std |
| |
| def encode(self, x: torch.Tensor) -> torch.Tensor: |
| """Get deterministic latent representation.""" |
| mu, logvar, orient = self.forward(x) |
| |
| z = self.reparameterize(mu, logvar) + 0.1 * orient |
| return z |
|
|
|
|
| class KakeyaOrientationDecoder(nn.Module): |
| """Decoder that reconstructs image from latent representation.""" |
| |
| def __init__( |
| self, |
| latent_dim: int = 256, |
| output_channels: int = 1, |
| base_channels: int = 32, |
| output_size: int = 128, |
| ): |
| super().__init__() |
| |
| self.latent_dim = latent_dim |
| self.output_size = output_size |
| |
| |
| |
| self.bottleneck_size = output_size // 8 |
| if self.bottleneck_size < 4: |
| self.bottleneck_size = 4 |
| |
| self.fc = nn.Linear(latent_dim, base_channels * 8 * self.bottleneck_size * self.bottleneck_size) |
| |
| |
| self.dec_conv1 = nn.ConvTranspose2d(base_channels * 8, base_channels * 8, 4, stride=2, padding=1) |
| self.dec_bn1 = nn.BatchNorm2d(base_channels * 8) |
| |
| self.dec_conv2 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, 4, stride=2, padding=1) |
| self.dec_bn2 = nn.BatchNorm2d(base_channels * 4) |
| |
| self.dec_conv3 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, 4, stride=2, padding=1) |
| self.dec_bn3 = nn.BatchNorm2d(base_channels * 2) |
| |
| self.dec_conv4 = nn.Conv2d(base_channels * 2, output_channels, 3, padding=1) |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): |
| nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.BatchNorm2d): |
| nn.init.constant_(m.weight, 1) |
| nn.init.constant_(m.bias, 0) |
| elif isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, 0, 0.01) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
| |
| def forward(self, z: torch.Tensor) -> torch.Tensor: |
| """ |
| Decode latent vector to image. |
| |
| Args: |
| z: (B, latent_dim) |
| |
| Returns: |
| image: (B, output_channels, output_size, output_size) |
| """ |
| h = self.fc(z) |
| h = h.view(h.size(0), -1, self.bottleneck_size, self.bottleneck_size) |
| |
| h = F.relu(self.dec_bn1(self.dec_conv1(h))) |
| h = F.relu(self.dec_bn2(self.dec_conv2(h))) |
| h = F.relu(self.dec_bn3(self.dec_conv3(h))) |
| |
| |
| if h.shape[2] != self.output_size or h.shape[3] != self.output_size: |
| h = F.interpolate(h, size=(self.output_size, self.output_size), mode='bilinear', align_corners=False) |
| |
| x_recon = torch.sigmoid(self.dec_conv4(h)) |
| |
| return x_recon |
|
|
|
|
| class KakeyaAutoencoder(nn.Module): |
| """Complete Kakeya-inspired autoencoder for orientation-preserving encoding.""" |
| |
| def __init__( |
| self, |
| input_channels: int = 1, |
| latent_dim: int = 256, |
| num_scales: int = 3, |
| base_channels: int = 32, |
| output_size: int = 128, |
| ): |
| super().__init__() |
| |
| self.encoder = KakeyaOrientationEncoder( |
| input_channels=input_channels, |
| latent_dim=latent_dim, |
| num_scales=num_scales, |
| base_channels=base_channels, |
| ) |
| |
| self.decoder = KakeyaOrientationDecoder( |
| latent_dim=latent_dim, |
| output_channels=input_channels, |
| base_channels=base_channels, |
| output_size=output_size, |
| ) |
| |
| self.latent_dim = latent_dim |
| |
| def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Full forward pass. |
| |
| Returns: |
| x_recon: reconstructed image |
| mu: latent mean |
| logvar: latent log variance |
| orient_features: orientation features |
| """ |
| mu, logvar, orient_features = self.encoder(x) |
| z = self.encoder.reparameterize(mu, logvar) + 0.1 * orient_features |
| x_recon = self.decoder(z) |
| |
| return x_recon, mu, logvar, orient_features |
| |
| def encode(self, x: torch.Tensor) -> torch.Tensor: |
| """Get latent representation.""" |
| return self.encoder.encode(x) |
| |
| def decode(self, z: torch.Tensor) -> torch.Tensor: |
| """Reconstruct from latent.""" |
| return self.decoder(z) |
|
|
|
|
| def kakeya_loss( |
| x_recon: torch.Tensor, |
| x_target: torch.Tensor, |
| mu: torch.Tensor, |
| logvar: torch.Tensor, |
| orient_features: torch.Tensor, |
| shearlet_coeffs_target: Optional[torch.Tensor] = None, |
| recon_weight: float = 1.0, |
| kl_weight: float = 0.001, |
| orient_weight: float = 0.1, |
| shearlet_weight: float = 0.5, |
| ) -> Tuple[torch.Tensor, dict]: |
| """ |
| Combined loss for Kakeya autoencoder. |
| |
| Components: |
| 1. Reconstruction loss (MSE) |
| 2. KL divergence (VAE regularization) |
| 3. Orientation preservation loss (encourage latent to capture directions) |
| 4. Shearlet coefficient matching (optional, for orientation fidelity) |
| """ |
| |
| recon_loss = F.mse_loss(x_recon, x_target, reduction='mean') |
| |
| |
| kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean() |
| |
| |
| |
| orient_entropy = -torch.sum( |
| F.softmax(orient_features, dim=1) * F.log_softmax(orient_features, dim=1), |
| dim=1 |
| ).mean() |
| orient_loss = -orient_entropy |
| |
| |
| shearlet_loss = torch.tensor(0.0, device=x_recon.device) |
| if shearlet_coeffs_target is not None: |
| |
| from FFST import shearletTransformSpect |
| |
| |
| pass |
| |
| total_loss = ( |
| recon_weight * recon_loss + |
| kl_weight * kl_loss + |
| orient_weight * orient_loss + |
| shearlet_weight * shearlet_loss |
| ) |
| |
| metrics = { |
| 'recon_loss': recon_loss.item(), |
| 'kl_loss': kl_loss.item(), |
| 'orient_loss': orient_loss.item(), |
| 'total_loss': total_loss.item(), |
| } |
| |
| return total_loss, metrics |
|
|
|
|
| if __name__ == "__main__": |
| |
| model = KakeyaAutoencoder(input_channels=1, latent_dim=256, num_scales=3, output_size=128) |
| |
| |
| x = torch.randn(2, 1, 128, 128) |
| x_recon, mu, logvar, orient = model(x) |
| |
| print(f"Input shape: {x.shape}") |
| print(f"Reconstructed shape: {x_recon.shape}") |
| print(f"Latent mu shape: {mu.shape}") |
| print(f"Orientation features shape: {orient.shape}") |
| |
| |
| z = model.encode(x) |
| print(f"Encoded latent shape: {z.shape}") |
| |
| x_decoded = model.decode(z) |
| print(f"Decoded shape: {x_decoded.shape}") |
| |
| print("\nModel test passed!") |
|
|