File size: 3,077 Bytes
252a794
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
CNN vs CGN resolution comparison.

Shows what each architecture effectively "sees":
- CGN: full 28x28 original image
- CNN: information remaining after conv+pool stages (5x5 upsampled back)
"""
import os, pickle
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt


def load_mnist():
    npz = os.path.join(os.path.dirname(__file__), '..', 'data', 'mnist.npz')
    d = np.load(npz)
    X = d['X'].astype(np.float64)
    y = d['y']
    return X[60000:], y[60000:]


def simulate_cnn_loss(img_28):
    """Simulate CNN's resolution loss: 28->26->13->11->5."""
    from PIL import Image
    img = Image.fromarray(img_28.astype(np.uint8))
    # conv1: 28->26 (crop border)
    img26 = img.crop((1, 1, 27, 27))
    # pool1: 26->13
    img13 = img26.resize((13, 13), Image.BILINEAR)
    # conv2: 13->11 (crop border)
    img11 = img13.crop((1, 1, 12, 12))
    # pool2: 11->5
    img5 = img11.resize((5, 5), Image.BILINEAR)
    # upscale back to 28x28 for visual comparison
    img_back = img5.resize((28, 28), Image.NEAREST)
    return np.array(img5), np.array(img_back)


def make_comparison(out_dir):
    X_te, y_te = load_mnist()

    # Pick one example per digit
    examples = []
    for d in range(10):
        idx = np.where(y_te == d)[0][0]
        examples.append((d, X_te[idx].reshape(28, 28)))

    fig, axes = plt.subplots(3, 10, figsize=(15, 5.5))

    for col, (digit, img) in enumerate(examples):
        # Row 0: Original (what CGN sees)
        axes[0, col].imshow(img, cmap='gray', vmin=0, vmax=255)
        axes[0, col].axis('off')
        if col == 0:
            axes[0, col].set_ylabel('CGN\n28x28\n(100%)', fontsize=9,
                                     rotation=0, labelpad=50, va='center')

        # Row 1: CNN's effective view (5x5 upsampled)
        img5, img_back = simulate_cnn_loss(img)
        axes[1, col].imshow(img_back, cmap='gray', vmin=0, vmax=255)
        axes[1, col].axis('off')
        if col == 0:
            axes[1, col].set_ylabel('CNN\n5x5→28x28\n(3%)', fontsize=9,
                                     rotation=0, labelpad=50, va='center')

        # Row 2: CNN's actual 5x5
        axes[2, col].imshow(img5, cmap='gray', vmin=0, vmax=255)
        axes[2, col].axis('off')
        if col == 0:
            axes[2, col].set_ylabel('CNN actual\n5x5\n(25 pixels)', fontsize=9,
                                     rotation=0, labelpad=50, va='center')

        axes[0, col].set_title(f'{digit}', fontsize=11, fontweight='bold')

    plt.suptitle('What each architecture sees\n'
                 'CGN: full resolution input    CNN: after conv→pool→conv→pool',
                 fontsize=12, fontweight='bold', y=1.02)
    plt.tight_layout()

    path = os.path.join(out_dir, 'cgn_vs_cnn_resolution.png')
    plt.savefig(path, dpi=150, bbox_inches='tight', facecolor='white')
    plt.close()
    print(f'Saved: {path}')


if __name__ == '__main__':
    out_dir = os.path.join(os.path.dirname(__file__), '..', 'figures')
    os.makedirs(out_dir, exist_ok=True)
    make_comparison(out_dir)