import torch import matplotlib.pyplot as plt from mymodel import MyCIFAR10Net # Example: Visualize first conv layer filters def visualize_filters(model_path='best_model.pth', save_path='filters.png'): model = MyCIFAR10Net(num_classes=10) model.load_state_dict(torch.load(model_path)) conv1_weights = model.conv1.weight.data.cpu() num_filters = conv1_weights.shape[0] fig, axes = plt.subplots(1, num_filters, figsize=(num_filters*2, 2)) for i in range(num_filters): ax = axes[i] # Normalize to [0,1] for visualization w = conv1_weights[i] w = (w - w.min()) / (w.max() - w.min()) ax.imshow(w.permute(1,2,0)) ax.axis('off') plt.tight_layout() plt.savefig(save_path) plt.show() # You can add more visualization functions (e.g., loss landscape, feature maps, etc.) if __name__ == "__main__": visualize_filters(save_path='my_filters.png')