kakeya-orientation-encoder / test_kakeya_encoder.py
DJLougen's picture
Upload test_kakeya_encoder.py
c9227f9 verified
"""Self-contained test script for Kakeya Orientation Encoder.
Run: python test_kakeya_encoder.py
"""
import sys, subprocess, os
# --- 1. Ensure PyShearlets is available ---------------------------------
try:
from FFST import shearletTransformSpect, inverseShearletTransformSpect, scalesShearsAndSpectra
except ImportError:
print("PyShearlets not found. Cloning and patching for NumPy 2.0...")
subprocess.run(["git", "clone", "https://github.com/grlee77/PyShearlets.git", "PyShearlets"],
check=True)
sys.path.insert(0, "PyShearlets")
# Patch numpy 2.0 compatibility
import numpy as np
import FFST._scalesShearsAndSpectra as sss
import FFST.shearletScaleShear as sss2
sss.np.NaN = np.nan
sss2.np.NaN = np.nan
from FFST import shearletTransformSpect, inverseShearletTransformSpect, scalesShearsAndSpectra
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# -----------------------------------------------------------------------
# Shearlet Transform wrapper
# -----------------------------------------------------------------------
class ShearletTransform:
def __init__(self, num_scales=3, real_coefficients=True):
self.num_scales = num_scales
self.real_coefficients = real_coefficients
self._psi_cache = {}
def _get_psi(self, shape):
key = (shape, self.num_scales)
if key not in self._psi_cache:
self._psi_cache[key] = scalesShearsAndSpectra(
shape, numOfScales=self.num_scales,
realCoefficients=self.real_coefficients)
return self._psi_cache[key]
def transform(self, image):
if image.dim() == 3:
image = image.unsqueeze(1)
B, C, H, W = image.shape
img_gray = image.mean(dim=1) if C > 1 else image.squeeze(1)
target_size = max(H, W)
if target_size % 2 == 0:
target_size += 1
pad_h, pad_w = target_size - H, target_size - W
img_padded = F.pad(img_gray, (0, pad_w, 0, pad_h), mode='reflect') \
if pad_h > 0 or pad_w > 0 else img_gray
psi = self._get_psi((target_size, target_size))
coeffs_list = []
for b in range(B):
coeffs, _ = shearletTransformSpect(
img_padded[b].cpu().numpy(), Psi=psi,
realCoefficients=self.real_coefficients)
ct = torch.from_numpy(coeffs).float()
if pad_h > 0 or pad_w > 0:
ct = ct[:H, :W, :]
coeffs_list.append(ct.permute(2, 0, 1))
return torch.stack(coeffs_list, dim=0)
@property
def num_shearlets(self):
return 1 + 4 * (2**self.num_scales - 1)
# -----------------------------------------------------------------------
# Encoder / Decoder
# -----------------------------------------------------------------------
class KakeyaOrientationEncoder(nn.Module):
def __init__(self, num_shearlets, latent_dim=256, base_channels=32):
super().__init__()
self.enc_conv1 = nn.Conv2d(num_shearlets, base_channels*2, 3, padding=1)
self.enc_bn1 = nn.BatchNorm2d(base_channels*2)
self.enc_conv2 = nn.Conv2d(base_channels*2, base_channels*4, 3, stride=2, padding=1)
self.enc_bn2 = nn.BatchNorm2d(base_channels*4)
self.enc_conv3 = nn.Conv2d(base_channels*4, base_channels*8, 3, stride=2, padding=1)
self.enc_bn3 = nn.BatchNorm2d(base_channels*8)
self.enc_conv4 = nn.Conv2d(base_channels*8, base_channels*8, 3, stride=2, padding=1)
self.enc_bn4 = nn.BatchNorm2d(base_channels*8)
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.fc_mu = nn.Linear(base_channels*8, latent_dim)
self.fc_logvar = nn.Linear(base_channels*8, latent_dim)
self.orientation_fc = nn.Sequential(
nn.Linear(num_shearlets, 64), nn.ReLU(), nn.Linear(64, 32))
self.orientation_proj = nn.Linear(32, latent_dim)
def forward(self, shearlet_coeffs):
h = F.relu(self.enc_bn1(self.enc_conv1(shearlet_coeffs)))
h = F.relu(self.enc_bn2(self.enc_conv2(h)))
h = F.relu(self.enc_bn3(self.enc_conv3(h)))
h = F.relu(self.enc_bn4(self.enc_conv4(h)))
h = self.global_pool(h).view(h.size(0), -1)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
energy = shearlet_coeffs.pow(2).mean(dim=[2,3])
energy = energy / (energy.sum(dim=1, keepdim=True) + 1e-8)
orient = self.orientation_proj(self.orientation_fc(energy))
return mu, logvar, orient
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def encode(self, shearlet_coeffs):
mu, logvar, orient = self.forward(shearlet_coeffs)
return self.reparameterize(mu, logvar) + 0.1 * orient
class SimpleDecoder(nn.Module):
def __init__(self, latent_dim=256, output_channels=1, base_channels=32, output_size=64):
super().__init__()
self.output_size = output_size
self.bottleneck_size = max(4, output_size // 8)
self.fc = nn.Linear(latent_dim, base_channels*8*self.bottleneck_size*self.bottleneck_size)
self.dec_conv1 = nn.ConvTranspose2d(base_channels*8, base_channels*8, 4, stride=2, padding=1)
self.dec_bn1 = nn.BatchNorm2d(base_channels*8)
self.dec_conv2 = nn.ConvTranspose2d(base_channels*8, base_channels*4, 4, stride=2, padding=1)
self.dec_bn2 = nn.BatchNorm2d(base_channels*4)
self.dec_conv3 = nn.ConvTranspose2d(base_channels*4, base_channels*2, 4, stride=2, padding=1)
self.dec_bn3 = nn.BatchNorm2d(base_channels*2)
self.dec_conv4 = nn.Conv2d(base_channels*2, output_channels, 3, padding=1)
def forward(self, z):
h = self.fc(z)
h = h.view(h.size(0), -1, self.bottleneck_size, self.bottleneck_size)
h = F.relu(self.dec_bn1(self.dec_conv1(h)))
h = F.relu(self.dec_bn2(self.dec_conv2(h)))
h = F.relu(self.dec_bn3(self.dec_conv3(h)))
if h.shape[2] != self.output_size or h.shape[3] != self.output_size:
h = F.interpolate(h, size=(self.output_size, self.output_size),
mode='bilinear', align_corners=False)
return torch.sigmoid(self.dec_conv4(h))
class KakeyaAutoencoder(nn.Module):
def __init__(self, num_shearlets, latent_dim=256, output_channels=1,
base_channels=32, output_size=64):
super().__init__()
self.encoder = KakeyaOrientationEncoder(num_shearlets, latent_dim, base_channels)
self.decoder = SimpleDecoder(latent_dim, output_channels, base_channels, output_size)
self.latent_dim = latent_dim
def forward(self, shearlet_coeffs):
mu, logvar, orient = self.encoder(shearlet_coeffs)
z = self.encoder.reparameterize(mu, logvar) + 0.1 * orient
return self.decoder(z), mu, logvar, orient
def encode(self, shearlet_coeffs):
return self.encoder.encode(shearlet_coeffs)
def decode(self, z):
return self.decoder(z)
# -----------------------------------------------------------------------
# Tests
# -----------------------------------------------------------------------
def run_tests():
print("="*60)
print("KAKEYA ORIENTATION ENCODER — SELF-CONTAINED TEST")
print("="*60)
# 1. Shearlet round-trip
print("\n[Test 1] Shearlet Transform Round-trip")
img = np.random.rand(64, 64)
ST, Psi = shearletTransformSpect(img)
recon = inverseShearletTransformSpect(ST, Psi)
err = np.max(np.abs(img - recon))
print(f" Coeffs shape : {ST.shape}")
print(f" Recon error : {err:.2e}")
assert err < 1e-10
print(" PASSED")
# 2. Model forward pass
print("\n[Test 2] Model Forward Pass")
shearlet = ShearletTransform(num_scales=3)
model = KakeyaAutoencoder(
num_shearlets=shearlet.num_shearlets,
latent_dim=128, output_channels=1,
base_channels=16, output_size=64)
x = torch.randn(2, 1, 64, 64)
with torch.no_grad():
sh = shearlet.transform(x)
x_recon, mu, logvar, orient = model(sh)
print(f" Input : {x.shape}")
print(f" Shearlet : {sh.shape}")
print(f" Reconstructed: {x_recon.shape}")
print(f" Latent mu : {mu.shape}")
print(f" Orientation : {orient.shape}")
print(" PASSED")
# 3. Orientation feature extraction
print("\n[Test 3] Orientation Feature Extraction")
with torch.no_grad():
energy = sh.pow(2).mean(dim=[2, 3])
print(f" Energy (first 5) : {energy[0][:5].numpy()}")
ent = -(energy[0] * torch.log(energy[0] + 1e-8)).sum().item()
print(f" Entropy : {ent:.4f}")
print(" PASSED")
# 4. Edge preservation (untrained sanity check)
print("\n[Test 4] Edge Preservation Sanity Check")
sx = torch.tensor([[-1,0,1],[-2,0,2],[-1,0,1]], dtype=torch.float32).view(1,1,3,3)
sy = torch.tensor([[-1,-2,-1],[0,0,0],[1,2,1]], dtype=torch.float32).view(1,1,3,3)
ex = F.conv2d(x, sx, padding=1)
ey = F.conv2d(x, sy, padding=1)
e_orig = torch.sqrt(ex**2 + ey**2 + 1e-8)
exr = F.conv2d(x_recon, sx, padding=1)
eyr = F.conv2d(x_recon, sy, padding=1)
e_recon = torch.sqrt(exr**2 + eyr**2 + 1e-8)
print(f" Image MSE : {F.mse_loss(x_recon, x).item():.6f}")
print(f" Edge MSE : {F.mse_loss(e_recon, e_orig).item():.6f}")
print(" PASSED (untrained — moderate values expected)")
# 5. Parameter count
print("\n[Test 5] Model Size")
n = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" Parameters : {n:,}")
print(" PASSED")
print("\n" + "="*60)
print("ALL TESTS PASSED!")
print("="*60)
if __name__ == "__main__":
run_tests()