File size: 689 Bytes
8ee14ff
 
2192664
 
8ee14ff
 
 
 
 
 
 
 
 
 
 
 
2192664
8ee14ff
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# -*- coding: utf-8 -*-
import torch
from PIL import Image, ImageOps
from torchvision.transforms import Compose, Resize, ToTensor


def test_architecture_params(net):
    total_params = sum(params.numel() for params in net.parameters())
    assert total_params == 133578


def test_model_prediction(cfg, device, net):
    image = Image.open('examples/example_1.jpg')

    image = ImageOps.grayscale(image)

    transforms = Compose([Resize(cfg['image_size']), ToTensor()])
    image = transforms(image).unsqueeze(0)
    data = image.to(device)

    with torch.no_grad():
        prediction = net(data)

    prediction = torch.argmax(prediction[0], dim=0).item()
    assert prediction == 5