Spaces:
Runtime error
Runtime error
| from matplotlib import pyplot as plt | |
| import numpy as np | |
| import torchvision | |
| def imshow(dataloader, title=None): | |
| inputs, _ = next(iter(dataloader)) | |
| out = torchvision.utils.make_grid(inputs) | |
| inp = out.numpy().transpose((1, 2, 0)) | |
| mean = np.array([0.485, 0.456, 0.406]) | |
| std = np.array([0.229, 0.224, 0.225]) | |
| inp = std * inp + mean | |
| inp = np.clip(inp, 0, 1) | |
| plt.imshow(inp) | |
| if title is not None: | |
| plt.title(title) | |
| plt.show() | |
| plt.pause(0.001) # pause a bit so that plots are updated | |