Edwin Jose Palathinkal commited on
Commit
2730fd2
·
0 Parent(s):

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ *.egg-info/
8
+ dist/
9
+ build/
10
+ *.egg
11
+
12
+ # Virtual environments
13
+ .venv/
14
+ venv/
15
+ ENV/
16
+ env/
17
+
18
+ # PyTorch
19
+ *.pt
20
+ *.pth
21
+
22
+ # Testing
23
+ .pytest_cache/
24
+ .coverage
25
+ htmlcov/
26
+ .tox/
27
+
28
+ # IDE
29
+ .vscode/
30
+ .idea/
31
+ *.swp
32
+ *.swo
33
+ *~
34
+
35
+ # OS
36
+ .DS_Store
37
+ Thumbs.db
38
+
39
+ # Project specific
40
+ namer_model.pt
41
+ .pip-tmp/
Makefile ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: help install dev train infer test lint format clean distclean
2
+
3
+ # Python environment
4
+ PYTHON := python3
5
+ VENV := .venv
6
+ VENV_PYTHON := $(VENV)/bin/python
7
+
8
+ # Default target
9
+ help:
10
+ @echo "Namer - Number to Name Transformer"
11
+ @echo ""
12
+ @echo "Available targets:"
13
+ @echo " make install - Install package in development mode"
14
+ @echo " make dev - Install with dev dependencies"
15
+ @echo " make train - Train the model"
16
+ @echo " make infer - Run interactive inference"
17
+ @echo " make test - Run test suite"
18
+ @echo " make lint - Run linting (ruff)"
19
+ @echo " make format - Format code (ruff)"
20
+ @echo " make typecheck - Run type checking (mypy)"
21
+ @echo " make clean - Remove generated files and caches"
22
+ @echo " make distclean - Deep clean including venv"
23
+ @echo ""
24
+
25
+ # Create virtual environment and install
26
+ $(VENV):
27
+ @echo "Creating virtual environment..."
28
+ $(PYTHON) -m venv $(VENV)
29
+ $(VENV_PYTHON) -m pip install --upgrade pip
30
+
31
+ # Install package
32
+ install: $(VENV)
33
+ @echo "Installing package..."
34
+ $(VENV_PYTHON) -m pip install -e .
35
+
36
+ # Install with dev dependencies
37
+ dev: $(VENV)
38
+ @echo "Installing with dev dependencies..."
39
+ $(VENV_PYTHON) -m pip install -e ".[dev]"
40
+
41
+ # Run training
42
+ train: $(VENV)
43
+ @echo "Starting training..."
44
+ $(VENV_PYTHON) -m namer train
45
+
46
+ # Run interactive inference
47
+ infer: $(VENV)
48
+ @echo "Starting inference..."
49
+ $(VENV_PYTHON) -m namer infer
50
+
51
+ # Run tests
52
+ test: $(VENV)
53
+ @echo "Running tests..."
54
+ $(VENV_PYTHON) -m pytest -v
55
+
56
+ # Run tests with coverage
57
+ test-cov: $(VENV)
58
+ @echo "Running tests with coverage..."
59
+ $(VENV_PYTHON) -m pytest --cov=namer --cov-report=html --cov-report=term
60
+
61
+ # Run linting
62
+ lint: $(VENV)
63
+ @echo "Running ruff linter..."
64
+ $(VENV_PYTHON) -m ruff check namer tests
65
+
66
+ # Fix linting issues
67
+ lint-fix: $(VENV)
68
+ @echo "Fixing linting issues..."
69
+ $(VENV_PYTHON) -m ruff check --fix namer tests
70
+
71
+ # Format code
72
+ format: $(VENV)
73
+ @echo "Formatting code..."
74
+ $(VENV_PYTHON) -m ruff format namer tests
75
+
76
+ # Run type checking
77
+ typecheck: $(VENV)
78
+ @echo "Running mypy..."
79
+ $(VENV_PYTHON) -m mypy namer
80
+
81
+ # Run all checks
82
+ check: lint typecheck test
83
+ @echo "All checks passed!"
84
+
85
+ # Clean generated files
86
+ clean:
87
+ @echo "Cleaning generated files..."
88
+ rm -f namer_model.pt
89
+ rm -rf htmlcov .pytest_cache .coverage
90
+ find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
91
+ find . -type f -name "*.pyc" -delete
92
+ find . -type f -name "*.pyo" -delete
93
+ find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
94
+ @echo "Clean complete!"
95
+
96
+ # Deep clean
97
+ distclean: clean
98
+ @echo "Removing virtual environment..."
99
+ rm -rf $(VENV)
100
+ @echo "All clean! Run 'make dev' to start fresh."
README.md ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Namer
2
+
3
+ A PyTorch transformer model that converts numbers to their English names.
4
+
5
+ ## Features
6
+
7
+ - **Transformer architecture** with cross-attention mechanism
8
+ - **Infinite dataset** training with early stopping
9
+ - **Modular design** following Python best practices
10
+ - **Type hints** throughout for better IDE support
11
+ - **Comprehensive test suite** with pytest
12
+ - **Modern tooling**: ruff (linting/formatting), mypy (type checking)
13
+
14
+ ## Installation
15
+
16
+ ```bash
17
+ # Clone the repository
18
+ git clone https://github.com/example/namer.git
19
+ cd namer
20
+
21
+ # Create virtual environment
22
+ python -m venv .venv
23
+ source .venv/bin/activate # On Windows: .venv\Scripts\activate
24
+
25
+ # Install in development mode
26
+ pip install -e ".[dev]"
27
+ ```
28
+
29
+ ## Usage
30
+
31
+ ### Command Line Interface
32
+
33
+ ```bash
34
+ # Show help
35
+ namer --help
36
+
37
+ # Run demonstrations
38
+ namer demo
39
+
40
+ # Train the model
41
+ namer train
42
+
43
+ # Train with custom settings
44
+ namer train --epochs 50 --steps 2000 --batch-size 64 --lr 0.0005
45
+
46
+ # Run interactive inference
47
+ namer infer
48
+
49
+ # Run quick test
50
+ namer test
51
+ ```
52
+
53
+ ### Python API
54
+
55
+ ```python
56
+ from namer import NamerTransformer, load_namer_model, predict_number_name
57
+
58
+ # Load a trained model
59
+ model = load_namer_model("namer_model.pt")
60
+
61
+ # Predict number names
62
+ name = predict_number_name(model, 123456)
63
+ print(name) # "one hundred twenty three thousand four hundred fifty six"
64
+ ```
65
+
66
+ ## Project Structure
67
+
68
+ ```
69
+ namer/
70
+ ├── namer/ # Main package
71
+ │ ├── __init__.py # Package exports
72
+ │ ├── main.py # CLI entry point
73
+ │ ├── models.py # Transformer model definitions
74
+ │ ├── data.py # Dataset classes
75
+ │ ├── training.py # Training loop
76
+ │ ├── inference.py # Inference utilities
77
+ │ └── utils.py # Number-to-name conversion utilities
78
+ ├── tests/ # Test suite
79
+ │ ├── test_utils.py
80
+ │ ├── test_models.py
81
+ │ ├── test_data.py
82
+ │ └── test_inference.py
83
+ ├── pyproject.toml # Project configuration
84
+ ├── README.md
85
+ └── Makefile # Convenience commands
86
+ ```
87
+
88
+ ## Development
89
+
90
+ ### Running Tests
91
+
92
+ ```bash
93
+ # Run all tests
94
+ pytest
95
+
96
+ # Run with coverage
97
+ pytest --cov=namer --cov-report=html
98
+
99
+ # Run specific test file
100
+ pytest tests/test_utils.py
101
+ ```
102
+
103
+ ### Linting and Formatting
104
+
105
+ ```bash
106
+ # Check code style
107
+ ruff check .
108
+
109
+ # Fix auto-fixable issues
110
+ ruff check --fix .
111
+
112
+ # Format code
113
+ ruff format .
114
+
115
+ # Type checking
116
+ mypy namer
117
+ ```
118
+
119
+ ### Makefile Commands
120
+
121
+ ```bash
122
+ make help # Show available commands
123
+ make install # Install dependencies
124
+ make train # Train the model
125
+ make inference # Run interactive inference
126
+ make test # Run tests
127
+ make clean # Clean generated files
128
+ make distclean # Deep clean including venv
129
+ ```
130
+
131
+ ## Model Architecture
132
+
133
+ The `NamerTransformer` uses an encoder-only architecture:
134
+
135
+ 1. **Digit Embedding** - Embeds digits 0-9 (plus padding token)
136
+ 2. **Positional Encoding** - Sinusoidal positional embeddings
137
+ 3. **Transformer Encoder** - Multi-layer encoder with self-attention
138
+ 4. **Cross-Attention** - Learned output queries attend to encoded digits
139
+ 5. **Output Projection** - Projects to vocabulary for each output position
140
+
141
+ ## Training
142
+
143
+ The model trains on an infinite dataset that generates random number-to-name mappings on-the-fly:
144
+
145
+ - Numbers up to 999,999 (configurable)
146
+ - Early stopping with patience (default: 10 epochs)
147
+ - Cross-entropy loss with -1 padding ignored
148
+ - Adam optimizer with configurable learning rate
149
+
150
+ ## Requirements
151
+
152
+ - Python 3.10+
153
+ - PyTorch 2.0+
154
+ - CUDA-capable GPU (optional, falls back to CPU)
155
+
156
+ ## License
157
+
158
+ MIT License - see LICENSE file for details.
namer/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Namer - A PyTorch transformer model for converting numbers to English names."""
2
+
3
+ __version__ = "0.2.0"
4
+
5
+ from namer.models import NamerTransformer, load_namer_model
6
+ from namer.inference import predict_number_name
7
+ from namer.utils import (
8
+ VOCABULARY,
9
+ encode,
10
+ decode,
11
+ int_to_digits,
12
+ digits_to_int,
13
+ read_digits,
14
+ read_triplet,
15
+ read_double,
16
+ )
17
+
18
+ __all__ = [
19
+ "NamerTransformer",
20
+ "load_namer_model",
21
+ "predict_number_name",
22
+ "VOCABULARY",
23
+ "encode",
24
+ "decode",
25
+ "int_to_digits",
26
+ "digits_to_int",
27
+ "read_digits",
28
+ "read_triplet",
29
+ "read_double",
30
+ ]
namer/__main__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Make namer package executable."""
2
+
3
+ import sys
4
+
5
+ from namer.main import main
6
+
7
+ if __name__ == "__main__":
8
+ sys.exit(main())
namer/data.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Dataset classes for Namer."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+
7
+ import torch
8
+ from torch.utils.data import IterableDataset, TensorDataset
9
+
10
+ from namer.utils import EOS_IDX, encode, int_to_digits, read_digits
11
+
12
+
13
+ class NamerDataset(TensorDataset):
14
+ """Finite dataset mapping random integers to encoded number names."""
15
+
16
+ def __init__(
17
+ self,
18
+ num_samples: int = 1000,
19
+ max_int: int = 999999,
20
+ max_seq_len: int = 20,
21
+ seed: int = 42,
22
+ ) -> None:
23
+ """Create a PyTorch TensorDataset mapping random integers to encoded number names.
24
+
25
+ Args:
26
+ num_samples: Number of samples to generate
27
+ max_int: Maximum random integer value
28
+ max_seq_len: Maximum sequence length for padding
29
+ seed: Random seed for reproducibility
30
+ """
31
+ rng = random.Random(seed)
32
+
33
+ digit_sequences: list[list[int]] = []
34
+ encoded_names: list[list[int]] = []
35
+
36
+ for _ in range(num_samples):
37
+ n = rng.randint(0, max_int)
38
+ digits = int_to_digits(n)
39
+ name = read_digits(digits)
40
+ encoded = encode(name)
41
+
42
+ digit_sequences.append(digits)
43
+ encoded_names.append(encoded)
44
+
45
+ # Pad sequences
46
+ padded_digits: list[list[int]] = []
47
+ padded_encoded: list[list[int]] = []
48
+
49
+ for digits, encoded in zip(digit_sequences, encoded_names):
50
+ # Pad digits with 10 to indicate padding
51
+ digits_padded = digits + [10] * (max_seq_len - len(digits))
52
+ digits_padded = digits_padded[:max_seq_len]
53
+
54
+ # Append EOS token to encoded, then pad with -1
55
+ encoded_with_eos = encoded + [EOS_IDX]
56
+ encoded_padded = encoded_with_eos + [-1] * (max_seq_len - len(encoded_with_eos))
57
+ encoded_padded = encoded_padded[:max_seq_len]
58
+
59
+ padded_digits.append(digits_padded)
60
+ padded_encoded.append(encoded_padded)
61
+
62
+ # Convert to tensors
63
+ digits_tensor = torch.tensor(padded_digits, dtype=torch.long)
64
+ encoded_tensor = torch.tensor(padded_encoded, dtype=torch.long)
65
+
66
+ super().__init__(digits_tensor, encoded_tensor)
67
+
68
+
69
+ class InfiniteNamerDataset(IterableDataset):
70
+ """Infinite dataset that generates random number-to-name mappings on-the-fly.
71
+
72
+ Uses Python generators to produce an endless stream of training samples.
73
+ Each iteration yields fresh random samples.
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ max_int: int = 999999,
79
+ max_seq_len: int = 20,
80
+ seed: int | None = None,
81
+ ) -> None:
82
+ """Initialize the infinite dataset.
83
+
84
+ Args:
85
+ max_int: Maximum random integer value
86
+ max_seq_len: Maximum sequence length for padding
87
+ seed: Random seed (optional, for reproducibility)
88
+ """
89
+ self.max_int = max_int
90
+ self.max_seq_len = max_seq_len
91
+ self.seed = seed
92
+ self.rng = random.Random(seed)
93
+
94
+ def _generate_sample(self) -> tuple[torch.Tensor, torch.Tensor]:
95
+ """Generate a single (digits, encoded_name) sample."""
96
+ n = self.rng.randint(0, self.max_int)
97
+ digits = int_to_digits(n)
98
+ name = read_digits(digits)
99
+ encoded = encode(name)
100
+
101
+ # Pad digits with 10 (padding index)
102
+ digits_padded = digits + [10] * (self.max_seq_len - len(digits))
103
+ digits_padded = digits_padded[: self.max_seq_len]
104
+
105
+ # Append EOS and pad with -1
106
+ encoded_with_eos = encoded + [EOS_IDX]
107
+ encoded_padded = encoded_with_eos + [-1] * (self.max_seq_len - len(encoded_with_eos))
108
+ encoded_padded = encoded_padded[: self.max_seq_len]
109
+
110
+ return (
111
+ torch.tensor(digits_padded, dtype=torch.long),
112
+ torch.tensor(encoded_padded, dtype=torch.long),
113
+ )
114
+
115
+ def __iter__(self) -> InfiniteNamerDataset:
116
+ """Yield samples infinitely.
117
+
118
+ Each worker in multi-worker DataLoader gets its own iterator
119
+ with a unique seed based on worker_id.
120
+ """
121
+ worker_info = torch.utils.data.get_worker_info()
122
+
123
+ if worker_info is None:
124
+ # Single-process loading
125
+ rng_seed = self.seed if self.seed else random.randint(0, 2**32)
126
+ self.rng = random.Random(rng_seed)
127
+ else:
128
+ # Multi-worker: each worker gets unique seed
129
+ worker_id = worker_info.id
130
+ base_seed = self.seed if self.seed else random.randint(0, 2**32)
131
+ self.rng = random.Random(base_seed + worker_id * 1000)
132
+
133
+ return self
134
+
135
+ def __next__(self) -> tuple[torch.Tensor, torch.Tensor]:
136
+ """Generate the next sample."""
137
+ return self._generate_sample()
namer/inference.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference utilities for Namer models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+
7
+ from namer.models import NamerTransformer, load_namer_model
8
+ from namer.utils import EOS_IDX, decode, int_to_digits
9
+
10
+
11
+ def predict_number_name(
12
+ model: NamerTransformer,
13
+ n: int,
14
+ device: str | torch.device | None = None,
15
+ ) -> str:
16
+ """Predict the English name of a number using the trained model.
17
+
18
+ Stops generation when <EOS> token is predicted.
19
+
20
+ Args:
21
+ model: Trained model
22
+ n: Integer to convert to name
23
+ device: Device to run inference on (auto-detected if None)
24
+
25
+ Returns:
26
+ Predicted English name of the number
27
+ """
28
+ if device is None:
29
+ device = next(model.parameters()).device
30
+
31
+ model.eval()
32
+
33
+ with torch.no_grad():
34
+ digits = int_to_digits(n)
35
+ padded = digits + [10] * (model.max_output_len - len(digits))
36
+ input_tensor = torch.tensor([padded], dtype=torch.long).to(device)
37
+
38
+ logits = model(input_tensor)
39
+ predictions = logits.argmax(dim=-1)[0].cpu().tolist()
40
+
41
+ # Collect tokens until EOS is predicted or max length reached
42
+ pred_indices: list[int] = []
43
+ for idx in predictions:
44
+ if idx == EOS_IDX:
45
+ break
46
+ pred_indices.append(idx)
47
+
48
+ # Try to decode
49
+ try:
50
+ return decode(pred_indices)
51
+ except ValueError:
52
+ # If decoding fails, try progressively shorter sequences
53
+ for length in range(len(pred_indices), 0, -1):
54
+ try:
55
+ return decode(pred_indices[:length])
56
+ except ValueError:
57
+ continue
58
+ return f"<decode error: {pred_indices}>"
59
+
60
+
61
+ def interactive_inference(model_path: str = "namer_model.pt") -> None:
62
+ """Run interactive inference session.
63
+
64
+ Args:
65
+ model_path: Path to the saved model file
66
+ """
67
+ import sys
68
+
69
+ print("Loading model...")
70
+ device = "cuda" if torch.cuda.is_available() else "cpu"
71
+ print(f"Using device: {device}")
72
+
73
+ try:
74
+ model = load_namer_model(model_path, device)
75
+ print("Model loaded successfully!\n")
76
+ except FileNotFoundError:
77
+ print(f"Error: Model file '{model_path}' not found.")
78
+ print("Please run training first: python -m namer train")
79
+ sys.exit(1)
80
+
81
+ print("Enter a number to convert (or 'quit' to exit):")
82
+ while True:
83
+ try:
84
+ user_input = input("> ").strip()
85
+
86
+ if user_input.lower() in ("quit", "exit", "q"):
87
+ break
88
+
89
+ n = int(user_input)
90
+ name = predict_number_name(model, n, device)
91
+ print(f" {n} -> '{name}'\n")
92
+
93
+ except ValueError:
94
+ print(" Please enter a valid integer\n")
95
+ except KeyboardInterrupt:
96
+ print("\nGoodbye!")
97
+ break
namer/main.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main entry point for namer CLI."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import sys
7
+
8
+ import torch
9
+
10
+ from namer.data import InfiniteNamerDataset
11
+ from namer.inference import interactive_inference, predict_number_name
12
+ from namer.models import NamerTransformer, load_namer_model
13
+ from namer.training import save_model, train_namer_model
14
+ from namer.utils import VOCABULARY, encode, int_to_digits, read_digits
15
+
16
+
17
+ def demo_command(args: argparse.Namespace) -> None:
18
+ """Run number name demonstration."""
19
+ print("--- Number Names Demo ---")
20
+ print("\nread_double (two digits):")
21
+ double_cases = [(0, 7), (1, 1), (2, 3), (3, 0), (0, 0), (5, 9)]
22
+ for a, b in double_cases:
23
+ from namer.utils import read_double
24
+
25
+ print(f" read_double({a}, {b}) = '{read_double(a, b)}'")
26
+
27
+ print("\nread_triplet (three digits):")
28
+ triplet_cases = [(1, 0, 6), (0, 0, 0), (9, 1, 9), (2, 0, 0), (0, 5, 5), (4, 2, 0)]
29
+ for a, b, c in triplet_cases:
30
+ from namer.utils import read_triplet
31
+
32
+ print(f" read_triplet({a}, {b}, {c}) = '{read_triplet(a, b, c)}'")
33
+
34
+ print(f"\nVOCABULARY ({len(VOCABULARY)} words):")
35
+ print(f" {VOCABULARY}")
36
+
37
+ print("\nencode (text to vocabulary indices):")
38
+ encode_cases = [
39
+ "one million",
40
+ "twenty three",
41
+ "one hundred twenty three",
42
+ "nine hundred nineteen",
43
+ "zero",
44
+ ]
45
+ for text in encode_cases:
46
+ print(f" encode('{text}') = {encode(text)}")
47
+
48
+ print("\nencode/decode round-trip:")
49
+ for text in ["one million", "twenty three", "zero"]:
50
+ encoded = encode(text)
51
+ from namer.utils import decode
52
+
53
+ decoded = decode(encoded)
54
+ print(f" '{text}' -> {encoded} -> '{decoded}'")
55
+
56
+ print("\nint_to_digits (integer to digit list):")
57
+ int_cases = [0, 7, 123, -456, 1002003, 9876543210]
58
+ for n in int_cases:
59
+ print(f" int_to_digits({n}) = {int_to_digits(n)}")
60
+
61
+
62
+ def train_command(
63
+ num_epochs: int = 30,
64
+ steps_per_epoch: int = 1000,
65
+ batch_size: int = 128,
66
+ learning_rate: float = 0.001,
67
+ ) -> None:
68
+ """Train the Namer model.
69
+
70
+ Args:
71
+ num_epochs: Number of training epochs
72
+ steps_per_epoch: Number of steps per epoch
73
+ batch_size: Batch size for training
74
+ learning_rate: Learning rate for optimizer
75
+ """
76
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
+ if device.type == "cuda":
78
+ print(f"Using GPU: {torch.cuda.get_device_name(device)}")
79
+ else:
80
+ print("Warning: CUDA not available, using CPU")
81
+
82
+ # Create infinite dataset for training
83
+ infinite_dataset = InfiniteNamerDataset(
84
+ max_int=999999,
85
+ max_seq_len=20,
86
+ seed=42,
87
+ )
88
+
89
+ # Create model
90
+ model = NamerTransformer(
91
+ vocab_size=len(VOCABULARY),
92
+ max_output_len=20,
93
+ d_model=128,
94
+ nhead=4,
95
+ num_encoder_layers=4,
96
+ dim_feedforward=512,
97
+ dropout=0.1,
98
+ )
99
+
100
+ print(f"\nTransformer Model parameters: {sum(p.numel() for p in model.parameters()):,}")
101
+
102
+ # Train model
103
+ trained_model = train_namer_model(
104
+ model=model,
105
+ infinite_dataset=infinite_dataset,
106
+ num_epochs=num_epochs,
107
+ steps_per_epoch=steps_per_epoch,
108
+ val_steps=100,
109
+ batch_size=batch_size,
110
+ learning_rate=learning_rate,
111
+ )
112
+
113
+ # Save model
114
+ save_model(trained_model)
115
+
116
+ # Test predictions
117
+ print("\n--- Model Predictions ---")
118
+ trained_model.eval()
119
+
120
+ test_numbers = [123, 4567, 89012, 555555, 999999, 42, 0, 1000]
121
+ device_obj = next(trained_model.parameters()).device
122
+
123
+ with torch.no_grad():
124
+ for n in test_numbers:
125
+ pred = predict_number_name(trained_model, n, device_obj)
126
+ actual = read_digits(int_to_digits(n))
127
+ match = "✓" if pred == actual else "✗"
128
+ print(f" {n}: pred='{pred}', actual='{actual}' {match}")
129
+
130
+
131
+ def test_command() -> None:
132
+ """Run quick inference test on saved model."""
133
+ try:
134
+ model = load_namer_model("namer_model.pt")
135
+ except FileNotFoundError:
136
+ print("Error: Model file 'namer_model.pt' not found.")
137
+ print("Please train the model first: python -m namer train")
138
+ sys.exit(1)
139
+
140
+ print("Running inference on loaded model:")
141
+ test_nums = [42, 123, 1000, 999999]
142
+ for n in test_nums:
143
+ pred = predict_number_name(model, n)
144
+ actual = read_digits(int_to_digits(n))
145
+ match = "✓" if pred == actual else "✗"
146
+ print(f" {n} -> '{pred}' (actual: '{actual}') {match}")
147
+
148
+
149
+ def main(argv: list[str] | None = None) -> int:
150
+ """Main CLI entry point.
151
+
152
+ Args:
153
+ argv: Command line arguments (defaults to sys.argv)
154
+
155
+ Returns:
156
+ Exit code
157
+ """
158
+ parser = argparse.ArgumentParser(
159
+ prog="namer",
160
+ description="A PyTorch transformer model for converting numbers to their English names.",
161
+ )
162
+
163
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
164
+
165
+ # Demo command
166
+ demo_parser = subparsers.add_parser("demo", help="Run number name demonstrations")
167
+ demo_parser.set_defaults(func=demo_command)
168
+
169
+ # Train command
170
+ train_parser = subparsers.add_parser("train", help="Train the model")
171
+ train_parser.add_argument(
172
+ "--epochs", type=int, default=30, help="Number of training epochs (default: 30)"
173
+ )
174
+ train_parser.add_argument(
175
+ "--steps", type=int, default=1000, help="Steps per epoch (default: 1000)"
176
+ )
177
+ train_parser.add_argument(
178
+ "--batch-size", type=int, default=128, help="Batch size (default: 128)"
179
+ )
180
+ train_parser.add_argument(
181
+ "--lr", type=float, default=0.001, help="Learning rate (default: 0.001)"
182
+ )
183
+ train_parser.set_defaults(
184
+ func=lambda args: train_command(
185
+ num_epochs=args.epochs,
186
+ steps_per_epoch=args.steps,
187
+ batch_size=args.batch_size,
188
+ learning_rate=args.lr,
189
+ )
190
+ )
191
+
192
+ # Inference command
193
+ infer_parser = subparsers.add_parser("infer", help="Run interactive inference")
194
+ infer_parser.set_defaults(func=lambda args: interactive_inference())
195
+
196
+ # Test command
197
+ test_parser = subparsers.add_parser("test", help="Run quick inference test")
198
+ test_parser.set_defaults(func=lambda args: test_command())
199
+
200
+ args = parser.parse_args(argv)
201
+
202
+ if args.command is None:
203
+ parser.print_help()
204
+ return 0
205
+
206
+ args.func(args)
207
+ return 0
208
+
209
+
210
+ if __name__ == "__main__":
211
+ sys.exit(main())
namer/models.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Model definitions for Namer."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+
9
+ class PositionalEncoding(nn.Module):
10
+ """Sinusoidal positional encoding for transformer."""
11
+
12
+ def __init__(self, d_model: int, max_len: int = 5000) -> None:
13
+ super().__init__()
14
+
15
+ pe = torch.zeros(max_len, d_model)
16
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
17
+ div_term = torch.exp(
18
+ torch.arange(0, d_model, 2).float()
19
+ * (-torch.log(torch.tensor(10000.0)) / d_model)
20
+ )
21
+
22
+ pe[:, 0::2] = torch.sin(position * div_term)
23
+ pe[:, 1::2] = torch.cos(position * div_term)
24
+
25
+ self.register_buffer("pe", pe)
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ """Add positional encoding to input.
29
+
30
+ Args:
31
+ x: (batch_size, seq_len, d_model)
32
+
33
+ Returns:
34
+ Tensor with positional encoding added
35
+ """
36
+ return x + self.pe[: x.size(1)]
37
+
38
+
39
+ class NamerTransformer(nn.Module):
40
+ """Transformer model for mapping digit sequences to number name tokens.
41
+
42
+ Architecture:
43
+ - Embedding layer for digits (11 values: 0-9 + padding)
44
+ - Positional encoding
45
+ - Transformer encoder layers
46
+ - Output projection to vocabulary for each position
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ vocab_size: int = 40,
52
+ max_output_len: int = 20,
53
+ d_model: int = 128,
54
+ nhead: int = 4,
55
+ num_encoder_layers: int = 4,
56
+ dim_feedforward: int = 512,
57
+ dropout: float = 0.1,
58
+ ) -> None:
59
+ super().__init__()
60
+ self.vocab_size = vocab_size
61
+ self.max_output_len = max_output_len
62
+ self.d_model = d_model
63
+
64
+ # Digit embedding (10 digits + 1 padding token = 11)
65
+ self.digit_embedding = nn.Embedding(11, d_model, padding_idx=10)
66
+
67
+ # Positional encoding
68
+ self.pos_encoder = PositionalEncoding(d_model, max_len=100)
69
+
70
+ # Transformer encoder
71
+ encoder_layer = nn.TransformerEncoderLayer(
72
+ d_model=d_model,
73
+ nhead=nhead,
74
+ dim_feedforward=dim_feedforward,
75
+ dropout=dropout,
76
+ batch_first=True,
77
+ )
78
+ self.transformer_encoder = nn.TransformerEncoder(
79
+ encoder_layer, num_layers=num_encoder_layers
80
+ )
81
+
82
+ # Output projection
83
+ self.output_projection = nn.Linear(d_model, vocab_size)
84
+
85
+ # Learned queries for each output position
86
+ self.output_queries = nn.Parameter(torch.randn(max_output_len, d_model))
87
+
88
+ # Cross-attention from output positions to encoded input
89
+ self.cross_attention = nn.MultiheadAttention(
90
+ d_model, nhead, dropout=dropout, batch_first=True
91
+ )
92
+
93
+ # Final output layers
94
+ self.output_norm = nn.LayerNorm(d_model)
95
+
96
+ def forward(self, digits: torch.Tensor) -> torch.Tensor:
97
+ """Forward pass.
98
+
99
+ Args:
100
+ digits: (batch_size, seq_len) tensor of digit indices (0-9), padding=10
101
+
102
+ Returns:
103
+ (batch_size, max_output_len, vocab_size) logits
104
+ """
105
+ batch_size, seq_len = digits.shape
106
+
107
+ # Handle padding: convert -1 padding to 10 (our padding index)
108
+ digits = digits.clone()
109
+ digits[digits == -1] = 10
110
+
111
+ # Create padding mask for transformer (True = padding)
112
+ src_key_padding_mask = digits == 10
113
+
114
+ # Embed digits: (batch, seq_len, d_model)
115
+ embedded = self.digit_embedding(digits)
116
+
117
+ # Add positional encoding
118
+ embedded = self.pos_encoder(embedded)
119
+
120
+ # Transformer encoder: (batch, seq_len, d_model)
121
+ memory = self.transformer_encoder(
122
+ embedded, src_key_padding_mask=src_key_padding_mask
123
+ )
124
+
125
+ # Expand queries for batch: (batch, max_output_len, d_model)
126
+ queries = self.output_queries.unsqueeze(0).expand(batch_size, -1, -1)
127
+
128
+ # Cross-attention from queries to encoded input
129
+ attn_output, _ = self.cross_attention(
130
+ queries, memory, memory, key_padding_mask=src_key_padding_mask
131
+ )
132
+
133
+ # Normalize and project to vocab
134
+ output = self.output_norm(attn_output)
135
+ logits = self.output_projection(output)
136
+
137
+ return logits
138
+
139
+
140
+ def load_namer_model(
141
+ model_path: str = "namer_model.pt",
142
+ device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu",
143
+ ) -> NamerTransformer:
144
+ """Load a trained Namer model for inference.
145
+
146
+ Args:
147
+ model_path: Path to the saved model file
148
+ device: Device to load the model on
149
+
150
+ Returns:
151
+ Loaded model in eval mode
152
+ """
153
+ checkpoint = torch.load(model_path, map_location=device)
154
+
155
+ model = NamerTransformer(
156
+ vocab_size=checkpoint["vocab_size"],
157
+ max_output_len=checkpoint["max_output_len"],
158
+ d_model=checkpoint.get("d_model", 128),
159
+ nhead=4,
160
+ num_encoder_layers=4,
161
+ dim_feedforward=512,
162
+ dropout=0.0, # No dropout for inference
163
+ )
164
+
165
+ model.load_state_dict(checkpoint["model_state_dict"])
166
+ model.to(device)
167
+ model.eval()
168
+
169
+ return model
namer/training.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training utilities for Namer models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from torch.utils.data import DataLoader, TensorDataset
9
+
10
+ from namer.models import NamerTransformer
11
+ from namer.data import InfiniteNamerDataset
12
+
13
+
14
+ def train_namer_model(
15
+ model: NamerTransformer,
16
+ dataset: TensorDataset | None = None,
17
+ infinite_dataset: InfiniteNamerDataset | None = None,
18
+ num_epochs: int = 50,
19
+ steps_per_epoch: int = 1000,
20
+ val_steps: int = 100,
21
+ batch_size: int = 64,
22
+ learning_rate: float = 0.001,
23
+ patience: int = 10,
24
+ device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu",
25
+ ) -> NamerTransformer:
26
+ """Train the model on a finite dataset or infinite iterable dataset.
27
+
28
+ Args:
29
+ model: The model to train
30
+ dataset: Finite TensorDataset with (digits, encoded_names) pairs
31
+ infinite_dataset: Infinite IterableDataset for infinite training
32
+ num_epochs: Number of training epochs
33
+ steps_per_epoch: Number of steps per epoch (for infinite dataset)
34
+ val_steps: Number of validation steps per epoch
35
+ batch_size: Batch size for training
36
+ learning_rate: Learning rate for optimizer
37
+ patience: Early stopping patience
38
+ device: Device to train on ('cuda' or 'cpu')
39
+
40
+ Returns:
41
+ Trained model
42
+ """
43
+ model = model.to(device)
44
+
45
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
46
+ criterion = nn.CrossEntropyLoss(ignore_index=-1)
47
+
48
+ print(f"Training on {device}")
49
+ print(f"Early stopping patience: {patience} epochs")
50
+
51
+ # Setup data loaders
52
+ if infinite_dataset is not None:
53
+ print(f"Using INFINITE dataset (max_int={infinite_dataset.max_int})")
54
+ print(f"Steps per epoch: {steps_per_epoch}, Val steps: {val_steps}")
55
+
56
+ train_loader = DataLoader(
57
+ infinite_dataset,
58
+ batch_size=batch_size,
59
+ num_workers=0,
60
+ )
61
+ val_loader = DataLoader(
62
+ infinite_dataset,
63
+ batch_size=batch_size,
64
+ num_workers=0,
65
+ )
66
+ else:
67
+ if dataset is None:
68
+ raise ValueError("Either dataset or infinite_dataset must be provided")
69
+
70
+ train_size = int(0.9 * len(dataset))
71
+ val_size = len(dataset) - train_size
72
+ train_dataset, val_dataset = torch.utils.data.random_split(
73
+ dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)
74
+ )
75
+ train_loader = DataLoader(
76
+ train_dataset, batch_size=batch_size, shuffle=True
77
+ )
78
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
79
+ steps_per_epoch = len(train_loader)
80
+ val_steps = len(val_loader)
81
+ print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
82
+
83
+ best_val_loss = float("inf")
84
+ epochs_without_improvement = 0
85
+ best_model_state: dict | None = None
86
+
87
+ for epoch in range(num_epochs):
88
+ # Training
89
+ model.train()
90
+ train_loss = 0.0
91
+ train_correct = 0
92
+ train_total = 0
93
+
94
+ train_iter = iter(train_loader)
95
+ for _ in range(steps_per_epoch):
96
+ digits_batch, target_batch = next(train_iter)
97
+ digits_batch = digits_batch.to(device)
98
+ target_batch = target_batch.to(device)
99
+
100
+ optimizer.zero_grad()
101
+
102
+ logits = model(digits_batch)
103
+ loss = criterion(
104
+ logits.view(-1, model.vocab_size), target_batch.view(-1)
105
+ )
106
+
107
+ loss.backward()
108
+ optimizer.step()
109
+
110
+ train_loss += loss.item()
111
+
112
+ mask = target_batch != -1
113
+ predictions = logits.argmax(dim=-1)
114
+ train_correct += ((predictions == target_batch) & mask).sum().item()
115
+ train_total += mask.sum().item()
116
+
117
+ train_loss /= steps_per_epoch
118
+ train_acc = train_correct / train_total if train_total > 0 else 0
119
+
120
+ # Validation
121
+ model.eval()
122
+ val_loss = 0.0
123
+ val_correct = 0
124
+ val_total = 0
125
+
126
+ with torch.no_grad():
127
+ val_iter = iter(val_loader)
128
+ for _ in range(val_steps):
129
+ digits_batch, target_batch = next(val_iter)
130
+ digits_batch = digits_batch.to(device)
131
+ target_batch = target_batch.to(device)
132
+
133
+ logits = model(digits_batch)
134
+ loss = criterion(
135
+ logits.view(-1, model.vocab_size), target_batch.view(-1)
136
+ )
137
+
138
+ val_loss += loss.item()
139
+
140
+ mask = target_batch != -1
141
+ predictions = logits.argmax(dim=-1)
142
+ val_correct += ((predictions == target_batch) & mask).sum().item()
143
+ val_total += mask.sum().item()
144
+
145
+ val_loss /= val_steps
146
+ val_acc = val_correct / val_total if val_total > 0 else 0
147
+
148
+ if val_loss < best_val_loss:
149
+ best_val_loss = val_loss
150
+ epochs_without_improvement = 0
151
+ best_model_state = model.state_dict().copy()
152
+ else:
153
+ epochs_without_improvement += 1
154
+
155
+ if (epoch + 1) % 10 == 0 or epoch == 0:
156
+ print(
157
+ f"Epoch {epoch+1}/{num_epochs}: "
158
+ f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
159
+ f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}, "
160
+ f"patience={epochs_without_improvement}/{patience}"
161
+ )
162
+
163
+ if epochs_without_improvement >= patience:
164
+ print(f"\nEarly stopping triggered! No improvement for {patience} epochs.")
165
+ break
166
+
167
+ print(f"\nBest validation loss: {best_val_loss:.4f}")
168
+
169
+ if best_model_state is not None:
170
+ model.load_state_dict(best_model_state)
171
+ print("Restored best model from checkpoint.")
172
+
173
+ return model
174
+
175
+
176
+ def save_model(model: NamerTransformer, model_path: str = "namer_model.pt") -> None:
177
+ """Save a trained model to disk.
178
+
179
+ Args:
180
+ model: The model to save
181
+ model_path: Path where to save the model
182
+ """
183
+ checkpoint = {
184
+ "model_type": "transformer",
185
+ "model_state_dict": model.state_dict(),
186
+ "vocab_size": model.vocab_size,
187
+ "max_output_len": model.max_output_len,
188
+ "d_model": model.d_model,
189
+ }
190
+
191
+ torch.save(checkpoint, model_path)
192
+ print(f"Model saved to {model_path}")
namer/utils.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Utility functions for number-to-name conversion."""
2
+
3
+ from __future__ import annotations
4
+
5
+ # Global constants for number names
6
+ ONES: tuple[str, ...] = (
7
+ "zero", "one", "two", "three", "four",
8
+ "five", "six", "seven", "eight", "nine"
9
+ )
10
+
11
+ TEENS: tuple[str, ...] = (
12
+ "ten", "eleven", "twelve", "thirteen", "fourteen",
13
+ "fifteen", "sixteen", "seventeen", "eighteen", "nineteen"
14
+ )
15
+
16
+ TENS: tuple[str, ...] = (
17
+ "", "", "twenty", "thirty", "forty",
18
+ "fifty", "sixty", "seventy", "eighty", "ninety"
19
+ )
20
+
21
+ # Scale words for powers of 1000
22
+ SCALES: tuple[str, ...] = (
23
+ "", "thousand", "million", "billion", "trillion",
24
+ "quadrillion", "quintillion", "sextillion", "septillion",
25
+ "octillion", "nonillion", "decillion"
26
+ )
27
+
28
+ # Combined vocabulary of all number words
29
+ VOCABULARY: list[str] = []
30
+ VOCABULARY.extend(ONES)
31
+ VOCABULARY.extend(TEENS)
32
+ VOCABULARY.extend([t for t in TENS if t]) # Exclude empty strings
33
+ VOCABULARY.append("hundred")
34
+ VOCABULARY.extend([s for s in SCALES if s]) # Exclude empty string
35
+ VOCABULARY.append("<EOS>") # End of sequence token
36
+
37
+ # Create a word-to-index lookup for efficient encoding
38
+ WORD_TO_INDEX: dict[str, int] = {word: idx for idx, word in enumerate(VOCABULARY)}
39
+
40
+ # Special token indices
41
+ EOS_IDX: int = VOCABULARY.index("<EOS>")
42
+
43
+
44
+ def int_to_digits(n: int) -> list[int]:
45
+ """Convert an integer to a list of its decimal digits.
46
+
47
+ Args:
48
+ n: An integer (can be any size, positive, negative, or zero)
49
+
50
+ Returns:
51
+ List of digits (0-9). Returns [0] for zero.
52
+ Negative numbers return digits without the sign.
53
+
54
+ Example:
55
+ >>> int_to_digits(123)
56
+ [1, 2, 3]
57
+ >>> int_to_digits(0)
58
+ [0]
59
+ >>> int_to_digits(-456)
60
+ [4, 5, 6]
61
+ """
62
+ if n == 0:
63
+ return [0]
64
+
65
+ n = abs(n)
66
+
67
+ digits: list[int] = []
68
+ while n > 0:
69
+ digits.append(n % 10)
70
+ n //= 10
71
+
72
+ return digits[::-1]
73
+
74
+
75
+ def digits_to_int(digits: list[int]) -> int:
76
+ """Convert a list of decimal digits to an integer.
77
+
78
+ This is the inverse of int_to_digits().
79
+
80
+ Args:
81
+ digits: List of digits (0-9)
82
+
83
+ Returns:
84
+ The integer value represented by the digits
85
+
86
+ Raises:
87
+ ValueError: If any digit is not 0-9
88
+
89
+ Example:
90
+ >>> digits_to_int([1, 2, 3])
91
+ 123
92
+ >>> digits_to_int([0])
93
+ 0
94
+ """
95
+ if not digits:
96
+ return 0
97
+
98
+ result = 0
99
+ for d in digits:
100
+ if not (0 <= d <= 9):
101
+ raise ValueError(f"Invalid digit {d}, must be 0-9")
102
+ result = result * 10 + d
103
+
104
+ return result
105
+
106
+
107
+ def encode(text: str) -> list[int]:
108
+ """Encode a string of number words into a list of vocabulary indices.
109
+
110
+ Args:
111
+ text: String containing space-separated number words (e.g., "one million")
112
+
113
+ Returns:
114
+ List of indices corresponding to each word in VOCABULARY
115
+
116
+ Raises:
117
+ ValueError: If a word is not found in VOCABULARY
118
+
119
+ Example:
120
+ >>> encode("one million")
121
+ [1, 29]
122
+ """
123
+ if not text or not text.strip():
124
+ return []
125
+
126
+ words = text.strip().lower().split()
127
+ indices: list[int] = []
128
+
129
+ for word in words:
130
+ if word not in WORD_TO_INDEX:
131
+ raise ValueError(f"Unknown word '{word}' not in VOCABULARY")
132
+ indices.append(WORD_TO_INDEX[word])
133
+
134
+ return indices
135
+
136
+
137
+ def decode(indices: list[int]) -> str:
138
+ """Decode a list of vocabulary indices into a string of number words.
139
+
140
+ This is the inverse of encode(). <EOS> tokens are ignored.
141
+
142
+ Args:
143
+ indices: List of indices into VOCABULARY (e.g., [1, 30])
144
+
145
+ Returns:
146
+ String of space-separated number words (e.g., "one million")
147
+
148
+ Raises:
149
+ ValueError: If an index is out of range
150
+
151
+ Example:
152
+ >>> decode([1, 30])
153
+ 'one million'
154
+ """
155
+ if not indices:
156
+ return ""
157
+
158
+ words: list[str] = []
159
+ for idx in indices:
160
+ if not (0 <= idx < len(VOCABULARY)):
161
+ raise ValueError(f"Index {idx} out of range for VOCABULARY (size {len(VOCABULARY)})")
162
+ word = VOCABULARY[idx]
163
+ if word != "<EOS>":
164
+ words.append(word)
165
+
166
+ return " ".join(words)
167
+
168
+
169
+ def read_double(a: int, b: int) -> str:
170
+ """Convert two digits (a, b) into the English name of the number they form.
171
+
172
+ Args:
173
+ a: Tens digit (0-9)
174
+ b: Ones digit (0-9)
175
+
176
+ Returns:
177
+ English name of the number (e.g., "twenty three", "eleven", "seven")
178
+ """
179
+ if not (0 <= a <= 9 and 0 <= b <= 9):
180
+ raise ValueError("Digits must be between 0 and 9")
181
+
182
+ number = a * 10 + b
183
+
184
+ if number < 10:
185
+ return ONES[number]
186
+ elif number < 20:
187
+ return TEENS[number - 10]
188
+ elif b == 0:
189
+ return TENS[a]
190
+ else:
191
+ return f"{TENS[a]} {ONES[b]}"
192
+
193
+
194
+ def read_triplet(a: int, b: int, c: int) -> str:
195
+ """Convert three digits (a, b, c) into the English name of the number they form.
196
+
197
+ Args:
198
+ a: Hundreds digit (0-9)
199
+ b: Tens digit (0-9)
200
+ c: Ones digit (0-9)
201
+
202
+ Returns:
203
+ English name of the number (e.g., "one hundred six", "zero", "nine hundred nineteen")
204
+ """
205
+ if not (0 <= a <= 9 and 0 <= b <= 9 and 0 <= c <= 9):
206
+ raise ValueError("Digits must be between 0 and 9")
207
+
208
+ if a == 0:
209
+ return read_double(b, c)
210
+
211
+ remainder = read_double(b, c)
212
+
213
+ if b == 0 and c == 0:
214
+ return f"{ONES[a]} hundred"
215
+ else:
216
+ return f"{ONES[a]} hundred {remainder}"
217
+
218
+
219
+ def read_digits(lst: list[int]) -> str:
220
+ """Convert a list of digits into the English name of the number they form.
221
+
222
+ Groups digits into triplets and combines with scale words (thousand, million, etc.)
223
+
224
+ Args:
225
+ lst: List of digits (0-9)
226
+
227
+ Returns:
228
+ English name of the number
229
+ """
230
+ if not lst:
231
+ return "zero"
232
+
233
+ for d in lst:
234
+ if not (0 <= d <= 9):
235
+ raise ValueError("All elements must be digits between 0 and 9")
236
+
237
+ if all(d == 0 for d in lst):
238
+ return "zero"
239
+
240
+ # Pad with leading zeros to make length a multiple of 3
241
+ padded = lst[:]
242
+ while len(padded) % 3 != 0:
243
+ padded = [0] + padded
244
+
245
+ # Group into triplets
246
+ triplets: list[tuple[int, int, int]] = []
247
+ for i in range(0, len(padded), 3):
248
+ triplets.append((padded[i], padded[i+1], padded[i+2]))
249
+
250
+ # Build the result by processing each triplet with its scale
251
+ parts: list[str] = []
252
+ num_triplets = len(triplets)
253
+
254
+ for i, (a, b, c) in enumerate(triplets):
255
+ if a == 0 and b == 0 and c == 0:
256
+ continue
257
+
258
+ triplet_name = read_triplet(a, b, c)
259
+ scale_index = num_triplets - 1 - i
260
+ scale = SCALES[scale_index] if scale_index < len(SCALES) else ""
261
+
262
+ if scale:
263
+ parts.append(f"{triplet_name} {scale}")
264
+ else:
265
+ parts.append(triplet_name)
266
+
267
+ return " ".join(parts)
pyproject.toml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["hatchling"]
3
+ build-backend = "hatchling.build"
4
+
5
+ [project]
6
+ name = "namer"
7
+ version = "0.2.0"
8
+ description = "A PyTorch transformer model for converting numbers to English names"
9
+ readme = "README.md"
10
+ license = {text = "MIT"}
11
+ requires-python = ">=3.10"
12
+ authors = [
13
+ {name = "Developer", email = "dev@example.com"}
14
+ ]
15
+ keywords = ["pytorch", "transformer", "nlp", "numbers", "ml"]
16
+ classifiers = [
17
+ "Development Status :: 3 - Alpha",
18
+ "Intended Audience :: Developers",
19
+ "License :: OSI Approved :: MIT License",
20
+ "Programming Language :: Python :: 3",
21
+ "Programming Language :: Python :: 3.10",
22
+ "Programming Language :: Python :: 3.11",
23
+ "Programming Language :: Python :: 3.12",
24
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
25
+ ]
26
+
27
+ dependencies = [
28
+ "torch>=2.0.0",
29
+ ]
30
+
31
+ [project.optional-dependencies]
32
+ dev = [
33
+ "pytest>=7.0",
34
+ "pytest-cov>=4.0",
35
+ "ruff>=0.1.0",
36
+ "mypy>=1.0",
37
+ ]
38
+
39
+ [project.scripts]
40
+ namer = "namer.main:main"
41
+
42
+ [project.urls]
43
+ Homepage = "https://github.com/example/namer"
44
+ Repository = "https://github.com/example/namer"
45
+ Issues = "https://github.com/example/namer/issues"
46
+
47
+ [tool.ruff]
48
+ target-version = "py310"
49
+ line-length = 88
50
+
51
+ [tool.ruff.lint]
52
+ select = [
53
+ "E", # pycodestyle errors
54
+ "W", # pycodestyle warnings
55
+ "F", # Pyflakes
56
+ "I", # isort
57
+ "N", # pep8-naming
58
+ "D", # pydocstyle
59
+ "UP", # pyupgrade
60
+ "B", # flake8-bugbear
61
+ "C4", # flake8-comprehensions
62
+ "SIM", # flake8-simplify
63
+ ]
64
+ ignore = ["D100", "D104"] # Missing docstrings in public packages/modules
65
+
66
+ [tool.ruff.lint.pydocstyle]
67
+ convention = "google"
68
+
69
+ [tool.mypy]
70
+ python_version = "3.10"
71
+ warn_return_any = true
72
+ warn_unused_configs = true
73
+ disallow_untyped_defs = true
74
+ disallow_incomplete_defs = true
75
+ check_untyped_defs = true
76
+ warn_redundant_casts = true
77
+ warn_unused_ignores = true
78
+ show_error_codes = true
79
+
80
+ [[tool.mypy.overrides]]
81
+ module = ["torch.*"]
82
+ ignore_missing_imports = true
83
+
84
+ [tool.pytest.ini_options]
85
+ testpaths = ["tests"]
86
+ python_files = ["test_*.py"]
87
+ python_functions = ["test_*"]
88
+ addopts = "-v --tb=short"
89
+
90
+ [tool.coverage.run]
91
+ source = ["namer"]
92
+ omit = ["*/tests/*"]
93
+
94
+ [tool.coverage.report]
95
+ exclude_lines = [
96
+ "pragma: no cover",
97
+ "def __repr__",
98
+ "raise AssertionError",
99
+ "raise NotImplementedError",
100
+ ]
tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Tests for Namer package."""
tests/test_data.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for dataset classes."""
2
+
3
+ import torch
4
+ from torch.utils.data import DataLoader
5
+
6
+ from namer.data import InfiniteNamerDataset, NamerDataset
7
+ from namer.utils import EOS_IDX, VOCABULARY
8
+
9
+
10
+ class TestNamerDataset:
11
+ """Tests for NamerDataset class."""
12
+
13
+ def test_length(self) -> None:
14
+ dataset = NamerDataset(num_samples=50, seed=42)
15
+ assert len(dataset) == 50
16
+
17
+ def test_sample_shape(self) -> None:
18
+ dataset = NamerDataset(num_samples=10, max_seq_len=20, seed=42)
19
+ digits, encoded = dataset[0]
20
+
21
+ assert digits.shape == (20,)
22
+ assert encoded.shape == (20,)
23
+ assert digits.dtype == torch.long
24
+ assert encoded.dtype == torch.long
25
+
26
+ def test_padding_value(self) -> None:
27
+ dataset = NamerDataset(num_samples=10, max_seq_len=20, seed=42)
28
+ digits, _ = dataset[0]
29
+
30
+ # Padding should be 10
31
+ assert (digits == 10).any() or len([d for d in digits if d != 10]) <= 6
32
+
33
+ def test_eos_present(self) -> None:
34
+ dataset = NamerDataset(num_samples=10, seed=42)
35
+ _, encoded = dataset[0]
36
+
37
+ # EOS token should be present
38
+ assert EOS_IDX in encoded.tolist()
39
+
40
+
41
+ class TestInfiniteNamerDataset:
42
+ """Tests for InfiniteNamerDataset class."""
43
+
44
+ def test_iteration(self) -> None:
45
+ dataset = InfiniteNamerDataset(seed=42)
46
+ iterator = iter(dataset)
47
+
48
+ # Can get multiple samples
49
+ for _ in range(10):
50
+ digits, encoded = next(iterator)
51
+ assert digits.shape == (20,)
52
+ assert encoded.shape == (20,)
53
+
54
+ def test_data_loader(self) -> None:
55
+ dataset = InfiniteNamerDataset(seed=42)
56
+ loader = DataLoader(dataset, batch_size=4, num_workers=0)
57
+
58
+ iterator = iter(loader)
59
+ digits_batch, encoded_batch = next(iterator)
60
+
61
+ assert digits_batch.shape == (4, 20)
62
+ assert encoded_batch.shape == (4, 20)
63
+
64
+ def test_reproducibility(self) -> None:
65
+ dataset1 = InfiniteNamerDataset(seed=42)
66
+ dataset2 = InfiniteNamerDataset(seed=42)
67
+
68
+ iter1 = iter(dataset1)
69
+ iter2 = iter(dataset2)
70
+
71
+ for _ in range(5):
72
+ d1, e1 = next(iter1)
73
+ d2, e2 = next(iter2)
74
+ assert torch.equal(d1, d2)
75
+ assert torch.equal(e1, e2)
76
+
77
+ def test_vocab_range(self) -> None:
78
+ dataset = InfiniteNamerDataset(seed=42)
79
+ iterator = iter(dataset)
80
+
81
+ for _ in range(20):
82
+ _, encoded = next(iterator)
83
+ # Valid tokens should be within vocab range (excluding -1 padding)
84
+ valid_tokens = encoded[encoded != -1]
85
+ assert (valid_tokens >= 0).all()
86
+ assert (valid_tokens < len(VOCABULARY)).all()
tests/test_inference.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for inference utilities."""
2
+
3
+ from unittest.mock import MagicMock, patch
4
+
5
+ import pytest
6
+ import torch
7
+
8
+ from namer.inference import predict_number_name
9
+ from namer.models import NamerTransformer
10
+ from namer.utils import VOCABULARY, read_digits, int_to_digits
11
+
12
+
13
+ class TestPredictNumberName:
14
+ """Tests for predict_number_name function."""
15
+
16
+ @pytest.fixture
17
+ def mock_model(self) -> MagicMock:
18
+ model = MagicMock(spec=NamerTransformer)
19
+ model.max_output_len = 20
20
+ model.vocab_size = len(VOCABULARY)
21
+
22
+ # Mock the device property
23
+ param = MagicMock()
24
+ param.device = torch.device("cpu")
25
+ model.parameters.return_value = iter([param])
26
+
27
+ return model
28
+
29
+ def test_basic_prediction(self, mock_model: MagicMock) -> None:
30
+ # Create fake logits that will select known tokens
31
+ # "one" is index 1 in VOCABULARY
32
+ fake_logits = torch.zeros(1, 20, len(VOCABULARY))
33
+ fake_logits[0, 0, 1] = 10.0 # "one"
34
+ fake_logits[0, 1, VOCABULARY.index("<EOS>")] = 10.0 # EOS
35
+
36
+ mock_model.return_value = fake_logits
37
+ mock_model.eval = MagicMock()
38
+
39
+ with patch("namer.inference.torch.no_grad"):
40
+ result = predict_number_name(mock_model, 1)
41
+
42
+ # Should decode to "one"
43
+ assert "one" in result.lower() or result.startswith("<")
44
+
45
+ def test_eos_stops_generation(self, mock_model: MagicMock) -> None:
46
+ # Logits that predict EOS immediately
47
+ fake_logits = torch.zeros(1, 20, len(VOCABULARY))
48
+ fake_logits[0, 0, VOCABULARY.index("<EOS>")] = 10.0
49
+
50
+ mock_model.return_value = fake_logits
51
+ mock_model.eval = MagicMock()
52
+
53
+ with patch("namer.inference.torch.no_grad"):
54
+ result = predict_number_name(mock_model, 0)
55
+
56
+ # Empty result when EOS is first
57
+ assert result == "" or result.startswith("<")
58
+
59
+ def test_device_override(self, mock_model: MagicMock) -> None:
60
+ fake_logits = torch.zeros(1, 20, len(VOCABULARY))
61
+ fake_logits[0, 0, 1] = 10.0
62
+ fake_logits[0, 1, VOCABULARY.index("<EOS>")] = 10.0
63
+
64
+ mock_model.return_value = fake_logits
65
+ mock_model.eval = MagicMock()
66
+
67
+ with patch("namer.inference.torch.no_grad"):
68
+ # Should not raise when device is specified
69
+ result = predict_number_name(mock_model, 1, device="cpu")
70
+
71
+ assert isinstance(result, str)
tests/test_models.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for model classes."""
2
+
3
+ import pytest
4
+ import torch
5
+
6
+ from namer.models import NamerTransformer, PositionalEncoding
7
+ from namer.utils import VOCABULARY
8
+
9
+
10
+ class TestPositionalEncoding:
11
+ """Tests for PositionalEncoding module."""
12
+
13
+ def test_shape(self) -> None:
14
+ pe = PositionalEncoding(d_model=128)
15
+ x = torch.randn(2, 10, 128) # batch=2, seq=10, dim=128
16
+ out = pe(x)
17
+ assert out.shape == (2, 10, 128)
18
+
19
+ def test_adds_position(self) -> None:
20
+ pe = PositionalEncoding(d_model=64)
21
+ x = torch.zeros(1, 5, 64)
22
+ out = pe(x)
23
+ # Output should be non-zero due to positional encoding
24
+ assert not torch.allclose(out, x)
25
+
26
+
27
+ class TestNamerTransformer:
28
+ """Tests for NamerTransformer model."""
29
+
30
+ @pytest.fixture
31
+ def model(self) -> NamerTransformer:
32
+ return NamerTransformer(
33
+ vocab_size=len(VOCABULARY),
34
+ max_output_len=20,
35
+ d_model=64,
36
+ nhead=4,
37
+ num_encoder_layers=2,
38
+ dim_feedforward=128,
39
+ dropout=0.0,
40
+ )
41
+
42
+ def test_forward_shape(self, model: NamerTransformer) -> None:
43
+ batch_size = 4
44
+ seq_len = 10
45
+ digits = torch.randint(0, 10, (batch_size, seq_len))
46
+
47
+ logits = model(digits)
48
+
49
+ assert logits.shape == (batch_size, model.max_output_len, model.vocab_size)
50
+
51
+ def test_forward_with_padding(self, model: NamerTransformer) -> None:
52
+ batch_size = 2
53
+ seq_len = 10
54
+ digits = torch.full((batch_size, seq_len), 10) # All padding
55
+ digits[:, :5] = torch.randint(0, 10, (batch_size, 5))
56
+
57
+ logits = model(digits)
58
+
59
+ assert logits.shape == (batch_size, model.max_output_len, model.vocab_size)
60
+
61
+ def test_forward_with_negative_padding(self, model: NamerTransformer) -> None:
62
+ batch_size = 2
63
+ seq_len = 10
64
+ digits = torch.full((batch_size, seq_len), -1) # -1 padding
65
+ digits[:, :5] = torch.randint(0, 10, (batch_size, 5))
66
+
67
+ logits = model(digits)
68
+
69
+ assert logits.shape == (batch_size, model.max_output_len, model.vocab_size)
70
+
71
+ def test_output_is_logits(self, model: NamerTransformer) -> None:
72
+ digits = torch.randint(0, 10, (1, 5))
73
+ logits = model(digits)
74
+
75
+ # Logits should not be probabilities (no softmax applied)
76
+ assert not torch.all((logits >= 0) & (logits <= 1))
77
+
78
+ def test_gradient_flow(self, model: NamerTransformer) -> None:
79
+ digits = torch.randint(0, 10, (2, 5))
80
+ target = torch.randint(0, len(VOCABULARY), (2, model.max_output_len))
81
+
82
+ logits = model(digits)
83
+ loss = torch.nn.functional.cross_entropy(
84
+ logits.view(-1, model.vocab_size),
85
+ target.view(-1)
86
+ )
87
+ loss.backward()
88
+
89
+ # Check that gradients exist
90
+ for param in model.parameters():
91
+ assert param.grad is not None
tests/test_utils.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for utility functions."""
2
+
3
+ import pytest
4
+
5
+ from namer.utils import (
6
+ EOS_IDX,
7
+ VOCABULARY,
8
+ decode,
9
+ digits_to_int,
10
+ encode,
11
+ int_to_digits,
12
+ read_digits,
13
+ read_double,
14
+ read_triplet,
15
+ )
16
+
17
+
18
+ class TestIntToDigits:
19
+ """Tests for int_to_digits function."""
20
+
21
+ def test_zero(self) -> None:
22
+ assert int_to_digits(0) == [0]
23
+
24
+ def test_positive(self) -> None:
25
+ assert int_to_digits(123) == [1, 2, 3]
26
+ assert int_to_digits(7) == [7]
27
+
28
+ def test_negative(self) -> None:
29
+ assert int_to_digits(-456) == [4, 5, 6]
30
+
31
+ def test_large_number(self) -> None:
32
+ assert int_to_digits(1002003) == [1, 0, 0, 2, 0, 0, 3]
33
+
34
+
35
+ class TestDigitsToInt:
36
+ """Tests for digits_to_int function."""
37
+
38
+ def test_empty(self) -> None:
39
+ assert digits_to_int([]) == 0
40
+
41
+ def test_single_digit(self) -> None:
42
+ assert digits_to_int([5]) == 5
43
+
44
+ def test_multiple_digits(self) -> None:
45
+ assert digits_to_int([1, 2, 3]) == 123
46
+
47
+ def test_with_zeros(self) -> None:
48
+ assert digits_to_int([1, 0, 0, 2]) == 1002
49
+
50
+ def test_invalid_digit(self) -> None:
51
+ with pytest.raises(ValueError, match="Invalid digit"):
52
+ digits_to_int([10])
53
+
54
+
55
+ class TestRoundTrip:
56
+ """Tests for int_to_digits <-> digits_to_int round-trip."""
57
+
58
+ def test_round_trip(self) -> None:
59
+ for n in [0, 42, 123, 1000, 999999, 1000000]:
60
+ assert digits_to_int(int_to_digits(n)) == abs(n)
61
+
62
+
63
+ class TestReadDouble:
64
+ """Tests for read_double function."""
65
+
66
+ def test_single_digit(self) -> None:
67
+ assert read_double(0, 7) == "seven"
68
+ assert read_double(0, 0) == "zero"
69
+
70
+ def test_teens(self) -> None:
71
+ assert read_double(1, 1) == "eleven"
72
+ assert read_double(1, 9) == "nineteen"
73
+
74
+ def test_tens(self) -> None:
75
+ assert read_double(3, 0) == "thirty"
76
+ assert read_double(5, 0) == "fifty"
77
+
78
+ def test_tens_and_ones(self) -> None:
79
+ assert read_double(2, 3) == "twenty three"
80
+ assert read_double(5, 9) == "fifty nine"
81
+
82
+ def test_invalid_digits(self) -> None:
83
+ with pytest.raises(ValueError, match="must be between 0 and 9"):
84
+ read_double(10, 5)
85
+
86
+
87
+ class TestReadTriplet:
88
+ """Tests for read_triplet function."""
89
+
90
+ def test_hundreds(self) -> None:
91
+ assert read_triplet(1, 0, 6) == "one hundred six"
92
+ assert read_triplet(2, 0, 0) == "two hundred"
93
+
94
+ def test_zero_hundreds(self) -> None:
95
+ assert read_triplet(0, 5, 5) == "fifty five"
96
+
97
+ def test_all_zeros(self) -> None:
98
+ assert read_triplet(0, 0, 0) == "zero"
99
+
100
+
101
+ class TestReadDigits:
102
+ """Tests for read_digits function."""
103
+
104
+ def test_empty(self) -> None:
105
+ assert read_digits([]) == "zero"
106
+
107
+ def test_zero(self) -> None:
108
+ assert read_digits([0]) == "zero"
109
+ assert read_digits([0, 0, 0]) == "zero"
110
+
111
+ def test_single_digit(self) -> None:
112
+ assert read_digits([5]) == "five"
113
+
114
+ def test_double_digit(self) -> None:
115
+ assert read_digits([4, 2]) == "forty two"
116
+
117
+ def test_triple_digit(self) -> None:
118
+ assert read_digits([1, 2, 3]) == "one hundred twenty three"
119
+
120
+ def test_thousands(self) -> None:
121
+ assert read_digits([1, 0, 0, 0]) == "one thousand"
122
+ assert read_digits([1, 2, 3, 4]) == "one thousand two hundred thirty four"
123
+
124
+ def test_millions(self) -> None:
125
+ assert read_digits([1, 0, 0, 0, 0, 0, 0]) == "one million"
126
+
127
+ def test_complex(self) -> None:
128
+ # 1,234,567
129
+ digits = [1, 2, 3, 4, 5, 6, 7]
130
+ result = read_digits(digits)
131
+ assert "one million" in result
132
+ assert "two hundred thirty four thousand" in result
133
+ assert "five hundred sixty seven" in result
134
+
135
+ def test_invalid_digit(self) -> None:
136
+ with pytest.raises(ValueError, match="must be digits"):
137
+ read_digits([1, 10, 3])
138
+
139
+
140
+ class TestEncode:
141
+ """Tests for encode function."""
142
+
143
+ def test_simple(self) -> None:
144
+ indices = encode("one million")
145
+ assert len(indices) == 2
146
+ assert all(0 <= i < len(VOCABULARY) for i in indices)
147
+
148
+ def test_multi_word(self) -> None:
149
+ indices = encode("twenty three")
150
+ assert len(indices) == 2
151
+
152
+ def test_empty(self) -> None:
153
+ assert encode("") == []
154
+ assert encode(" ") == []
155
+
156
+ def test_unknown_word(self) -> None:
157
+ with pytest.raises(ValueError, match="Unknown word"):
158
+ encode("unknown")
159
+
160
+
161
+ class TestDecode:
162
+ """Tests for decode function."""
163
+
164
+ def test_simple(self) -> None:
165
+ encoded = encode("one million")
166
+ assert decode(encoded) == "one million"
167
+
168
+ def test_with_eos(self) -> None:
169
+ encoded = encode("one million") + [EOS_IDX]
170
+ assert decode(encoded) == "one million"
171
+
172
+ def test_empty(self) -> None:
173
+ assert decode([]) == ""
174
+
175
+ def test_invalid_index(self) -> None:
176
+ with pytest.raises(ValueError, match="out of range"):
177
+ decode([9999])
178
+
179
+
180
+ class TestEncodeDecodeRoundTrip:
181
+ """Tests for encode/decode round-trip."""
182
+
183
+ def test_round_trip(self) -> None:
184
+ test_cases = [
185
+ "one million",
186
+ "twenty three",
187
+ "one hundred twenty three",
188
+ "zero",
189
+ "nine hundred nineteen",
190
+ ]
191
+ for text in test_cases:
192
+ encoded = encode(text)
193
+ assert decode(encoded) == text