| """Self-contained test script for Kakeya Orientation Encoder. |
| |
| Run: python test_kakeya_encoder.py |
| """ |
| import sys, subprocess, os |
|
|
| |
| try: |
| from FFST import shearletTransformSpect, inverseShearletTransformSpect, scalesShearsAndSpectra |
| except ImportError: |
| print("PyShearlets not found. Cloning and patching for NumPy 2.0...") |
| subprocess.run(["git", "clone", "https://github.com/grlee77/PyShearlets.git", "PyShearlets"], |
| check=True) |
| sys.path.insert(0, "PyShearlets") |
| |
| import numpy as np |
| import FFST._scalesShearsAndSpectra as sss |
| import FFST.shearletScaleShear as sss2 |
| sss.np.NaN = np.nan |
| sss2.np.NaN = np.nan |
| from FFST import shearletTransformSpect, inverseShearletTransformSpect, scalesShearsAndSpectra |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
| class ShearletTransform: |
| def __init__(self, num_scales=3, real_coefficients=True): |
| self.num_scales = num_scales |
| self.real_coefficients = real_coefficients |
| self._psi_cache = {} |
|
|
| def _get_psi(self, shape): |
| key = (shape, self.num_scales) |
| if key not in self._psi_cache: |
| self._psi_cache[key] = scalesShearsAndSpectra( |
| shape, numOfScales=self.num_scales, |
| realCoefficients=self.real_coefficients) |
| return self._psi_cache[key] |
|
|
| def transform(self, image): |
| if image.dim() == 3: |
| image = image.unsqueeze(1) |
| B, C, H, W = image.shape |
| img_gray = image.mean(dim=1) if C > 1 else image.squeeze(1) |
| target_size = max(H, W) |
| if target_size % 2 == 0: |
| target_size += 1 |
| pad_h, pad_w = target_size - H, target_size - W |
| img_padded = F.pad(img_gray, (0, pad_w, 0, pad_h), mode='reflect') \ |
| if pad_h > 0 or pad_w > 0 else img_gray |
| psi = self._get_psi((target_size, target_size)) |
| coeffs_list = [] |
| for b in range(B): |
| coeffs, _ = shearletTransformSpect( |
| img_padded[b].cpu().numpy(), Psi=psi, |
| realCoefficients=self.real_coefficients) |
| ct = torch.from_numpy(coeffs).float() |
| if pad_h > 0 or pad_w > 0: |
| ct = ct[:H, :W, :] |
| coeffs_list.append(ct.permute(2, 0, 1)) |
| return torch.stack(coeffs_list, dim=0) |
|
|
| @property |
| def num_shearlets(self): |
| return 1 + 4 * (2**self.num_scales - 1) |
|
|
|
|
| |
| |
| |
| class KakeyaOrientationEncoder(nn.Module): |
| def __init__(self, num_shearlets, latent_dim=256, base_channels=32): |
| super().__init__() |
| self.enc_conv1 = nn.Conv2d(num_shearlets, base_channels*2, 3, padding=1) |
| self.enc_bn1 = nn.BatchNorm2d(base_channels*2) |
| self.enc_conv2 = nn.Conv2d(base_channels*2, base_channels*4, 3, stride=2, padding=1) |
| self.enc_bn2 = nn.BatchNorm2d(base_channels*4) |
| self.enc_conv3 = nn.Conv2d(base_channels*4, base_channels*8, 3, stride=2, padding=1) |
| self.enc_bn3 = nn.BatchNorm2d(base_channels*8) |
| self.enc_conv4 = nn.Conv2d(base_channels*8, base_channels*8, 3, stride=2, padding=1) |
| self.enc_bn4 = nn.BatchNorm2d(base_channels*8) |
| self.global_pool = nn.AdaptiveAvgPool2d(1) |
| self.fc_mu = nn.Linear(base_channels*8, latent_dim) |
| self.fc_logvar = nn.Linear(base_channels*8, latent_dim) |
| self.orientation_fc = nn.Sequential( |
| nn.Linear(num_shearlets, 64), nn.ReLU(), nn.Linear(64, 32)) |
| self.orientation_proj = nn.Linear(32, latent_dim) |
|
|
| def forward(self, shearlet_coeffs): |
| h = F.relu(self.enc_bn1(self.enc_conv1(shearlet_coeffs))) |
| h = F.relu(self.enc_bn2(self.enc_conv2(h))) |
| h = F.relu(self.enc_bn3(self.enc_conv3(h))) |
| h = F.relu(self.enc_bn4(self.enc_conv4(h))) |
| h = self.global_pool(h).view(h.size(0), -1) |
| mu = self.fc_mu(h) |
| logvar = self.fc_logvar(h) |
| energy = shearlet_coeffs.pow(2).mean(dim=[2,3]) |
| energy = energy / (energy.sum(dim=1, keepdim=True) + 1e-8) |
| orient = self.orientation_proj(self.orientation_fc(energy)) |
| return mu, logvar, orient |
|
|
| def reparameterize(self, mu, logvar): |
| std = torch.exp(0.5 * logvar) |
| eps = torch.randn_like(std) |
| return mu + eps * std |
|
|
| def encode(self, shearlet_coeffs): |
| mu, logvar, orient = self.forward(shearlet_coeffs) |
| return self.reparameterize(mu, logvar) + 0.1 * orient |
|
|
|
|
| class SimpleDecoder(nn.Module): |
| def __init__(self, latent_dim=256, output_channels=1, base_channels=32, output_size=64): |
| super().__init__() |
| self.output_size = output_size |
| self.bottleneck_size = max(4, output_size // 8) |
| self.fc = nn.Linear(latent_dim, base_channels*8*self.bottleneck_size*self.bottleneck_size) |
| self.dec_conv1 = nn.ConvTranspose2d(base_channels*8, base_channels*8, 4, stride=2, padding=1) |
| self.dec_bn1 = nn.BatchNorm2d(base_channels*8) |
| self.dec_conv2 = nn.ConvTranspose2d(base_channels*8, base_channels*4, 4, stride=2, padding=1) |
| self.dec_bn2 = nn.BatchNorm2d(base_channels*4) |
| self.dec_conv3 = nn.ConvTranspose2d(base_channels*4, base_channels*2, 4, stride=2, padding=1) |
| self.dec_bn3 = nn.BatchNorm2d(base_channels*2) |
| self.dec_conv4 = nn.Conv2d(base_channels*2, output_channels, 3, padding=1) |
|
|
| def forward(self, z): |
| h = self.fc(z) |
| h = h.view(h.size(0), -1, self.bottleneck_size, self.bottleneck_size) |
| h = F.relu(self.dec_bn1(self.dec_conv1(h))) |
| h = F.relu(self.dec_bn2(self.dec_conv2(h))) |
| h = F.relu(self.dec_bn3(self.dec_conv3(h))) |
| if h.shape[2] != self.output_size or h.shape[3] != self.output_size: |
| h = F.interpolate(h, size=(self.output_size, self.output_size), |
| mode='bilinear', align_corners=False) |
| return torch.sigmoid(self.dec_conv4(h)) |
|
|
|
|
| class KakeyaAutoencoder(nn.Module): |
| def __init__(self, num_shearlets, latent_dim=256, output_channels=1, |
| base_channels=32, output_size=64): |
| super().__init__() |
| self.encoder = KakeyaOrientationEncoder(num_shearlets, latent_dim, base_channels) |
| self.decoder = SimpleDecoder(latent_dim, output_channels, base_channels, output_size) |
| self.latent_dim = latent_dim |
|
|
| def forward(self, shearlet_coeffs): |
| mu, logvar, orient = self.encoder(shearlet_coeffs) |
| z = self.encoder.reparameterize(mu, logvar) + 0.1 * orient |
| return self.decoder(z), mu, logvar, orient |
|
|
| def encode(self, shearlet_coeffs): |
| return self.encoder.encode(shearlet_coeffs) |
|
|
| def decode(self, z): |
| return self.decoder(z) |
|
|
|
|
| |
| |
| |
| def run_tests(): |
| print("="*60) |
| print("KAKEYA ORIENTATION ENCODER — SELF-CONTAINED TEST") |
| print("="*60) |
|
|
| |
| print("\n[Test 1] Shearlet Transform Round-trip") |
| img = np.random.rand(64, 64) |
| ST, Psi = shearletTransformSpect(img) |
| recon = inverseShearletTransformSpect(ST, Psi) |
| err = np.max(np.abs(img - recon)) |
| print(f" Coeffs shape : {ST.shape}") |
| print(f" Recon error : {err:.2e}") |
| assert err < 1e-10 |
| print(" PASSED") |
|
|
| |
| print("\n[Test 2] Model Forward Pass") |
| shearlet = ShearletTransform(num_scales=3) |
| model = KakeyaAutoencoder( |
| num_shearlets=shearlet.num_shearlets, |
| latent_dim=128, output_channels=1, |
| base_channels=16, output_size=64) |
| x = torch.randn(2, 1, 64, 64) |
| with torch.no_grad(): |
| sh = shearlet.transform(x) |
| x_recon, mu, logvar, orient = model(sh) |
| print(f" Input : {x.shape}") |
| print(f" Shearlet : {sh.shape}") |
| print(f" Reconstructed: {x_recon.shape}") |
| print(f" Latent mu : {mu.shape}") |
| print(f" Orientation : {orient.shape}") |
| print(" PASSED") |
|
|
| |
| print("\n[Test 3] Orientation Feature Extraction") |
| with torch.no_grad(): |
| energy = sh.pow(2).mean(dim=[2, 3]) |
| print(f" Energy (first 5) : {energy[0][:5].numpy()}") |
| ent = -(energy[0] * torch.log(energy[0] + 1e-8)).sum().item() |
| print(f" Entropy : {ent:.4f}") |
| print(" PASSED") |
|
|
| |
| print("\n[Test 4] Edge Preservation Sanity Check") |
| sx = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], dtype=torch.float32).view(1,1,3,3) |
| sy = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], dtype=torch.float32).view(1,1,3,3) |
| ex = F.conv2d(x, sx, padding=1) |
| ey = F.conv2d(x, sy, padding=1) |
| e_orig = torch.sqrt(ex**2 + ey**2 + 1e-8) |
| exr = F.conv2d(x_recon, sx, padding=1) |
| eyr = F.conv2d(x_recon, sy, padding=1) |
| e_recon = torch.sqrt(exr**2 + eyr**2 + 1e-8) |
| print(f" Image MSE : {F.mse_loss(x_recon, x).item():.6f}") |
| print(f" Edge MSE : {F.mse_loss(e_recon, e_orig).item():.6f}") |
| print(" PASSED (untrained — moderate values expected)") |
|
|
| |
| print("\n[Test 5] Model Size") |
| n = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f" Parameters : {n:,}") |
| print(" PASSED") |
|
|
| print("\n" + "="*60) |
| print("ALL TESTS PASSED!") |
| print("="*60) |
|
|
|
|
| if __name__ == "__main__": |
| run_tests() |
|
|