File size: 925 Bytes
0bae2fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import torch
import matplotlib.pyplot as plt
from mymodel import MyCIFAR10Net
# Example: Visualize first conv layer filters
def visualize_filters(model_path='best_model.pth', save_path='filters.png'):
model = MyCIFAR10Net(num_classes=10)
model.load_state_dict(torch.load(model_path))
conv1_weights = model.conv1.weight.data.cpu()
num_filters = conv1_weights.shape[0]
fig, axes = plt.subplots(1, num_filters, figsize=(num_filters*2, 2))
for i in range(num_filters):
ax = axes[i]
# Normalize to [0,1] for visualization
w = conv1_weights[i]
w = (w - w.min()) / (w.max() - w.min())
ax.imshow(w.permute(1,2,0))
ax.axis('off')
plt.tight_layout()
plt.savefig(save_path)
plt.show()
# You can add more visualization functions (e.g., loss landscape, feature maps, etc.)
if __name__ == "__main__":
visualize_filters(save_path='my_filters.png')
|