cgn-mnist / scripts /compare_resolution.py
Leon Cynn
Upload scripts/compare_resolution.py with huggingface_hub
252a794 verified
"""
CNN vs CGN resolution comparison.
Shows what each architecture effectively "sees":
- CGN: full 28x28 original image
- CNN: information remaining after conv+pool stages (5x5 upsampled back)
"""
import os, pickle
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
def load_mnist():
npz = os.path.join(os.path.dirname(__file__), '..', 'data', 'mnist.npz')
d = np.load(npz)
X = d['X'].astype(np.float64)
y = d['y']
return X[60000:], y[60000:]
def simulate_cnn_loss(img_28):
"""Simulate CNN's resolution loss: 28->26->13->11->5."""
from PIL import Image
img = Image.fromarray(img_28.astype(np.uint8))
# conv1: 28->26 (crop border)
img26 = img.crop((1, 1, 27, 27))
# pool1: 26->13
img13 = img26.resize((13, 13), Image.BILINEAR)
# conv2: 13->11 (crop border)
img11 = img13.crop((1, 1, 12, 12))
# pool2: 11->5
img5 = img11.resize((5, 5), Image.BILINEAR)
# upscale back to 28x28 for visual comparison
img_back = img5.resize((28, 28), Image.NEAREST)
return np.array(img5), np.array(img_back)
def make_comparison(out_dir):
X_te, y_te = load_mnist()
# Pick one example per digit
examples = []
for d in range(10):
idx = np.where(y_te == d)[0][0]
examples.append((d, X_te[idx].reshape(28, 28)))
fig, axes = plt.subplots(3, 10, figsize=(15, 5.5))
for col, (digit, img) in enumerate(examples):
# Row 0: Original (what CGN sees)
axes[0, col].imshow(img, cmap='gray', vmin=0, vmax=255)
axes[0, col].axis('off')
if col == 0:
axes[0, col].set_ylabel('CGN\n28x28\n(100%)', fontsize=9,
rotation=0, labelpad=50, va='center')
# Row 1: CNN's effective view (5x5 upsampled)
img5, img_back = simulate_cnn_loss(img)
axes[1, col].imshow(img_back, cmap='gray', vmin=0, vmax=255)
axes[1, col].axis('off')
if col == 0:
axes[1, col].set_ylabel('CNN\n5x5→28x28\n(3%)', fontsize=9,
rotation=0, labelpad=50, va='center')
# Row 2: CNN's actual 5x5
axes[2, col].imshow(img5, cmap='gray', vmin=0, vmax=255)
axes[2, col].axis('off')
if col == 0:
axes[2, col].set_ylabel('CNN actual\n5x5\n(25 pixels)', fontsize=9,
rotation=0, labelpad=50, va='center')
axes[0, col].set_title(f'{digit}', fontsize=11, fontweight='bold')
plt.suptitle('What each architecture sees\n'
'CGN: full resolution input CNN: after conv→pool→conv→pool',
fontsize=12, fontweight='bold', y=1.02)
plt.tight_layout()
path = os.path.join(out_dir, 'cgn_vs_cnn_resolution.png')
plt.savefig(path, dpi=150, bbox_inches='tight', facecolor='white')
plt.close()
print(f'Saved: {path}')
if __name__ == '__main__':
out_dir = os.path.join(os.path.dirname(__file__), '..', 'figures')
os.makedirs(out_dir, exist_ok=True)
make_comparison(out_dir)