File size: 2,691 Bytes
2730fd2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | """Tests for dataset classes."""
import torch
from torch.utils.data import DataLoader
from namer.data import InfiniteNamerDataset, NamerDataset
from namer.utils import EOS_IDX, VOCABULARY
class TestNamerDataset:
"""Tests for NamerDataset class."""
def test_length(self) -> None:
dataset = NamerDataset(num_samples=50, seed=42)
assert len(dataset) == 50
def test_sample_shape(self) -> None:
dataset = NamerDataset(num_samples=10, max_seq_len=20, seed=42)
digits, encoded = dataset[0]
assert digits.shape == (20,)
assert encoded.shape == (20,)
assert digits.dtype == torch.long
assert encoded.dtype == torch.long
def test_padding_value(self) -> None:
dataset = NamerDataset(num_samples=10, max_seq_len=20, seed=42)
digits, _ = dataset[0]
# Padding should be 10
assert (digits == 10).any() or len([d for d in digits if d != 10]) <= 6
def test_eos_present(self) -> None:
dataset = NamerDataset(num_samples=10, seed=42)
_, encoded = dataset[0]
# EOS token should be present
assert EOS_IDX in encoded.tolist()
class TestInfiniteNamerDataset:
"""Tests for InfiniteNamerDataset class."""
def test_iteration(self) -> None:
dataset = InfiniteNamerDataset(seed=42)
iterator = iter(dataset)
# Can get multiple samples
for _ in range(10):
digits, encoded = next(iterator)
assert digits.shape == (20,)
assert encoded.shape == (20,)
def test_data_loader(self) -> None:
dataset = InfiniteNamerDataset(seed=42)
loader = DataLoader(dataset, batch_size=4, num_workers=0)
iterator = iter(loader)
digits_batch, encoded_batch = next(iterator)
assert digits_batch.shape == (4, 20)
assert encoded_batch.shape == (4, 20)
def test_reproducibility(self) -> None:
dataset1 = InfiniteNamerDataset(seed=42)
dataset2 = InfiniteNamerDataset(seed=42)
iter1 = iter(dataset1)
iter2 = iter(dataset2)
for _ in range(5):
d1, e1 = next(iter1)
d2, e2 = next(iter2)
assert torch.equal(d1, d2)
assert torch.equal(e1, e2)
def test_vocab_range(self) -> None:
dataset = InfiniteNamerDataset(seed=42)
iterator = iter(dataset)
for _ in range(20):
_, encoded = next(iterator)
# Valid tokens should be within vocab range (excluding -1 padding)
valid_tokens = encoded[encoded != -1]
assert (valid_tokens >= 0).all()
assert (valid_tokens < len(VOCABULARY)).all()
|