File size: 925 Bytes
0bae2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import matplotlib.pyplot as plt
from mymodel import MyCIFAR10Net

# Example: Visualize first conv layer filters

def visualize_filters(model_path='best_model.pth', save_path='filters.png'):
    model = MyCIFAR10Net(num_classes=10)
    model.load_state_dict(torch.load(model_path))
    conv1_weights = model.conv1.weight.data.cpu()
    num_filters = conv1_weights.shape[0]
    fig, axes = plt.subplots(1, num_filters, figsize=(num_filters*2, 2))
    for i in range(num_filters):
        ax = axes[i]
        # Normalize to [0,1] for visualization
        w = conv1_weights[i]
        w = (w - w.min()) / (w.max() - w.min())
        ax.imshow(w.permute(1,2,0))
        ax.axis('off')
    plt.tight_layout()
    plt.savefig(save_path)
    plt.show()

# You can add more visualization functions (e.g., loss landscape, feature maps, etc.)

if __name__ == "__main__":
    visualize_filters(save_path='my_filters.png')