""" CNN vs CGN resolution comparison. Shows what each architecture effectively "sees": - CGN: full 28x28 original image - CNN: information remaining after conv+pool stages (5x5 upsampled back) """ import os, pickle import numpy as np import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt def load_mnist(): npz = os.path.join(os.path.dirname(__file__), '..', 'data', 'mnist.npz') d = np.load(npz) X = d['X'].astype(np.float64) y = d['y'] return X[60000:], y[60000:] def simulate_cnn_loss(img_28): """Simulate CNN's resolution loss: 28->26->13->11->5.""" from PIL import Image img = Image.fromarray(img_28.astype(np.uint8)) # conv1: 28->26 (crop border) img26 = img.crop((1, 1, 27, 27)) # pool1: 26->13 img13 = img26.resize((13, 13), Image.BILINEAR) # conv2: 13->11 (crop border) img11 = img13.crop((1, 1, 12, 12)) # pool2: 11->5 img5 = img11.resize((5, 5), Image.BILINEAR) # upscale back to 28x28 for visual comparison img_back = img5.resize((28, 28), Image.NEAREST) return np.array(img5), np.array(img_back) def make_comparison(out_dir): X_te, y_te = load_mnist() # Pick one example per digit examples = [] for d in range(10): idx = np.where(y_te == d)[0][0] examples.append((d, X_te[idx].reshape(28, 28))) fig, axes = plt.subplots(3, 10, figsize=(15, 5.5)) for col, (digit, img) in enumerate(examples): # Row 0: Original (what CGN sees) axes[0, col].imshow(img, cmap='gray', vmin=0, vmax=255) axes[0, col].axis('off') if col == 0: axes[0, col].set_ylabel('CGN\n28x28\n(100%)', fontsize=9, rotation=0, labelpad=50, va='center') # Row 1: CNN's effective view (5x5 upsampled) img5, img_back = simulate_cnn_loss(img) axes[1, col].imshow(img_back, cmap='gray', vmin=0, vmax=255) axes[1, col].axis('off') if col == 0: axes[1, col].set_ylabel('CNN\n5x5→28x28\n(3%)', fontsize=9, rotation=0, labelpad=50, va='center') # Row 2: CNN's actual 5x5 axes[2, col].imshow(img5, cmap='gray', vmin=0, vmax=255) axes[2, col].axis('off') if col == 0: axes[2, col].set_ylabel('CNN actual\n5x5\n(25 pixels)', fontsize=9, rotation=0, labelpad=50, va='center') axes[0, col].set_title(f'{digit}', fontsize=11, fontweight='bold') plt.suptitle('What each architecture sees\n' 'CGN: full resolution input CNN: after conv→pool→conv→pool', fontsize=12, fontweight='bold', y=1.02) plt.tight_layout() path = os.path.join(out_dir, 'cgn_vs_cnn_resolution.png') plt.savefig(path, dpi=150, bbox_inches='tight', facecolor='white') plt.close() print(f'Saved: {path}') if __name__ == '__main__': out_dir = os.path.join(os.path.dirname(__file__), '..', 'figures') os.makedirs(out_dir, exist_ok=True) make_comparison(out_dir)