| """ |
| CNN vs CGN resolution comparison. |
| |
| Shows what each architecture effectively "sees": |
| - CGN: full 28x28 original image |
| - CNN: information remaining after conv+pool stages (5x5 upsampled back) |
| """ |
| import os, pickle |
| import numpy as np |
| import matplotlib |
| matplotlib.use('Agg') |
| import matplotlib.pyplot as plt |
|
|
|
|
| def load_mnist(): |
| npz = os.path.join(os.path.dirname(__file__), '..', 'data', 'mnist.npz') |
| d = np.load(npz) |
| X = d['X'].astype(np.float64) |
| y = d['y'] |
| return X[60000:], y[60000:] |
|
|
|
|
| def simulate_cnn_loss(img_28): |
| """Simulate CNN's resolution loss: 28->26->13->11->5.""" |
| from PIL import Image |
| img = Image.fromarray(img_28.astype(np.uint8)) |
| |
| img26 = img.crop((1, 1, 27, 27)) |
| |
| img13 = img26.resize((13, 13), Image.BILINEAR) |
| |
| img11 = img13.crop((1, 1, 12, 12)) |
| |
| img5 = img11.resize((5, 5), Image.BILINEAR) |
| |
| img_back = img5.resize((28, 28), Image.NEAREST) |
| return np.array(img5), np.array(img_back) |
|
|
|
|
| def make_comparison(out_dir): |
| X_te, y_te = load_mnist() |
|
|
| |
| examples = [] |
| for d in range(10): |
| idx = np.where(y_te == d)[0][0] |
| examples.append((d, X_te[idx].reshape(28, 28))) |
|
|
| fig, axes = plt.subplots(3, 10, figsize=(15, 5.5)) |
|
|
| for col, (digit, img) in enumerate(examples): |
| |
| axes[0, col].imshow(img, cmap='gray', vmin=0, vmax=255) |
| axes[0, col].axis('off') |
| if col == 0: |
| axes[0, col].set_ylabel('CGN\n28x28\n(100%)', fontsize=9, |
| rotation=0, labelpad=50, va='center') |
|
|
| |
| img5, img_back = simulate_cnn_loss(img) |
| axes[1, col].imshow(img_back, cmap='gray', vmin=0, vmax=255) |
| axes[1, col].axis('off') |
| if col == 0: |
| axes[1, col].set_ylabel('CNN\n5x5→28x28\n(3%)', fontsize=9, |
| rotation=0, labelpad=50, va='center') |
|
|
| |
| axes[2, col].imshow(img5, cmap='gray', vmin=0, vmax=255) |
| axes[2, col].axis('off') |
| if col == 0: |
| axes[2, col].set_ylabel('CNN actual\n5x5\n(25 pixels)', fontsize=9, |
| rotation=0, labelpad=50, va='center') |
|
|
| axes[0, col].set_title(f'{digit}', fontsize=11, fontweight='bold') |
|
|
| plt.suptitle('What each architecture sees\n' |
| 'CGN: full resolution input CNN: after conv→pool→conv→pool', |
| fontsize=12, fontweight='bold', y=1.02) |
| plt.tight_layout() |
|
|
| path = os.path.join(out_dir, 'cgn_vs_cnn_resolution.png') |
| plt.savefig(path, dpi=150, bbox_inches='tight', facecolor='white') |
| plt.close() |
| print(f'Saved: {path}') |
|
|
|
|
| if __name__ == '__main__': |
| out_dir = os.path.join(os.path.dirname(__file__), '..', 'figures') |
| os.makedirs(out_dir, exist_ok=True) |
| make_comparison(out_dir) |
|
|