File size: 194 Bytes
36c95ba
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
import pytest
import torch


@pytest.mark.parametrize("batch_size", [1, 2, 5])
def test_smoke(batch_size):
    x = torch.rand(batch_size, 2, 3)
    assert x.shape == (batch_size, 2, 3), x.shape