Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| import torch | |
| from PIL import Image, ImageOps | |
| from torchvision.transforms import Compose, Resize, ToTensor | |
| def test_architecture_params(net): | |
| total_params = sum(params.numel() for params in net.parameters()) | |
| assert total_params == 133578 | |
| def test_model_prediction(cfg, device, net): | |
| image = Image.open('examples/example_1.jpg') | |
| image = ImageOps.grayscale(image) | |
| transforms = Compose([Resize(cfg['image_size']), ToTensor()]) | |
| image = transforms(image).unsqueeze(0) | |
| data = image.to(device) | |
| with torch.no_grad(): | |
| prediction = net(data) | |
| prediction = torch.argmax(prediction[0], dim=0).item() | |
| assert prediction == 5 | |