Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from data_loaders import load_dataset, get_dataset_sizes, get_dataloaders, get_class_names | |
| # Visualize Model Predictions | |
| def visualize_model(model, device, batch_size, fig_name="Predictions"): | |
| images_so_far = 0 | |
| _fig = plt.figure(fig_name) | |
| model.eval() | |
| # Load data | |
| data_set = load_dataset() | |
| dataset_sizes = get_dataset_sizes(data_set) | |
| class_names = get_class_names(data_set) | |
| dataloaders = get_dataloaders(data_set, batch_size) | |
| with torch.no_grad(): | |
| for _i, (inputs, labels) in enumerate(dataloaders["validation"]): | |
| inputs = inputs.to(device) | |
| labels = labels.to(device) | |
| outputs = model(inputs) | |
| _, preds = torch.max(outputs, 1) | |
| for j in range(inputs.size()[0]): | |
| images_so_far += 1 | |
| ax = plt.subplot(batch_size // 2, 2, images_so_far) | |
| ax.axis("off") | |
| ax.set_title("[Pred: {}]\n[Label: {}]".format(class_names[preds[j]], class_names[labels[j]])) | |
| imshow(inputs.cpu().data[j]) | |
| if images_so_far == batch_size: | |
| return | |
| plt.show() | |
| # Show images | |
| def imshow(inp, title=None): | |
| """Display image from tensor.""" | |
| inp = inp.numpy().transpose((1, 2, 0)) | |
| # Inverse of the initial normalization operation. | |
| mean = np.array([0.485, 0.456, 0.406]) | |
| std = np.array([0.229, 0.224, 0.225]) | |
| inp = std * inp + mean | |
| inp = np.clip(inp, 0, 1) | |
| plt.imshow(inp) | |
| if title is not None: | |
| plt.title(title) | |