File size: 2,691 Bytes
2730fd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""Tests for dataset classes."""

import torch
from torch.utils.data import DataLoader

from namer.data import InfiniteNamerDataset, NamerDataset
from namer.utils import EOS_IDX, VOCABULARY


class TestNamerDataset:
    """Tests for NamerDataset class."""

    def test_length(self) -> None:
        dataset = NamerDataset(num_samples=50, seed=42)
        assert len(dataset) == 50

    def test_sample_shape(self) -> None:
        dataset = NamerDataset(num_samples=10, max_seq_len=20, seed=42)
        digits, encoded = dataset[0]

        assert digits.shape == (20,)
        assert encoded.shape == (20,)
        assert digits.dtype == torch.long
        assert encoded.dtype == torch.long

    def test_padding_value(self) -> None:
        dataset = NamerDataset(num_samples=10, max_seq_len=20, seed=42)
        digits, _ = dataset[0]

        # Padding should be 10
        assert (digits == 10).any() or len([d for d in digits if d != 10]) <= 6

    def test_eos_present(self) -> None:
        dataset = NamerDataset(num_samples=10, seed=42)
        _, encoded = dataset[0]

        # EOS token should be present
        assert EOS_IDX in encoded.tolist()


class TestInfiniteNamerDataset:
    """Tests for InfiniteNamerDataset class."""

    def test_iteration(self) -> None:
        dataset = InfiniteNamerDataset(seed=42)
        iterator = iter(dataset)

        # Can get multiple samples
        for _ in range(10):
            digits, encoded = next(iterator)
            assert digits.shape == (20,)
            assert encoded.shape == (20,)

    def test_data_loader(self) -> None:
        dataset = InfiniteNamerDataset(seed=42)
        loader = DataLoader(dataset, batch_size=4, num_workers=0)

        iterator = iter(loader)
        digits_batch, encoded_batch = next(iterator)

        assert digits_batch.shape == (4, 20)
        assert encoded_batch.shape == (4, 20)

    def test_reproducibility(self) -> None:
        dataset1 = InfiniteNamerDataset(seed=42)
        dataset2 = InfiniteNamerDataset(seed=42)

        iter1 = iter(dataset1)
        iter2 = iter(dataset2)

        for _ in range(5):
            d1, e1 = next(iter1)
            d2, e2 = next(iter2)
            assert torch.equal(d1, d2)
            assert torch.equal(e1, e2)

    def test_vocab_range(self) -> None:
        dataset = InfiniteNamerDataset(seed=42)
        iterator = iter(dataset)

        for _ in range(20):
            _, encoded = next(iterator)
            # Valid tokens should be within vocab range (excluding -1 padding)
            valid_tokens = encoded[encoded != -1]
            assert (valid_tokens >= 0).all()
            assert (valid_tokens < len(VOCABULARY)).all()