Spaces:
Sleeping
Sleeping
Commit ·
8345416
1
Parent(s): 32ba924
Initial commit
Browse files- .gitignore +58 -0
- .python-version +1 -0
- README.md +8 -0
- Smollm_135.ipynb +1002 -0
- app.py +160 -0
- check_cuda.py +9 -0
- checkpoint_info.txt +2 -0
- inference.py +126 -0
- inspect_checkpoint.py +19 -0
- main.py +6 -0
- model.py +239 -0
- model/smollm_135_checkpoint.pth +3 -0
- profile_app.py +33 -0
- pyproject.toml +12 -0
- requirements.txt +4 -0
- test_inference.py +25 -0
- test_kv_cache.py +45 -0
- test_model.py +22 -0
- test_tiktoken.py +16 -0
.gitignore
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
*.so
|
| 6 |
+
.Python
|
| 7 |
+
build/
|
| 8 |
+
develop-eggs/
|
| 9 |
+
dist/
|
| 10 |
+
downloads/
|
| 11 |
+
eggs/
|
| 12 |
+
.eggs/
|
| 13 |
+
lib/
|
| 14 |
+
lib64/
|
| 15 |
+
parts/
|
| 16 |
+
sdist/
|
| 17 |
+
var/
|
| 18 |
+
wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
|
| 23 |
+
# Virtual environments
|
| 24 |
+
venv/
|
| 25 |
+
env/
|
| 26 |
+
ENV/
|
| 27 |
+
.venv
|
| 28 |
+
|
| 29 |
+
# IDE
|
| 30 |
+
.vscode/
|
| 31 |
+
.idea/
|
| 32 |
+
*.swp
|
| 33 |
+
*.swo
|
| 34 |
+
*~
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# Data files
|
| 39 |
+
input.txt
|
| 40 |
+
*.csv
|
| 41 |
+
*.json
|
| 42 |
+
|
| 43 |
+
# Jupyter Notebook
|
| 44 |
+
.ipynb_checkpoints/
|
| 45 |
+
*.ipynb_checkpoints/
|
| 46 |
+
|
| 47 |
+
# OS
|
| 48 |
+
.DS_Store
|
| 49 |
+
Thumbs.db
|
| 50 |
+
|
| 51 |
+
# Logs
|
| 52 |
+
*.log
|
| 53 |
+
logs/
|
| 54 |
+
|
| 55 |
+
# Hugging Face cache
|
| 56 |
+
.cache/
|
| 57 |
+
hf_cache/
|
| 58 |
+
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.11
|
README.md
CHANGED
|
@@ -12,3 +12,11 @@ short_description: Smollm-135 base model trained with dummy data with speedups
|
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
| 15 |
+
|
| 16 |
+
## Running Locally
|
| 17 |
+
|
| 18 |
+
To run the application locally, use the following command:
|
| 19 |
+
|
| 20 |
+
```bash
|
| 21 |
+
uv run app.py
|
| 22 |
+
```
|
Smollm_135.ipynb
ADDED
|
@@ -0,0 +1,1002 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "d8ad4585",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"name": "stdout",
|
| 11 |
+
"output_type": "stream",
|
| 12 |
+
"text": [
|
| 13 |
+
"using device: cuda\n",
|
| 14 |
+
"Model parameters: 135.18M\n",
|
| 15 |
+
"loaded 338025 tokens\n",
|
| 16 |
+
"1 epoch = 165 batches\n",
|
| 17 |
+
"Starting training...\n",
|
| 18 |
+
"step 0 | loss: 10.8725 | dt: 1274.74ms | tok/sec: 1606.61\n",
|
| 19 |
+
"step 10 | loss: 7.9941 | dt: 2626.35ms | tok/sec: 779.79\n",
|
| 20 |
+
"step 20 | loss: 6.9462 | dt: 2627.22ms | tok/sec: 779.53\n",
|
| 21 |
+
"step 30 | loss: 6.7740 | dt: 3246.48ms | tok/sec: 630.84\n",
|
| 22 |
+
"step 40 | loss: 6.9956 | dt: 2638.15ms | tok/sec: 776.30\n",
|
| 23 |
+
"step 50 | loss: 6.8735 | dt: 2615.98ms | tok/sec: 782.88\n",
|
| 24 |
+
"step 60 | loss: 6.5163 | dt: 2640.97ms | tok/sec: 775.47\n",
|
| 25 |
+
"step 70 | loss: 6.5162 | dt: 2637.36ms | tok/sec: 776.53\n",
|
| 26 |
+
"step 80 | loss: 6.4836 | dt: 2661.23ms | tok/sec: 769.57\n",
|
| 27 |
+
"step 90 | loss: 6.5255 | dt: 2643.36ms | tok/sec: 774.77\n",
|
| 28 |
+
"step 100 | loss: 6.2876 | dt: 2677.88ms | tok/sec: 764.78\n",
|
| 29 |
+
"step 110 | loss: 6.5272 | dt: 2666.03ms | tok/sec: 768.18\n",
|
| 30 |
+
"step 120 | loss: 6.1898 | dt: 2824.14ms | tok/sec: 725.18\n",
|
| 31 |
+
"step 130 | loss: 6.1057 | dt: 2646.03ms | tok/sec: 773.99\n",
|
| 32 |
+
"step 140 | loss: 5.9576 | dt: 2632.82ms | tok/sec: 777.87\n",
|
| 33 |
+
"step 150 | loss: 6.4212 | dt: 2637.88ms | tok/sec: 776.38\n",
|
| 34 |
+
"step 160 | loss: 6.4857 | dt: 2641.72ms | tok/sec: 775.25\n",
|
| 35 |
+
"step 170 | loss: 6.2381 | dt: 2624.55ms | tok/sec: 780.32\n",
|
| 36 |
+
"step 180 | loss: 5.6074 | dt: 2650.12ms | tok/sec: 772.80\n",
|
| 37 |
+
"step 190 | loss: 6.0601 | dt: 2657.61ms | tok/sec: 770.62\n",
|
| 38 |
+
"step 200 | loss: 5.5407 | dt: 2856.74ms | tok/sec: 716.90\n",
|
| 39 |
+
"step 210 | loss: 5.8250 | dt: 2647.57ms | tok/sec: 773.54\n",
|
| 40 |
+
"step 220 | loss: 6.0356 | dt: 2635.08ms | tok/sec: 777.21\n",
|
| 41 |
+
"step 230 | loss: 5.7742 | dt: 2637.31ms | tok/sec: 776.55\n",
|
| 42 |
+
"step 240 | loss: 5.8564 | dt: 2645.37ms | tok/sec: 774.18\n",
|
| 43 |
+
"step 250 | loss: 5.4802 | dt: 2660.91ms | tok/sec: 769.66\n",
|
| 44 |
+
"step 260 | loss: 5.6751 | dt: 2632.88ms | tok/sec: 777.85\n",
|
| 45 |
+
"step 270 | loss: 5.9273 | dt: 2733.56ms | tok/sec: 749.21\n",
|
| 46 |
+
"step 280 | loss: 5.9138 | dt: 2626.31ms | tok/sec: 779.80\n",
|
| 47 |
+
"step 290 | loss: 5.6861 | dt: 2638.67ms | tok/sec: 776.15\n",
|
| 48 |
+
"step 300 | loss: 5.2012 | dt: 2642.88ms | tok/sec: 774.91\n",
|
| 49 |
+
"step 310 | loss: 5.6114 | dt: 2649.90ms | tok/sec: 772.86\n",
|
| 50 |
+
"step 320 | loss: 5.0033 | dt: 2688.08ms | tok/sec: 761.88\n",
|
| 51 |
+
"step 330 | loss: 5.6259 | dt: 2682.14ms | tok/sec: 763.57\n",
|
| 52 |
+
"step 340 | loss: 5.1127 | dt: 2650.79ms | tok/sec: 772.60\n",
|
| 53 |
+
"step 350 | loss: 5.3045 | dt: 2678.00ms | tok/sec: 764.75\n",
|
| 54 |
+
"step 360 | loss: 5.2118 | dt: 2666.44ms | tok/sec: 768.06\n",
|
| 55 |
+
"step 370 | loss: 5.4723 | dt: 2639.73ms | tok/sec: 775.84\n",
|
| 56 |
+
"step 380 | loss: 5.4257 | dt: 2653.49ms | tok/sec: 771.81\n",
|
| 57 |
+
"step 390 | loss: 5.0813 | dt: 2623.92ms | tok/sec: 780.51\n",
|
| 58 |
+
"step 400 | loss: 5.0538 | dt: 2637.88ms | tok/sec: 776.38\n",
|
| 59 |
+
"step 410 | loss: 5.0351 | dt: 2708.26ms | tok/sec: 756.20\n",
|
| 60 |
+
"step 420 | loss: 5.0659 | dt: 2661.11ms | tok/sec: 769.60\n",
|
| 61 |
+
"step 430 | loss: 4.9364 | dt: 2685.40ms | tok/sec: 762.64\n",
|
| 62 |
+
"step 440 | loss: 5.3093 | dt: 2632.10ms | tok/sec: 778.08\n",
|
| 63 |
+
"step 450 | loss: 5.0675 | dt: 2653.00ms | tok/sec: 771.96\n",
|
| 64 |
+
"step 460 | loss: 4.9092 | dt: 2668.73ms | tok/sec: 767.41\n",
|
| 65 |
+
"step 470 | loss: 4.5785 | dt: 2700.87ms | tok/sec: 758.28\n",
|
| 66 |
+
"step 480 | loss: 5.1269 | dt: 2705.93ms | tok/sec: 756.86\n",
|
| 67 |
+
"step 490 | loss: 5.4856 | dt: 2653.12ms | tok/sec: 771.92\n",
|
| 68 |
+
"step 500 | loss: 5.2739 | dt: 2679.78ms | tok/sec: 764.24\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"--- Generating text at step 500 ---\n",
|
| 71 |
+
"!ua any cons mocking\n",
|
| 72 |
+
"I they have we be passion so his all\n",
|
| 73 |
+
"Redpt' grip:' the grave,An thy work!\n",
|
| 74 |
+
"For embr, old Capitol, goodout than coming,\n",
|
| 75 |
+
"When thet vir Rome daughter to the chance\n",
|
| 76 |
+
"-----------------------------------\n",
|
| 77 |
+
"\n",
|
| 78 |
+
"step 510 | loss: 4.6435 | dt: 2697.87ms | tok/sec: 759.12\n",
|
| 79 |
+
"step 520 | loss: 4.9901 | dt: 2688.89ms | tok/sec: 761.65\n",
|
| 80 |
+
"step 530 | loss: 4.5844 | dt: 2703.14ms | tok/sec: 757.64\n",
|
| 81 |
+
"step 540 | loss: 4.9779 | dt: 2663.58ms | tok/sec: 768.89\n",
|
| 82 |
+
"step 550 | loss: 5.1966 | dt: 2714.69ms | tok/sec: 754.42\n",
|
| 83 |
+
"step 560 | loss: 4.9851 | dt: 2695.66ms | tok/sec: 759.74\n",
|
| 84 |
+
"step 570 | loss: 5.2129 | dt: 2698.44ms | tok/sec: 758.96\n",
|
| 85 |
+
"step 580 | loss: 4.7397 | dt: 2714.69ms | tok/sec: 754.41\n",
|
| 86 |
+
"step 590 | loss: 4.9552 | dt: 2692.72ms | tok/sec: 760.57\n",
|
| 87 |
+
"step 600 | loss: 5.2714 | dt: 2713.71ms | tok/sec: 754.69\n",
|
| 88 |
+
"step 610 | loss: 5.2913 | dt: 2653.37ms | tok/sec: 771.85\n",
|
| 89 |
+
"step 620 | loss: 5.0200 | dt: 2701.37ms | tok/sec: 758.13\n",
|
| 90 |
+
"step 630 | loss: 4.3609 | dt: 2689.80ms | tok/sec: 761.39\n",
|
| 91 |
+
"step 640 | loss: 4.9107 | dt: 2769.70ms | tok/sec: 739.43\n",
|
| 92 |
+
"step 650 | loss: 4.3624 | dt: 2707.07ms | tok/sec: 756.54\n",
|
| 93 |
+
"step 660 | loss: 5.0022 | dt: 2720.41ms | tok/sec: 752.83\n",
|
| 94 |
+
"step 670 | loss: 4.4913 | dt: 2680.64ms | tok/sec: 764.00\n",
|
| 95 |
+
"step 680 | loss: 4.7648 | dt: 2673.28ms | tok/sec: 766.10\n",
|
| 96 |
+
"step 690 | loss: 4.6267 | dt: 2671.13ms | tok/sec: 766.72\n",
|
| 97 |
+
"step 700 | loss: 4.8468 | dt: 2683.00ms | tok/sec: 763.33\n",
|
| 98 |
+
"step 710 | loss: 4.8544 | dt: 2678.93ms | tok/sec: 764.48\n",
|
| 99 |
+
"step 720 | loss: 4.5148 | dt: 2698.22ms | tok/sec: 759.02\n",
|
| 100 |
+
"step 730 | loss: 4.4280 | dt: 2694.01ms | tok/sec: 760.21\n",
|
| 101 |
+
"step 740 | loss: 4.4265 | dt: 2681.62ms | tok/sec: 763.72\n",
|
| 102 |
+
"step 750 | loss: 4.4757 | dt: 2671.03ms | tok/sec: 766.75\n",
|
| 103 |
+
"step 760 | loss: 4.3867 | dt: 2712.16ms | tok/sec: 755.12\n",
|
| 104 |
+
"step 770 | loss: 4.8252 | dt: 2681.40ms | tok/sec: 763.78\n",
|
| 105 |
+
"step 780 | loss: 4.6916 | dt: 2861.17ms | tok/sec: 715.79\n",
|
| 106 |
+
"step 790 | loss: 4.3555 | dt: 2682.62ms | tok/sec: 763.43\n",
|
| 107 |
+
"step 800 | loss: 4.0581 | dt: 2695.10ms | tok/sec: 759.90\n",
|
| 108 |
+
"step 810 | loss: 4.5024 | dt: 2718.02ms | tok/sec: 753.49\n",
|
| 109 |
+
"step 820 | loss: 4.9491 | dt: 2688.01ms | tok/sec: 761.90\n",
|
| 110 |
+
"step 830 | loss: 4.7404 | dt: 2688.48ms | tok/sec: 761.77\n",
|
| 111 |
+
"step 840 | loss: 4.1571 | dt: 2683.11ms | tok/sec: 763.29\n",
|
| 112 |
+
"step 850 | loss: 4.2970 | dt: 2673.24ms | tok/sec: 766.11\n",
|
| 113 |
+
"step 860 | loss: 4.1351 | dt: 2673.66ms | tok/sec: 765.99\n",
|
| 114 |
+
"step 870 | loss: 4.5339 | dt: 2723.79ms | tok/sec: 751.89\n",
|
| 115 |
+
"step 880 | loss: 4.7270 | dt: 2655.06ms | tok/sec: 771.36\n",
|
| 116 |
+
"step 890 | loss: 4.5174 | dt: 2654.60ms | tok/sec: 771.49\n",
|
| 117 |
+
"step 900 | loss: 4.7254 | dt: 2671.42ms | tok/sec: 766.63\n",
|
| 118 |
+
"step 910 | loss: 4.2173 | dt: 2706.20ms | tok/sec: 756.78\n",
|
| 119 |
+
"step 920 | loss: 4.4660 | dt: 2687.33ms | tok/sec: 762.09\n",
|
| 120 |
+
"step 930 | loss: 4.8292 | dt: 2668.57ms | tok/sec: 767.45\n",
|
| 121 |
+
"step 940 | loss: 4.8714 | dt: 2681.79ms | tok/sec: 763.67\n",
|
| 122 |
+
"step 950 | loss: 4.5977 | dt: 2689.44ms | tok/sec: 761.50\n",
|
| 123 |
+
"step 960 | loss: 3.9432 | dt: 2693.87ms | tok/sec: 760.24\n",
|
| 124 |
+
"step 970 | loss: 4.4277 | dt: 2685.85ms | tok/sec: 762.51\n",
|
| 125 |
+
"step 980 | loss: 3.9359 | dt: 2665.88ms | tok/sec: 768.23\n",
|
| 126 |
+
"step 990 | loss: 4.6091 | dt: 2693.13ms | tok/sec: 760.45\n",
|
| 127 |
+
"step 1000 | loss: 4.0868 | dt: 2702.22ms | tok/sec: 757.90\n",
|
| 128 |
+
"\n",
|
| 129 |
+
"--- Generating text at step 1000 ---\n",
|
| 130 |
+
"! then he did you\n",
|
| 131 |
+
"And pray them have be devour carries meet\n",
|
| 132 |
+
"And dare man, general's true.\n",
|
| 133 |
+
"\n",
|
| 134 |
+
"LidINIUS:\n",
|
| 135 |
+
"'I can I hope with you were access with call me.\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"HONAN\n",
|
| 138 |
+
"-----------------------------------\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"step 1010 | loss: 4.4185 | dt: 2691.11ms | tok/sec: 761.03\n",
|
| 141 |
+
"step 1020 | loss: 4.2212 | dt: 2687.50ms | tok/sec: 762.05\n",
|
| 142 |
+
"step 1030 | loss: 4.4616 | dt: 2670.47ms | tok/sec: 766.91\n",
|
| 143 |
+
"step 1040 | loss: 4.4004 | dt: 2662.88ms | tok/sec: 769.09\n",
|
| 144 |
+
"step 1050 | loss: 4.1682 | dt: 2705.36ms | tok/sec: 757.02\n",
|
| 145 |
+
"step 1060 | loss: 4.0242 | dt: 2722.96ms | tok/sec: 752.12\n",
|
| 146 |
+
"step 1070 | loss: 4.0509 | dt: 2717.44ms | tok/sec: 753.65\n",
|
| 147 |
+
"step 1080 | loss: 4.0544 | dt: 2667.92ms | tok/sec: 767.64\n",
|
| 148 |
+
"step 1090 | loss: 4.0126 | dt: 2740.84ms | tok/sec: 747.22\n",
|
| 149 |
+
"step 1100 | loss: 4.4749 | dt: 2762.05ms | tok/sec: 741.48\n",
|
| 150 |
+
"step 1110 | loss: 4.3578 | dt: 2634.24ms | tok/sec: 777.45\n",
|
| 151 |
+
"step 1120 | loss: 4.0779 | dt: 2677.47ms | tok/sec: 764.90\n",
|
| 152 |
+
"step 1130 | loss: 3.7411 | dt: 2685.44ms | tok/sec: 762.63\n",
|
| 153 |
+
"step 1140 | loss: 4.1268 | dt: 2690.60ms | tok/sec: 761.17\n",
|
| 154 |
+
"step 1150 | loss: 4.5661 | dt: 2719.96ms | tok/sec: 752.95\n",
|
| 155 |
+
"step 1160 | loss: 4.3289 | dt: 2666.21ms | tok/sec: 768.13\n",
|
| 156 |
+
"step 1170 | loss: 3.8129 | dt: 2671.28ms | tok/sec: 766.67\n",
|
| 157 |
+
"step 1180 | loss: 3.8706 | dt: 2717.71ms | tok/sec: 753.58\n",
|
| 158 |
+
"step 1190 | loss: 3.8226 | dt: 2669.59ms | tok/sec: 767.16\n",
|
| 159 |
+
"step 1200 | loss: 4.1762 | dt: 2666.23ms | tok/sec: 768.13\n",
|
| 160 |
+
"step 1210 | loss: 4.3757 | dt: 2708.58ms | tok/sec: 756.12\n",
|
| 161 |
+
"step 1220 | loss: 4.1351 | dt: 2685.38ms | tok/sec: 762.65\n",
|
| 162 |
+
"step 1230 | loss: 4.3564 | dt: 2687.16ms | tok/sec: 762.14\n",
|
| 163 |
+
"step 1240 | loss: 3.7988 | dt: 2675.93ms | tok/sec: 765.34\n",
|
| 164 |
+
"step 1250 | loss: 4.1403 | dt: 2686.86ms | tok/sec: 762.23\n",
|
| 165 |
+
"step 1260 | loss: 4.4226 | dt: 2667.90ms | tok/sec: 767.64\n",
|
| 166 |
+
"step 1270 | loss: 4.4887 | dt: 2739.30ms | tok/sec: 747.64\n",
|
| 167 |
+
"step 1280 | loss: 4.2127 | dt: 2691.33ms | tok/sec: 760.96\n",
|
| 168 |
+
"step 1290 | loss: 3.6873 | dt: 2672.71ms | tok/sec: 766.26\n",
|
| 169 |
+
"step 1300 | loss: 4.0008 | dt: 2696.75ms | tok/sec: 759.43\n",
|
| 170 |
+
"step 1310 | loss: 3.6128 | dt: 2662.96ms | tok/sec: 769.07\n",
|
| 171 |
+
"step 1320 | loss: 4.2900 | dt: 2656.27ms | tok/sec: 771.01\n",
|
| 172 |
+
"step 1330 | loss: 3.8012 | dt: 2660.07ms | tok/sec: 769.91\n",
|
| 173 |
+
"step 1340 | loss: 4.1523 | dt: 2642.03ms | tok/sec: 775.16\n",
|
| 174 |
+
"step 1350 | loss: 3.9300 | dt: 2683.45ms | tok/sec: 763.20\n",
|
| 175 |
+
"step 1360 | loss: 4.1542 | dt: 2639.64ms | tok/sec: 775.86\n",
|
| 176 |
+
"step 1370 | loss: 4.0666 | dt: 2660.70ms | tok/sec: 769.72\n",
|
| 177 |
+
"step 1380 | loss: 3.8962 | dt: 2655.70ms | tok/sec: 771.17\n",
|
| 178 |
+
"step 1390 | loss: 3.7736 | dt: 2637.29ms | tok/sec: 776.55\n",
|
| 179 |
+
"step 1400 | loss: 3.7997 | dt: 2647.38ms | tok/sec: 773.59\n",
|
| 180 |
+
"step 1410 | loss: 3.7525 | dt: 2629.39ms | tok/sec: 778.89\n",
|
| 181 |
+
"step 1420 | loss: 3.6936 | dt: 2688.40ms | tok/sec: 761.79\n",
|
| 182 |
+
"step 1430 | loss: 4.1674 | dt: 2638.92ms | tok/sec: 776.07\n",
|
| 183 |
+
"step 1440 | loss: 4.1046 | dt: 2648.60ms | tok/sec: 773.24\n",
|
| 184 |
+
"step 1450 | loss: 3.8131 | dt: 2636.89ms | tok/sec: 776.67\n",
|
| 185 |
+
"step 1460 | loss: 3.4866 | dt: 2659.35ms | tok/sec: 770.11\n",
|
| 186 |
+
"step 1470 | loss: 3.7860 | dt: 2638.77ms | tok/sec: 776.12\n",
|
| 187 |
+
"step 1480 | loss: 4.1924 | dt: 2632.46ms | tok/sec: 777.98\n",
|
| 188 |
+
"step 1490 | loss: 4.0126 | dt: 2635.47ms | tok/sec: 777.09\n",
|
| 189 |
+
"step 1500 | loss: 3.4828 | dt: 2633.86ms | tok/sec: 777.57\n",
|
| 190 |
+
"\n",
|
| 191 |
+
"--- Generating text at step 1500 ---\n",
|
| 192 |
+
"! we rid,\n",
|
| 193 |
+
"my thing was so done.\n",
|
| 194 |
+
"\n",
|
| 195 |
+
"GLUUS:\n",
|
| 196 |
+
"F mercy in mocked itUS;\n",
|
| 197 |
+
"But rest at Rome,iances for him: but\n",
|
| 198 |
+
"Masters are worse, proud to the people,\n",
|
| 199 |
+
"And for your\n",
|
| 200 |
+
"-----------------------------------\n",
|
| 201 |
+
"\n",
|
| 202 |
+
"step 1510 | loss: 3.6085 | dt: 2699.83ms | tok/sec: 758.57\n",
|
| 203 |
+
"step 1520 | loss: 3.5473 | dt: 2703.50ms | tok/sec: 757.54\n",
|
| 204 |
+
"step 1530 | loss: 3.8369 | dt: 2704.69ms | tok/sec: 757.20\n",
|
| 205 |
+
"step 1540 | loss: 4.0431 | dt: 2699.75ms | tok/sec: 758.59\n",
|
| 206 |
+
"step 1550 | loss: 3.8126 | dt: 2697.98ms | tok/sec: 759.09\n",
|
| 207 |
+
"step 1560 | loss: 4.0581 | dt: 2708.62ms | tok/sec: 756.10\n",
|
| 208 |
+
"step 1570 | loss: 3.4292 | dt: 2701.01ms | tok/sec: 758.23\n",
|
| 209 |
+
"step 1580 | loss: 3.8859 | dt: 2701.13ms | tok/sec: 758.20\n",
|
| 210 |
+
"step 1590 | loss: 4.0481 | dt: 2708.01ms | tok/sec: 756.27\n",
|
| 211 |
+
"step 1600 | loss: 4.1384 | dt: 2723.99ms | tok/sec: 751.84\n",
|
| 212 |
+
"step 1610 | loss: 3.9161 | dt: 2718.54ms | tok/sec: 753.35\n",
|
| 213 |
+
"step 1620 | loss: 3.3759 | dt: 2728.75ms | tok/sec: 750.53\n",
|
| 214 |
+
"step 1630 | loss: 3.6518 | dt: 2700.32ms | tok/sec: 758.43\n",
|
| 215 |
+
"step 1640 | loss: 3.3208 | dt: 2696.71ms | tok/sec: 759.44\n",
|
| 216 |
+
"step 1650 | loss: 3.9156 | dt: 2707.93ms | tok/sec: 756.30\n",
|
| 217 |
+
"step 1660 | loss: 3.5364 | dt: 2725.21ms | tok/sec: 751.50\n",
|
| 218 |
+
"step 1670 | loss: 3.8675 | dt: 2708.14ms | tok/sec: 756.24\n",
|
| 219 |
+
"step 1680 | loss: 3.6225 | dt: 2702.13ms | tok/sec: 757.92\n",
|
| 220 |
+
"step 1690 | loss: 3.8710 | dt: 2694.93ms | tok/sec: 759.95\n",
|
| 221 |
+
"step 1700 | loss: 3.7588 | dt: 2696.90ms | tok/sec: 759.39\n",
|
| 222 |
+
"step 1710 | loss: 3.6354 | dt: 2727.15ms | tok/sec: 750.97\n",
|
| 223 |
+
"step 1720 | loss: 3.5004 | dt: 2710.80ms | tok/sec: 755.50\n",
|
| 224 |
+
"step 1730 | loss: 3.5569 | dt: 2736.50ms | tok/sec: 748.40\n",
|
| 225 |
+
"step 1740 | loss: 3.4937 | dt: 2700.68ms | tok/sec: 758.33\n",
|
| 226 |
+
"step 1750 | loss: 3.4585 | dt: 2704.23ms | tok/sec: 757.33\n",
|
| 227 |
+
"step 1760 | loss: 3.8912 | dt: 2718.99ms | tok/sec: 753.22\n",
|
| 228 |
+
"step 1770 | loss: 3.9121 | dt: 2759.54ms | tok/sec: 742.15\n",
|
| 229 |
+
"step 1780 | loss: 3.5978 | dt: 2701.65ms | tok/sec: 758.06\n",
|
| 230 |
+
"step 1790 | loss: 3.2438 | dt: 2705.15ms | tok/sec: 757.08\n",
|
| 231 |
+
"step 1800 | loss: 3.4297 | dt: 2697.91ms | tok/sec: 759.11\n",
|
| 232 |
+
"step 1810 | loss: 3.8631 | dt: 2723.87ms | tok/sec: 751.87\n",
|
| 233 |
+
"step 1820 | loss: 3.7029 | dt: 2717.25ms | tok/sec: 753.70\n",
|
| 234 |
+
"step 1830 | loss: 3.2425 | dt: 2727.89ms | tok/sec: 750.76\n",
|
| 235 |
+
"step 1840 | loss: 3.3572 | dt: 2722.87ms | tok/sec: 752.15\n",
|
| 236 |
+
"step 1850 | loss: 3.2716 | dt: 2707.57ms | tok/sec: 756.40\n",
|
| 237 |
+
"step 1860 | loss: 3.5134 | dt: 2699.27ms | tok/sec: 758.73\n",
|
| 238 |
+
"step 1870 | loss: 3.7097 | dt: 2723.70ms | tok/sec: 751.92\n",
|
| 239 |
+
"step 1880 | loss: 3.5368 | dt: 2721.65ms | tok/sec: 752.48\n",
|
| 240 |
+
"step 1890 | loss: 3.7578 | dt: 2698.47ms | tok/sec: 758.95\n",
|
| 241 |
+
"step 1900 | loss: 3.2360 | dt: 2703.24ms | tok/sec: 757.61\n",
|
| 242 |
+
"step 1910 | loss: 3.5294 | dt: 2695.89ms | tok/sec: 759.68\n",
|
| 243 |
+
"step 1920 | loss: 3.7359 | dt: 2723.54ms | tok/sec: 751.96\n",
|
| 244 |
+
"step 1930 | loss: 3.8138 | dt: 2696.43ms | tok/sec: 759.52\n",
|
| 245 |
+
"step 1940 | loss: 3.6550 | dt: 2704.04ms | tok/sec: 757.38\n",
|
| 246 |
+
"step 1950 | loss: 3.0794 | dt: 2704.37ms | tok/sec: 757.29\n",
|
| 247 |
+
"step 1960 | loss: 3.3290 | dt: 2702.23ms | tok/sec: 757.89\n",
|
| 248 |
+
"step 1970 | loss: 3.0204 | dt: 2734.40ms | tok/sec: 748.98\n",
|
| 249 |
+
"step 1980 | loss: 3.5516 | dt: 2703.91ms | tok/sec: 757.42\n",
|
| 250 |
+
"step 1990 | loss: 3.3070 | dt: 2706.01ms | tok/sec: 756.83\n",
|
| 251 |
+
"step 2000 | loss: 3.5758 | dt: 2726.00ms | tok/sec: 751.28\n",
|
| 252 |
+
"\n",
|
| 253 |
+
"--- Generating text at step 2000 ---\n",
|
| 254 |
+
"!,\n",
|
| 255 |
+
"if I put thee due by! see I do not weep\n",
|
| 256 |
+
"Thy thing in this heels hits them,\n",
|
| 257 |
+
"That neverYour valour should be in pride,\n",
|
| 258 |
+
"Call him more deity tears;Signiorio\n",
|
| 259 |
+
"In time of tears\n",
|
| 260 |
+
"-----------------------------------\n",
|
| 261 |
+
"\n",
|
| 262 |
+
"step 2010 | loss: 3.3916 | dt: 2700.63ms | tok/sec: 758.34\n",
|
| 263 |
+
"step 2020 | loss: 3.5364 | dt: 2709.72ms | tok/sec: 755.80\n",
|
| 264 |
+
"step 2030 | loss: 3.4518 | dt: 2708.71ms | tok/sec: 756.08\n",
|
| 265 |
+
"step 2040 | loss: 3.2968 | dt: 2699.77ms | tok/sec: 758.58\n",
|
| 266 |
+
"step 2050 | loss: 3.1897 | dt: 2705.14ms | tok/sec: 757.08\n",
|
| 267 |
+
"step 2060 | loss: 3.2544 | dt: 2695.59ms | tok/sec: 759.76\n",
|
| 268 |
+
"step 2070 | loss: 3.1843 | dt: 2699.45ms | tok/sec: 758.67\n",
|
| 269 |
+
"step 2080 | loss: 3.1527 | dt: 2697.34ms | tok/sec: 759.27\n",
|
| 270 |
+
"step 2090 | loss: 3.5361 | dt: 2698.70ms | tok/sec: 758.88\n",
|
| 271 |
+
"step 2100 | loss: 3.6072 | dt: 2698.23ms | tok/sec: 759.02\n",
|
| 272 |
+
"step 2110 | loss: 3.2871 | dt: 2697.04ms | tok/sec: 759.35\n",
|
| 273 |
+
"step 2120 | loss: 2.9393 | dt: 2715.87ms | tok/sec: 754.09\n",
|
| 274 |
+
"step 2130 | loss: 3.0962 | dt: 2706.66ms | tok/sec: 756.65\n",
|
| 275 |
+
"step 2140 | loss: 3.4452 | dt: 2705.67ms | tok/sec: 756.93\n",
|
| 276 |
+
"step 2150 | loss: 3.2900 | dt: 2707.68ms | tok/sec: 756.37\n",
|
| 277 |
+
"step 2160 | loss: 2.9626 | dt: 2707.08ms | tok/sec: 756.54\n",
|
| 278 |
+
"step 2170 | loss: 3.0681 | dt: 2702.33ms | tok/sec: 757.86\n",
|
| 279 |
+
"step 2180 | loss: 2.9747 | dt: 2702.98ms | tok/sec: 757.68\n",
|
| 280 |
+
"step 2190 | loss: 3.1063 | dt: 2716.67ms | tok/sec: 753.87\n",
|
| 281 |
+
"step 2200 | loss: 3.3207 | dt: 2703.48ms | tok/sec: 757.54\n",
|
| 282 |
+
"step 2210 | loss: 3.1918 | dt: 2701.67ms | tok/sec: 758.05\n",
|
| 283 |
+
"step 2220 | loss: 3.3162 | dt: 2703.80ms | tok/sec: 757.45\n",
|
| 284 |
+
"step 2230 | loss: 2.8865 | dt: 2705.55ms | tok/sec: 756.96\n",
|
| 285 |
+
"step 2240 | loss: 3.1759 | dt: 2717.99ms | tok/sec: 753.50\n",
|
| 286 |
+
"step 2250 | loss: 3.3846 | dt: 2703.64ms | tok/sec: 757.50\n",
|
| 287 |
+
"step 2260 | loss: 3.3697 | dt: 2732.82ms | tok/sec: 749.41\n",
|
| 288 |
+
"step 2270 | loss: 3.2227 | dt: 2703.69ms | tok/sec: 757.48\n",
|
| 289 |
+
"step 2280 | loss: 2.8169 | dt: 2706.62ms | tok/sec: 756.66\n",
|
| 290 |
+
"step 2290 | loss: 2.9860 | dt: 2700.74ms | tok/sec: 758.31\n",
|
| 291 |
+
"step 2300 | loss: 2.7406 | dt: 2711.47ms | tok/sec: 755.31\n",
|
| 292 |
+
"step 2310 | loss: 3.1365 | dt: 2706.79ms | tok/sec: 756.62\n",
|
| 293 |
+
"step 2320 | loss: 2.9496 | dt: 2699.63ms | tok/sec: 758.62\n",
|
| 294 |
+
"step 2330 | loss: 3.2225 | dt: 2702.71ms | tok/sec: 757.76\n",
|
| 295 |
+
"step 2340 | loss: 3.0330 | dt: 2710.74ms | tok/sec: 755.51\n",
|
| 296 |
+
"step 2350 | loss: 3.1792 | dt: 2704.27ms | tok/sec: 757.32\n",
|
| 297 |
+
"step 2360 | loss: 3.0794 | dt: 2709.94ms | tok/sec: 755.74\n",
|
| 298 |
+
"step 2370 | loss: 2.9420 | dt: 2697.96ms | tok/sec: 759.09\n",
|
| 299 |
+
"step 2380 | loss: 2.8587 | dt: 2695.43ms | tok/sec: 759.80\n",
|
| 300 |
+
"step 2390 | loss: 2.8747 | dt: 2695.93ms | tok/sec: 759.66\n",
|
| 301 |
+
"step 2400 | loss: 2.8319 | dt: 2711.99ms | tok/sec: 755.16\n",
|
| 302 |
+
"step 2410 | loss: 2.8368 | dt: 2736.55ms | tok/sec: 748.39\n",
|
| 303 |
+
"step 2420 | loss: 3.1382 | dt: 2693.63ms | tok/sec: 760.31\n",
|
| 304 |
+
"step 2430 | loss: 3.2540 | dt: 2703.97ms | tok/sec: 757.41\n",
|
| 305 |
+
"step 2440 | loss: 2.8659 | dt: 2702.61ms | tok/sec: 757.79\n",
|
| 306 |
+
"step 2450 | loss: 2.6254 | dt: 2700.30ms | tok/sec: 758.43\n",
|
| 307 |
+
"step 2460 | loss: 2.7329 | dt: 2706.12ms | tok/sec: 756.80\n",
|
| 308 |
+
"step 2470 | loss: 3.0561 | dt: 2704.39ms | tok/sec: 757.29\n",
|
| 309 |
+
"step 2480 | loss: 2.8807 | dt: 2705.39ms | tok/sec: 757.01\n",
|
| 310 |
+
"step 2490 | loss: 2.6715 | dt: 2699.49ms | tok/sec: 758.66\n",
|
| 311 |
+
"step 2500 | loss: 2.6661 | dt: 2701.31ms | tok/sec: 758.15\n",
|
| 312 |
+
"\n",
|
| 313 |
+
"--- Generating text at step 2500 ---\n",
|
| 314 |
+
"! wasBoy\n",
|
| 315 |
+
"'s one heinous to purchase with us and fold'd\n",
|
| 316 |
+
"Than is of lips for your time his son,\n",
|
| 317 |
+
"Which my conscience did forsworn,\n",
|
| 318 |
+
"It like a friendly that owe'd their go\n",
|
| 319 |
+
"But, afterced his\n",
|
| 320 |
+
"-----------------------------------\n",
|
| 321 |
+
"\n",
|
| 322 |
+
"step 2510 | loss: 2.6807 | dt: 2698.79ms | tok/sec: 758.86\n",
|
| 323 |
+
"step 2520 | loss: 2.8045 | dt: 2696.69ms | tok/sec: 759.45\n",
|
| 324 |
+
"step 2530 | loss: 2.9238 | dt: 2701.84ms | tok/sec: 758.00\n",
|
| 325 |
+
"step 2540 | loss: 2.7829 | dt: 2696.82ms | tok/sec: 759.41\n",
|
| 326 |
+
"step 2550 | loss: 2.8688 | dt: 2713.57ms | tok/sec: 754.72\n",
|
| 327 |
+
"step 2560 | loss: 2.5317 | dt: 2700.27ms | tok/sec: 758.44\n",
|
| 328 |
+
"step 2570 | loss: 2.8389 | dt: 2694.92ms | tok/sec: 759.95\n",
|
| 329 |
+
"step 2580 | loss: 2.8973 | dt: 2700.89ms | tok/sec: 758.27\n",
|
| 330 |
+
"step 2590 | loss: 2.9376 | dt: 2701.52ms | tok/sec: 758.09\n",
|
| 331 |
+
"step 2600 | loss: 2.8337 | dt: 2734.45ms | tok/sec: 748.96\n",
|
| 332 |
+
"step 2610 | loss: 2.3944 | dt: 2696.95ms | tok/sec: 759.38\n",
|
| 333 |
+
"step 2620 | loss: 2.5720 | dt: 2705.58ms | tok/sec: 756.95\n",
|
| 334 |
+
"step 2630 | loss: 2.3286 | dt: 2726.76ms | tok/sec: 751.07\n",
|
| 335 |
+
"step 2640 | loss: 2.6583 | dt: 2712.49ms | tok/sec: 755.03\n",
|
| 336 |
+
"step 2650 | loss: 2.5653 | dt: 2706.73ms | tok/sec: 756.63\n",
|
| 337 |
+
"step 2660 | loss: 2.8053 | dt: 2709.10ms | tok/sec: 755.97\n",
|
| 338 |
+
"step 2670 | loss: 2.6431 | dt: 2696.00ms | tok/sec: 759.64\n",
|
| 339 |
+
"step 2680 | loss: 2.6953 | dt: 2746.24ms | tok/sec: 745.75\n",
|
| 340 |
+
"step 2690 | loss: 2.6361 | dt: 2716.29ms | tok/sec: 753.97\n",
|
| 341 |
+
"step 2700 | loss: 2.4590 | dt: 2774.52ms | tok/sec: 738.15\n",
|
| 342 |
+
"step 2710 | loss: 2.3923 | dt: 2749.37ms | tok/sec: 744.90\n",
|
| 343 |
+
"step 2720 | loss: 2.4924 | dt: 2701.07ms | tok/sec: 758.22\n",
|
| 344 |
+
"step 2730 | loss: 2.4631 | dt: 2711.96ms | tok/sec: 755.17\n",
|
| 345 |
+
"step 2740 | loss: 2.4050 | dt: 2718.98ms | tok/sec: 753.22\n",
|
| 346 |
+
"step 2750 | loss: 2.6518 | dt: 2697.21ms | tok/sec: 759.30\n",
|
| 347 |
+
"step 2760 | loss: 2.7260 | dt: 2876.82ms | tok/sec: 711.90\n",
|
| 348 |
+
"step 2770 | loss: 2.3956 | dt: 2746.34ms | tok/sec: 745.72\n",
|
| 349 |
+
"step 2780 | loss: 2.1536 | dt: 2684.97ms | tok/sec: 762.76\n",
|
| 350 |
+
"step 2790 | loss: 2.2741 | dt: 2954.08ms | tok/sec: 693.28\n",
|
| 351 |
+
"step 2800 | loss: 2.5598 | dt: 2906.71ms | tok/sec: 704.58\n",
|
| 352 |
+
"step 2810 | loss: 2.2911 | dt: 2680.05ms | tok/sec: 764.16\n",
|
| 353 |
+
"step 2820 | loss: 2.1884 | dt: 2682.93ms | tok/sec: 763.34\n",
|
| 354 |
+
"step 2830 | loss: 2.2343 | dt: 2675.76ms | tok/sec: 765.39\n",
|
| 355 |
+
"step 2840 | loss: 2.2411 | dt: 2695.51ms | tok/sec: 759.78\n",
|
| 356 |
+
"step 2850 | loss: 2.2290 | dt: 2864.33ms | tok/sec: 715.00\n",
|
| 357 |
+
"step 2860 | loss: 2.2985 | dt: 2659.26ms | tok/sec: 770.14\n",
|
| 358 |
+
"step 2870 | loss: 2.2090 | dt: 2667.58ms | tok/sec: 767.74\n",
|
| 359 |
+
"step 2880 | loss: 2.2442 | dt: 2657.60ms | tok/sec: 770.62\n",
|
| 360 |
+
"step 2890 | loss: 2.0802 | dt: 2673.51ms | tok/sec: 766.03\n",
|
| 361 |
+
"step 2900 | loss: 2.2653 | dt: 2668.15ms | tok/sec: 767.57\n",
|
| 362 |
+
"step 2910 | loss: 2.2835 | dt: 2662.19ms | tok/sec: 769.29\n",
|
| 363 |
+
"step 2920 | loss: 2.1685 | dt: 2660.40ms | tok/sec: 769.81\n",
|
| 364 |
+
"step 2930 | loss: 2.1162 | dt: 2660.61ms | tok/sec: 769.75\n",
|
| 365 |
+
"step 2940 | loss: 1.8274 | dt: 2661.90ms | tok/sec: 769.38\n",
|
| 366 |
+
"step 2950 | loss: 1.8564 | dt: 2703.16ms | tok/sec: 757.63\n",
|
| 367 |
+
"step 2960 | loss: 1.7792 | dt: 2658.62ms | tok/sec: 770.32\n",
|
| 368 |
+
"step 2970 | loss: 2.0663 | dt: 2662.07ms | tok/sec: 769.33\n",
|
| 369 |
+
"step 2980 | loss: 1.9452 | dt: 2660.27ms | tok/sec: 769.85\n",
|
| 370 |
+
"step 2990 | loss: 2.1067 | dt: 2664.71ms | tok/sec: 768.56\n",
|
| 371 |
+
"step 3000 | loss: 1.9999 | dt: 2669.69ms | tok/sec: 767.13\n",
|
| 372 |
+
"\n",
|
| 373 |
+
"--- Generating text at step 3000 ---\n",
|
| 374 |
+
"! thrive lives, the rage d\n",
|
| 375 |
+
"Endiciansed in the wind O very danger!\n",
|
| 376 |
+
"\n",
|
| 377 |
+
"First Murderer:\n",
|
| 378 |
+
"What commend thee well that are this!\n",
|
| 379 |
+
"\n",
|
| 380 |
+
"First Murderer:\n",
|
| 381 |
+
"On my poor lord,\n",
|
| 382 |
+
"I'll follow thee better\n",
|
| 383 |
+
"-----------------------------------\n",
|
| 384 |
+
"\n",
|
| 385 |
+
"step 3010 | loss: 1.9782 | dt: 2669.18ms | tok/sec: 767.28\n",
|
| 386 |
+
"step 3020 | loss: 1.9162 | dt: 2816.11ms | tok/sec: 727.24\n",
|
| 387 |
+
"step 3030 | loss: 1.7600 | dt: 2748.60ms | tok/sec: 745.11\n",
|
| 388 |
+
"step 3040 | loss: 1.7559 | dt: 3330.35ms | tok/sec: 614.95\n",
|
| 389 |
+
"step 3050 | loss: 1.8405 | dt: 2766.42ms | tok/sec: 740.31\n",
|
| 390 |
+
"step 3060 | loss: 1.8245 | dt: 2675.78ms | tok/sec: 765.38\n",
|
| 391 |
+
"step 3070 | loss: 1.7706 | dt: 2834.19ms | tok/sec: 722.61\n",
|
| 392 |
+
"step 3080 | loss: 1.8807 | dt: 2701.07ms | tok/sec: 758.22\n",
|
| 393 |
+
"step 3090 | loss: 1.9564 | dt: 2774.91ms | tok/sec: 738.04\n",
|
| 394 |
+
"step 3100 | loss: 1.6477 | dt: 2787.13ms | tok/sec: 734.81\n",
|
| 395 |
+
"step 3110 | loss: 1.5303 | dt: 2768.86ms | tok/sec: 739.65\n",
|
| 396 |
+
"step 3120 | loss: 1.5144 | dt: 2668.25ms | tok/sec: 767.54\n",
|
| 397 |
+
"step 3130 | loss: 1.6874 | dt: 2684.58ms | tok/sec: 762.87\n",
|
| 398 |
+
"step 3140 | loss: 1.5967 | dt: 2672.88ms | tok/sec: 766.22\n",
|
| 399 |
+
"step 3150 | loss: 1.4894 | dt: 2685.02ms | tok/sec: 762.75\n",
|
| 400 |
+
"step 3160 | loss: 1.5369 | dt: 2780.46ms | tok/sec: 736.57\n",
|
| 401 |
+
"step 3170 | loss: 1.5521 | dt: 2766.64ms | tok/sec: 740.25\n",
|
| 402 |
+
"step 3180 | loss: 1.5144 | dt: 2710.92ms | tok/sec: 755.46\n",
|
| 403 |
+
"step 3190 | loss: 1.5083 | dt: 2749.40ms | tok/sec: 744.89\n",
|
| 404 |
+
"step 3200 | loss: 1.4799 | dt: 2678.04ms | tok/sec: 764.74\n",
|
| 405 |
+
"step 3210 | loss: 1.5245 | dt: 2707.56ms | tok/sec: 756.40\n",
|
| 406 |
+
"step 3220 | loss: 1.3773 | dt: 2688.40ms | tok/sec: 761.79\n",
|
| 407 |
+
"step 3230 | loss: 1.5153 | dt: 2684.65ms | tok/sec: 762.86\n",
|
| 408 |
+
"step 3240 | loss: 1.4807 | dt: 2753.05ms | tok/sec: 743.90\n",
|
| 409 |
+
"step 3250 | loss: 1.4735 | dt: 2755.05ms | tok/sec: 743.36\n",
|
| 410 |
+
"step 3260 | loss: 1.3735 | dt: 2719.84ms | tok/sec: 752.98\n",
|
| 411 |
+
"step 3270 | loss: 1.2075 | dt: 2843.99ms | tok/sec: 720.12\n",
|
| 412 |
+
"step 3280 | loss: 1.2305 | dt: 2703.84ms | tok/sec: 757.44\n",
|
| 413 |
+
"step 3290 | loss: 1.1408 | dt: 2739.76ms | tok/sec: 747.51\n",
|
| 414 |
+
"step 3300 | loss: 1.2520 | dt: 2892.35ms | tok/sec: 708.07\n",
|
| 415 |
+
"step 3310 | loss: 1.3056 | dt: 2716.91ms | tok/sec: 753.80\n",
|
| 416 |
+
"step 3320 | loss: 1.2803 | dt: 2681.74ms | tok/sec: 763.68\n",
|
| 417 |
+
"step 3330 | loss: 1.2850 | dt: 2725.97ms | tok/sec: 751.29\n",
|
| 418 |
+
"step 3340 | loss: 1.2028 | dt: 2692.16ms | tok/sec: 760.73\n",
|
| 419 |
+
"step 3350 | loss: 1.1584 | dt: 2698.27ms | tok/sec: 759.00\n",
|
| 420 |
+
"step 3360 | loss: 1.0396 | dt: 2998.41ms | tok/sec: 683.03\n",
|
| 421 |
+
"step 3370 | loss: 1.0607 | dt: 2903.31ms | tok/sec: 705.40\n",
|
| 422 |
+
"step 3380 | loss: 1.1164 | dt: 3033.53ms | tok/sec: 675.12\n",
|
| 423 |
+
"step 3390 | loss: 1.1108 | dt: 2886.95ms | tok/sec: 709.40\n",
|
| 424 |
+
"step 3400 | loss: 1.0315 | dt: 2752.63ms | tok/sec: 744.02\n",
|
| 425 |
+
"step 3410 | loss: 1.1562 | dt: 2726.22ms | tok/sec: 751.22\n",
|
| 426 |
+
"step 3420 | loss: 1.1321 | dt: 2708.28ms | tok/sec: 756.20\n",
|
| 427 |
+
"step 3430 | loss: 0.9760 | dt: 2681.26ms | tok/sec: 763.82\n",
|
| 428 |
+
"step 3440 | loss: 0.8877 | dt: 2717.40ms | tok/sec: 753.66\n",
|
| 429 |
+
"step 3450 | loss: 0.9401 | dt: 2842.15ms | tok/sec: 720.58\n",
|
| 430 |
+
"step 3460 | loss: 0.9696 | dt: 2802.88ms | tok/sec: 730.68\n",
|
| 431 |
+
"step 3470 | loss: 0.8294 | dt: 2881.42ms | tok/sec: 710.76\n",
|
| 432 |
+
"step 3480 | loss: 0.9166 | dt: 2741.99ms | tok/sec: 746.90\n",
|
| 433 |
+
"step 3490 | loss: 0.8187 | dt: 2768.42ms | tok/sec: 739.77\n",
|
| 434 |
+
"step 3500 | loss: 0.8012 | dt: 2875.77ms | tok/sec: 712.16\n",
|
| 435 |
+
"\n",
|
| 436 |
+
"--- Generating text at step 3500 ---\n",
|
| 437 |
+
"! who child is Rivers; all,saving thy father knows promise wish' moan'occnard's ease that Alban comes blaced boys.\n",
|
| 438 |
+
"\n",
|
| 439 |
+
"BUCKINGHAM:\n",
|
| 440 |
+
"A greater prince, a royal cousin doth make us\n",
|
| 441 |
+
"down; and\n",
|
| 442 |
+
"-----------------------------------\n",
|
| 443 |
+
"\n",
|
| 444 |
+
"step 3510 | loss: 0.8895 | dt: 2911.00ms | tok/sec: 703.54\n",
|
| 445 |
+
"step 3520 | loss: 0.8510 | dt: 2742.43ms | tok/sec: 746.78\n",
|
| 446 |
+
"step 3530 | loss: 0.7573 | dt: 2752.54ms | tok/sec: 744.04\n",
|
| 447 |
+
"step 3540 | loss: 0.7299 | dt: 2756.80ms | tok/sec: 742.89\n",
|
| 448 |
+
"step 3550 | loss: 0.7541 | dt: 2748.92ms | tok/sec: 745.02\n",
|
| 449 |
+
"step 3560 | loss: 0.7650 | dt: 2749.40ms | tok/sec: 744.89\n",
|
| 450 |
+
"step 3570 | loss: 0.6896 | dt: 2862.06ms | tok/sec: 715.57\n",
|
| 451 |
+
"step 3580 | loss: 0.6993 | dt: 2753.84ms | tok/sec: 743.69\n",
|
| 452 |
+
"step 3590 | loss: 0.6519 | dt: 2998.55ms | tok/sec: 683.00\n",
|
| 453 |
+
"step 3600 | loss: 0.6217 | dt: 2740.47ms | tok/sec: 747.32\n",
|
| 454 |
+
"step 3610 | loss: 0.6385 | dt: 2790.61ms | tok/sec: 733.89\n",
|
| 455 |
+
"step 3620 | loss: 0.5826 | dt: 2770.31ms | tok/sec: 739.27\n",
|
| 456 |
+
"step 3630 | loss: 0.6594 | dt: 2783.17ms | tok/sec: 735.85\n",
|
| 457 |
+
"step 3640 | loss: 0.5835 | dt: 2924.64ms | tok/sec: 700.26\n",
|
| 458 |
+
"step 3650 | loss: 0.6141 | dt: 2800.75ms | tok/sec: 731.23\n",
|
| 459 |
+
"step 3660 | loss: 0.6132 | dt: 2762.03ms | tok/sec: 741.48\n",
|
| 460 |
+
"step 3670 | loss: 0.5363 | dt: 2849.28ms | tok/sec: 718.78\n",
|
| 461 |
+
"step 3680 | loss: 0.5985 | dt: 2842.88ms | tok/sec: 720.40\n",
|
| 462 |
+
"step 3690 | loss: 0.4882 | dt: 2755.73ms | tok/sec: 743.18\n",
|
| 463 |
+
"step 3700 | loss: 0.4431 | dt: 2730.10ms | tok/sec: 750.16\n",
|
| 464 |
+
"step 3710 | loss: 0.4325 | dt: 2700.16ms | tok/sec: 758.47\n",
|
| 465 |
+
"step 3720 | loss: 0.4599 | dt: 2738.14ms | tok/sec: 747.95\n",
|
| 466 |
+
"step 3730 | loss: 0.4503 | dt: 2748.69ms | tok/sec: 745.08\n",
|
| 467 |
+
"step 3740 | loss: 0.4781 | dt: 2740.93ms | tok/sec: 747.19\n",
|
| 468 |
+
"step 3750 | loss: 0.5382 | dt: 2849.36ms | tok/sec: 718.76\n",
|
| 469 |
+
"step 3760 | loss: 0.3955 | dt: 2729.07ms | tok/sec: 750.44\n",
|
| 470 |
+
"step 3770 | loss: 0.3948 | dt: 2822.13ms | tok/sec: 725.69\n",
|
| 471 |
+
"step 3780 | loss: 0.4099 | dt: 2917.25ms | tok/sec: 702.03\n",
|
| 472 |
+
"step 3790 | loss: 0.4257 | dt: 2775.18ms | tok/sec: 737.97\n",
|
| 473 |
+
"step 3800 | loss: 0.3690 | dt: 2777.39ms | tok/sec: 737.38\n",
|
| 474 |
+
"step 3810 | loss: 0.3799 | dt: 2850.00ms | tok/sec: 718.60\n",
|
| 475 |
+
"step 3820 | loss: 0.3161 | dt: 2853.62ms | tok/sec: 717.69\n",
|
| 476 |
+
"step 3830 | loss: 0.3599 | dt: 2750.53ms | tok/sec: 744.58\n",
|
| 477 |
+
"step 3840 | loss: 0.3355 | dt: 2805.98ms | tok/sec: 729.87\n",
|
| 478 |
+
"step 3850 | loss: 0.3302 | dt: 3000.64ms | tok/sec: 682.52\n",
|
| 479 |
+
"step 3860 | loss: 0.3285 | dt: 2791.06ms | tok/sec: 733.77\n",
|
| 480 |
+
"step 3870 | loss: 0.2618 | dt: 2751.07ms | tok/sec: 744.44\n",
|
| 481 |
+
"step 3880 | loss: 0.3333 | dt: 2798.18ms | tok/sec: 731.90\n",
|
| 482 |
+
"step 3890 | loss: 0.2814 | dt: 2870.91ms | tok/sec: 713.36\n",
|
| 483 |
+
"step 3900 | loss: 0.2948 | dt: 2761.24ms | tok/sec: 741.70\n",
|
| 484 |
+
"step 3910 | loss: 0.2399 | dt: 2869.73ms | tok/sec: 713.66\n",
|
| 485 |
+
"step 3920 | loss: 0.2674 | dt: 2980.56ms | tok/sec: 687.12\n",
|
| 486 |
+
"step 3930 | loss: 0.2109 | dt: 2954.62ms | tok/sec: 693.15\n",
|
| 487 |
+
"step 3940 | loss: 0.2220 | dt: 2755.38ms | tok/sec: 743.27\n",
|
| 488 |
+
"step 3950 | loss: 0.2295 | dt: 2780.64ms | tok/sec: 736.52\n",
|
| 489 |
+
"step 3960 | loss: 0.2387 | dt: 2795.49ms | tok/sec: 732.61\n",
|
| 490 |
+
"step 3970 | loss: 0.2579 | dt: 3017.64ms | tok/sec: 678.68\n",
|
| 491 |
+
"step 3980 | loss: 0.2046 | dt: 2761.85ms | tok/sec: 741.53\n",
|
| 492 |
+
"step 3990 | loss: 0.2144 | dt: 2744.11ms | tok/sec: 746.33\n",
|
| 493 |
+
"step 4000 | loss: 0.2070 | dt: 2722.34ms | tok/sec: 752.29\n",
|
| 494 |
+
"\n",
|
| 495 |
+
"--- Generating text at step 4000 ---\n",
|
| 496 |
+
"!\n",
|
| 497 |
+
"If presently you will take the one\n",
|
| 498 |
+
"To what we will o\n",
|
| 499 |
+
"\n",
|
| 500 |
+
"'ll theretove him to known your lord.\n",
|
| 501 |
+
"\n",
|
| 502 |
+
"PRINCEE:\n",
|
| 503 |
+
"Indeed, my lord, safe still well we deserve\n",
|
| 504 |
+
"In your will.\n",
|
| 505 |
+
"\n",
|
| 506 |
+
"\n",
|
| 507 |
+
"-----------------------------------\n",
|
| 508 |
+
"\n",
|
| 509 |
+
"step 4010 | loss: 0.2144 | dt: 2902.75ms | tok/sec: 705.54\n",
|
| 510 |
+
"step 4020 | loss: 0.2045 | dt: 2945.05ms | tok/sec: 695.40\n",
|
| 511 |
+
"step 4030 | loss: 0.1790 | dt: 3056.03ms | tok/sec: 670.15\n",
|
| 512 |
+
"step 4040 | loss: 0.1516 | dt: 2747.85ms | tok/sec: 745.31\n",
|
| 513 |
+
"step 4050 | loss: 0.1615 | dt: 2805.18ms | tok/sec: 730.08\n",
|
| 514 |
+
"step 4060 | loss: 0.1421 | dt: 2749.10ms | tok/sec: 744.97\n",
|
| 515 |
+
"step 4070 | loss: 0.1674 | dt: 2859.66ms | tok/sec: 716.17\n",
|
| 516 |
+
"step 4080 | loss: 0.1633 | dt: 2847.74ms | tok/sec: 719.17\n",
|
| 517 |
+
"step 4090 | loss: 0.1570 | dt: 2916.81ms | tok/sec: 702.14\n",
|
| 518 |
+
"step 4100 | loss: 0.1592 | dt: 3016.92ms | tok/sec: 678.84\n",
|
| 519 |
+
"step 4110 | loss: 0.1363 | dt: 2727.48ms | tok/sec: 750.88\n",
|
| 520 |
+
"step 4120 | loss: 0.1797 | dt: 2730.33ms | tok/sec: 750.09\n",
|
| 521 |
+
"step 4130 | loss: 0.1210 | dt: 2699.00ms | tok/sec: 758.80\n",
|
| 522 |
+
"step 4140 | loss: 0.1253 | dt: 2867.20ms | tok/sec: 714.29\n",
|
| 523 |
+
"step 4150 | loss: 0.1016 | dt: 2856.48ms | tok/sec: 716.97\n",
|
| 524 |
+
"step 4160 | loss: 0.1162 | dt: 2750.59ms | tok/sec: 744.57\n",
|
| 525 |
+
"step 4170 | loss: 0.1599 | dt: 2886.92ms | tok/sec: 709.41\n",
|
| 526 |
+
"step 4180 | loss: 0.1116 | dt: 2701.22ms | tok/sec: 758.18\n",
|
| 527 |
+
"step 4190 | loss: 0.1093 | dt: 2674.94ms | tok/sec: 765.63\n",
|
| 528 |
+
"step 4200 | loss: 0.1099 | dt: 2875.45ms | tok/sec: 712.24\n",
|
| 529 |
+
"step 4210 | loss: 0.1319 | dt: 2775.36ms | tok/sec: 737.92\n",
|
| 530 |
+
"step 4220 | loss: 0.0932 | dt: 2740.37ms | tok/sec: 747.34\n",
|
| 531 |
+
"step 4230 | loss: 0.1047 | dt: 2697.50ms | tok/sec: 759.22\n",
|
| 532 |
+
"step 4240 | loss: 0.1159 | dt: 2656.54ms | tok/sec: 770.93\n",
|
| 533 |
+
"step 4250 | loss: 0.0943 | dt: 2632.27ms | tok/sec: 778.04\n",
|
| 534 |
+
"step 4260 | loss: 0.0935 | dt: 2733.61ms | tok/sec: 749.19\n",
|
| 535 |
+
"step 4270 | loss: 0.1044 | dt: 2700.76ms | tok/sec: 758.31\n",
|
| 536 |
+
"step 4280 | loss: 0.0909 | dt: 2690.66ms | tok/sec: 761.15\n",
|
| 537 |
+
"step 4290 | loss: 0.0993 | dt: 2689.09ms | tok/sec: 761.60\n",
|
| 538 |
+
"step 4300 | loss: 0.0931 | dt: 2710.26ms | tok/sec: 755.65\n",
|
| 539 |
+
"step 4310 | loss: 0.0841 | dt: 2795.26ms | tok/sec: 732.67\n",
|
| 540 |
+
"step 4320 | loss: 0.0729 | dt: 2704.03ms | tok/sec: 757.39\n",
|
| 541 |
+
"step 4330 | loss: 0.0633 | dt: 2676.90ms | tok/sec: 765.07\n",
|
| 542 |
+
"step 4340 | loss: 0.0859 | dt: 2761.68ms | tok/sec: 741.58\n",
|
| 543 |
+
"step 4350 | loss: 0.0625 | dt: 2778.00ms | tok/sec: 737.22\n",
|
| 544 |
+
"step 4360 | loss: 0.0638 | dt: 3028.06ms | tok/sec: 676.34\n",
|
| 545 |
+
"step 4370 | loss: 0.0841 | dt: 2895.60ms | tok/sec: 707.28\n",
|
| 546 |
+
"step 4380 | loss: 0.0731 | dt: 2678.41ms | tok/sec: 764.63\n",
|
| 547 |
+
"step 4390 | loss: 0.0676 | dt: 2632.73ms | tok/sec: 777.90\n",
|
| 548 |
+
"step 4400 | loss: 0.0789 | dt: 2633.57ms | tok/sec: 777.65\n",
|
| 549 |
+
"step 4410 | loss: 0.0870 | dt: 2633.24ms | tok/sec: 777.75\n",
|
| 550 |
+
"step 4420 | loss: 0.0564 | dt: 2666.58ms | tok/sec: 768.03\n",
|
| 551 |
+
"step 4430 | loss: 0.0510 | dt: 2663.78ms | tok/sec: 768.83\n",
|
| 552 |
+
"step 4440 | loss: 0.0844 | dt: 2637.76ms | tok/sec: 776.42\n",
|
| 553 |
+
"step 4450 | loss: 0.0574 | dt: 2649.37ms | tok/sec: 773.02\n",
|
| 554 |
+
"step 4460 | loss: 0.0670 | dt: 2642.94ms | tok/sec: 774.89\n",
|
| 555 |
+
"step 4470 | loss: 0.0818 | dt: 2635.25ms | tok/sec: 777.15\n",
|
| 556 |
+
"step 4480 | loss: 0.0611 | dt: 2641.49ms | tok/sec: 775.32\n",
|
| 557 |
+
"step 4490 | loss: 0.0571 | dt: 2718.30ms | tok/sec: 753.41\n",
|
| 558 |
+
"step 4500 | loss: 0.0524 | dt: 2714.53ms | tok/sec: 754.46\n",
|
| 559 |
+
"\n",
|
| 560 |
+
"--- Generating text at step 4500 ---\n",
|
| 561 |
+
"! to Tyenting in their hate.\n",
|
| 562 |
+
"\n",
|
| 563 |
+
"BUCKINGHAM:\n",
|
| 564 |
+
"My lord, I'll bear unto your head\n",
|
| 565 |
+
"Which best lives hath had but your breath or\n",
|
| 566 |
+
"Plured with the hawnes with yields that\n",
|
| 567 |
+
"There did touch his\n",
|
| 568 |
+
"-----------------------------------\n",
|
| 569 |
+
"\n",
|
| 570 |
+
"step 4510 | loss: 0.0513 | dt: 2759.85ms | tok/sec: 742.07\n",
|
| 571 |
+
"step 4520 | loss: 0.0582 | dt: 2780.63ms | tok/sec: 736.52\n",
|
| 572 |
+
"step 4530 | loss: 0.0607 | dt: 2788.91ms | tok/sec: 734.34\n",
|
| 573 |
+
"step 4540 | loss: 0.0537 | dt: 2705.52ms | tok/sec: 756.97\n",
|
| 574 |
+
"step 4550 | loss: 0.0469 | dt: 2797.45ms | tok/sec: 732.10\n",
|
| 575 |
+
"step 4560 | loss: 0.0513 | dt: 2715.12ms | tok/sec: 754.29\n",
|
| 576 |
+
"step 4570 | loss: 0.0537 | dt: 2701.81ms | tok/sec: 758.01\n",
|
| 577 |
+
"step 4580 | loss: 0.0555 | dt: 2703.04ms | tok/sec: 757.66\n",
|
| 578 |
+
"step 4590 | loss: 0.0617 | dt: 2702.77ms | tok/sec: 757.74\n",
|
| 579 |
+
"step 4600 | loss: 0.0601 | dt: 2703.01ms | tok/sec: 757.67\n",
|
| 580 |
+
"step 4610 | loss: 0.0386 | dt: 2713.98ms | tok/sec: 754.61\n",
|
| 581 |
+
"step 4620 | loss: 0.0469 | dt: 2702.42ms | tok/sec: 757.84\n",
|
| 582 |
+
"step 4630 | loss: 0.0429 | dt: 2701.25ms | tok/sec: 758.17\n",
|
| 583 |
+
"step 4640 | loss: 0.0436 | dt: 2700.08ms | tok/sec: 758.50\n",
|
| 584 |
+
"step 4650 | loss: 0.0552 | dt: 2707.12ms | tok/sec: 756.52\n",
|
| 585 |
+
"step 4660 | loss: 0.0478 | dt: 2704.64ms | tok/sec: 757.22\n",
|
| 586 |
+
"step 4670 | loss: 0.0503 | dt: 2700.75ms | tok/sec: 758.31\n",
|
| 587 |
+
"step 4680 | loss: 0.0370 | dt: 2703.38ms | tok/sec: 757.57\n",
|
| 588 |
+
"step 4690 | loss: 0.0488 | dt: 2690.58ms | tok/sec: 761.18\n",
|
| 589 |
+
"step 4700 | loss: 0.0395 | dt: 2711.10ms | tok/sec: 755.41\n",
|
| 590 |
+
"step 4710 | loss: 0.0384 | dt: 2695.39ms | tok/sec: 759.82\n",
|
| 591 |
+
"step 4720 | loss: 0.0320 | dt: 2698.22ms | tok/sec: 759.02\n",
|
| 592 |
+
"step 4730 | loss: 0.0389 | dt: 2693.92ms | tok/sec: 760.23\n",
|
| 593 |
+
"step 4740 | loss: 0.0417 | dt: 2715.99ms | tok/sec: 754.05\n",
|
| 594 |
+
"step 4750 | loss: 0.0327 | dt: 2707.00ms | tok/sec: 756.56\n",
|
| 595 |
+
"step 4760 | loss: 0.0383 | dt: 2702.06ms | tok/sec: 757.94\n",
|
| 596 |
+
"step 4770 | loss: 0.0482 | dt: 2720.02ms | tok/sec: 752.93\n",
|
| 597 |
+
"step 4780 | loss: 0.0285 | dt: 2695.32ms | tok/sec: 759.84\n",
|
| 598 |
+
"step 4790 | loss: 0.0466 | dt: 2699.53ms | tok/sec: 758.65\n",
|
| 599 |
+
"step 4800 | loss: 0.0293 | dt: 2747.82ms | tok/sec: 745.32\n",
|
| 600 |
+
"step 4810 | loss: 0.0475 | dt: 2694.82ms | tok/sec: 759.98\n",
|
| 601 |
+
"step 4820 | loss: 0.0331 | dt: 2697.06ms | tok/sec: 759.34\n",
|
| 602 |
+
"step 4830 | loss: 0.0424 | dt: 2695.18ms | tok/sec: 759.88\n",
|
| 603 |
+
"step 4840 | loss: 0.0388 | dt: 2699.61ms | tok/sec: 758.63\n",
|
| 604 |
+
"step 4850 | loss: 0.0278 | dt: 2697.17ms | tok/sec: 759.31\n",
|
| 605 |
+
"step 4860 | loss: 0.0352 | dt: 2712.98ms | tok/sec: 754.89\n",
|
| 606 |
+
"step 4870 | loss: 0.0231 | dt: 2708.67ms | tok/sec: 756.09\n",
|
| 607 |
+
"step 4880 | loss: 0.0355 | dt: 2697.11ms | tok/sec: 759.33\n",
|
| 608 |
+
"step 4890 | loss: 0.0379 | dt: 2696.41ms | tok/sec: 759.53\n",
|
| 609 |
+
"step 4900 | loss: 0.0251 | dt: 2692.32ms | tok/sec: 760.68\n",
|
| 610 |
+
"step 4910 | loss: 0.0263 | dt: 2695.45ms | tok/sec: 759.80\n",
|
| 611 |
+
"step 4920 | loss: 0.0279 | dt: 2703.49ms | tok/sec: 757.54\n",
|
| 612 |
+
"step 4930 | loss: 0.0253 | dt: 2696.46ms | tok/sec: 759.51\n",
|
| 613 |
+
"step 4940 | loss: 0.0279 | dt: 2719.05ms | tok/sec: 753.21\n",
|
| 614 |
+
"step 4950 | loss: 0.0303 | dt: 2698.82ms | tok/sec: 758.85\n",
|
| 615 |
+
"step 4960 | loss: 0.0277 | dt: 2702.86ms | tok/sec: 757.72\n",
|
| 616 |
+
"step 4970 | loss: 0.0230 | dt: 2708.92ms | tok/sec: 756.02\n",
|
| 617 |
+
"step 4980 | loss: 0.0397 | dt: 2693.32ms | tok/sec: 760.40\n",
|
| 618 |
+
"step 4990 | loss: 0.0244 | dt: 2699.65ms | tok/sec: 758.62\n",
|
| 619 |
+
"Saving model to smollm_135_checkpoint.pth\n",
|
| 620 |
+
"\n",
|
| 621 |
+
"--- Resuming training from checkpoint ---\n",
|
| 622 |
+
"Checkpoint loaded successfully.\n",
|
| 623 |
+
"Resume step 0 | loss: 0.0218\n",
|
| 624 |
+
"Resume step 10 | loss: 0.0335\n",
|
| 625 |
+
"Resume step 20 | loss: 0.0360\n",
|
| 626 |
+
"Resume step 30 | loss: 0.0450\n",
|
| 627 |
+
"Resume step 40 | loss: 0.0396\n",
|
| 628 |
+
"Resumed training completed.\n"
|
| 629 |
+
]
|
| 630 |
+
}
|
| 631 |
+
],
|
| 632 |
+
"source": [
|
| 633 |
+
"# SmolLM-135M Implementation (Llama Architecture)\n",
|
| 634 |
+
"# Based on: https://huggingface.co/HuggingFaceTB/SmolLM-135M\n",
|
| 635 |
+
"\n",
|
| 636 |
+
"import math\n",
|
| 637 |
+
"import inspect\n",
|
| 638 |
+
"from dataclasses import dataclass\n",
|
| 639 |
+
"from typing import Optional, Tuple\n",
|
| 640 |
+
"\n",
|
| 641 |
+
"import torch\n",
|
| 642 |
+
"import torch.nn as nn\n",
|
| 643 |
+
"from torch.nn import functional as F\n",
|
| 644 |
+
"\n",
|
| 645 |
+
"# Configuration for SmolLM-135M\n",
|
| 646 |
+
"@dataclass\n",
|
| 647 |
+
"class SmolLMConfig:\n",
|
| 648 |
+
" block_size: int = 512 # Reduced to 512 for 4GB GPU training\n",
|
| 649 |
+
" vocab_size: int = 50304 # Aligned to 50304 for tiktoken compatibility (SmolLM native is 49152)\n",
|
| 650 |
+
" n_layer: int = 30\n",
|
| 651 |
+
" n_head: int = 9\n",
|
| 652 |
+
" n_kv_head: int = 3 # Grouped Query Attention (GQA)\n",
|
| 653 |
+
" n_embd: int = 576\n",
|
| 654 |
+
" intermediate_size: int = 1536 # SwiGLU intermediate size\n",
|
| 655 |
+
" rms_norm_eps: float = 1e-5\n",
|
| 656 |
+
" rope_theta: float = 10000.0\n",
|
| 657 |
+
" dropout: float = 0.0\n",
|
| 658 |
+
" bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster\n",
|
| 659 |
+
"\n",
|
| 660 |
+
"class RMSNorm(nn.Module):\n",
|
| 661 |
+
" def __init__(self, dim: int, eps: float = 1e-6):\n",
|
| 662 |
+
" super().__init__()\n",
|
| 663 |
+
" self.eps = eps\n",
|
| 664 |
+
" self.weight = nn.Parameter(torch.ones(dim))\n",
|
| 665 |
+
"\n",
|
| 666 |
+
" def _norm(self, x):\n",
|
| 667 |
+
" return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n",
|
| 668 |
+
"\n",
|
| 669 |
+
" def forward(self, x):\n",
|
| 670 |
+
" output = self._norm(x.float()).type_as(x)\n",
|
| 671 |
+
" return output * self.weight\n",
|
| 672 |
+
"\n",
|
| 673 |
+
"def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):\n",
|
| 674 |
+
" freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))\n",
|
| 675 |
+
" t = torch.arange(end, device=freqs.device, dtype=torch.float32)\n",
|
| 676 |
+
" freqs = torch.outer(t, freqs)\n",
|
| 677 |
+
" freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64\n",
|
| 678 |
+
" return freqs_cis\n",
|
| 679 |
+
"\n",
|
| 680 |
+
"def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):\n",
|
| 681 |
+
" ndim = x.ndim\n",
|
| 682 |
+
" assert 0 <= 1 < ndim\n",
|
| 683 |
+
" assert freqs_cis.shape == (x.shape[1], x.shape[-1])\n",
|
| 684 |
+
" shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]\n",
|
| 685 |
+
" return freqs_cis.view(*shape)\n",
|
| 686 |
+
"\n",
|
| 687 |
+
"def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):\n",
|
| 688 |
+
" xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))\n",
|
| 689 |
+
" xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))\n",
|
| 690 |
+
" freqs_cis = reshape_for_broadcast(freqs_cis, xq_)\n",
|
| 691 |
+
" xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)\n",
|
| 692 |
+
" xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)\n",
|
| 693 |
+
" return xq_out.type_as(xq), xk_out.type_as(xk)\n",
|
| 694 |
+
"\n",
|
| 695 |
+
"class CausalSelfAttention(nn.Module):\n",
|
| 696 |
+
" def __init__(self, config: SmolLMConfig):\n",
|
| 697 |
+
" super().__init__()\n",
|
| 698 |
+
" self.n_head = config.n_head\n",
|
| 699 |
+
" self.n_kv_head = config.n_kv_head\n",
|
| 700 |
+
" self.n_embd = config.n_embd\n",
|
| 701 |
+
" self.head_dim = config.n_embd // config.n_head\n",
|
| 702 |
+
" self.n_rep = self.n_head // self.n_kv_head\n",
|
| 703 |
+
"\n",
|
| 704 |
+
" self.wq = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=config.bias)\n",
|
| 705 |
+
" self.wk = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=config.bias)\n",
|
| 706 |
+
" self.wv = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=config.bias)\n",
|
| 707 |
+
" self.wo = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=config.bias)\n",
|
| 708 |
+
"\n",
|
| 709 |
+
" self.dropout = config.dropout\n",
|
| 710 |
+
" self.resid_dropout = nn.Dropout(config.dropout)\n",
|
| 711 |
+
"\n",
|
| 712 |
+
" def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):\n",
|
| 713 |
+
" B, T, C = x.shape\n",
|
| 714 |
+
"\n",
|
| 715 |
+
" xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)\n",
|
| 716 |
+
" xq = xq.view(B, T, self.n_head, self.head_dim)\n",
|
| 717 |
+
" xk = xk.view(B, T, self.n_kv_head, self.head_dim)\n",
|
| 718 |
+
" xv = xv.view(B, T, self.n_kv_head, self.head_dim)\n",
|
| 719 |
+
"\n",
|
| 720 |
+
" xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)\n",
|
| 721 |
+
"\n",
|
| 722 |
+
" # Grouped Query Attention: repeat k/v heads to match q heads\n",
|
| 723 |
+
" xk = torch.repeat_interleave(xk, dim=2, repeats=self.n_rep)\n",
|
| 724 |
+
" xv = torch.repeat_interleave(xv, dim=2, repeats=self.n_rep)\n",
|
| 725 |
+
"\n",
|
| 726 |
+
" # Make heads batch dimension\n",
|
| 727 |
+
" xq = xq.transpose(1, 2) # (B, n_head, T, head_dim)\n",
|
| 728 |
+
" xk = xk.transpose(1, 2) # (B, n_head, T, head_dim)\n",
|
| 729 |
+
" xv = xv.transpose(1, 2) # (B, n_head, T, head_dim)\n",
|
| 730 |
+
"\n",
|
| 731 |
+
" # Flash Attention\n",
|
| 732 |
+
" output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)\n",
|
| 733 |
+
"\n",
|
| 734 |
+
" output = output.transpose(1, 2).contiguous().view(B, T, C)\n",
|
| 735 |
+
" return self.resid_dropout(self.wo(output))\n",
|
| 736 |
+
"\n",
|
| 737 |
+
"class SwiGLU(nn.Module):\n",
|
| 738 |
+
" def __init__(self, config: SmolLMConfig):\n",
|
| 739 |
+
" super().__init__()\n",
|
| 740 |
+
" self.w1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) # Gate\n",
|
| 741 |
+
" self.w3 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) # Value\n",
|
| 742 |
+
" self.w2 = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) # Output\n",
|
| 743 |
+
" self.dropout = nn.Dropout(config.dropout)\n",
|
| 744 |
+
"\n",
|
| 745 |
+
" def forward(self, x):\n",
|
| 746 |
+
" return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))\n",
|
| 747 |
+
"\n",
|
| 748 |
+
"class Block(nn.Module):\n",
|
| 749 |
+
" def __init__(self, config: SmolLMConfig):\n",
|
| 750 |
+
" super().__init__()\n",
|
| 751 |
+
" self.attention_norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)\n",
|
| 752 |
+
" self.attention = CausalSelfAttention(config)\n",
|
| 753 |
+
" self.ffn_norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)\n",
|
| 754 |
+
" self.feed_forward = SwiGLU(config)\n",
|
| 755 |
+
"\n",
|
| 756 |
+
" def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):\n",
|
| 757 |
+
" h = x + self.attention(self.attention_norm(x), freqs_cis)\n",
|
| 758 |
+
" out = h + self.feed_forward(self.ffn_norm(h))\n",
|
| 759 |
+
" return out\n",
|
| 760 |
+
"\n",
|
| 761 |
+
"class SmolLM(nn.Module):\n",
|
| 762 |
+
" def __init__(self, config: SmolLMConfig):\n",
|
| 763 |
+
" super().__init__()\n",
|
| 764 |
+
" self.config = config\n",
|
| 765 |
+
" self.tok_embeddings = nn.Embedding(config.vocab_size, config.n_embd)\n",
|
| 766 |
+
" self.layers = nn.ModuleList([Block(config) for _ in range(config.n_layer)])\n",
|
| 767 |
+
" self.norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)\n",
|
| 768 |
+
" self.output = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
|
| 769 |
+
"\n",
|
| 770 |
+
" # Weight sharing\n",
|
| 771 |
+
" self.tok_embeddings.weight = self.output.weight\n",
|
| 772 |
+
"\n",
|
| 773 |
+
" # Precompute RoPE frequencies\n",
|
| 774 |
+
" self.freqs_cis = precompute_freqs_cis(config.n_embd // config.n_head, config.block_size * 2, config.rope_theta)\n",
|
| 775 |
+
"\n",
|
| 776 |
+
" self.apply(self._init_weights)\n",
|
| 777 |
+
"\n",
|
| 778 |
+
" def _init_weights(self, module):\n",
|
| 779 |
+
" if isinstance(module, nn.Linear):\n",
|
| 780 |
+
" torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
|
| 781 |
+
" if module.bias is not None:\n",
|
| 782 |
+
" torch.nn.init.zeros_(module.bias)\n",
|
| 783 |
+
" elif isinstance(module, nn.Embedding):\n",
|
| 784 |
+
" torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
|
| 785 |
+
"\n",
|
| 786 |
+
" def forward(self, idx, targets=None):\n",
|
| 787 |
+
" B, T = idx.shape\n",
|
| 788 |
+
" x = self.tok_embeddings(idx)\n",
|
| 789 |
+
" \n",
|
| 790 |
+
" # Ensure freqs_cis is on the correct device\n",
|
| 791 |
+
" if self.freqs_cis.device != x.device:\n",
|
| 792 |
+
" self.freqs_cis = self.freqs_cis.to(x.device)\n",
|
| 793 |
+
" freqs_cis = self.freqs_cis[:T]\n",
|
| 794 |
+
"\n",
|
| 795 |
+
" for layer in self.layers:\n",
|
| 796 |
+
" x = layer(x, freqs_cis)\n",
|
| 797 |
+
" \n",
|
| 798 |
+
" x = self.norm(x)\n",
|
| 799 |
+
" logits = self.output(x)\n",
|
| 800 |
+
"\n",
|
| 801 |
+
" loss = None\n",
|
| 802 |
+
" if targets is not None:\n",
|
| 803 |
+
" loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))\n",
|
| 804 |
+
" \n",
|
| 805 |
+
" return logits, loss\n",
|
| 806 |
+
"\n",
|
| 807 |
+
"# Device selection\n",
|
| 808 |
+
"device = 'cpu'\n",
|
| 809 |
+
"if torch.cuda.is_available():\n",
|
| 810 |
+
" device = 'cuda'\n",
|
| 811 |
+
"elif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n",
|
| 812 |
+
" device = \"mps\"\n",
|
| 813 |
+
"print(f\"using device: {device}\")\n",
|
| 814 |
+
"\n",
|
| 815 |
+
"# Data Loading\n",
|
| 816 |
+
"import tiktoken\n",
|
| 817 |
+
"\n",
|
| 818 |
+
"class DataLoaderLite:\n",
|
| 819 |
+
" def __init__(self, B, T):\n",
|
| 820 |
+
" self.B = B\n",
|
| 821 |
+
" self.T = T\n",
|
| 822 |
+
"\n",
|
| 823 |
+
" # Load tokens from disk\n",
|
| 824 |
+
" try:\n",
|
| 825 |
+
" with open('input.txt', 'r', encoding='utf-8') as f:\n",
|
| 826 |
+
" text = f.read()\n",
|
| 827 |
+
" except FileNotFoundError:\n",
|
| 828 |
+
" print(\"Error: input.txt not found. Please ensure the file exists.\")\n",
|
| 829 |
+
" text = \"Hello world \" * 1000 # Fallback for testing if file missing\n",
|
| 830 |
+
" \n",
|
| 831 |
+
" enc = tiktoken.get_encoding('gpt2') \n",
|
| 832 |
+
" tokens = enc.encode(text)\n",
|
| 833 |
+
" self.tokens = torch.tensor(tokens)\n",
|
| 834 |
+
" print(f'loaded {len(self.tokens)} tokens')\n",
|
| 835 |
+
" print(f'1 epoch = {len(self.tokens) // (B * T)} batches')\n",
|
| 836 |
+
"\n",
|
| 837 |
+
" self.current_position = 0\n",
|
| 838 |
+
" \n",
|
| 839 |
+
" def next_batch(self):\n",
|
| 840 |
+
" B, T = self.B, self.T\n",
|
| 841 |
+
" buf = self.tokens[self.current_position: self.current_position + B * T + 1]\n",
|
| 842 |
+
" x = (buf[:-1]).view(B, T) # inputs\n",
|
| 843 |
+
" y = (buf[1:]).view(B, T) # targets\n",
|
| 844 |
+
" self.current_position += B*T\n",
|
| 845 |
+
" if self.current_position + (B * T + 1) > len(self.tokens):\n",
|
| 846 |
+
" self.current_position = 0\n",
|
| 847 |
+
" return x, y\n",
|
| 848 |
+
"\n",
|
| 849 |
+
"# Training Setup\n",
|
| 850 |
+
"torch.manual_seed(1337)\n",
|
| 851 |
+
"if torch.cuda.is_available():\n",
|
| 852 |
+
" torch.cuda.manual_seed(1337)\n",
|
| 853 |
+
"\n",
|
| 854 |
+
"torch.set_float32_matmul_precision('high')\n",
|
| 855 |
+
"\n",
|
| 856 |
+
"config = SmolLMConfig()\n",
|
| 857 |
+
"model = SmolLM(config)\n",
|
| 858 |
+
"model.to(device)\n",
|
| 859 |
+
"\n",
|
| 860 |
+
"print(f\"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M\")\n",
|
| 861 |
+
"\n",
|
| 862 |
+
"# Generation Function\n",
|
| 863 |
+
"@torch.no_grad()\n",
|
| 864 |
+
"def generate(model, idx, max_new_tokens, temperature=1.0, top_k=None):\n",
|
| 865 |
+
" \"\"\"\n",
|
| 866 |
+
" Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete\n",
|
| 867 |
+
" the sequence max_new_tokens times, feeding the predictions back into the model each time.\n",
|
| 868 |
+
" \"\"\"\n",
|
| 869 |
+
" for _ in range(max_new_tokens):\n",
|
| 870 |
+
" # if the sequence context is growing too long we must crop it at block_size\n",
|
| 871 |
+
" idx_cond = idx if idx.size(1) <= model.config.block_size else idx[:, -model.config.block_size:]\n",
|
| 872 |
+
" # forward the model to get the logits for the index in the sequence\n",
|
| 873 |
+
" logits, _ = model(idx_cond)\n",
|
| 874 |
+
" # pluck the logits at the final step and scale by desired temperature\n",
|
| 875 |
+
" logits = logits[:, -1, :] / temperature\n",
|
| 876 |
+
" # optionally crop the logits to only the top k options\n",
|
| 877 |
+
" if top_k is not None:\n",
|
| 878 |
+
" v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n",
|
| 879 |
+
" logits[logits < v[:, [-1]]] = -float('Inf')\n",
|
| 880 |
+
" # apply softmax to convert logits to (normalized) probabilities\n",
|
| 881 |
+
" probs = F.softmax(logits, dim=-1)\n",
|
| 882 |
+
" # sample from the distribution\n",
|
| 883 |
+
" idx_next = torch.multinomial(probs, num_samples=1)\n",
|
| 884 |
+
" # append sampled index to the running sequence and continue\n",
|
| 885 |
+
" idx = torch.cat((idx, idx_next), dim=1)\n",
|
| 886 |
+
"\n",
|
| 887 |
+
" return idx\n",
|
| 888 |
+
"\n",
|
| 889 |
+
"# Training Loop\n",
|
| 890 |
+
"train_loader = DataLoaderLite(B = 4, T = 512) # Reduced batch size and context for 4GB GPU\n",
|
| 891 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr = 3e-4)\n",
|
| 892 |
+
"\n",
|
| 893 |
+
"import time\n",
|
| 894 |
+
"import os\n",
|
| 895 |
+
"\n",
|
| 896 |
+
"max_steps = 5000\n",
|
| 897 |
+
"eval_interval = 500\n",
|
| 898 |
+
"save_path = \"smollm_135_checkpoint.pth\"\n",
|
| 899 |
+
"\n",
|
| 900 |
+
"print(\"Starting training...\")\n",
|
| 901 |
+
"for i in range(max_steps):\n",
|
| 902 |
+
" t0 = time.time()\n",
|
| 903 |
+
" x, y = train_loader.next_batch()\n",
|
| 904 |
+
" x, y = x.to(device), y.to(device)\n",
|
| 905 |
+
" optimizer.zero_grad()\n",
|
| 906 |
+
" \n",
|
| 907 |
+
" # Mixed precision training\n",
|
| 908 |
+
" with torch.autocast(device_type=device, dtype=torch.bfloat16 if device=='cuda' else torch.float32):\n",
|
| 909 |
+
" logits, loss = model(x, y) \n",
|
| 910 |
+
" \n",
|
| 911 |
+
" loss.backward()\n",
|
| 912 |
+
" optimizer.step()\n",
|
| 913 |
+
" \n",
|
| 914 |
+
" if device == 'cuda':\n",
|
| 915 |
+
" torch.cuda.synchronize() \n",
|
| 916 |
+
" \n",
|
| 917 |
+
" t1 = time.time()\n",
|
| 918 |
+
" dt = (t1 - t0) * 1000\n",
|
| 919 |
+
" tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)\n",
|
| 920 |
+
" \n",
|
| 921 |
+
" if i % 10 == 0:\n",
|
| 922 |
+
" print(f'step {i} | loss: {loss.item():.4f} | dt: {dt:.2f}ms | tok/sec: {tokens_per_sec:.2f}')\n",
|
| 923 |
+
" \n",
|
| 924 |
+
" # Generate output every 500 steps\n",
|
| 925 |
+
" if i > 0 and i % eval_interval == 0:\n",
|
| 926 |
+
" print(f\"\\n--- Generating text at step {i} ---\")\n",
|
| 927 |
+
" context = torch.zeros((1, 1), dtype=torch.long, device=device) # Start with token 0 (usually valid)\n",
|
| 928 |
+
" generated = generate(model, context, max_new_tokens=50)\n",
|
| 929 |
+
" # Decode using tiktoken (gpt2 encoding as used in DataLoader)\n",
|
| 930 |
+
" enc = tiktoken.get_encoding('gpt2')\n",
|
| 931 |
+
" decoded = enc.decode(generated[0].tolist())\n",
|
| 932 |
+
" # Force ASCII for Windows console compatibility\n",
|
| 933 |
+
" print(decoded.encode('ascii', errors='ignore').decode('ascii'))\n",
|
| 934 |
+
" print(\"-----------------------------------\\n\")\n",
|
| 935 |
+
"\n",
|
| 936 |
+
"# Save checkpoint\n",
|
| 937 |
+
"print(f\"Saving model to {save_path}\")\n",
|
| 938 |
+
"torch.save(model.state_dict(), save_path)\n",
|
| 939 |
+
"\n",
|
| 940 |
+
"# Resume training demonstration\n",
|
| 941 |
+
"print(\"\\n--- Resuming training from checkpoint ---\")\n",
|
| 942 |
+
"# Re-initialize model to prove loading works\n",
|
| 943 |
+
"model_new = SmolLM(config)\n",
|
| 944 |
+
"model_new.to(device)\n",
|
| 945 |
+
"model_new.load_state_dict(torch.load(save_path))\n",
|
| 946 |
+
"print(\"Checkpoint loaded successfully.\")\n",
|
| 947 |
+
"\n",
|
| 948 |
+
"optimizer_new = torch.optim.AdamW(model_new.parameters(), lr = 3e-4)\n",
|
| 949 |
+
"\n",
|
| 950 |
+
"# Train for another 50 steps\n",
|
| 951 |
+
"for i in range(50):\n",
|
| 952 |
+
" t0 = time.time()\n",
|
| 953 |
+
" x, y = train_loader.next_batch()\n",
|
| 954 |
+
" x, y = x.to(device), y.to(device)\n",
|
| 955 |
+
" optimizer_new.zero_grad()\n",
|
| 956 |
+
" \n",
|
| 957 |
+
" with torch.autocast(device_type=device, dtype=torch.bfloat16 if device=='cuda' else torch.float32):\n",
|
| 958 |
+
" logits, loss = model_new(x, y) \n",
|
| 959 |
+
" \n",
|
| 960 |
+
" loss.backward()\n",
|
| 961 |
+
" optimizer_new.step()\n",
|
| 962 |
+
" \n",
|
| 963 |
+
" if device == 'cuda':\n",
|
| 964 |
+
" torch.cuda.synchronize()\n",
|
| 965 |
+
" \n",
|
| 966 |
+
" if i % 10 == 0:\n",
|
| 967 |
+
" print(f'Resume step {i} | loss: {loss.item():.4f}')\n",
|
| 968 |
+
"\n",
|
| 969 |
+
"print(\"Resumed training completed.\")\n"
|
| 970 |
+
]
|
| 971 |
+
},
|
| 972 |
+
{
|
| 973 |
+
"cell_type": "code",
|
| 974 |
+
"execution_count": null,
|
| 975 |
+
"id": "a720457f",
|
| 976 |
+
"metadata": {},
|
| 977 |
+
"outputs": [],
|
| 978 |
+
"source": []
|
| 979 |
+
}
|
| 980 |
+
],
|
| 981 |
+
"metadata": {
|
| 982 |
+
"kernelspec": {
|
| 983 |
+
"display_name": "Python 3",
|
| 984 |
+
"language": "python",
|
| 985 |
+
"name": "python3"
|
| 986 |
+
},
|
| 987 |
+
"language_info": {
|
| 988 |
+
"codemirror_mode": {
|
| 989 |
+
"name": "ipython",
|
| 990 |
+
"version": 3
|
| 991 |
+
},
|
| 992 |
+
"file_extension": ".py",
|
| 993 |
+
"mimetype": "text/x-python",
|
| 994 |
+
"name": "python",
|
| 995 |
+
"nbconvert_exporter": "python",
|
| 996 |
+
"pygments_lexer": "ipython3",
|
| 997 |
+
"version": "3.13.7"
|
| 998 |
+
}
|
| 999 |
+
},
|
| 1000 |
+
"nbformat": 4,
|
| 1001 |
+
"nbformat_minor": 5
|
| 1002 |
+
}
|
app.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio App for Sentence Completion
|
| 3 |
+
Main entry point for Hugging Face Spaces
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import torch
|
| 8 |
+
from inference import load_model, generate_text, get_device
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Global model variable
|
| 12 |
+
model = None
|
| 13 |
+
device = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def initialize_model(model_path=None):
|
| 17 |
+
"""Initialize the model on startup"""
|
| 18 |
+
global model, device
|
| 19 |
+
try:
|
| 20 |
+
model, device = load_model(model_path=model_path)
|
| 21 |
+
#model.eval() # Set to eval mode once
|
| 22 |
+
return f"Model loaded successfully on device: {device}"
|
| 23 |
+
except Exception as e:
|
| 24 |
+
return f"Error loading model: {str(e)}"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def complete_sentence(prompt, max_tokens, top_k, temperature):
|
| 28 |
+
"""Generate sentence completion based on prompt"""
|
| 29 |
+
global model, device
|
| 30 |
+
|
| 31 |
+
if model is None:
|
| 32 |
+
return "Error: Model not loaded. Please restart the app."
|
| 33 |
+
|
| 34 |
+
if not prompt.strip():
|
| 35 |
+
return "Please enter a prompt to complete."
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
# Generate completion
|
| 39 |
+
print(prompt)
|
| 40 |
+
generated_text = generate_text(
|
| 41 |
+
prompt=prompt,
|
| 42 |
+
model=model,
|
| 43 |
+
max_tokens=int(max_tokens),
|
| 44 |
+
device=device
|
| 45 |
+
)
|
| 46 |
+
print(device)
|
| 47 |
+
|
| 48 |
+
return generated_text
|
| 49 |
+
except Exception as e:
|
| 50 |
+
return f"Error generating text: {str(e)}"
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def create_interface():
|
| 54 |
+
"""Create and return the Gradio interface"""
|
| 55 |
+
|
| 56 |
+
# Initialize model on startup
|
| 57 |
+
# Try to load from common checkpoint paths
|
| 58 |
+
checkpoint_paths = [
|
| 59 |
+
'./model/smollm_135_checkpoint.pth',
|
| 60 |
+
'./model/model.pth',
|
| 61 |
+
'model.pt',
|
| 62 |
+
'checkpoint.pth',
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
model_path = None
|
| 66 |
+
for path in checkpoint_paths:
|
| 67 |
+
import os
|
| 68 |
+
if os.path.exists(path):
|
| 69 |
+
model_path = path
|
| 70 |
+
print(f"Model found at {path}")
|
| 71 |
+
break
|
| 72 |
+
else:
|
| 73 |
+
print(f"Model not found at {path}")
|
| 74 |
+
|
| 75 |
+
status = initialize_model(model_path=model_path)
|
| 76 |
+
print(status)
|
| 77 |
+
|
| 78 |
+
# Create Gradio interface
|
| 79 |
+
with gr.Blocks(title="Sentence Completion with SmolLM-135M") as demo:
|
| 80 |
+
gr.Markdown(
|
| 81 |
+
"""
|
| 82 |
+
# Sentence Completion with SmolLM-135M
|
| 83 |
+
|
| 84 |
+
Enter a prompt and the model will complete the sentence for you.
|
| 85 |
+
Adjust the parameters to control the generation behavior.
|
| 86 |
+
"""
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
with gr.Row():
|
| 90 |
+
with gr.Column(scale=2):
|
| 91 |
+
prompt_input = gr.Textbox(
|
| 92 |
+
label="Prompt",
|
| 93 |
+
placeholder="Enter your prompt here...",
|
| 94 |
+
lines=3,
|
| 95 |
+
value="The future of artificial intelligence is"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
with gr.Row():
|
| 99 |
+
max_tokens_slider = gr.Slider(
|
| 100 |
+
minimum=10,
|
| 101 |
+
maximum=200,
|
| 102 |
+
value=50,
|
| 103 |
+
step=10,
|
| 104 |
+
label="Max Tokens"
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
top_k_slider = gr.Slider(
|
| 108 |
+
minimum=1,
|
| 109 |
+
maximum=100,
|
| 110 |
+
value=50,
|
| 111 |
+
step=1,
|
| 112 |
+
label="Top-K"
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
temperature_slider = gr.Slider(
|
| 116 |
+
minimum=0.1,
|
| 117 |
+
maximum=2.0,
|
| 118 |
+
value=1.0,
|
| 119 |
+
step=0.1,
|
| 120 |
+
label="Temperature"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
generate_btn = gr.Button("Generate", variant="primary")
|
| 124 |
+
|
| 125 |
+
with gr.Column(scale=2):
|
| 126 |
+
output_text = gr.Textbox(
|
| 127 |
+
label="Generated Text",
|
| 128 |
+
lines=10,
|
| 129 |
+
interactive=False
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
gr.Markdown(
|
| 133 |
+
"""
|
| 134 |
+
### Parameters:
|
| 135 |
+
- **Max Tokens**: Maximum number of tokens to generate
|
| 136 |
+
- **Top-K**: Sample from top K most likely tokens (lower = more focused)
|
| 137 |
+
- **Temperature**: Controls randomness (lower = more deterministic, higher = more creative)
|
| 138 |
+
"""
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Set up the generate function
|
| 142 |
+
generate_btn.click(
|
| 143 |
+
fn=complete_sentence,
|
| 144 |
+
inputs=[prompt_input, max_tokens_slider, top_k_slider, temperature_slider],
|
| 145 |
+
outputs=output_text
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Also generate on Enter key press
|
| 149 |
+
prompt_input.submit(
|
| 150 |
+
fn=complete_sentence,
|
| 151 |
+
inputs=[prompt_input, max_tokens_slider, top_k_slider, temperature_slider],
|
| 152 |
+
outputs=output_text
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
return demo
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
+
demo = create_interface()
|
| 160 |
+
demo.launch(share=False)
|
check_cuda.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 4 |
+
print(f"CUDA device count: {torch.cuda.device_count()}")
|
| 5 |
+
if torch.cuda.is_available():
|
| 6 |
+
print(f"Current device: {torch.cuda.current_device()}")
|
| 7 |
+
print(f"Device name: {torch.cuda.get_device_name(0)}")
|
| 8 |
+
else:
|
| 9 |
+
print("CUDA not available")
|
checkpoint_info.txt
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Checkpoint type: <class 'collections.OrderedDict'>
|
| 2 |
+
Keys: ['tok_embeddings.weight', 'layers.0.attention_norm.weight', 'layers.0.attention.wq.weight', 'layers.0.attention.wk.weight', 'layers.0.attention.wv.weight', 'layers.0.attention.wo.weight', 'layers.0.ffn_norm.weight', 'layers.0.feed_forward.w1.weight', 'layers.0.feed_forward.w3.weight', 'layers.0.feed_forward.w2.weight', 'layers.1.attention_norm.weight', 'layers.1.attention.wq.weight', 'layers.1.attention.wk.weight', 'layers.1.attention.wv.weight', 'layers.1.attention.wo.weight', 'layers.1.ffn_norm.weight', 'layers.1.feed_forward.w1.weight', 'layers.1.feed_forward.w3.weight', 'layers.1.feed_forward.w2.weight', 'layers.2.attention_norm.weight', 'layers.2.attention.wq.weight', 'layers.2.attention.wk.weight', 'layers.2.attention.wv.weight', 'layers.2.attention.wo.weight', 'layers.2.ffn_norm.weight', 'layers.2.feed_forward.w1.weight', 'layers.2.feed_forward.w3.weight', 'layers.2.feed_forward.w2.weight', 'layers.3.attention_norm.weight', 'layers.3.attention.wq.weight', 'layers.3.attention.wk.weight', 'layers.3.attention.wv.weight', 'layers.3.attention.wo.weight', 'layers.3.ffn_norm.weight', 'layers.3.feed_forward.w1.weight', 'layers.3.feed_forward.w3.weight', 'layers.3.feed_forward.w2.weight', 'layers.4.attention_norm.weight', 'layers.4.attention.wq.weight', 'layers.4.attention.wk.weight', 'layers.4.attention.wv.weight', 'layers.4.attention.wo.weight', 'layers.4.ffn_norm.weight', 'layers.4.feed_forward.w1.weight', 'layers.4.feed_forward.w3.weight', 'layers.4.feed_forward.w2.weight', 'layers.5.attention_norm.weight', 'layers.5.attention.wq.weight', 'layers.5.attention.wk.weight', 'layers.5.attention.wv.weight', 'layers.5.attention.wo.weight', 'layers.5.ffn_norm.weight', 'layers.5.feed_forward.w1.weight', 'layers.5.feed_forward.w3.weight', 'layers.5.feed_forward.w2.weight', 'layers.6.attention_norm.weight', 'layers.6.attention.wq.weight', 'layers.6.attention.wk.weight', 'layers.6.attention.wv.weight', 'layers.6.attention.wo.weight', 'layers.6.ffn_norm.weight', 'layers.6.feed_forward.w1.weight', 'layers.6.feed_forward.w3.weight', 'layers.6.feed_forward.w2.weight', 'layers.7.attention_norm.weight', 'layers.7.attention.wq.weight', 'layers.7.attention.wk.weight', 'layers.7.attention.wv.weight', 'layers.7.attention.wo.weight', 'layers.7.ffn_norm.weight', 'layers.7.feed_forward.w1.weight', 'layers.7.feed_forward.w3.weight', 'layers.7.feed_forward.w2.weight', 'layers.8.attention_norm.weight', 'layers.8.attention.wq.weight', 'layers.8.attention.wk.weight', 'layers.8.attention.wv.weight', 'layers.8.attention.wo.weight', 'layers.8.ffn_norm.weight', 'layers.8.feed_forward.w1.weight', 'layers.8.feed_forward.w3.weight', 'layers.8.feed_forward.w2.weight', 'layers.9.attention_norm.weight', 'layers.9.attention.wq.weight', 'layers.9.attention.wk.weight', 'layers.9.attention.wv.weight', 'layers.9.attention.wo.weight', 'layers.9.ffn_norm.weight', 'layers.9.feed_forward.w1.weight', 'layers.9.feed_forward.w3.weight', 'layers.9.feed_forward.w2.weight', 'layers.10.attention_norm.weight', 'layers.10.attention.wq.weight', 'layers.10.attention.wk.weight', 'layers.10.attention.wv.weight', 'layers.10.attention.wo.weight', 'layers.10.ffn_norm.weight', 'layers.10.feed_forward.w1.weight', 'layers.10.feed_forward.w3.weight', 'layers.10.feed_forward.w2.weight', 'layers.11.attention_norm.weight', 'layers.11.attention.wq.weight', 'layers.11.attention.wk.weight', 'layers.11.attention.wv.weight', 'layers.11.attention.wo.weight', 'layers.11.ffn_norm.weight', 'layers.11.feed_forward.w1.weight', 'layers.11.feed_forward.w3.weight', 'layers.11.feed_forward.w2.weight', 'layers.12.attention_norm.weight', 'layers.12.attention.wq.weight', 'layers.12.attention.wk.weight', 'layers.12.attention.wv.weight', 'layers.12.attention.wo.weight', 'layers.12.ffn_norm.weight', 'layers.12.feed_forward.w1.weight', 'layers.12.feed_forward.w3.weight', 'layers.12.feed_forward.w2.weight', 'layers.13.attention_norm.weight', 'layers.13.attention.wq.weight', 'layers.13.attention.wk.weight', 'layers.13.attention.wv.weight', 'layers.13.attention.wo.weight', 'layers.13.ffn_norm.weight', 'layers.13.feed_forward.w1.weight', 'layers.13.feed_forward.w3.weight', 'layers.13.feed_forward.w2.weight', 'layers.14.attention_norm.weight', 'layers.14.attention.wq.weight', 'layers.14.attention.wk.weight', 'layers.14.attention.wv.weight', 'layers.14.attention.wo.weight', 'layers.14.ffn_norm.weight', 'layers.14.feed_forward.w1.weight', 'layers.14.feed_forward.w3.weight', 'layers.14.feed_forward.w2.weight', 'layers.15.attention_norm.weight', 'layers.15.attention.wq.weight', 'layers.15.attention.wk.weight', 'layers.15.attention.wv.weight', 'layers.15.attention.wo.weight', 'layers.15.ffn_norm.weight', 'layers.15.feed_forward.w1.weight', 'layers.15.feed_forward.w3.weight', 'layers.15.feed_forward.w2.weight', 'layers.16.attention_norm.weight', 'layers.16.attention.wq.weight', 'layers.16.attention.wk.weight', 'layers.16.attention.wv.weight', 'layers.16.attention.wo.weight', 'layers.16.ffn_norm.weight', 'layers.16.feed_forward.w1.weight', 'layers.16.feed_forward.w3.weight', 'layers.16.feed_forward.w2.weight', 'layers.17.attention_norm.weight', 'layers.17.attention.wq.weight', 'layers.17.attention.wk.weight', 'layers.17.attention.wv.weight', 'layers.17.attention.wo.weight', 'layers.17.ffn_norm.weight', 'layers.17.feed_forward.w1.weight', 'layers.17.feed_forward.w3.weight', 'layers.17.feed_forward.w2.weight', 'layers.18.attention_norm.weight', 'layers.18.attention.wq.weight', 'layers.18.attention.wk.weight', 'layers.18.attention.wv.weight', 'layers.18.attention.wo.weight', 'layers.18.ffn_norm.weight', 'layers.18.feed_forward.w1.weight', 'layers.18.feed_forward.w3.weight', 'layers.18.feed_forward.w2.weight', 'layers.19.attention_norm.weight', 'layers.19.attention.wq.weight', 'layers.19.attention.wk.weight', 'layers.19.attention.wv.weight', 'layers.19.attention.wo.weight', 'layers.19.ffn_norm.weight', 'layers.19.feed_forward.w1.weight', 'layers.19.feed_forward.w3.weight', 'layers.19.feed_forward.w2.weight', 'layers.20.attention_norm.weight', 'layers.20.attention.wq.weight', 'layers.20.attention.wk.weight', 'layers.20.attention.wv.weight', 'layers.20.attention.wo.weight', 'layers.20.ffn_norm.weight', 'layers.20.feed_forward.w1.weight', 'layers.20.feed_forward.w3.weight', 'layers.20.feed_forward.w2.weight', 'layers.21.attention_norm.weight', 'layers.21.attention.wq.weight', 'layers.21.attention.wk.weight', 'layers.21.attention.wv.weight', 'layers.21.attention.wo.weight', 'layers.21.ffn_norm.weight', 'layers.21.feed_forward.w1.weight', 'layers.21.feed_forward.w3.weight', 'layers.21.feed_forward.w2.weight', 'layers.22.attention_norm.weight', 'layers.22.attention.wq.weight', 'layers.22.attention.wk.weight', 'layers.22.attention.wv.weight', 'layers.22.attention.wo.weight', 'layers.22.ffn_norm.weight', 'layers.22.feed_forward.w1.weight', 'layers.22.feed_forward.w3.weight', 'layers.22.feed_forward.w2.weight', 'layers.23.attention_norm.weight', 'layers.23.attention.wq.weight', 'layers.23.attention.wk.weight', 'layers.23.attention.wv.weight', 'layers.23.attention.wo.weight', 'layers.23.ffn_norm.weight', 'layers.23.feed_forward.w1.weight', 'layers.23.feed_forward.w3.weight', 'layers.23.feed_forward.w2.weight', 'layers.24.attention_norm.weight', 'layers.24.attention.wq.weight', 'layers.24.attention.wk.weight', 'layers.24.attention.wv.weight', 'layers.24.attention.wo.weight', 'layers.24.ffn_norm.weight', 'layers.24.feed_forward.w1.weight', 'layers.24.feed_forward.w3.weight', 'layers.24.feed_forward.w2.weight', 'layers.25.attention_norm.weight', 'layers.25.attention.wq.weight', 'layers.25.attention.wk.weight', 'layers.25.attention.wv.weight', 'layers.25.attention.wo.weight', 'layers.25.ffn_norm.weight', 'layers.25.feed_forward.w1.weight', 'layers.25.feed_forward.w3.weight', 'layers.25.feed_forward.w2.weight', 'layers.26.attention_norm.weight', 'layers.26.attention.wq.weight', 'layers.26.attention.wk.weight', 'layers.26.attention.wv.weight', 'layers.26.attention.wo.weight', 'layers.26.ffn_norm.weight', 'layers.26.feed_forward.w1.weight', 'layers.26.feed_forward.w3.weight', 'layers.26.feed_forward.w2.weight', 'layers.27.attention_norm.weight', 'layers.27.attention.wq.weight', 'layers.27.attention.wk.weight', 'layers.27.attention.wv.weight', 'layers.27.attention.wo.weight', 'layers.27.ffn_norm.weight', 'layers.27.feed_forward.w1.weight', 'layers.27.feed_forward.w3.weight', 'layers.27.feed_forward.w2.weight', 'layers.28.attention_norm.weight', 'layers.28.attention.wq.weight', 'layers.28.attention.wk.weight', 'layers.28.attention.wv.weight', 'layers.28.attention.wo.weight', 'layers.28.ffn_norm.weight', 'layers.28.feed_forward.w1.weight', 'layers.28.feed_forward.w3.weight', 'layers.28.feed_forward.w2.weight', 'layers.29.attention_norm.weight', 'layers.29.attention.wq.weight', 'layers.29.attention.wk.weight', 'layers.29.attention.wv.weight', 'layers.29.attention.wo.weight', 'layers.29.ffn_norm.weight', 'layers.29.feed_forward.w1.weight', 'layers.29.feed_forward.w3.weight', 'layers.29.feed_forward.w2.weight', 'norm.weight', 'output.weight']
|
inference.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference and Model Loading Utilities
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
import tiktoken
|
| 9 |
+
from model import SmolLM, SmolLMConfig
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_device():
|
| 13 |
+
"""Auto-detect and return the best available device"""
|
| 14 |
+
print(f"[DEBUG] Checking device availability...")
|
| 15 |
+
print(f"[DEBUG] torch.cuda.is_available(): {torch.cuda.is_available()}")
|
| 16 |
+
if torch.cuda.is_available():
|
| 17 |
+
print(f"[DEBUG] CUDA device: {torch.cuda.get_device_name(0)}")
|
| 18 |
+
return 'cuda'
|
| 19 |
+
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
| 20 |
+
print(f"[DEBUG] Using MPS device")
|
| 21 |
+
return "mps"
|
| 22 |
+
else:
|
| 23 |
+
print(f"[DEBUG] Falling back to CPU")
|
| 24 |
+
return 'cpu'
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def load_model(model_path=None, device=None):
|
| 28 |
+
"""
|
| 29 |
+
Load SmolLM model from checkpoint.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
model_path: Path to saved model checkpoint (.pth or .pt file)
|
| 33 |
+
device: Device to load model on (auto-detected if None)
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
Loaded model and device
|
| 37 |
+
"""
|
| 38 |
+
if device is None:
|
| 39 |
+
device = get_device()
|
| 40 |
+
|
| 41 |
+
# Try to load saved checkpoint first
|
| 42 |
+
if model_path and os.path.exists(model_path):
|
| 43 |
+
try:
|
| 44 |
+
print(f"Loading saved model from {model_path}...")
|
| 45 |
+
model = SmolLM.load_checkpoint(model_path, device=device)
|
| 46 |
+
return model, device
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"Failed to load saved model: {e}")
|
| 49 |
+
|
| 50 |
+
# Fallback to untrained model
|
| 51 |
+
print("Creating model with default config (untrained)...")
|
| 52 |
+
config = SmolLMConfig()
|
| 53 |
+
model = SmolLM(config)
|
| 54 |
+
model.to(device)
|
| 55 |
+
return model, device
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def generate_text(prompt, model, max_tokens=50, top_k=50, temperature=1.0, device="cpu"):
|
| 59 |
+
"""
|
| 60 |
+
Generate text completion for a given prompt using the SmolLM model.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
prompt: Input text prompt
|
| 64 |
+
model: SmolLM model instance
|
| 65 |
+
max_tokens: Maximum number of tokens to generate
|
| 66 |
+
top_k: Top-k sampling parameter (None for no top-k filtering)
|
| 67 |
+
temperature: Temperature for sampling (higher = more random)
|
| 68 |
+
device: Device to run inference on
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Generated text string (including original prompt)
|
| 72 |
+
"""
|
| 73 |
+
# Global tokenizer cache
|
| 74 |
+
_TOKENIZER = None
|
| 75 |
+
|
| 76 |
+
def _get_tokenizer():
|
| 77 |
+
global _TOKENIZER
|
| 78 |
+
if _TOKENIZER is None:
|
| 79 |
+
_TOKENIZER = tiktoken.get_encoding("gpt2")
|
| 80 |
+
return _TOKENIZER
|
| 81 |
+
|
| 82 |
+
def generate_text(prompt, model, max_tokens=50, top_k=50, temperature=1.0, device="cpu"):
|
| 83 |
+
"""
|
| 84 |
+
Generate text completion for a given prompt using the SmolLM model.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
prompt: Input text prompt
|
| 88 |
+
model: SmolLM model instance
|
| 89 |
+
max_tokens: Maximum number of tokens to generate
|
| 90 |
+
top_k: Top-k sampling parameter (None for no top-k filtering)
|
| 91 |
+
temperature: Temperature for sampling (higher = more random)
|
| 92 |
+
device: Device to run inference on
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Generated text string (including original prompt)
|
| 96 |
+
"""
|
| 97 |
+
enc = _get_tokenizer()
|
| 98 |
+
model.eval()
|
| 99 |
+
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
# tokenize prompt
|
| 102 |
+
input_ids = enc.encode(prompt)
|
| 103 |
+
x = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
|
| 104 |
+
print(max_tokens)
|
| 105 |
+
past_kv = None
|
| 106 |
+
generated_ids = list(input_ids)
|
| 107 |
+
|
| 108 |
+
for _ in range(max_tokens):
|
| 109 |
+
print(x.shape)
|
| 110 |
+
logits, _ = model(x)
|
| 111 |
+
print(logits.shape)
|
| 112 |
+
logits = logits[:, -1, :] / temperature
|
| 113 |
+
print(logits.shape)
|
| 114 |
+
if top_k is not None:
|
| 115 |
+
topk = torch.topk(logits, top_k, dim=-1)
|
| 116 |
+
mask = logits < topk.values[:, -1].unsqueeze(-1)
|
| 117 |
+
logits = logits.masked_fill(mask, -float("inf"))
|
| 118 |
+
|
| 119 |
+
probs = F.softmax(logits, dim=-1)
|
| 120 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 121 |
+
x = torch.cat((x, next_token), dim=1)
|
| 122 |
+
|
| 123 |
+
generated_ids = x[0].tolist()
|
| 124 |
+
|
| 125 |
+
# generated_ids already contains the prompt and generated tokens
|
| 126 |
+
return enc.decode(generated_ids)
|
inspect_checkpoint.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import sys
|
| 3 |
+
|
| 4 |
+
def inspect_checkpoint():
|
| 5 |
+
path = './model/smollm_135_checkpoint.pth'
|
| 6 |
+
output_file = 'checkpoint_info.txt'
|
| 7 |
+
with open(output_file, 'w') as f:
|
| 8 |
+
try:
|
| 9 |
+
checkpoint = torch.load(path, map_location='cpu')
|
| 10 |
+
f.write(f"Checkpoint type: {type(checkpoint)}\n")
|
| 11 |
+
if isinstance(checkpoint, dict):
|
| 12 |
+
f.write(f"Keys: {list(checkpoint.keys())}\n")
|
| 13 |
+
else:
|
| 14 |
+
f.write("Checkpoint is not a dictionary.\n")
|
| 15 |
+
except Exception as e:
|
| 16 |
+
f.write(f"Error loading checkpoint: {e}\n")
|
| 17 |
+
|
| 18 |
+
if __name__ == "__main__":
|
| 19 |
+
inspect_checkpoint()
|
main.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def main():
|
| 2 |
+
print("Hello from smollm-135!")
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
main()
|
model.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
SmolLM-135M Implementation (Llama Architecture)
|
| 3 |
+
Based on: https://huggingface.co/HuggingFaceTB/SmolLM-135M
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import inspect
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from typing import Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.nn import functional as F
|
| 14 |
+
|
| 15 |
+
# Configuration for SmolLM-135M
|
| 16 |
+
@dataclass
|
| 17 |
+
class SmolLMConfig:
|
| 18 |
+
block_size: int = 512 # Reduced to 512 for 4GB GPU training
|
| 19 |
+
vocab_size: int = 50304 # Aligned to 50304 for tiktoken compatibility (SmolLM native is 49152)
|
| 20 |
+
n_layer: int = 30
|
| 21 |
+
n_head: int = 9
|
| 22 |
+
n_kv_head: int = 3 # Grouped Query Attention (GQA)
|
| 23 |
+
n_embd: int = 576
|
| 24 |
+
intermediate_size: int = 1536 # SwiGLU intermediate size
|
| 25 |
+
rms_norm_eps: float = 1e-5
|
| 26 |
+
rope_theta: float = 10000.0
|
| 27 |
+
dropout: float = 0.0
|
| 28 |
+
bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
|
| 29 |
+
|
| 30 |
+
class RMSNorm(nn.Module):
|
| 31 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.eps = eps
|
| 34 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 35 |
+
|
| 36 |
+
def _norm(self, x):
|
| 37 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
output = self._norm(x.float()).type_as(x)
|
| 41 |
+
return output * self.weight
|
| 42 |
+
|
| 43 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
| 44 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 45 |
+
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
| 46 |
+
freqs = torch.outer(t, freqs)
|
| 47 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 48 |
+
return freqs_cis
|
| 49 |
+
|
| 50 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
| 51 |
+
ndim = x.ndim
|
| 52 |
+
assert 0 <= 1 < ndim
|
| 53 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
| 54 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 55 |
+
return freqs_cis.view(*shape)
|
| 56 |
+
|
| 57 |
+
def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
|
| 58 |
+
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
|
| 59 |
+
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
|
| 60 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
|
| 61 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
|
| 62 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
|
| 63 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
| 64 |
+
|
| 65 |
+
class CausalSelfAttention(nn.Module):
|
| 66 |
+
def __init__(self, config: SmolLMConfig):
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.n_head = config.n_head
|
| 69 |
+
self.n_kv_head = config.n_kv_head
|
| 70 |
+
self.n_embd = config.n_embd
|
| 71 |
+
self.head_dim = config.n_embd // config.n_head
|
| 72 |
+
self.n_rep = self.n_head // self.n_kv_head
|
| 73 |
+
|
| 74 |
+
self.wq = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=config.bias)
|
| 75 |
+
self.wk = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=config.bias)
|
| 76 |
+
self.wv = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=config.bias)
|
| 77 |
+
self.wo = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=config.bias)
|
| 78 |
+
|
| 79 |
+
self.dropout = config.dropout
|
| 80 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 81 |
+
|
| 82 |
+
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
|
| 83 |
+
B, T, C = x.shape
|
| 84 |
+
|
| 85 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
| 86 |
+
xq = xq.view(B, T, self.n_head, self.head_dim)
|
| 87 |
+
xk = xk.view(B, T, self.n_kv_head, self.head_dim)
|
| 88 |
+
xv = xv.view(B, T, self.n_kv_head, self.head_dim)
|
| 89 |
+
|
| 90 |
+
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
|
| 91 |
+
|
| 92 |
+
# Grouped Query Attention: repeat k/v heads to match q heads
|
| 93 |
+
xk = torch.repeat_interleave(xk, dim=2, repeats=self.n_rep)
|
| 94 |
+
xv = torch.repeat_interleave(xv, dim=2, repeats=self.n_rep)
|
| 95 |
+
|
| 96 |
+
# Make heads batch dimension
|
| 97 |
+
xq = xq.transpose(1, 2) # (B, n_head, T, head_dim)
|
| 98 |
+
xk = xk.transpose(1, 2) # (B, n_head, T, head_dim)
|
| 99 |
+
xv = xv.transpose(1, 2) # (B, n_head, T, head_dim)
|
| 100 |
+
|
| 101 |
+
if past_kv is not None:
|
| 102 |
+
k_cache, v_cache = past_kv
|
| 103 |
+
xk = torch.cat([k_cache, xk], dim=2)
|
| 104 |
+
xv = torch.cat([v_cache, xv], dim=2)
|
| 105 |
+
|
| 106 |
+
current_kv = (xk, xv)
|
| 107 |
+
|
| 108 |
+
# Flash Attention
|
| 109 |
+
if past_kv is not None and T == 1:
|
| 110 |
+
# Optimization: no causal mask needed for the last token attending to all previous
|
| 111 |
+
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=False)
|
| 112 |
+
else:
|
| 113 |
+
output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
|
| 114 |
+
|
| 115 |
+
output = output.transpose(1, 2).contiguous().view(B, T, C)
|
| 116 |
+
return self.resid_dropout(self.wo(output)), current_kv
|
| 117 |
+
|
| 118 |
+
class SwiGLU(nn.Module):
|
| 119 |
+
def __init__(self, config: SmolLMConfig):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.w1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) # Gate
|
| 122 |
+
self.w3 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) # Value
|
| 123 |
+
self.w2 = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) # Output
|
| 124 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
| 128 |
+
|
| 129 |
+
class Block(nn.Module):
|
| 130 |
+
def __init__(self, config: SmolLMConfig):
|
| 131 |
+
super().__init__()
|
| 132 |
+
self.attention_norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
|
| 133 |
+
self.attention = CausalSelfAttention(config)
|
| 134 |
+
self.ffn_norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
|
| 135 |
+
self.feed_forward = SwiGLU(config)
|
| 136 |
+
|
| 137 |
+
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
|
| 138 |
+
attn_out, layer_kv = self.attention(self.attention_norm(x), freqs_cis, past_kv)
|
| 139 |
+
h = x + attn_out
|
| 140 |
+
out = h + self.feed_forward(self.ffn_norm(h))
|
| 141 |
+
return out, layer_kv
|
| 142 |
+
|
| 143 |
+
class SmolLM(nn.Module):
|
| 144 |
+
def __init__(self, config: SmolLMConfig):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.config = config
|
| 147 |
+
self.tok_embeddings = nn.Embedding(config.vocab_size, config.n_embd)
|
| 148 |
+
self.layers = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
|
| 149 |
+
self.norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
|
| 150 |
+
self.output = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 151 |
+
|
| 152 |
+
# Weight sharing
|
| 153 |
+
self.tok_embeddings.weight = self.output.weight
|
| 154 |
+
|
| 155 |
+
# Precompute RoPE frequencies
|
| 156 |
+
self.freqs_cis = precompute_freqs_cis(config.n_embd // config.n_head, config.block_size * 2, config.rope_theta)
|
| 157 |
+
|
| 158 |
+
self.apply(self._init_weights)
|
| 159 |
+
|
| 160 |
+
def _init_weights(self, module):
|
| 161 |
+
if isinstance(module, nn.Linear):
|
| 162 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 163 |
+
if module.bias is not None:
|
| 164 |
+
torch.nn.init.zeros_(module.bias)
|
| 165 |
+
elif isinstance(module, nn.Embedding):
|
| 166 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 167 |
+
|
| 168 |
+
def forward(self, idx, targets=None, past_kv=None):
|
| 169 |
+
B, T = idx.shape
|
| 170 |
+
x = self.tok_embeddings(idx)
|
| 171 |
+
|
| 172 |
+
# Determine starting position for RoPE
|
| 173 |
+
start_pos = 0
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# Ensure freqs_cis is on the correct device
|
| 177 |
+
if self.freqs_cis.device != x.device:
|
| 178 |
+
self.freqs_cis = self.freqs_cis.to(x.device)
|
| 179 |
+
|
| 180 |
+
# Select freqs_cis for the current positions
|
| 181 |
+
freqs_cis = self.freqs_cis[start_pos : start_pos + T]
|
| 182 |
+
|
| 183 |
+
new_past_kv = []
|
| 184 |
+
for i, layer in enumerate(self.layers):
|
| 185 |
+
layer_past_kv = past_kv[i] if past_kv is not None else None
|
| 186 |
+
x, layer_kv = layer(x, freqs_cis, past_kv=layer_past_kv)
|
| 187 |
+
new_past_kv.append(layer_kv)
|
| 188 |
+
|
| 189 |
+
x = self.norm(x)
|
| 190 |
+
logits = self.output(x)
|
| 191 |
+
|
| 192 |
+
loss = None
|
| 193 |
+
if targets is not None:
|
| 194 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
|
| 195 |
+
|
| 196 |
+
return logits, loss
|
| 197 |
+
|
| 198 |
+
def save_checkpoint(self, filepath):
|
| 199 |
+
"""Save model checkpoint with config"""
|
| 200 |
+
checkpoint = {
|
| 201 |
+
'model_state_dict': self.state_dict(),
|
| 202 |
+
'config': {
|
| 203 |
+
'block_size': self.config.block_size,
|
| 204 |
+
'vocab_size': self.config.vocab_size,
|
| 205 |
+
'n_layer': self.config.n_layer,
|
| 206 |
+
'n_head': self.config.n_head,
|
| 207 |
+
'n_kv_head': self.config.n_kv_head,
|
| 208 |
+
'n_embd': self.config.n_embd,
|
| 209 |
+
'intermediate_size': self.config.intermediate_size,
|
| 210 |
+
'rms_norm_eps': self.config.rms_norm_eps,
|
| 211 |
+
'rope_theta': self.config.rope_theta,
|
| 212 |
+
'dropout': self.config.dropout,
|
| 213 |
+
'bias': self.config.bias
|
| 214 |
+
}
|
| 215 |
+
}
|
| 216 |
+
torch.save(checkpoint, filepath)
|
| 217 |
+
print(f"Model saved to {filepath}")
|
| 218 |
+
|
| 219 |
+
@classmethod
|
| 220 |
+
def load_checkpoint(cls, filepath, device='cpu'):
|
| 221 |
+
"""Load model from checkpoint file"""
|
| 222 |
+
checkpoint = torch.load(filepath, map_location=device)
|
| 223 |
+
|
| 224 |
+
if isinstance(checkpoint, dict) and 'config' in checkpoint:
|
| 225 |
+
# Checkpoint contains config and state dict
|
| 226 |
+
config_dict = checkpoint['config']
|
| 227 |
+
config = SmolLMConfig(**config_dict)
|
| 228 |
+
state_dict = checkpoint['model_state_dict']
|
| 229 |
+
else:
|
| 230 |
+
# Checkpoint is likely just the state dict
|
| 231 |
+
print("Warning: Checkpoint does not contain config. Using default SmolLMConfig.")
|
| 232 |
+
config = SmolLMConfig()
|
| 233 |
+
state_dict = checkpoint
|
| 234 |
+
|
| 235 |
+
model = cls(config)
|
| 236 |
+
model.load_state_dict(state_dict)
|
| 237 |
+
model.to(device)
|
| 238 |
+
print(f"Model loaded from {filepath}")
|
| 239 |
+
return model
|
model/smollm_135_checkpoint.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:691cbc7dafd54bde47470447eb33f27ab2f6ac5e77cd45b71942aceeac678b85
|
| 3 |
+
size 540826259
|
profile_app.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
from inference import load_model, generate_text, get_device
|
| 4 |
+
|
| 5 |
+
def profile_app():
|
| 6 |
+
print("Profiling app performance...")
|
| 7 |
+
|
| 8 |
+
# 1. Measure Model Loading
|
| 9 |
+
start_load = time.time()
|
| 10 |
+
checkpoint_path = './model/smollm_135_checkpoint.pth'
|
| 11 |
+
model, device = load_model(model_path=checkpoint_path)
|
| 12 |
+
print(device)
|
| 13 |
+
end_load = time.time()
|
| 14 |
+
print(f"Model loading time: {end_load - start_load:.4f}s")
|
| 15 |
+
print(f"Device: {device}")
|
| 16 |
+
|
| 17 |
+
# 2. Measure Generation
|
| 18 |
+
prompt = "The future of AI is"
|
| 19 |
+
max_tokens = 50
|
| 20 |
+
|
| 21 |
+
start_gen = time.time()
|
| 22 |
+
output = generate_text(prompt, model, max_tokens=max_tokens, device=device)
|
| 23 |
+
end_gen = time.time()
|
| 24 |
+
|
| 25 |
+
duration = end_gen - start_gen
|
| 26 |
+
tokens_per_sec = max_tokens / duration
|
| 27 |
+
|
| 28 |
+
print(f"Generation time: {duration:.4f}s")
|
| 29 |
+
print(f"Tokens per second: {tokens_per_sec:.2f}")
|
| 30 |
+
print(f"Output: {output}")
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
profile_app()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "smollm-135"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Add your description here"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"gradio>=6.0.1",
|
| 9 |
+
"tiktoken>=0.12.0",
|
| 10 |
+
"torch>=2.9.1",
|
| 11 |
+
"transformers>=4.57.3",
|
| 12 |
+
]
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0.0
|
| 2 |
+
torch==2.0.0
|
| 3 |
+
transformers>=4.30.0
|
| 4 |
+
tiktoken>=0.5.0
|
test_inference.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from inference import load_model, generate_text
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def test_inference():
|
| 5 |
+
print("Testing inference...")
|
| 6 |
+
|
| 7 |
+
checkpoint_path = './model/smollm_135_checkpoint.pth'
|
| 8 |
+
if not os.path.exists(checkpoint_path):
|
| 9 |
+
print(f"Warning: Checkpoint not found at {checkpoint_path}")
|
| 10 |
+
checkpoint_path = None
|
| 11 |
+
|
| 12 |
+
# Load model
|
| 13 |
+
model, device = load_model(model_path=checkpoint_path)
|
| 14 |
+
print(f"Model loaded on {device}")
|
| 15 |
+
|
| 16 |
+
# Generate text
|
| 17 |
+
prompt = "Hello, world"
|
| 18 |
+
print(f"Generating text for prompt: '{prompt}'")
|
| 19 |
+
generated = generate_text(prompt, model, max_tokens=10, device=device)
|
| 20 |
+
print(f"Generated text: {generated}")
|
| 21 |
+
|
| 22 |
+
print("Inference test passed.")
|
| 23 |
+
|
| 24 |
+
if __name__ == "__main__":
|
| 25 |
+
test_inference()
|
test_kv_cache.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import time
|
| 3 |
+
from model import SmolLM, SmolLMConfig
|
| 4 |
+
from inference import load_model
|
| 5 |
+
|
| 6 |
+
def test_kv_performance():
|
| 7 |
+
print("Testing KV cache performance...")
|
| 8 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 9 |
+
print(f"Device: {device}")
|
| 10 |
+
|
| 11 |
+
# Load model (untrained is fine for performance test)
|
| 12 |
+
config = SmolLMConfig()
|
| 13 |
+
model = SmolLM(config).to(device)
|
| 14 |
+
model.eval()
|
| 15 |
+
|
| 16 |
+
prompt_len = 10
|
| 17 |
+
gen_len = 50
|
| 18 |
+
input_ids = torch.randint(0, config.vocab_size, (1, prompt_len)).to(device)
|
| 19 |
+
|
| 20 |
+
# Measure generation with KV cache
|
| 21 |
+
start_time = time.time()
|
| 22 |
+
past_kv = None
|
| 23 |
+
x = input_ids
|
| 24 |
+
|
| 25 |
+
with torch.no_grad():
|
| 26 |
+
# Prefill
|
| 27 |
+
_, _, past_kv = model(x)
|
| 28 |
+
|
| 29 |
+
# Generate
|
| 30 |
+
for _ in range(gen_len):
|
| 31 |
+
model_input = x[:, -1:]
|
| 32 |
+
_, _, past_kv = model(model_input, past_kv=past_kv)
|
| 33 |
+
# Dummy token selection
|
| 34 |
+
next_token = torch.tensor([[0]], device=device)
|
| 35 |
+
x = torch.cat([x, next_token], dim=1)
|
| 36 |
+
|
| 37 |
+
end_time = time.time()
|
| 38 |
+
duration = end_time - start_time
|
| 39 |
+
tokens_per_sec = gen_len / duration
|
| 40 |
+
|
| 41 |
+
print(f"Generated {gen_len} tokens in {duration:.4f}s")
|
| 42 |
+
print(f"Speed: {tokens_per_sec:.2f} tokens/sec")
|
| 43 |
+
|
| 44 |
+
if __name__ == "__main__":
|
| 45 |
+
test_kv_performance()
|
test_model.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from model import SmolLM, SmolLMConfig
|
| 3 |
+
|
| 4 |
+
def test_model():
|
| 5 |
+
print("Testing SmolLM model...")
|
| 6 |
+
config = SmolLMConfig()
|
| 7 |
+
model = SmolLM(config)
|
| 8 |
+
print("Model instantiated successfully.")
|
| 9 |
+
|
| 10 |
+
# Create dummy input
|
| 11 |
+
idx = torch.randint(0, config.vocab_size, (1, 128))
|
| 12 |
+
print(f"Input shape: {idx.shape}")
|
| 13 |
+
|
| 14 |
+
# Forward pass
|
| 15 |
+
logits, loss, _ = model(idx)
|
| 16 |
+
print(f"Logits shape: {logits.shape}")
|
| 17 |
+
|
| 18 |
+
assert logits.shape == (1, 128, config.vocab_size)
|
| 19 |
+
print("Forward pass successful.")
|
| 20 |
+
|
| 21 |
+
if __name__ == "__main__":
|
| 22 |
+
test_model()
|
test_tiktoken.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import tiktoken
|
| 3 |
+
|
| 4 |
+
def test_tiktoken():
|
| 5 |
+
start = time.time()
|
| 6 |
+
enc = tiktoken.get_encoding("gpt2")
|
| 7 |
+
end = time.time()
|
| 8 |
+
print(f"tiktoken load time: {end - start:.4f}s")
|
| 9 |
+
|
| 10 |
+
start = time.time()
|
| 11 |
+
enc.encode("Hello world")
|
| 12 |
+
end = time.time()
|
| 13 |
+
print(f"encoding time: {end - start:.4f}s")
|
| 14 |
+
|
| 15 |
+
if __name__ == "__main__":
|
| 16 |
+
test_tiktoken()
|