File size: 2,401 Bytes
2730fd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""Tests for inference utilities."""

from unittest.mock import MagicMock, patch

import pytest
import torch

from namer.inference import predict_number_name
from namer.models import NamerTransformer
from namer.utils import VOCABULARY, read_digits, int_to_digits


class TestPredictNumberName:
    """Tests for predict_number_name function."""

    @pytest.fixture
    def mock_model(self) -> MagicMock:
        model = MagicMock(spec=NamerTransformer)
        model.max_output_len = 20
        model.vocab_size = len(VOCABULARY)

        # Mock the device property
        param = MagicMock()
        param.device = torch.device("cpu")
        model.parameters.return_value = iter([param])

        return model

    def test_basic_prediction(self, mock_model: MagicMock) -> None:
        # Create fake logits that will select known tokens
        # "one" is index 1 in VOCABULARY
        fake_logits = torch.zeros(1, 20, len(VOCABULARY))
        fake_logits[0, 0, 1] = 10.0  # "one"
        fake_logits[0, 1, VOCABULARY.index("<EOS>")] = 10.0  # EOS

        mock_model.return_value = fake_logits
        mock_model.eval = MagicMock()

        with patch("namer.inference.torch.no_grad"):
            result = predict_number_name(mock_model, 1)

        # Should decode to "one"
        assert "one" in result.lower() or result.startswith("<")

    def test_eos_stops_generation(self, mock_model: MagicMock) -> None:
        # Logits that predict EOS immediately
        fake_logits = torch.zeros(1, 20, len(VOCABULARY))
        fake_logits[0, 0, VOCABULARY.index("<EOS>")] = 10.0

        mock_model.return_value = fake_logits
        mock_model.eval = MagicMock()

        with patch("namer.inference.torch.no_grad"):
            result = predict_number_name(mock_model, 0)

        # Empty result when EOS is first
        assert result == "" or result.startswith("<")

    def test_device_override(self, mock_model: MagicMock) -> None:
        fake_logits = torch.zeros(1, 20, len(VOCABULARY))
        fake_logits[0, 0, 1] = 10.0
        fake_logits[0, 1, VOCABULARY.index("<EOS>")] = 10.0

        mock_model.return_value = fake_logits
        mock_model.eval = MagicMock()

        with patch("namer.inference.torch.no_grad"):
            # Should not raise when device is specified
            result = predict_number_name(mock_model, 1, device="cpu")

        assert isinstance(result, str)