"""Self-contained test script for Kakeya Orientation Encoder. Run: python test_kakeya_encoder.py """ import sys, subprocess, os # --- 1. Ensure PyShearlets is available --------------------------------- try: from FFST import shearletTransformSpect, inverseShearletTransformSpect, scalesShearsAndSpectra except ImportError: print("PyShearlets not found. Cloning and patching for NumPy 2.0...") subprocess.run(["git", "clone", "https://github.com/grlee77/PyShearlets.git", "PyShearlets"], check=True) sys.path.insert(0, "PyShearlets") # Patch numpy 2.0 compatibility import numpy as np import FFST._scalesShearsAndSpectra as sss import FFST.shearletScaleShear as sss2 sss.np.NaN = np.nan sss2.np.NaN = np.nan from FFST import shearletTransformSpect, inverseShearletTransformSpect, scalesShearsAndSpectra import numpy as np import torch import torch.nn as nn import torch.nn.functional as F # ----------------------------------------------------------------------- # Shearlet Transform wrapper # ----------------------------------------------------------------------- class ShearletTransform: def __init__(self, num_scales=3, real_coefficients=True): self.num_scales = num_scales self.real_coefficients = real_coefficients self._psi_cache = {} def _get_psi(self, shape): key = (shape, self.num_scales) if key not in self._psi_cache: self._psi_cache[key] = scalesShearsAndSpectra( shape, numOfScales=self.num_scales, realCoefficients=self.real_coefficients) return self._psi_cache[key] def transform(self, image): if image.dim() == 3: image = image.unsqueeze(1) B, C, H, W = image.shape img_gray = image.mean(dim=1) if C > 1 else image.squeeze(1) target_size = max(H, W) if target_size % 2 == 0: target_size += 1 pad_h, pad_w = target_size - H, target_size - W img_padded = F.pad(img_gray, (0, pad_w, 0, pad_h), mode='reflect') \ if pad_h > 0 or pad_w > 0 else img_gray psi = self._get_psi((target_size, target_size)) coeffs_list = [] for b in range(B): coeffs, _ = shearletTransformSpect( img_padded[b].cpu().numpy(), Psi=psi, realCoefficients=self.real_coefficients) ct = torch.from_numpy(coeffs).float() if pad_h > 0 or pad_w > 0: ct = ct[:H, :W, :] coeffs_list.append(ct.permute(2, 0, 1)) return torch.stack(coeffs_list, dim=0) @property def num_shearlets(self): return 1 + 4 * (2**self.num_scales - 1) # ----------------------------------------------------------------------- # Encoder / Decoder # ----------------------------------------------------------------------- class KakeyaOrientationEncoder(nn.Module): def __init__(self, num_shearlets, latent_dim=256, base_channels=32): super().__init__() self.enc_conv1 = nn.Conv2d(num_shearlets, base_channels*2, 3, padding=1) self.enc_bn1 = nn.BatchNorm2d(base_channels*2) self.enc_conv2 = nn.Conv2d(base_channels*2, base_channels*4, 3, stride=2, padding=1) self.enc_bn2 = nn.BatchNorm2d(base_channels*4) self.enc_conv3 = nn.Conv2d(base_channels*4, base_channels*8, 3, stride=2, padding=1) self.enc_bn3 = nn.BatchNorm2d(base_channels*8) self.enc_conv4 = nn.Conv2d(base_channels*8, base_channels*8, 3, stride=2, padding=1) self.enc_bn4 = nn.BatchNorm2d(base_channels*8) self.global_pool = nn.AdaptiveAvgPool2d(1) self.fc_mu = nn.Linear(base_channels*8, latent_dim) self.fc_logvar = nn.Linear(base_channels*8, latent_dim) self.orientation_fc = nn.Sequential( nn.Linear(num_shearlets, 64), nn.ReLU(), nn.Linear(64, 32)) self.orientation_proj = nn.Linear(32, latent_dim) def forward(self, shearlet_coeffs): h = F.relu(self.enc_bn1(self.enc_conv1(shearlet_coeffs))) h = F.relu(self.enc_bn2(self.enc_conv2(h))) h = F.relu(self.enc_bn3(self.enc_conv3(h))) h = F.relu(self.enc_bn4(self.enc_conv4(h))) h = self.global_pool(h).view(h.size(0), -1) mu = self.fc_mu(h) logvar = self.fc_logvar(h) energy = shearlet_coeffs.pow(2).mean(dim=[2,3]) energy = energy / (energy.sum(dim=1, keepdim=True) + 1e-8) orient = self.orientation_proj(self.orientation_fc(energy)) return mu, logvar, orient def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def encode(self, shearlet_coeffs): mu, logvar, orient = self.forward(shearlet_coeffs) return self.reparameterize(mu, logvar) + 0.1 * orient class SimpleDecoder(nn.Module): def __init__(self, latent_dim=256, output_channels=1, base_channels=32, output_size=64): super().__init__() self.output_size = output_size self.bottleneck_size = max(4, output_size // 8) self.fc = nn.Linear(latent_dim, base_channels*8*self.bottleneck_size*self.bottleneck_size) self.dec_conv1 = nn.ConvTranspose2d(base_channels*8, base_channels*8, 4, stride=2, padding=1) self.dec_bn1 = nn.BatchNorm2d(base_channels*8) self.dec_conv2 = nn.ConvTranspose2d(base_channels*8, base_channels*4, 4, stride=2, padding=1) self.dec_bn2 = nn.BatchNorm2d(base_channels*4) self.dec_conv3 = nn.ConvTranspose2d(base_channels*4, base_channels*2, 4, stride=2, padding=1) self.dec_bn3 = nn.BatchNorm2d(base_channels*2) self.dec_conv4 = nn.Conv2d(base_channels*2, output_channels, 3, padding=1) def forward(self, z): h = self.fc(z) h = h.view(h.size(0), -1, self.bottleneck_size, self.bottleneck_size) h = F.relu(self.dec_bn1(self.dec_conv1(h))) h = F.relu(self.dec_bn2(self.dec_conv2(h))) h = F.relu(self.dec_bn3(self.dec_conv3(h))) if h.shape[2] != self.output_size or h.shape[3] != self.output_size: h = F.interpolate(h, size=(self.output_size, self.output_size), mode='bilinear', align_corners=False) return torch.sigmoid(self.dec_conv4(h)) class KakeyaAutoencoder(nn.Module): def __init__(self, num_shearlets, latent_dim=256, output_channels=1, base_channels=32, output_size=64): super().__init__() self.encoder = KakeyaOrientationEncoder(num_shearlets, latent_dim, base_channels) self.decoder = SimpleDecoder(latent_dim, output_channels, base_channels, output_size) self.latent_dim = latent_dim def forward(self, shearlet_coeffs): mu, logvar, orient = self.encoder(shearlet_coeffs) z = self.encoder.reparameterize(mu, logvar) + 0.1 * orient return self.decoder(z), mu, logvar, orient def encode(self, shearlet_coeffs): return self.encoder.encode(shearlet_coeffs) def decode(self, z): return self.decoder(z) # ----------------------------------------------------------------------- # Tests # ----------------------------------------------------------------------- def run_tests(): print("="*60) print("KAKEYA ORIENTATION ENCODER — SELF-CONTAINED TEST") print("="*60) # 1. Shearlet round-trip print("\n[Test 1] Shearlet Transform Round-trip") img = np.random.rand(64, 64) ST, Psi = shearletTransformSpect(img) recon = inverseShearletTransformSpect(ST, Psi) err = np.max(np.abs(img - recon)) print(f" Coeffs shape : {ST.shape}") print(f" Recon error : {err:.2e}") assert err < 1e-10 print(" PASSED") # 2. Model forward pass print("\n[Test 2] Model Forward Pass") shearlet = ShearletTransform(num_scales=3) model = KakeyaAutoencoder( num_shearlets=shearlet.num_shearlets, latent_dim=128, output_channels=1, base_channels=16, output_size=64) x = torch.randn(2, 1, 64, 64) with torch.no_grad(): sh = shearlet.transform(x) x_recon, mu, logvar, orient = model(sh) print(f" Input : {x.shape}") print(f" Shearlet : {sh.shape}") print(f" Reconstructed: {x_recon.shape}") print(f" Latent mu : {mu.shape}") print(f" Orientation : {orient.shape}") print(" PASSED") # 3. Orientation feature extraction print("\n[Test 3] Orientation Feature Extraction") with torch.no_grad(): energy = sh.pow(2).mean(dim=[2, 3]) print(f" Energy (first 5) : {energy[0][:5].numpy()}") ent = -(energy[0] * torch.log(energy[0] + 1e-8)).sum().item() print(f" Entropy : {ent:.4f}") print(" PASSED") # 4. Edge preservation (untrained sanity check) print("\n[Test 4] Edge Preservation Sanity Check") sx = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], dtype=torch.float32).view(1,1,3,3) sy = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], dtype=torch.float32).view(1,1,3,3) ex = F.conv2d(x, sx, padding=1) ey = F.conv2d(x, sy, padding=1) e_orig = torch.sqrt(ex**2 + ey**2 + 1e-8) exr = F.conv2d(x_recon, sx, padding=1) eyr = F.conv2d(x_recon, sy, padding=1) e_recon = torch.sqrt(exr**2 + eyr**2 + 1e-8) print(f" Image MSE : {F.mse_loss(x_recon, x).item():.6f}") print(f" Edge MSE : {F.mse_loss(e_recon, e_orig).item():.6f}") print(" PASSED (untrained — moderate values expected)") # 5. Parameter count print("\n[Test 5] Model Size") n = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f" Parameters : {n:,}") print(" PASSED") print("\n" + "="*60) print("ALL TESTS PASSED!") print("="*60) if __name__ == "__main__": run_tests()