| from sudoku.helper import compute_loss | |
| import torch | |
| def test_compute_loss(): | |
| x = torch.zeros((3, 2, 729)) | |
| y = torch.zeros((3, 2, 729)) | |
| output = torch.zeros((3, 2, 729)) | |
| y[:, 0, 0] = 1 | |
| output[0, 0, 0] = 0.1 | |
| output[1, 0, 0] = 0.1 | |
| output[2, 0, 1] = 0.1 | |
| new_x = (output > 0).type("torch.FloatTensor") | |
| loss_error, loss_no_improve, n_error, n_no_improve = compute_loss( | |
| x, y, output, new_x | |
| ) | |