Edwin Jose Palathinkal commited on
Commit ·
2730fd2
0
Parent(s):
Initial commit
Browse files- .gitignore +41 -0
- Makefile +100 -0
- README.md +158 -0
- namer/__init__.py +30 -0
- namer/__main__.py +8 -0
- namer/data.py +137 -0
- namer/inference.py +97 -0
- namer/main.py +211 -0
- namer/models.py +169 -0
- namer/training.py +192 -0
- namer/utils.py +267 -0
- pyproject.toml +100 -0
- tests/__init__.py +1 -0
- tests/test_data.py +86 -0
- tests/test_inference.py +71 -0
- tests/test_models.py +91 -0
- tests/test_utils.py +193 -0
.gitignore
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
*.egg-info/
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
*.egg
|
| 11 |
+
|
| 12 |
+
# Virtual environments
|
| 13 |
+
.venv/
|
| 14 |
+
venv/
|
| 15 |
+
ENV/
|
| 16 |
+
env/
|
| 17 |
+
|
| 18 |
+
# PyTorch
|
| 19 |
+
*.pt
|
| 20 |
+
*.pth
|
| 21 |
+
|
| 22 |
+
# Testing
|
| 23 |
+
.pytest_cache/
|
| 24 |
+
.coverage
|
| 25 |
+
htmlcov/
|
| 26 |
+
.tox/
|
| 27 |
+
|
| 28 |
+
# IDE
|
| 29 |
+
.vscode/
|
| 30 |
+
.idea/
|
| 31 |
+
*.swp
|
| 32 |
+
*.swo
|
| 33 |
+
*~
|
| 34 |
+
|
| 35 |
+
# OS
|
| 36 |
+
.DS_Store
|
| 37 |
+
Thumbs.db
|
| 38 |
+
|
| 39 |
+
# Project specific
|
| 40 |
+
namer_model.pt
|
| 41 |
+
.pip-tmp/
|
Makefile
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.PHONY: help install dev train infer test lint format clean distclean
|
| 2 |
+
|
| 3 |
+
# Python environment
|
| 4 |
+
PYTHON := python3
|
| 5 |
+
VENV := .venv
|
| 6 |
+
VENV_PYTHON := $(VENV)/bin/python
|
| 7 |
+
|
| 8 |
+
# Default target
|
| 9 |
+
help:
|
| 10 |
+
@echo "Namer - Number to Name Transformer"
|
| 11 |
+
@echo ""
|
| 12 |
+
@echo "Available targets:"
|
| 13 |
+
@echo " make install - Install package in development mode"
|
| 14 |
+
@echo " make dev - Install with dev dependencies"
|
| 15 |
+
@echo " make train - Train the model"
|
| 16 |
+
@echo " make infer - Run interactive inference"
|
| 17 |
+
@echo " make test - Run test suite"
|
| 18 |
+
@echo " make lint - Run linting (ruff)"
|
| 19 |
+
@echo " make format - Format code (ruff)"
|
| 20 |
+
@echo " make typecheck - Run type checking (mypy)"
|
| 21 |
+
@echo " make clean - Remove generated files and caches"
|
| 22 |
+
@echo " make distclean - Deep clean including venv"
|
| 23 |
+
@echo ""
|
| 24 |
+
|
| 25 |
+
# Create virtual environment and install
|
| 26 |
+
$(VENV):
|
| 27 |
+
@echo "Creating virtual environment..."
|
| 28 |
+
$(PYTHON) -m venv $(VENV)
|
| 29 |
+
$(VENV_PYTHON) -m pip install --upgrade pip
|
| 30 |
+
|
| 31 |
+
# Install package
|
| 32 |
+
install: $(VENV)
|
| 33 |
+
@echo "Installing package..."
|
| 34 |
+
$(VENV_PYTHON) -m pip install -e .
|
| 35 |
+
|
| 36 |
+
# Install with dev dependencies
|
| 37 |
+
dev: $(VENV)
|
| 38 |
+
@echo "Installing with dev dependencies..."
|
| 39 |
+
$(VENV_PYTHON) -m pip install -e ".[dev]"
|
| 40 |
+
|
| 41 |
+
# Run training
|
| 42 |
+
train: $(VENV)
|
| 43 |
+
@echo "Starting training..."
|
| 44 |
+
$(VENV_PYTHON) -m namer train
|
| 45 |
+
|
| 46 |
+
# Run interactive inference
|
| 47 |
+
infer: $(VENV)
|
| 48 |
+
@echo "Starting inference..."
|
| 49 |
+
$(VENV_PYTHON) -m namer infer
|
| 50 |
+
|
| 51 |
+
# Run tests
|
| 52 |
+
test: $(VENV)
|
| 53 |
+
@echo "Running tests..."
|
| 54 |
+
$(VENV_PYTHON) -m pytest -v
|
| 55 |
+
|
| 56 |
+
# Run tests with coverage
|
| 57 |
+
test-cov: $(VENV)
|
| 58 |
+
@echo "Running tests with coverage..."
|
| 59 |
+
$(VENV_PYTHON) -m pytest --cov=namer --cov-report=html --cov-report=term
|
| 60 |
+
|
| 61 |
+
# Run linting
|
| 62 |
+
lint: $(VENV)
|
| 63 |
+
@echo "Running ruff linter..."
|
| 64 |
+
$(VENV_PYTHON) -m ruff check namer tests
|
| 65 |
+
|
| 66 |
+
# Fix linting issues
|
| 67 |
+
lint-fix: $(VENV)
|
| 68 |
+
@echo "Fixing linting issues..."
|
| 69 |
+
$(VENV_PYTHON) -m ruff check --fix namer tests
|
| 70 |
+
|
| 71 |
+
# Format code
|
| 72 |
+
format: $(VENV)
|
| 73 |
+
@echo "Formatting code..."
|
| 74 |
+
$(VENV_PYTHON) -m ruff format namer tests
|
| 75 |
+
|
| 76 |
+
# Run type checking
|
| 77 |
+
typecheck: $(VENV)
|
| 78 |
+
@echo "Running mypy..."
|
| 79 |
+
$(VENV_PYTHON) -m mypy namer
|
| 80 |
+
|
| 81 |
+
# Run all checks
|
| 82 |
+
check: lint typecheck test
|
| 83 |
+
@echo "All checks passed!"
|
| 84 |
+
|
| 85 |
+
# Clean generated files
|
| 86 |
+
clean:
|
| 87 |
+
@echo "Cleaning generated files..."
|
| 88 |
+
rm -f namer_model.pt
|
| 89 |
+
rm -rf htmlcov .pytest_cache .coverage
|
| 90 |
+
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
| 91 |
+
find . -type f -name "*.pyc" -delete
|
| 92 |
+
find . -type f -name "*.pyo" -delete
|
| 93 |
+
find . -type d -name "*.egg-info" -exec rm -rf {} + 2>/dev/null || true
|
| 94 |
+
@echo "Clean complete!"
|
| 95 |
+
|
| 96 |
+
# Deep clean
|
| 97 |
+
distclean: clean
|
| 98 |
+
@echo "Removing virtual environment..."
|
| 99 |
+
rm -rf $(VENV)
|
| 100 |
+
@echo "All clean! Run 'make dev' to start fresh."
|
README.md
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Namer
|
| 2 |
+
|
| 3 |
+
A PyTorch transformer model that converts numbers to their English names.
|
| 4 |
+
|
| 5 |
+
## Features
|
| 6 |
+
|
| 7 |
+
- **Transformer architecture** with cross-attention mechanism
|
| 8 |
+
- **Infinite dataset** training with early stopping
|
| 9 |
+
- **Modular design** following Python best practices
|
| 10 |
+
- **Type hints** throughout for better IDE support
|
| 11 |
+
- **Comprehensive test suite** with pytest
|
| 12 |
+
- **Modern tooling**: ruff (linting/formatting), mypy (type checking)
|
| 13 |
+
|
| 14 |
+
## Installation
|
| 15 |
+
|
| 16 |
+
```bash
|
| 17 |
+
# Clone the repository
|
| 18 |
+
git clone https://github.com/example/namer.git
|
| 19 |
+
cd namer
|
| 20 |
+
|
| 21 |
+
# Create virtual environment
|
| 22 |
+
python -m venv .venv
|
| 23 |
+
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
| 24 |
+
|
| 25 |
+
# Install in development mode
|
| 26 |
+
pip install -e ".[dev]"
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Usage
|
| 30 |
+
|
| 31 |
+
### Command Line Interface
|
| 32 |
+
|
| 33 |
+
```bash
|
| 34 |
+
# Show help
|
| 35 |
+
namer --help
|
| 36 |
+
|
| 37 |
+
# Run demonstrations
|
| 38 |
+
namer demo
|
| 39 |
+
|
| 40 |
+
# Train the model
|
| 41 |
+
namer train
|
| 42 |
+
|
| 43 |
+
# Train with custom settings
|
| 44 |
+
namer train --epochs 50 --steps 2000 --batch-size 64 --lr 0.0005
|
| 45 |
+
|
| 46 |
+
# Run interactive inference
|
| 47 |
+
namer infer
|
| 48 |
+
|
| 49 |
+
# Run quick test
|
| 50 |
+
namer test
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
### Python API
|
| 54 |
+
|
| 55 |
+
```python
|
| 56 |
+
from namer import NamerTransformer, load_namer_model, predict_number_name
|
| 57 |
+
|
| 58 |
+
# Load a trained model
|
| 59 |
+
model = load_namer_model("namer_model.pt")
|
| 60 |
+
|
| 61 |
+
# Predict number names
|
| 62 |
+
name = predict_number_name(model, 123456)
|
| 63 |
+
print(name) # "one hundred twenty three thousand four hundred fifty six"
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
## Project Structure
|
| 67 |
+
|
| 68 |
+
```
|
| 69 |
+
namer/
|
| 70 |
+
├── namer/ # Main package
|
| 71 |
+
│ ├── __init__.py # Package exports
|
| 72 |
+
│ ├── main.py # CLI entry point
|
| 73 |
+
│ ├── models.py # Transformer model definitions
|
| 74 |
+
│ ├── data.py # Dataset classes
|
| 75 |
+
│ ├── training.py # Training loop
|
| 76 |
+
│ ├── inference.py # Inference utilities
|
| 77 |
+
│ └── utils.py # Number-to-name conversion utilities
|
| 78 |
+
├── tests/ # Test suite
|
| 79 |
+
│ ├── test_utils.py
|
| 80 |
+
│ ├── test_models.py
|
| 81 |
+
│ ├── test_data.py
|
| 82 |
+
│ └── test_inference.py
|
| 83 |
+
├── pyproject.toml # Project configuration
|
| 84 |
+
├── README.md
|
| 85 |
+
└── Makefile # Convenience commands
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
## Development
|
| 89 |
+
|
| 90 |
+
### Running Tests
|
| 91 |
+
|
| 92 |
+
```bash
|
| 93 |
+
# Run all tests
|
| 94 |
+
pytest
|
| 95 |
+
|
| 96 |
+
# Run with coverage
|
| 97 |
+
pytest --cov=namer --cov-report=html
|
| 98 |
+
|
| 99 |
+
# Run specific test file
|
| 100 |
+
pytest tests/test_utils.py
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
### Linting and Formatting
|
| 104 |
+
|
| 105 |
+
```bash
|
| 106 |
+
# Check code style
|
| 107 |
+
ruff check .
|
| 108 |
+
|
| 109 |
+
# Fix auto-fixable issues
|
| 110 |
+
ruff check --fix .
|
| 111 |
+
|
| 112 |
+
# Format code
|
| 113 |
+
ruff format .
|
| 114 |
+
|
| 115 |
+
# Type checking
|
| 116 |
+
mypy namer
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
### Makefile Commands
|
| 120 |
+
|
| 121 |
+
```bash
|
| 122 |
+
make help # Show available commands
|
| 123 |
+
make install # Install dependencies
|
| 124 |
+
make train # Train the model
|
| 125 |
+
make inference # Run interactive inference
|
| 126 |
+
make test # Run tests
|
| 127 |
+
make clean # Clean generated files
|
| 128 |
+
make distclean # Deep clean including venv
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
## Model Architecture
|
| 132 |
+
|
| 133 |
+
The `NamerTransformer` uses an encoder-only architecture:
|
| 134 |
+
|
| 135 |
+
1. **Digit Embedding** - Embeds digits 0-9 (plus padding token)
|
| 136 |
+
2. **Positional Encoding** - Sinusoidal positional embeddings
|
| 137 |
+
3. **Transformer Encoder** - Multi-layer encoder with self-attention
|
| 138 |
+
4. **Cross-Attention** - Learned output queries attend to encoded digits
|
| 139 |
+
5. **Output Projection** - Projects to vocabulary for each output position
|
| 140 |
+
|
| 141 |
+
## Training
|
| 142 |
+
|
| 143 |
+
The model trains on an infinite dataset that generates random number-to-name mappings on-the-fly:
|
| 144 |
+
|
| 145 |
+
- Numbers up to 999,999 (configurable)
|
| 146 |
+
- Early stopping with patience (default: 10 epochs)
|
| 147 |
+
- Cross-entropy loss with -1 padding ignored
|
| 148 |
+
- Adam optimizer with configurable learning rate
|
| 149 |
+
|
| 150 |
+
## Requirements
|
| 151 |
+
|
| 152 |
+
- Python 3.10+
|
| 153 |
+
- PyTorch 2.0+
|
| 154 |
+
- CUDA-capable GPU (optional, falls back to CPU)
|
| 155 |
+
|
| 156 |
+
## License
|
| 157 |
+
|
| 158 |
+
MIT License - see LICENSE file for details.
|
namer/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Namer - A PyTorch transformer model for converting numbers to English names."""
|
| 2 |
+
|
| 3 |
+
__version__ = "0.2.0"
|
| 4 |
+
|
| 5 |
+
from namer.models import NamerTransformer, load_namer_model
|
| 6 |
+
from namer.inference import predict_number_name
|
| 7 |
+
from namer.utils import (
|
| 8 |
+
VOCABULARY,
|
| 9 |
+
encode,
|
| 10 |
+
decode,
|
| 11 |
+
int_to_digits,
|
| 12 |
+
digits_to_int,
|
| 13 |
+
read_digits,
|
| 14 |
+
read_triplet,
|
| 15 |
+
read_double,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
"NamerTransformer",
|
| 20 |
+
"load_namer_model",
|
| 21 |
+
"predict_number_name",
|
| 22 |
+
"VOCABULARY",
|
| 23 |
+
"encode",
|
| 24 |
+
"decode",
|
| 25 |
+
"int_to_digits",
|
| 26 |
+
"digits_to_int",
|
| 27 |
+
"read_digits",
|
| 28 |
+
"read_triplet",
|
| 29 |
+
"read_double",
|
| 30 |
+
]
|
namer/__main__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Make namer package executable."""
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
from namer.main import main
|
| 6 |
+
|
| 7 |
+
if __name__ == "__main__":
|
| 8 |
+
sys.exit(main())
|
namer/data.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Dataset classes for Namer."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch.utils.data import IterableDataset, TensorDataset
|
| 9 |
+
|
| 10 |
+
from namer.utils import EOS_IDX, encode, int_to_digits, read_digits
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class NamerDataset(TensorDataset):
|
| 14 |
+
"""Finite dataset mapping random integers to encoded number names."""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
num_samples: int = 1000,
|
| 19 |
+
max_int: int = 999999,
|
| 20 |
+
max_seq_len: int = 20,
|
| 21 |
+
seed: int = 42,
|
| 22 |
+
) -> None:
|
| 23 |
+
"""Create a PyTorch TensorDataset mapping random integers to encoded number names.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
num_samples: Number of samples to generate
|
| 27 |
+
max_int: Maximum random integer value
|
| 28 |
+
max_seq_len: Maximum sequence length for padding
|
| 29 |
+
seed: Random seed for reproducibility
|
| 30 |
+
"""
|
| 31 |
+
rng = random.Random(seed)
|
| 32 |
+
|
| 33 |
+
digit_sequences: list[list[int]] = []
|
| 34 |
+
encoded_names: list[list[int]] = []
|
| 35 |
+
|
| 36 |
+
for _ in range(num_samples):
|
| 37 |
+
n = rng.randint(0, max_int)
|
| 38 |
+
digits = int_to_digits(n)
|
| 39 |
+
name = read_digits(digits)
|
| 40 |
+
encoded = encode(name)
|
| 41 |
+
|
| 42 |
+
digit_sequences.append(digits)
|
| 43 |
+
encoded_names.append(encoded)
|
| 44 |
+
|
| 45 |
+
# Pad sequences
|
| 46 |
+
padded_digits: list[list[int]] = []
|
| 47 |
+
padded_encoded: list[list[int]] = []
|
| 48 |
+
|
| 49 |
+
for digits, encoded in zip(digit_sequences, encoded_names):
|
| 50 |
+
# Pad digits with 10 to indicate padding
|
| 51 |
+
digits_padded = digits + [10] * (max_seq_len - len(digits))
|
| 52 |
+
digits_padded = digits_padded[:max_seq_len]
|
| 53 |
+
|
| 54 |
+
# Append EOS token to encoded, then pad with -1
|
| 55 |
+
encoded_with_eos = encoded + [EOS_IDX]
|
| 56 |
+
encoded_padded = encoded_with_eos + [-1] * (max_seq_len - len(encoded_with_eos))
|
| 57 |
+
encoded_padded = encoded_padded[:max_seq_len]
|
| 58 |
+
|
| 59 |
+
padded_digits.append(digits_padded)
|
| 60 |
+
padded_encoded.append(encoded_padded)
|
| 61 |
+
|
| 62 |
+
# Convert to tensors
|
| 63 |
+
digits_tensor = torch.tensor(padded_digits, dtype=torch.long)
|
| 64 |
+
encoded_tensor = torch.tensor(padded_encoded, dtype=torch.long)
|
| 65 |
+
|
| 66 |
+
super().__init__(digits_tensor, encoded_tensor)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class InfiniteNamerDataset(IterableDataset):
|
| 70 |
+
"""Infinite dataset that generates random number-to-name mappings on-the-fly.
|
| 71 |
+
|
| 72 |
+
Uses Python generators to produce an endless stream of training samples.
|
| 73 |
+
Each iteration yields fresh random samples.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
max_int: int = 999999,
|
| 79 |
+
max_seq_len: int = 20,
|
| 80 |
+
seed: int | None = None,
|
| 81 |
+
) -> None:
|
| 82 |
+
"""Initialize the infinite dataset.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
max_int: Maximum random integer value
|
| 86 |
+
max_seq_len: Maximum sequence length for padding
|
| 87 |
+
seed: Random seed (optional, for reproducibility)
|
| 88 |
+
"""
|
| 89 |
+
self.max_int = max_int
|
| 90 |
+
self.max_seq_len = max_seq_len
|
| 91 |
+
self.seed = seed
|
| 92 |
+
self.rng = random.Random(seed)
|
| 93 |
+
|
| 94 |
+
def _generate_sample(self) -> tuple[torch.Tensor, torch.Tensor]:
|
| 95 |
+
"""Generate a single (digits, encoded_name) sample."""
|
| 96 |
+
n = self.rng.randint(0, self.max_int)
|
| 97 |
+
digits = int_to_digits(n)
|
| 98 |
+
name = read_digits(digits)
|
| 99 |
+
encoded = encode(name)
|
| 100 |
+
|
| 101 |
+
# Pad digits with 10 (padding index)
|
| 102 |
+
digits_padded = digits + [10] * (self.max_seq_len - len(digits))
|
| 103 |
+
digits_padded = digits_padded[: self.max_seq_len]
|
| 104 |
+
|
| 105 |
+
# Append EOS and pad with -1
|
| 106 |
+
encoded_with_eos = encoded + [EOS_IDX]
|
| 107 |
+
encoded_padded = encoded_with_eos + [-1] * (self.max_seq_len - len(encoded_with_eos))
|
| 108 |
+
encoded_padded = encoded_padded[: self.max_seq_len]
|
| 109 |
+
|
| 110 |
+
return (
|
| 111 |
+
torch.tensor(digits_padded, dtype=torch.long),
|
| 112 |
+
torch.tensor(encoded_padded, dtype=torch.long),
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def __iter__(self) -> InfiniteNamerDataset:
|
| 116 |
+
"""Yield samples infinitely.
|
| 117 |
+
|
| 118 |
+
Each worker in multi-worker DataLoader gets its own iterator
|
| 119 |
+
with a unique seed based on worker_id.
|
| 120 |
+
"""
|
| 121 |
+
worker_info = torch.utils.data.get_worker_info()
|
| 122 |
+
|
| 123 |
+
if worker_info is None:
|
| 124 |
+
# Single-process loading
|
| 125 |
+
rng_seed = self.seed if self.seed else random.randint(0, 2**32)
|
| 126 |
+
self.rng = random.Random(rng_seed)
|
| 127 |
+
else:
|
| 128 |
+
# Multi-worker: each worker gets unique seed
|
| 129 |
+
worker_id = worker_info.id
|
| 130 |
+
base_seed = self.seed if self.seed else random.randint(0, 2**32)
|
| 131 |
+
self.rng = random.Random(base_seed + worker_id * 1000)
|
| 132 |
+
|
| 133 |
+
return self
|
| 134 |
+
|
| 135 |
+
def __next__(self) -> tuple[torch.Tensor, torch.Tensor]:
|
| 136 |
+
"""Generate the next sample."""
|
| 137 |
+
return self._generate_sample()
|
namer/inference.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference utilities for Namer models."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
from namer.models import NamerTransformer, load_namer_model
|
| 8 |
+
from namer.utils import EOS_IDX, decode, int_to_digits
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def predict_number_name(
|
| 12 |
+
model: NamerTransformer,
|
| 13 |
+
n: int,
|
| 14 |
+
device: str | torch.device | None = None,
|
| 15 |
+
) -> str:
|
| 16 |
+
"""Predict the English name of a number using the trained model.
|
| 17 |
+
|
| 18 |
+
Stops generation when <EOS> token is predicted.
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
model: Trained model
|
| 22 |
+
n: Integer to convert to name
|
| 23 |
+
device: Device to run inference on (auto-detected if None)
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Predicted English name of the number
|
| 27 |
+
"""
|
| 28 |
+
if device is None:
|
| 29 |
+
device = next(model.parameters()).device
|
| 30 |
+
|
| 31 |
+
model.eval()
|
| 32 |
+
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
digits = int_to_digits(n)
|
| 35 |
+
padded = digits + [10] * (model.max_output_len - len(digits))
|
| 36 |
+
input_tensor = torch.tensor([padded], dtype=torch.long).to(device)
|
| 37 |
+
|
| 38 |
+
logits = model(input_tensor)
|
| 39 |
+
predictions = logits.argmax(dim=-1)[0].cpu().tolist()
|
| 40 |
+
|
| 41 |
+
# Collect tokens until EOS is predicted or max length reached
|
| 42 |
+
pred_indices: list[int] = []
|
| 43 |
+
for idx in predictions:
|
| 44 |
+
if idx == EOS_IDX:
|
| 45 |
+
break
|
| 46 |
+
pred_indices.append(idx)
|
| 47 |
+
|
| 48 |
+
# Try to decode
|
| 49 |
+
try:
|
| 50 |
+
return decode(pred_indices)
|
| 51 |
+
except ValueError:
|
| 52 |
+
# If decoding fails, try progressively shorter sequences
|
| 53 |
+
for length in range(len(pred_indices), 0, -1):
|
| 54 |
+
try:
|
| 55 |
+
return decode(pred_indices[:length])
|
| 56 |
+
except ValueError:
|
| 57 |
+
continue
|
| 58 |
+
return f"<decode error: {pred_indices}>"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def interactive_inference(model_path: str = "namer_model.pt") -> None:
|
| 62 |
+
"""Run interactive inference session.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
model_path: Path to the saved model file
|
| 66 |
+
"""
|
| 67 |
+
import sys
|
| 68 |
+
|
| 69 |
+
print("Loading model...")
|
| 70 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 71 |
+
print(f"Using device: {device}")
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
model = load_namer_model(model_path, device)
|
| 75 |
+
print("Model loaded successfully!\n")
|
| 76 |
+
except FileNotFoundError:
|
| 77 |
+
print(f"Error: Model file '{model_path}' not found.")
|
| 78 |
+
print("Please run training first: python -m namer train")
|
| 79 |
+
sys.exit(1)
|
| 80 |
+
|
| 81 |
+
print("Enter a number to convert (or 'quit' to exit):")
|
| 82 |
+
while True:
|
| 83 |
+
try:
|
| 84 |
+
user_input = input("> ").strip()
|
| 85 |
+
|
| 86 |
+
if user_input.lower() in ("quit", "exit", "q"):
|
| 87 |
+
break
|
| 88 |
+
|
| 89 |
+
n = int(user_input)
|
| 90 |
+
name = predict_number_name(model, n, device)
|
| 91 |
+
print(f" {n} -> '{name}'\n")
|
| 92 |
+
|
| 93 |
+
except ValueError:
|
| 94 |
+
print(" Please enter a valid integer\n")
|
| 95 |
+
except KeyboardInterrupt:
|
| 96 |
+
print("\nGoodbye!")
|
| 97 |
+
break
|
namer/main.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Main entry point for namer CLI."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from namer.data import InfiniteNamerDataset
|
| 11 |
+
from namer.inference import interactive_inference, predict_number_name
|
| 12 |
+
from namer.models import NamerTransformer, load_namer_model
|
| 13 |
+
from namer.training import save_model, train_namer_model
|
| 14 |
+
from namer.utils import VOCABULARY, encode, int_to_digits, read_digits
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def demo_command(args: argparse.Namespace) -> None:
|
| 18 |
+
"""Run number name demonstration."""
|
| 19 |
+
print("--- Number Names Demo ---")
|
| 20 |
+
print("\nread_double (two digits):")
|
| 21 |
+
double_cases = [(0, 7), (1, 1), (2, 3), (3, 0), (0, 0), (5, 9)]
|
| 22 |
+
for a, b in double_cases:
|
| 23 |
+
from namer.utils import read_double
|
| 24 |
+
|
| 25 |
+
print(f" read_double({a}, {b}) = '{read_double(a, b)}'")
|
| 26 |
+
|
| 27 |
+
print("\nread_triplet (three digits):")
|
| 28 |
+
triplet_cases = [(1, 0, 6), (0, 0, 0), (9, 1, 9), (2, 0, 0), (0, 5, 5), (4, 2, 0)]
|
| 29 |
+
for a, b, c in triplet_cases:
|
| 30 |
+
from namer.utils import read_triplet
|
| 31 |
+
|
| 32 |
+
print(f" read_triplet({a}, {b}, {c}) = '{read_triplet(a, b, c)}'")
|
| 33 |
+
|
| 34 |
+
print(f"\nVOCABULARY ({len(VOCABULARY)} words):")
|
| 35 |
+
print(f" {VOCABULARY}")
|
| 36 |
+
|
| 37 |
+
print("\nencode (text to vocabulary indices):")
|
| 38 |
+
encode_cases = [
|
| 39 |
+
"one million",
|
| 40 |
+
"twenty three",
|
| 41 |
+
"one hundred twenty three",
|
| 42 |
+
"nine hundred nineteen",
|
| 43 |
+
"zero",
|
| 44 |
+
]
|
| 45 |
+
for text in encode_cases:
|
| 46 |
+
print(f" encode('{text}') = {encode(text)}")
|
| 47 |
+
|
| 48 |
+
print("\nencode/decode round-trip:")
|
| 49 |
+
for text in ["one million", "twenty three", "zero"]:
|
| 50 |
+
encoded = encode(text)
|
| 51 |
+
from namer.utils import decode
|
| 52 |
+
|
| 53 |
+
decoded = decode(encoded)
|
| 54 |
+
print(f" '{text}' -> {encoded} -> '{decoded}'")
|
| 55 |
+
|
| 56 |
+
print("\nint_to_digits (integer to digit list):")
|
| 57 |
+
int_cases = [0, 7, 123, -456, 1002003, 9876543210]
|
| 58 |
+
for n in int_cases:
|
| 59 |
+
print(f" int_to_digits({n}) = {int_to_digits(n)}")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def train_command(
|
| 63 |
+
num_epochs: int = 30,
|
| 64 |
+
steps_per_epoch: int = 1000,
|
| 65 |
+
batch_size: int = 128,
|
| 66 |
+
learning_rate: float = 0.001,
|
| 67 |
+
) -> None:
|
| 68 |
+
"""Train the Namer model.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
num_epochs: Number of training epochs
|
| 72 |
+
steps_per_epoch: Number of steps per epoch
|
| 73 |
+
batch_size: Batch size for training
|
| 74 |
+
learning_rate: Learning rate for optimizer
|
| 75 |
+
"""
|
| 76 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 77 |
+
if device.type == "cuda":
|
| 78 |
+
print(f"Using GPU: {torch.cuda.get_device_name(device)}")
|
| 79 |
+
else:
|
| 80 |
+
print("Warning: CUDA not available, using CPU")
|
| 81 |
+
|
| 82 |
+
# Create infinite dataset for training
|
| 83 |
+
infinite_dataset = InfiniteNamerDataset(
|
| 84 |
+
max_int=999999,
|
| 85 |
+
max_seq_len=20,
|
| 86 |
+
seed=42,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
# Create model
|
| 90 |
+
model = NamerTransformer(
|
| 91 |
+
vocab_size=len(VOCABULARY),
|
| 92 |
+
max_output_len=20,
|
| 93 |
+
d_model=128,
|
| 94 |
+
nhead=4,
|
| 95 |
+
num_encoder_layers=4,
|
| 96 |
+
dim_feedforward=512,
|
| 97 |
+
dropout=0.1,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
print(f"\nTransformer Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 101 |
+
|
| 102 |
+
# Train model
|
| 103 |
+
trained_model = train_namer_model(
|
| 104 |
+
model=model,
|
| 105 |
+
infinite_dataset=infinite_dataset,
|
| 106 |
+
num_epochs=num_epochs,
|
| 107 |
+
steps_per_epoch=steps_per_epoch,
|
| 108 |
+
val_steps=100,
|
| 109 |
+
batch_size=batch_size,
|
| 110 |
+
learning_rate=learning_rate,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Save model
|
| 114 |
+
save_model(trained_model)
|
| 115 |
+
|
| 116 |
+
# Test predictions
|
| 117 |
+
print("\n--- Model Predictions ---")
|
| 118 |
+
trained_model.eval()
|
| 119 |
+
|
| 120 |
+
test_numbers = [123, 4567, 89012, 555555, 999999, 42, 0, 1000]
|
| 121 |
+
device_obj = next(trained_model.parameters()).device
|
| 122 |
+
|
| 123 |
+
with torch.no_grad():
|
| 124 |
+
for n in test_numbers:
|
| 125 |
+
pred = predict_number_name(trained_model, n, device_obj)
|
| 126 |
+
actual = read_digits(int_to_digits(n))
|
| 127 |
+
match = "✓" if pred == actual else "✗"
|
| 128 |
+
print(f" {n}: pred='{pred}', actual='{actual}' {match}")
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def test_command() -> None:
|
| 132 |
+
"""Run quick inference test on saved model."""
|
| 133 |
+
try:
|
| 134 |
+
model = load_namer_model("namer_model.pt")
|
| 135 |
+
except FileNotFoundError:
|
| 136 |
+
print("Error: Model file 'namer_model.pt' not found.")
|
| 137 |
+
print("Please train the model first: python -m namer train")
|
| 138 |
+
sys.exit(1)
|
| 139 |
+
|
| 140 |
+
print("Running inference on loaded model:")
|
| 141 |
+
test_nums = [42, 123, 1000, 999999]
|
| 142 |
+
for n in test_nums:
|
| 143 |
+
pred = predict_number_name(model, n)
|
| 144 |
+
actual = read_digits(int_to_digits(n))
|
| 145 |
+
match = "✓" if pred == actual else "✗"
|
| 146 |
+
print(f" {n} -> '{pred}' (actual: '{actual}') {match}")
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def main(argv: list[str] | None = None) -> int:
|
| 150 |
+
"""Main CLI entry point.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
argv: Command line arguments (defaults to sys.argv)
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Exit code
|
| 157 |
+
"""
|
| 158 |
+
parser = argparse.ArgumentParser(
|
| 159 |
+
prog="namer",
|
| 160 |
+
description="A PyTorch transformer model for converting numbers to their English names.",
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
subparsers = parser.add_subparsers(dest="command", help="Available commands")
|
| 164 |
+
|
| 165 |
+
# Demo command
|
| 166 |
+
demo_parser = subparsers.add_parser("demo", help="Run number name demonstrations")
|
| 167 |
+
demo_parser.set_defaults(func=demo_command)
|
| 168 |
+
|
| 169 |
+
# Train command
|
| 170 |
+
train_parser = subparsers.add_parser("train", help="Train the model")
|
| 171 |
+
train_parser.add_argument(
|
| 172 |
+
"--epochs", type=int, default=30, help="Number of training epochs (default: 30)"
|
| 173 |
+
)
|
| 174 |
+
train_parser.add_argument(
|
| 175 |
+
"--steps", type=int, default=1000, help="Steps per epoch (default: 1000)"
|
| 176 |
+
)
|
| 177 |
+
train_parser.add_argument(
|
| 178 |
+
"--batch-size", type=int, default=128, help="Batch size (default: 128)"
|
| 179 |
+
)
|
| 180 |
+
train_parser.add_argument(
|
| 181 |
+
"--lr", type=float, default=0.001, help="Learning rate (default: 0.001)"
|
| 182 |
+
)
|
| 183 |
+
train_parser.set_defaults(
|
| 184 |
+
func=lambda args: train_command(
|
| 185 |
+
num_epochs=args.epochs,
|
| 186 |
+
steps_per_epoch=args.steps,
|
| 187 |
+
batch_size=args.batch_size,
|
| 188 |
+
learning_rate=args.lr,
|
| 189 |
+
)
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Inference command
|
| 193 |
+
infer_parser = subparsers.add_parser("infer", help="Run interactive inference")
|
| 194 |
+
infer_parser.set_defaults(func=lambda args: interactive_inference())
|
| 195 |
+
|
| 196 |
+
# Test command
|
| 197 |
+
test_parser = subparsers.add_parser("test", help="Run quick inference test")
|
| 198 |
+
test_parser.set_defaults(func=lambda args: test_command())
|
| 199 |
+
|
| 200 |
+
args = parser.parse_args(argv)
|
| 201 |
+
|
| 202 |
+
if args.command is None:
|
| 203 |
+
parser.print_help()
|
| 204 |
+
return 0
|
| 205 |
+
|
| 206 |
+
args.func(args)
|
| 207 |
+
return 0
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
if __name__ == "__main__":
|
| 211 |
+
sys.exit(main())
|
namer/models.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Model definitions for Namer."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PositionalEncoding(nn.Module):
|
| 10 |
+
"""Sinusoidal positional encoding for transformer."""
|
| 11 |
+
|
| 12 |
+
def __init__(self, d_model: int, max_len: int = 5000) -> None:
|
| 13 |
+
super().__init__()
|
| 14 |
+
|
| 15 |
+
pe = torch.zeros(max_len, d_model)
|
| 16 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 17 |
+
div_term = torch.exp(
|
| 18 |
+
torch.arange(0, d_model, 2).float()
|
| 19 |
+
* (-torch.log(torch.tensor(10000.0)) / d_model)
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 23 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 24 |
+
|
| 25 |
+
self.register_buffer("pe", pe)
|
| 26 |
+
|
| 27 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 28 |
+
"""Add positional encoding to input.
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
x: (batch_size, seq_len, d_model)
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
Tensor with positional encoding added
|
| 35 |
+
"""
|
| 36 |
+
return x + self.pe[: x.size(1)]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class NamerTransformer(nn.Module):
|
| 40 |
+
"""Transformer model for mapping digit sequences to number name tokens.
|
| 41 |
+
|
| 42 |
+
Architecture:
|
| 43 |
+
- Embedding layer for digits (11 values: 0-9 + padding)
|
| 44 |
+
- Positional encoding
|
| 45 |
+
- Transformer encoder layers
|
| 46 |
+
- Output projection to vocabulary for each position
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
vocab_size: int = 40,
|
| 52 |
+
max_output_len: int = 20,
|
| 53 |
+
d_model: int = 128,
|
| 54 |
+
nhead: int = 4,
|
| 55 |
+
num_encoder_layers: int = 4,
|
| 56 |
+
dim_feedforward: int = 512,
|
| 57 |
+
dropout: float = 0.1,
|
| 58 |
+
) -> None:
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.vocab_size = vocab_size
|
| 61 |
+
self.max_output_len = max_output_len
|
| 62 |
+
self.d_model = d_model
|
| 63 |
+
|
| 64 |
+
# Digit embedding (10 digits + 1 padding token = 11)
|
| 65 |
+
self.digit_embedding = nn.Embedding(11, d_model, padding_idx=10)
|
| 66 |
+
|
| 67 |
+
# Positional encoding
|
| 68 |
+
self.pos_encoder = PositionalEncoding(d_model, max_len=100)
|
| 69 |
+
|
| 70 |
+
# Transformer encoder
|
| 71 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 72 |
+
d_model=d_model,
|
| 73 |
+
nhead=nhead,
|
| 74 |
+
dim_feedforward=dim_feedforward,
|
| 75 |
+
dropout=dropout,
|
| 76 |
+
batch_first=True,
|
| 77 |
+
)
|
| 78 |
+
self.transformer_encoder = nn.TransformerEncoder(
|
| 79 |
+
encoder_layer, num_layers=num_encoder_layers
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Output projection
|
| 83 |
+
self.output_projection = nn.Linear(d_model, vocab_size)
|
| 84 |
+
|
| 85 |
+
# Learned queries for each output position
|
| 86 |
+
self.output_queries = nn.Parameter(torch.randn(max_output_len, d_model))
|
| 87 |
+
|
| 88 |
+
# Cross-attention from output positions to encoded input
|
| 89 |
+
self.cross_attention = nn.MultiheadAttention(
|
| 90 |
+
d_model, nhead, dropout=dropout, batch_first=True
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Final output layers
|
| 94 |
+
self.output_norm = nn.LayerNorm(d_model)
|
| 95 |
+
|
| 96 |
+
def forward(self, digits: torch.Tensor) -> torch.Tensor:
|
| 97 |
+
"""Forward pass.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
digits: (batch_size, seq_len) tensor of digit indices (0-9), padding=10
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
(batch_size, max_output_len, vocab_size) logits
|
| 104 |
+
"""
|
| 105 |
+
batch_size, seq_len = digits.shape
|
| 106 |
+
|
| 107 |
+
# Handle padding: convert -1 padding to 10 (our padding index)
|
| 108 |
+
digits = digits.clone()
|
| 109 |
+
digits[digits == -1] = 10
|
| 110 |
+
|
| 111 |
+
# Create padding mask for transformer (True = padding)
|
| 112 |
+
src_key_padding_mask = digits == 10
|
| 113 |
+
|
| 114 |
+
# Embed digits: (batch, seq_len, d_model)
|
| 115 |
+
embedded = self.digit_embedding(digits)
|
| 116 |
+
|
| 117 |
+
# Add positional encoding
|
| 118 |
+
embedded = self.pos_encoder(embedded)
|
| 119 |
+
|
| 120 |
+
# Transformer encoder: (batch, seq_len, d_model)
|
| 121 |
+
memory = self.transformer_encoder(
|
| 122 |
+
embedded, src_key_padding_mask=src_key_padding_mask
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Expand queries for batch: (batch, max_output_len, d_model)
|
| 126 |
+
queries = self.output_queries.unsqueeze(0).expand(batch_size, -1, -1)
|
| 127 |
+
|
| 128 |
+
# Cross-attention from queries to encoded input
|
| 129 |
+
attn_output, _ = self.cross_attention(
|
| 130 |
+
queries, memory, memory, key_padding_mask=src_key_padding_mask
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Normalize and project to vocab
|
| 134 |
+
output = self.output_norm(attn_output)
|
| 135 |
+
logits = self.output_projection(output)
|
| 136 |
+
|
| 137 |
+
return logits
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def load_namer_model(
|
| 141 |
+
model_path: str = "namer_model.pt",
|
| 142 |
+
device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu",
|
| 143 |
+
) -> NamerTransformer:
|
| 144 |
+
"""Load a trained Namer model for inference.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
model_path: Path to the saved model file
|
| 148 |
+
device: Device to load the model on
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Loaded model in eval mode
|
| 152 |
+
"""
|
| 153 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 154 |
+
|
| 155 |
+
model = NamerTransformer(
|
| 156 |
+
vocab_size=checkpoint["vocab_size"],
|
| 157 |
+
max_output_len=checkpoint["max_output_len"],
|
| 158 |
+
d_model=checkpoint.get("d_model", 128),
|
| 159 |
+
nhead=4,
|
| 160 |
+
num_encoder_layers=4,
|
| 161 |
+
dim_feedforward=512,
|
| 162 |
+
dropout=0.0, # No dropout for inference
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
model.load_state_dict(checkpoint["model_state_dict"])
|
| 166 |
+
model.to(device)
|
| 167 |
+
model.eval()
|
| 168 |
+
|
| 169 |
+
return model
|
namer/training.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training utilities for Namer models."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 9 |
+
|
| 10 |
+
from namer.models import NamerTransformer
|
| 11 |
+
from namer.data import InfiniteNamerDataset
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def train_namer_model(
|
| 15 |
+
model: NamerTransformer,
|
| 16 |
+
dataset: TensorDataset | None = None,
|
| 17 |
+
infinite_dataset: InfiniteNamerDataset | None = None,
|
| 18 |
+
num_epochs: int = 50,
|
| 19 |
+
steps_per_epoch: int = 1000,
|
| 20 |
+
val_steps: int = 100,
|
| 21 |
+
batch_size: int = 64,
|
| 22 |
+
learning_rate: float = 0.001,
|
| 23 |
+
patience: int = 10,
|
| 24 |
+
device: str | torch.device = "cuda" if torch.cuda.is_available() else "cpu",
|
| 25 |
+
) -> NamerTransformer:
|
| 26 |
+
"""Train the model on a finite dataset or infinite iterable dataset.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
model: The model to train
|
| 30 |
+
dataset: Finite TensorDataset with (digits, encoded_names) pairs
|
| 31 |
+
infinite_dataset: Infinite IterableDataset for infinite training
|
| 32 |
+
num_epochs: Number of training epochs
|
| 33 |
+
steps_per_epoch: Number of steps per epoch (for infinite dataset)
|
| 34 |
+
val_steps: Number of validation steps per epoch
|
| 35 |
+
batch_size: Batch size for training
|
| 36 |
+
learning_rate: Learning rate for optimizer
|
| 37 |
+
patience: Early stopping patience
|
| 38 |
+
device: Device to train on ('cuda' or 'cpu')
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Trained model
|
| 42 |
+
"""
|
| 43 |
+
model = model.to(device)
|
| 44 |
+
|
| 45 |
+
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
|
| 46 |
+
criterion = nn.CrossEntropyLoss(ignore_index=-1)
|
| 47 |
+
|
| 48 |
+
print(f"Training on {device}")
|
| 49 |
+
print(f"Early stopping patience: {patience} epochs")
|
| 50 |
+
|
| 51 |
+
# Setup data loaders
|
| 52 |
+
if infinite_dataset is not None:
|
| 53 |
+
print(f"Using INFINITE dataset (max_int={infinite_dataset.max_int})")
|
| 54 |
+
print(f"Steps per epoch: {steps_per_epoch}, Val steps: {val_steps}")
|
| 55 |
+
|
| 56 |
+
train_loader = DataLoader(
|
| 57 |
+
infinite_dataset,
|
| 58 |
+
batch_size=batch_size,
|
| 59 |
+
num_workers=0,
|
| 60 |
+
)
|
| 61 |
+
val_loader = DataLoader(
|
| 62 |
+
infinite_dataset,
|
| 63 |
+
batch_size=batch_size,
|
| 64 |
+
num_workers=0,
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
if dataset is None:
|
| 68 |
+
raise ValueError("Either dataset or infinite_dataset must be provided")
|
| 69 |
+
|
| 70 |
+
train_size = int(0.9 * len(dataset))
|
| 71 |
+
val_size = len(dataset) - train_size
|
| 72 |
+
train_dataset, val_dataset = torch.utils.data.random_split(
|
| 73 |
+
dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42)
|
| 74 |
+
)
|
| 75 |
+
train_loader = DataLoader(
|
| 76 |
+
train_dataset, batch_size=batch_size, shuffle=True
|
| 77 |
+
)
|
| 78 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
|
| 79 |
+
steps_per_epoch = len(train_loader)
|
| 80 |
+
val_steps = len(val_loader)
|
| 81 |
+
print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")
|
| 82 |
+
|
| 83 |
+
best_val_loss = float("inf")
|
| 84 |
+
epochs_without_improvement = 0
|
| 85 |
+
best_model_state: dict | None = None
|
| 86 |
+
|
| 87 |
+
for epoch in range(num_epochs):
|
| 88 |
+
# Training
|
| 89 |
+
model.train()
|
| 90 |
+
train_loss = 0.0
|
| 91 |
+
train_correct = 0
|
| 92 |
+
train_total = 0
|
| 93 |
+
|
| 94 |
+
train_iter = iter(train_loader)
|
| 95 |
+
for _ in range(steps_per_epoch):
|
| 96 |
+
digits_batch, target_batch = next(train_iter)
|
| 97 |
+
digits_batch = digits_batch.to(device)
|
| 98 |
+
target_batch = target_batch.to(device)
|
| 99 |
+
|
| 100 |
+
optimizer.zero_grad()
|
| 101 |
+
|
| 102 |
+
logits = model(digits_batch)
|
| 103 |
+
loss = criterion(
|
| 104 |
+
logits.view(-1, model.vocab_size), target_batch.view(-1)
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
loss.backward()
|
| 108 |
+
optimizer.step()
|
| 109 |
+
|
| 110 |
+
train_loss += loss.item()
|
| 111 |
+
|
| 112 |
+
mask = target_batch != -1
|
| 113 |
+
predictions = logits.argmax(dim=-1)
|
| 114 |
+
train_correct += ((predictions == target_batch) & mask).sum().item()
|
| 115 |
+
train_total += mask.sum().item()
|
| 116 |
+
|
| 117 |
+
train_loss /= steps_per_epoch
|
| 118 |
+
train_acc = train_correct / train_total if train_total > 0 else 0
|
| 119 |
+
|
| 120 |
+
# Validation
|
| 121 |
+
model.eval()
|
| 122 |
+
val_loss = 0.0
|
| 123 |
+
val_correct = 0
|
| 124 |
+
val_total = 0
|
| 125 |
+
|
| 126 |
+
with torch.no_grad():
|
| 127 |
+
val_iter = iter(val_loader)
|
| 128 |
+
for _ in range(val_steps):
|
| 129 |
+
digits_batch, target_batch = next(val_iter)
|
| 130 |
+
digits_batch = digits_batch.to(device)
|
| 131 |
+
target_batch = target_batch.to(device)
|
| 132 |
+
|
| 133 |
+
logits = model(digits_batch)
|
| 134 |
+
loss = criterion(
|
| 135 |
+
logits.view(-1, model.vocab_size), target_batch.view(-1)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
val_loss += loss.item()
|
| 139 |
+
|
| 140 |
+
mask = target_batch != -1
|
| 141 |
+
predictions = logits.argmax(dim=-1)
|
| 142 |
+
val_correct += ((predictions == target_batch) & mask).sum().item()
|
| 143 |
+
val_total += mask.sum().item()
|
| 144 |
+
|
| 145 |
+
val_loss /= val_steps
|
| 146 |
+
val_acc = val_correct / val_total if val_total > 0 else 0
|
| 147 |
+
|
| 148 |
+
if val_loss < best_val_loss:
|
| 149 |
+
best_val_loss = val_loss
|
| 150 |
+
epochs_without_improvement = 0
|
| 151 |
+
best_model_state = model.state_dict().copy()
|
| 152 |
+
else:
|
| 153 |
+
epochs_without_improvement += 1
|
| 154 |
+
|
| 155 |
+
if (epoch + 1) % 10 == 0 or epoch == 0:
|
| 156 |
+
print(
|
| 157 |
+
f"Epoch {epoch+1}/{num_epochs}: "
|
| 158 |
+
f"train_loss={train_loss:.4f}, train_acc={train_acc:.4f}, "
|
| 159 |
+
f"val_loss={val_loss:.4f}, val_acc={val_acc:.4f}, "
|
| 160 |
+
f"patience={epochs_without_improvement}/{patience}"
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if epochs_without_improvement >= patience:
|
| 164 |
+
print(f"\nEarly stopping triggered! No improvement for {patience} epochs.")
|
| 165 |
+
break
|
| 166 |
+
|
| 167 |
+
print(f"\nBest validation loss: {best_val_loss:.4f}")
|
| 168 |
+
|
| 169 |
+
if best_model_state is not None:
|
| 170 |
+
model.load_state_dict(best_model_state)
|
| 171 |
+
print("Restored best model from checkpoint.")
|
| 172 |
+
|
| 173 |
+
return model
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def save_model(model: NamerTransformer, model_path: str = "namer_model.pt") -> None:
|
| 177 |
+
"""Save a trained model to disk.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
model: The model to save
|
| 181 |
+
model_path: Path where to save the model
|
| 182 |
+
"""
|
| 183 |
+
checkpoint = {
|
| 184 |
+
"model_type": "transformer",
|
| 185 |
+
"model_state_dict": model.state_dict(),
|
| 186 |
+
"vocab_size": model.vocab_size,
|
| 187 |
+
"max_output_len": model.max_output_len,
|
| 188 |
+
"d_model": model.d_model,
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
torch.save(checkpoint, model_path)
|
| 192 |
+
print(f"Model saved to {model_path}")
|
namer/utils.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Utility functions for number-to-name conversion."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
# Global constants for number names
|
| 6 |
+
ONES: tuple[str, ...] = (
|
| 7 |
+
"zero", "one", "two", "three", "four",
|
| 8 |
+
"five", "six", "seven", "eight", "nine"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
TEENS: tuple[str, ...] = (
|
| 12 |
+
"ten", "eleven", "twelve", "thirteen", "fourteen",
|
| 13 |
+
"fifteen", "sixteen", "seventeen", "eighteen", "nineteen"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
TENS: tuple[str, ...] = (
|
| 17 |
+
"", "", "twenty", "thirty", "forty",
|
| 18 |
+
"fifty", "sixty", "seventy", "eighty", "ninety"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# Scale words for powers of 1000
|
| 22 |
+
SCALES: tuple[str, ...] = (
|
| 23 |
+
"", "thousand", "million", "billion", "trillion",
|
| 24 |
+
"quadrillion", "quintillion", "sextillion", "septillion",
|
| 25 |
+
"octillion", "nonillion", "decillion"
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Combined vocabulary of all number words
|
| 29 |
+
VOCABULARY: list[str] = []
|
| 30 |
+
VOCABULARY.extend(ONES)
|
| 31 |
+
VOCABULARY.extend(TEENS)
|
| 32 |
+
VOCABULARY.extend([t for t in TENS if t]) # Exclude empty strings
|
| 33 |
+
VOCABULARY.append("hundred")
|
| 34 |
+
VOCABULARY.extend([s for s in SCALES if s]) # Exclude empty string
|
| 35 |
+
VOCABULARY.append("<EOS>") # End of sequence token
|
| 36 |
+
|
| 37 |
+
# Create a word-to-index lookup for efficient encoding
|
| 38 |
+
WORD_TO_INDEX: dict[str, int] = {word: idx for idx, word in enumerate(VOCABULARY)}
|
| 39 |
+
|
| 40 |
+
# Special token indices
|
| 41 |
+
EOS_IDX: int = VOCABULARY.index("<EOS>")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def int_to_digits(n: int) -> list[int]:
|
| 45 |
+
"""Convert an integer to a list of its decimal digits.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
n: An integer (can be any size, positive, negative, or zero)
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
List of digits (0-9). Returns [0] for zero.
|
| 52 |
+
Negative numbers return digits without the sign.
|
| 53 |
+
|
| 54 |
+
Example:
|
| 55 |
+
>>> int_to_digits(123)
|
| 56 |
+
[1, 2, 3]
|
| 57 |
+
>>> int_to_digits(0)
|
| 58 |
+
[0]
|
| 59 |
+
>>> int_to_digits(-456)
|
| 60 |
+
[4, 5, 6]
|
| 61 |
+
"""
|
| 62 |
+
if n == 0:
|
| 63 |
+
return [0]
|
| 64 |
+
|
| 65 |
+
n = abs(n)
|
| 66 |
+
|
| 67 |
+
digits: list[int] = []
|
| 68 |
+
while n > 0:
|
| 69 |
+
digits.append(n % 10)
|
| 70 |
+
n //= 10
|
| 71 |
+
|
| 72 |
+
return digits[::-1]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def digits_to_int(digits: list[int]) -> int:
|
| 76 |
+
"""Convert a list of decimal digits to an integer.
|
| 77 |
+
|
| 78 |
+
This is the inverse of int_to_digits().
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
digits: List of digits (0-9)
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
The integer value represented by the digits
|
| 85 |
+
|
| 86 |
+
Raises:
|
| 87 |
+
ValueError: If any digit is not 0-9
|
| 88 |
+
|
| 89 |
+
Example:
|
| 90 |
+
>>> digits_to_int([1, 2, 3])
|
| 91 |
+
123
|
| 92 |
+
>>> digits_to_int([0])
|
| 93 |
+
0
|
| 94 |
+
"""
|
| 95 |
+
if not digits:
|
| 96 |
+
return 0
|
| 97 |
+
|
| 98 |
+
result = 0
|
| 99 |
+
for d in digits:
|
| 100 |
+
if not (0 <= d <= 9):
|
| 101 |
+
raise ValueError(f"Invalid digit {d}, must be 0-9")
|
| 102 |
+
result = result * 10 + d
|
| 103 |
+
|
| 104 |
+
return result
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def encode(text: str) -> list[int]:
|
| 108 |
+
"""Encode a string of number words into a list of vocabulary indices.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
text: String containing space-separated number words (e.g., "one million")
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
List of indices corresponding to each word in VOCABULARY
|
| 115 |
+
|
| 116 |
+
Raises:
|
| 117 |
+
ValueError: If a word is not found in VOCABULARY
|
| 118 |
+
|
| 119 |
+
Example:
|
| 120 |
+
>>> encode("one million")
|
| 121 |
+
[1, 29]
|
| 122 |
+
"""
|
| 123 |
+
if not text or not text.strip():
|
| 124 |
+
return []
|
| 125 |
+
|
| 126 |
+
words = text.strip().lower().split()
|
| 127 |
+
indices: list[int] = []
|
| 128 |
+
|
| 129 |
+
for word in words:
|
| 130 |
+
if word not in WORD_TO_INDEX:
|
| 131 |
+
raise ValueError(f"Unknown word '{word}' not in VOCABULARY")
|
| 132 |
+
indices.append(WORD_TO_INDEX[word])
|
| 133 |
+
|
| 134 |
+
return indices
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def decode(indices: list[int]) -> str:
|
| 138 |
+
"""Decode a list of vocabulary indices into a string of number words.
|
| 139 |
+
|
| 140 |
+
This is the inverse of encode(). <EOS> tokens are ignored.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
indices: List of indices into VOCABULARY (e.g., [1, 30])
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
String of space-separated number words (e.g., "one million")
|
| 147 |
+
|
| 148 |
+
Raises:
|
| 149 |
+
ValueError: If an index is out of range
|
| 150 |
+
|
| 151 |
+
Example:
|
| 152 |
+
>>> decode([1, 30])
|
| 153 |
+
'one million'
|
| 154 |
+
"""
|
| 155 |
+
if not indices:
|
| 156 |
+
return ""
|
| 157 |
+
|
| 158 |
+
words: list[str] = []
|
| 159 |
+
for idx in indices:
|
| 160 |
+
if not (0 <= idx < len(VOCABULARY)):
|
| 161 |
+
raise ValueError(f"Index {idx} out of range for VOCABULARY (size {len(VOCABULARY)})")
|
| 162 |
+
word = VOCABULARY[idx]
|
| 163 |
+
if word != "<EOS>":
|
| 164 |
+
words.append(word)
|
| 165 |
+
|
| 166 |
+
return " ".join(words)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def read_double(a: int, b: int) -> str:
|
| 170 |
+
"""Convert two digits (a, b) into the English name of the number they form.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
a: Tens digit (0-9)
|
| 174 |
+
b: Ones digit (0-9)
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
English name of the number (e.g., "twenty three", "eleven", "seven")
|
| 178 |
+
"""
|
| 179 |
+
if not (0 <= a <= 9 and 0 <= b <= 9):
|
| 180 |
+
raise ValueError("Digits must be between 0 and 9")
|
| 181 |
+
|
| 182 |
+
number = a * 10 + b
|
| 183 |
+
|
| 184 |
+
if number < 10:
|
| 185 |
+
return ONES[number]
|
| 186 |
+
elif number < 20:
|
| 187 |
+
return TEENS[number - 10]
|
| 188 |
+
elif b == 0:
|
| 189 |
+
return TENS[a]
|
| 190 |
+
else:
|
| 191 |
+
return f"{TENS[a]} {ONES[b]}"
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def read_triplet(a: int, b: int, c: int) -> str:
|
| 195 |
+
"""Convert three digits (a, b, c) into the English name of the number they form.
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
a: Hundreds digit (0-9)
|
| 199 |
+
b: Tens digit (0-9)
|
| 200 |
+
c: Ones digit (0-9)
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
English name of the number (e.g., "one hundred six", "zero", "nine hundred nineteen")
|
| 204 |
+
"""
|
| 205 |
+
if not (0 <= a <= 9 and 0 <= b <= 9 and 0 <= c <= 9):
|
| 206 |
+
raise ValueError("Digits must be between 0 and 9")
|
| 207 |
+
|
| 208 |
+
if a == 0:
|
| 209 |
+
return read_double(b, c)
|
| 210 |
+
|
| 211 |
+
remainder = read_double(b, c)
|
| 212 |
+
|
| 213 |
+
if b == 0 and c == 0:
|
| 214 |
+
return f"{ONES[a]} hundred"
|
| 215 |
+
else:
|
| 216 |
+
return f"{ONES[a]} hundred {remainder}"
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def read_digits(lst: list[int]) -> str:
|
| 220 |
+
"""Convert a list of digits into the English name of the number they form.
|
| 221 |
+
|
| 222 |
+
Groups digits into triplets and combines with scale words (thousand, million, etc.)
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
lst: List of digits (0-9)
|
| 226 |
+
|
| 227 |
+
Returns:
|
| 228 |
+
English name of the number
|
| 229 |
+
"""
|
| 230 |
+
if not lst:
|
| 231 |
+
return "zero"
|
| 232 |
+
|
| 233 |
+
for d in lst:
|
| 234 |
+
if not (0 <= d <= 9):
|
| 235 |
+
raise ValueError("All elements must be digits between 0 and 9")
|
| 236 |
+
|
| 237 |
+
if all(d == 0 for d in lst):
|
| 238 |
+
return "zero"
|
| 239 |
+
|
| 240 |
+
# Pad with leading zeros to make length a multiple of 3
|
| 241 |
+
padded = lst[:]
|
| 242 |
+
while len(padded) % 3 != 0:
|
| 243 |
+
padded = [0] + padded
|
| 244 |
+
|
| 245 |
+
# Group into triplets
|
| 246 |
+
triplets: list[tuple[int, int, int]] = []
|
| 247 |
+
for i in range(0, len(padded), 3):
|
| 248 |
+
triplets.append((padded[i], padded[i+1], padded[i+2]))
|
| 249 |
+
|
| 250 |
+
# Build the result by processing each triplet with its scale
|
| 251 |
+
parts: list[str] = []
|
| 252 |
+
num_triplets = len(triplets)
|
| 253 |
+
|
| 254 |
+
for i, (a, b, c) in enumerate(triplets):
|
| 255 |
+
if a == 0 and b == 0 and c == 0:
|
| 256 |
+
continue
|
| 257 |
+
|
| 258 |
+
triplet_name = read_triplet(a, b, c)
|
| 259 |
+
scale_index = num_triplets - 1 - i
|
| 260 |
+
scale = SCALES[scale_index] if scale_index < len(SCALES) else ""
|
| 261 |
+
|
| 262 |
+
if scale:
|
| 263 |
+
parts.append(f"{triplet_name} {scale}")
|
| 264 |
+
else:
|
| 265 |
+
parts.append(triplet_name)
|
| 266 |
+
|
| 267 |
+
return " ".join(parts)
|
pyproject.toml
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["hatchling"]
|
| 3 |
+
build-backend = "hatchling.build"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "namer"
|
| 7 |
+
version = "0.2.0"
|
| 8 |
+
description = "A PyTorch transformer model for converting numbers to English names"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
license = {text = "MIT"}
|
| 11 |
+
requires-python = ">=3.10"
|
| 12 |
+
authors = [
|
| 13 |
+
{name = "Developer", email = "dev@example.com"}
|
| 14 |
+
]
|
| 15 |
+
keywords = ["pytorch", "transformer", "nlp", "numbers", "ml"]
|
| 16 |
+
classifiers = [
|
| 17 |
+
"Development Status :: 3 - Alpha",
|
| 18 |
+
"Intended Audience :: Developers",
|
| 19 |
+
"License :: OSI Approved :: MIT License",
|
| 20 |
+
"Programming Language :: Python :: 3",
|
| 21 |
+
"Programming Language :: Python :: 3.10",
|
| 22 |
+
"Programming Language :: Python :: 3.11",
|
| 23 |
+
"Programming Language :: Python :: 3.12",
|
| 24 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
dependencies = [
|
| 28 |
+
"torch>=2.0.0",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
[project.optional-dependencies]
|
| 32 |
+
dev = [
|
| 33 |
+
"pytest>=7.0",
|
| 34 |
+
"pytest-cov>=4.0",
|
| 35 |
+
"ruff>=0.1.0",
|
| 36 |
+
"mypy>=1.0",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
[project.scripts]
|
| 40 |
+
namer = "namer.main:main"
|
| 41 |
+
|
| 42 |
+
[project.urls]
|
| 43 |
+
Homepage = "https://github.com/example/namer"
|
| 44 |
+
Repository = "https://github.com/example/namer"
|
| 45 |
+
Issues = "https://github.com/example/namer/issues"
|
| 46 |
+
|
| 47 |
+
[tool.ruff]
|
| 48 |
+
target-version = "py310"
|
| 49 |
+
line-length = 88
|
| 50 |
+
|
| 51 |
+
[tool.ruff.lint]
|
| 52 |
+
select = [
|
| 53 |
+
"E", # pycodestyle errors
|
| 54 |
+
"W", # pycodestyle warnings
|
| 55 |
+
"F", # Pyflakes
|
| 56 |
+
"I", # isort
|
| 57 |
+
"N", # pep8-naming
|
| 58 |
+
"D", # pydocstyle
|
| 59 |
+
"UP", # pyupgrade
|
| 60 |
+
"B", # flake8-bugbear
|
| 61 |
+
"C4", # flake8-comprehensions
|
| 62 |
+
"SIM", # flake8-simplify
|
| 63 |
+
]
|
| 64 |
+
ignore = ["D100", "D104"] # Missing docstrings in public packages/modules
|
| 65 |
+
|
| 66 |
+
[tool.ruff.lint.pydocstyle]
|
| 67 |
+
convention = "google"
|
| 68 |
+
|
| 69 |
+
[tool.mypy]
|
| 70 |
+
python_version = "3.10"
|
| 71 |
+
warn_return_any = true
|
| 72 |
+
warn_unused_configs = true
|
| 73 |
+
disallow_untyped_defs = true
|
| 74 |
+
disallow_incomplete_defs = true
|
| 75 |
+
check_untyped_defs = true
|
| 76 |
+
warn_redundant_casts = true
|
| 77 |
+
warn_unused_ignores = true
|
| 78 |
+
show_error_codes = true
|
| 79 |
+
|
| 80 |
+
[[tool.mypy.overrides]]
|
| 81 |
+
module = ["torch.*"]
|
| 82 |
+
ignore_missing_imports = true
|
| 83 |
+
|
| 84 |
+
[tool.pytest.ini_options]
|
| 85 |
+
testpaths = ["tests"]
|
| 86 |
+
python_files = ["test_*.py"]
|
| 87 |
+
python_functions = ["test_*"]
|
| 88 |
+
addopts = "-v --tb=short"
|
| 89 |
+
|
| 90 |
+
[tool.coverage.run]
|
| 91 |
+
source = ["namer"]
|
| 92 |
+
omit = ["*/tests/*"]
|
| 93 |
+
|
| 94 |
+
[tool.coverage.report]
|
| 95 |
+
exclude_lines = [
|
| 96 |
+
"pragma: no cover",
|
| 97 |
+
"def __repr__",
|
| 98 |
+
"raise AssertionError",
|
| 99 |
+
"raise NotImplementedError",
|
| 100 |
+
]
|
tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Tests for Namer package."""
|
tests/test_data.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for dataset classes."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
|
| 6 |
+
from namer.data import InfiniteNamerDataset, NamerDataset
|
| 7 |
+
from namer.utils import EOS_IDX, VOCABULARY
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TestNamerDataset:
|
| 11 |
+
"""Tests for NamerDataset class."""
|
| 12 |
+
|
| 13 |
+
def test_length(self) -> None:
|
| 14 |
+
dataset = NamerDataset(num_samples=50, seed=42)
|
| 15 |
+
assert len(dataset) == 50
|
| 16 |
+
|
| 17 |
+
def test_sample_shape(self) -> None:
|
| 18 |
+
dataset = NamerDataset(num_samples=10, max_seq_len=20, seed=42)
|
| 19 |
+
digits, encoded = dataset[0]
|
| 20 |
+
|
| 21 |
+
assert digits.shape == (20,)
|
| 22 |
+
assert encoded.shape == (20,)
|
| 23 |
+
assert digits.dtype == torch.long
|
| 24 |
+
assert encoded.dtype == torch.long
|
| 25 |
+
|
| 26 |
+
def test_padding_value(self) -> None:
|
| 27 |
+
dataset = NamerDataset(num_samples=10, max_seq_len=20, seed=42)
|
| 28 |
+
digits, _ = dataset[0]
|
| 29 |
+
|
| 30 |
+
# Padding should be 10
|
| 31 |
+
assert (digits == 10).any() or len([d for d in digits if d != 10]) <= 6
|
| 32 |
+
|
| 33 |
+
def test_eos_present(self) -> None:
|
| 34 |
+
dataset = NamerDataset(num_samples=10, seed=42)
|
| 35 |
+
_, encoded = dataset[0]
|
| 36 |
+
|
| 37 |
+
# EOS token should be present
|
| 38 |
+
assert EOS_IDX in encoded.tolist()
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TestInfiniteNamerDataset:
|
| 42 |
+
"""Tests for InfiniteNamerDataset class."""
|
| 43 |
+
|
| 44 |
+
def test_iteration(self) -> None:
|
| 45 |
+
dataset = InfiniteNamerDataset(seed=42)
|
| 46 |
+
iterator = iter(dataset)
|
| 47 |
+
|
| 48 |
+
# Can get multiple samples
|
| 49 |
+
for _ in range(10):
|
| 50 |
+
digits, encoded = next(iterator)
|
| 51 |
+
assert digits.shape == (20,)
|
| 52 |
+
assert encoded.shape == (20,)
|
| 53 |
+
|
| 54 |
+
def test_data_loader(self) -> None:
|
| 55 |
+
dataset = InfiniteNamerDataset(seed=42)
|
| 56 |
+
loader = DataLoader(dataset, batch_size=4, num_workers=0)
|
| 57 |
+
|
| 58 |
+
iterator = iter(loader)
|
| 59 |
+
digits_batch, encoded_batch = next(iterator)
|
| 60 |
+
|
| 61 |
+
assert digits_batch.shape == (4, 20)
|
| 62 |
+
assert encoded_batch.shape == (4, 20)
|
| 63 |
+
|
| 64 |
+
def test_reproducibility(self) -> None:
|
| 65 |
+
dataset1 = InfiniteNamerDataset(seed=42)
|
| 66 |
+
dataset2 = InfiniteNamerDataset(seed=42)
|
| 67 |
+
|
| 68 |
+
iter1 = iter(dataset1)
|
| 69 |
+
iter2 = iter(dataset2)
|
| 70 |
+
|
| 71 |
+
for _ in range(5):
|
| 72 |
+
d1, e1 = next(iter1)
|
| 73 |
+
d2, e2 = next(iter2)
|
| 74 |
+
assert torch.equal(d1, d2)
|
| 75 |
+
assert torch.equal(e1, e2)
|
| 76 |
+
|
| 77 |
+
def test_vocab_range(self) -> None:
|
| 78 |
+
dataset = InfiniteNamerDataset(seed=42)
|
| 79 |
+
iterator = iter(dataset)
|
| 80 |
+
|
| 81 |
+
for _ in range(20):
|
| 82 |
+
_, encoded = next(iterator)
|
| 83 |
+
# Valid tokens should be within vocab range (excluding -1 padding)
|
| 84 |
+
valid_tokens = encoded[encoded != -1]
|
| 85 |
+
assert (valid_tokens >= 0).all()
|
| 86 |
+
assert (valid_tokens < len(VOCABULARY)).all()
|
tests/test_inference.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for inference utilities."""
|
| 2 |
+
|
| 3 |
+
from unittest.mock import MagicMock, patch
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from namer.inference import predict_number_name
|
| 9 |
+
from namer.models import NamerTransformer
|
| 10 |
+
from namer.utils import VOCABULARY, read_digits, int_to_digits
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TestPredictNumberName:
|
| 14 |
+
"""Tests for predict_number_name function."""
|
| 15 |
+
|
| 16 |
+
@pytest.fixture
|
| 17 |
+
def mock_model(self) -> MagicMock:
|
| 18 |
+
model = MagicMock(spec=NamerTransformer)
|
| 19 |
+
model.max_output_len = 20
|
| 20 |
+
model.vocab_size = len(VOCABULARY)
|
| 21 |
+
|
| 22 |
+
# Mock the device property
|
| 23 |
+
param = MagicMock()
|
| 24 |
+
param.device = torch.device("cpu")
|
| 25 |
+
model.parameters.return_value = iter([param])
|
| 26 |
+
|
| 27 |
+
return model
|
| 28 |
+
|
| 29 |
+
def test_basic_prediction(self, mock_model: MagicMock) -> None:
|
| 30 |
+
# Create fake logits that will select known tokens
|
| 31 |
+
# "one" is index 1 in VOCABULARY
|
| 32 |
+
fake_logits = torch.zeros(1, 20, len(VOCABULARY))
|
| 33 |
+
fake_logits[0, 0, 1] = 10.0 # "one"
|
| 34 |
+
fake_logits[0, 1, VOCABULARY.index("<EOS>")] = 10.0 # EOS
|
| 35 |
+
|
| 36 |
+
mock_model.return_value = fake_logits
|
| 37 |
+
mock_model.eval = MagicMock()
|
| 38 |
+
|
| 39 |
+
with patch("namer.inference.torch.no_grad"):
|
| 40 |
+
result = predict_number_name(mock_model, 1)
|
| 41 |
+
|
| 42 |
+
# Should decode to "one"
|
| 43 |
+
assert "one" in result.lower() or result.startswith("<")
|
| 44 |
+
|
| 45 |
+
def test_eos_stops_generation(self, mock_model: MagicMock) -> None:
|
| 46 |
+
# Logits that predict EOS immediately
|
| 47 |
+
fake_logits = torch.zeros(1, 20, len(VOCABULARY))
|
| 48 |
+
fake_logits[0, 0, VOCABULARY.index("<EOS>")] = 10.0
|
| 49 |
+
|
| 50 |
+
mock_model.return_value = fake_logits
|
| 51 |
+
mock_model.eval = MagicMock()
|
| 52 |
+
|
| 53 |
+
with patch("namer.inference.torch.no_grad"):
|
| 54 |
+
result = predict_number_name(mock_model, 0)
|
| 55 |
+
|
| 56 |
+
# Empty result when EOS is first
|
| 57 |
+
assert result == "" or result.startswith("<")
|
| 58 |
+
|
| 59 |
+
def test_device_override(self, mock_model: MagicMock) -> None:
|
| 60 |
+
fake_logits = torch.zeros(1, 20, len(VOCABULARY))
|
| 61 |
+
fake_logits[0, 0, 1] = 10.0
|
| 62 |
+
fake_logits[0, 1, VOCABULARY.index("<EOS>")] = 10.0
|
| 63 |
+
|
| 64 |
+
mock_model.return_value = fake_logits
|
| 65 |
+
mock_model.eval = MagicMock()
|
| 66 |
+
|
| 67 |
+
with patch("namer.inference.torch.no_grad"):
|
| 68 |
+
# Should not raise when device is specified
|
| 69 |
+
result = predict_number_name(mock_model, 1, device="cpu")
|
| 70 |
+
|
| 71 |
+
assert isinstance(result, str)
|
tests/test_models.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for model classes."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from namer.models import NamerTransformer, PositionalEncoding
|
| 7 |
+
from namer.utils import VOCABULARY
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TestPositionalEncoding:
|
| 11 |
+
"""Tests for PositionalEncoding module."""
|
| 12 |
+
|
| 13 |
+
def test_shape(self) -> None:
|
| 14 |
+
pe = PositionalEncoding(d_model=128)
|
| 15 |
+
x = torch.randn(2, 10, 128) # batch=2, seq=10, dim=128
|
| 16 |
+
out = pe(x)
|
| 17 |
+
assert out.shape == (2, 10, 128)
|
| 18 |
+
|
| 19 |
+
def test_adds_position(self) -> None:
|
| 20 |
+
pe = PositionalEncoding(d_model=64)
|
| 21 |
+
x = torch.zeros(1, 5, 64)
|
| 22 |
+
out = pe(x)
|
| 23 |
+
# Output should be non-zero due to positional encoding
|
| 24 |
+
assert not torch.allclose(out, x)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TestNamerTransformer:
|
| 28 |
+
"""Tests for NamerTransformer model."""
|
| 29 |
+
|
| 30 |
+
@pytest.fixture
|
| 31 |
+
def model(self) -> NamerTransformer:
|
| 32 |
+
return NamerTransformer(
|
| 33 |
+
vocab_size=len(VOCABULARY),
|
| 34 |
+
max_output_len=20,
|
| 35 |
+
d_model=64,
|
| 36 |
+
nhead=4,
|
| 37 |
+
num_encoder_layers=2,
|
| 38 |
+
dim_feedforward=128,
|
| 39 |
+
dropout=0.0,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def test_forward_shape(self, model: NamerTransformer) -> None:
|
| 43 |
+
batch_size = 4
|
| 44 |
+
seq_len = 10
|
| 45 |
+
digits = torch.randint(0, 10, (batch_size, seq_len))
|
| 46 |
+
|
| 47 |
+
logits = model(digits)
|
| 48 |
+
|
| 49 |
+
assert logits.shape == (batch_size, model.max_output_len, model.vocab_size)
|
| 50 |
+
|
| 51 |
+
def test_forward_with_padding(self, model: NamerTransformer) -> None:
|
| 52 |
+
batch_size = 2
|
| 53 |
+
seq_len = 10
|
| 54 |
+
digits = torch.full((batch_size, seq_len), 10) # All padding
|
| 55 |
+
digits[:, :5] = torch.randint(0, 10, (batch_size, 5))
|
| 56 |
+
|
| 57 |
+
logits = model(digits)
|
| 58 |
+
|
| 59 |
+
assert logits.shape == (batch_size, model.max_output_len, model.vocab_size)
|
| 60 |
+
|
| 61 |
+
def test_forward_with_negative_padding(self, model: NamerTransformer) -> None:
|
| 62 |
+
batch_size = 2
|
| 63 |
+
seq_len = 10
|
| 64 |
+
digits = torch.full((batch_size, seq_len), -1) # -1 padding
|
| 65 |
+
digits[:, :5] = torch.randint(0, 10, (batch_size, 5))
|
| 66 |
+
|
| 67 |
+
logits = model(digits)
|
| 68 |
+
|
| 69 |
+
assert logits.shape == (batch_size, model.max_output_len, model.vocab_size)
|
| 70 |
+
|
| 71 |
+
def test_output_is_logits(self, model: NamerTransformer) -> None:
|
| 72 |
+
digits = torch.randint(0, 10, (1, 5))
|
| 73 |
+
logits = model(digits)
|
| 74 |
+
|
| 75 |
+
# Logits should not be probabilities (no softmax applied)
|
| 76 |
+
assert not torch.all((logits >= 0) & (logits <= 1))
|
| 77 |
+
|
| 78 |
+
def test_gradient_flow(self, model: NamerTransformer) -> None:
|
| 79 |
+
digits = torch.randint(0, 10, (2, 5))
|
| 80 |
+
target = torch.randint(0, len(VOCABULARY), (2, model.max_output_len))
|
| 81 |
+
|
| 82 |
+
logits = model(digits)
|
| 83 |
+
loss = torch.nn.functional.cross_entropy(
|
| 84 |
+
logits.view(-1, model.vocab_size),
|
| 85 |
+
target.view(-1)
|
| 86 |
+
)
|
| 87 |
+
loss.backward()
|
| 88 |
+
|
| 89 |
+
# Check that gradients exist
|
| 90 |
+
for param in model.parameters():
|
| 91 |
+
assert param.grad is not None
|
tests/test_utils.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for utility functions."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
|
| 5 |
+
from namer.utils import (
|
| 6 |
+
EOS_IDX,
|
| 7 |
+
VOCABULARY,
|
| 8 |
+
decode,
|
| 9 |
+
digits_to_int,
|
| 10 |
+
encode,
|
| 11 |
+
int_to_digits,
|
| 12 |
+
read_digits,
|
| 13 |
+
read_double,
|
| 14 |
+
read_triplet,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TestIntToDigits:
|
| 19 |
+
"""Tests for int_to_digits function."""
|
| 20 |
+
|
| 21 |
+
def test_zero(self) -> None:
|
| 22 |
+
assert int_to_digits(0) == [0]
|
| 23 |
+
|
| 24 |
+
def test_positive(self) -> None:
|
| 25 |
+
assert int_to_digits(123) == [1, 2, 3]
|
| 26 |
+
assert int_to_digits(7) == [7]
|
| 27 |
+
|
| 28 |
+
def test_negative(self) -> None:
|
| 29 |
+
assert int_to_digits(-456) == [4, 5, 6]
|
| 30 |
+
|
| 31 |
+
def test_large_number(self) -> None:
|
| 32 |
+
assert int_to_digits(1002003) == [1, 0, 0, 2, 0, 0, 3]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class TestDigitsToInt:
|
| 36 |
+
"""Tests for digits_to_int function."""
|
| 37 |
+
|
| 38 |
+
def test_empty(self) -> None:
|
| 39 |
+
assert digits_to_int([]) == 0
|
| 40 |
+
|
| 41 |
+
def test_single_digit(self) -> None:
|
| 42 |
+
assert digits_to_int([5]) == 5
|
| 43 |
+
|
| 44 |
+
def test_multiple_digits(self) -> None:
|
| 45 |
+
assert digits_to_int([1, 2, 3]) == 123
|
| 46 |
+
|
| 47 |
+
def test_with_zeros(self) -> None:
|
| 48 |
+
assert digits_to_int([1, 0, 0, 2]) == 1002
|
| 49 |
+
|
| 50 |
+
def test_invalid_digit(self) -> None:
|
| 51 |
+
with pytest.raises(ValueError, match="Invalid digit"):
|
| 52 |
+
digits_to_int([10])
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class TestRoundTrip:
|
| 56 |
+
"""Tests for int_to_digits <-> digits_to_int round-trip."""
|
| 57 |
+
|
| 58 |
+
def test_round_trip(self) -> None:
|
| 59 |
+
for n in [0, 42, 123, 1000, 999999, 1000000]:
|
| 60 |
+
assert digits_to_int(int_to_digits(n)) == abs(n)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class TestReadDouble:
|
| 64 |
+
"""Tests for read_double function."""
|
| 65 |
+
|
| 66 |
+
def test_single_digit(self) -> None:
|
| 67 |
+
assert read_double(0, 7) == "seven"
|
| 68 |
+
assert read_double(0, 0) == "zero"
|
| 69 |
+
|
| 70 |
+
def test_teens(self) -> None:
|
| 71 |
+
assert read_double(1, 1) == "eleven"
|
| 72 |
+
assert read_double(1, 9) == "nineteen"
|
| 73 |
+
|
| 74 |
+
def test_tens(self) -> None:
|
| 75 |
+
assert read_double(3, 0) == "thirty"
|
| 76 |
+
assert read_double(5, 0) == "fifty"
|
| 77 |
+
|
| 78 |
+
def test_tens_and_ones(self) -> None:
|
| 79 |
+
assert read_double(2, 3) == "twenty three"
|
| 80 |
+
assert read_double(5, 9) == "fifty nine"
|
| 81 |
+
|
| 82 |
+
def test_invalid_digits(self) -> None:
|
| 83 |
+
with pytest.raises(ValueError, match="must be between 0 and 9"):
|
| 84 |
+
read_double(10, 5)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class TestReadTriplet:
|
| 88 |
+
"""Tests for read_triplet function."""
|
| 89 |
+
|
| 90 |
+
def test_hundreds(self) -> None:
|
| 91 |
+
assert read_triplet(1, 0, 6) == "one hundred six"
|
| 92 |
+
assert read_triplet(2, 0, 0) == "two hundred"
|
| 93 |
+
|
| 94 |
+
def test_zero_hundreds(self) -> None:
|
| 95 |
+
assert read_triplet(0, 5, 5) == "fifty five"
|
| 96 |
+
|
| 97 |
+
def test_all_zeros(self) -> None:
|
| 98 |
+
assert read_triplet(0, 0, 0) == "zero"
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class TestReadDigits:
|
| 102 |
+
"""Tests for read_digits function."""
|
| 103 |
+
|
| 104 |
+
def test_empty(self) -> None:
|
| 105 |
+
assert read_digits([]) == "zero"
|
| 106 |
+
|
| 107 |
+
def test_zero(self) -> None:
|
| 108 |
+
assert read_digits([0]) == "zero"
|
| 109 |
+
assert read_digits([0, 0, 0]) == "zero"
|
| 110 |
+
|
| 111 |
+
def test_single_digit(self) -> None:
|
| 112 |
+
assert read_digits([5]) == "five"
|
| 113 |
+
|
| 114 |
+
def test_double_digit(self) -> None:
|
| 115 |
+
assert read_digits([4, 2]) == "forty two"
|
| 116 |
+
|
| 117 |
+
def test_triple_digit(self) -> None:
|
| 118 |
+
assert read_digits([1, 2, 3]) == "one hundred twenty three"
|
| 119 |
+
|
| 120 |
+
def test_thousands(self) -> None:
|
| 121 |
+
assert read_digits([1, 0, 0, 0]) == "one thousand"
|
| 122 |
+
assert read_digits([1, 2, 3, 4]) == "one thousand two hundred thirty four"
|
| 123 |
+
|
| 124 |
+
def test_millions(self) -> None:
|
| 125 |
+
assert read_digits([1, 0, 0, 0, 0, 0, 0]) == "one million"
|
| 126 |
+
|
| 127 |
+
def test_complex(self) -> None:
|
| 128 |
+
# 1,234,567
|
| 129 |
+
digits = [1, 2, 3, 4, 5, 6, 7]
|
| 130 |
+
result = read_digits(digits)
|
| 131 |
+
assert "one million" in result
|
| 132 |
+
assert "two hundred thirty four thousand" in result
|
| 133 |
+
assert "five hundred sixty seven" in result
|
| 134 |
+
|
| 135 |
+
def test_invalid_digit(self) -> None:
|
| 136 |
+
with pytest.raises(ValueError, match="must be digits"):
|
| 137 |
+
read_digits([1, 10, 3])
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class TestEncode:
|
| 141 |
+
"""Tests for encode function."""
|
| 142 |
+
|
| 143 |
+
def test_simple(self) -> None:
|
| 144 |
+
indices = encode("one million")
|
| 145 |
+
assert len(indices) == 2
|
| 146 |
+
assert all(0 <= i < len(VOCABULARY) for i in indices)
|
| 147 |
+
|
| 148 |
+
def test_multi_word(self) -> None:
|
| 149 |
+
indices = encode("twenty three")
|
| 150 |
+
assert len(indices) == 2
|
| 151 |
+
|
| 152 |
+
def test_empty(self) -> None:
|
| 153 |
+
assert encode("") == []
|
| 154 |
+
assert encode(" ") == []
|
| 155 |
+
|
| 156 |
+
def test_unknown_word(self) -> None:
|
| 157 |
+
with pytest.raises(ValueError, match="Unknown word"):
|
| 158 |
+
encode("unknown")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
class TestDecode:
|
| 162 |
+
"""Tests for decode function."""
|
| 163 |
+
|
| 164 |
+
def test_simple(self) -> None:
|
| 165 |
+
encoded = encode("one million")
|
| 166 |
+
assert decode(encoded) == "one million"
|
| 167 |
+
|
| 168 |
+
def test_with_eos(self) -> None:
|
| 169 |
+
encoded = encode("one million") + [EOS_IDX]
|
| 170 |
+
assert decode(encoded) == "one million"
|
| 171 |
+
|
| 172 |
+
def test_empty(self) -> None:
|
| 173 |
+
assert decode([]) == ""
|
| 174 |
+
|
| 175 |
+
def test_invalid_index(self) -> None:
|
| 176 |
+
with pytest.raises(ValueError, match="out of range"):
|
| 177 |
+
decode([9999])
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class TestEncodeDecodeRoundTrip:
|
| 181 |
+
"""Tests for encode/decode round-trip."""
|
| 182 |
+
|
| 183 |
+
def test_round_trip(self) -> None:
|
| 184 |
+
test_cases = [
|
| 185 |
+
"one million",
|
| 186 |
+
"twenty three",
|
| 187 |
+
"one hundred twenty three",
|
| 188 |
+
"zero",
|
| 189 |
+
"nine hundred nineteen",
|
| 190 |
+
]
|
| 191 |
+
for text in test_cases:
|
| 192 |
+
encoded = encode(text)
|
| 193 |
+
assert decode(encoded) == text
|