Krishnakanth1993 commited on
Commit
8345416
·
1 Parent(s): 32ba924

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual environments
24
+ venv/
25
+ env/
26
+ ENV/
27
+ .venv
28
+
29
+ # IDE
30
+ .vscode/
31
+ .idea/
32
+ *.swp
33
+ *.swo
34
+ *~
35
+
36
+
37
+
38
+ # Data files
39
+ input.txt
40
+ *.csv
41
+ *.json
42
+
43
+ # Jupyter Notebook
44
+ .ipynb_checkpoints/
45
+ *.ipynb_checkpoints/
46
+
47
+ # OS
48
+ .DS_Store
49
+ Thumbs.db
50
+
51
+ # Logs
52
+ *.log
53
+ logs/
54
+
55
+ # Hugging Face cache
56
+ .cache/
57
+ hf_cache/
58
+
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
README.md CHANGED
@@ -12,3 +12,11 @@ short_description: Smollm-135 base model trained with dummy data with speedups
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+ ## Running Locally
17
+
18
+ To run the application locally, use the following command:
19
+
20
+ ```bash
21
+ uv run app.py
22
+ ```
Smollm_135.ipynb ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "d8ad4585",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "using device: cuda\n",
14
+ "Model parameters: 135.18M\n",
15
+ "loaded 338025 tokens\n",
16
+ "1 epoch = 165 batches\n",
17
+ "Starting training...\n",
18
+ "step 0 | loss: 10.8725 | dt: 1274.74ms | tok/sec: 1606.61\n",
19
+ "step 10 | loss: 7.9941 | dt: 2626.35ms | tok/sec: 779.79\n",
20
+ "step 20 | loss: 6.9462 | dt: 2627.22ms | tok/sec: 779.53\n",
21
+ "step 30 | loss: 6.7740 | dt: 3246.48ms | tok/sec: 630.84\n",
22
+ "step 40 | loss: 6.9956 | dt: 2638.15ms | tok/sec: 776.30\n",
23
+ "step 50 | loss: 6.8735 | dt: 2615.98ms | tok/sec: 782.88\n",
24
+ "step 60 | loss: 6.5163 | dt: 2640.97ms | tok/sec: 775.47\n",
25
+ "step 70 | loss: 6.5162 | dt: 2637.36ms | tok/sec: 776.53\n",
26
+ "step 80 | loss: 6.4836 | dt: 2661.23ms | tok/sec: 769.57\n",
27
+ "step 90 | loss: 6.5255 | dt: 2643.36ms | tok/sec: 774.77\n",
28
+ "step 100 | loss: 6.2876 | dt: 2677.88ms | tok/sec: 764.78\n",
29
+ "step 110 | loss: 6.5272 | dt: 2666.03ms | tok/sec: 768.18\n",
30
+ "step 120 | loss: 6.1898 | dt: 2824.14ms | tok/sec: 725.18\n",
31
+ "step 130 | loss: 6.1057 | dt: 2646.03ms | tok/sec: 773.99\n",
32
+ "step 140 | loss: 5.9576 | dt: 2632.82ms | tok/sec: 777.87\n",
33
+ "step 150 | loss: 6.4212 | dt: 2637.88ms | tok/sec: 776.38\n",
34
+ "step 160 | loss: 6.4857 | dt: 2641.72ms | tok/sec: 775.25\n",
35
+ "step 170 | loss: 6.2381 | dt: 2624.55ms | tok/sec: 780.32\n",
36
+ "step 180 | loss: 5.6074 | dt: 2650.12ms | tok/sec: 772.80\n",
37
+ "step 190 | loss: 6.0601 | dt: 2657.61ms | tok/sec: 770.62\n",
38
+ "step 200 | loss: 5.5407 | dt: 2856.74ms | tok/sec: 716.90\n",
39
+ "step 210 | loss: 5.8250 | dt: 2647.57ms | tok/sec: 773.54\n",
40
+ "step 220 | loss: 6.0356 | dt: 2635.08ms | tok/sec: 777.21\n",
41
+ "step 230 | loss: 5.7742 | dt: 2637.31ms | tok/sec: 776.55\n",
42
+ "step 240 | loss: 5.8564 | dt: 2645.37ms | tok/sec: 774.18\n",
43
+ "step 250 | loss: 5.4802 | dt: 2660.91ms | tok/sec: 769.66\n",
44
+ "step 260 | loss: 5.6751 | dt: 2632.88ms | tok/sec: 777.85\n",
45
+ "step 270 | loss: 5.9273 | dt: 2733.56ms | tok/sec: 749.21\n",
46
+ "step 280 | loss: 5.9138 | dt: 2626.31ms | tok/sec: 779.80\n",
47
+ "step 290 | loss: 5.6861 | dt: 2638.67ms | tok/sec: 776.15\n",
48
+ "step 300 | loss: 5.2012 | dt: 2642.88ms | tok/sec: 774.91\n",
49
+ "step 310 | loss: 5.6114 | dt: 2649.90ms | tok/sec: 772.86\n",
50
+ "step 320 | loss: 5.0033 | dt: 2688.08ms | tok/sec: 761.88\n",
51
+ "step 330 | loss: 5.6259 | dt: 2682.14ms | tok/sec: 763.57\n",
52
+ "step 340 | loss: 5.1127 | dt: 2650.79ms | tok/sec: 772.60\n",
53
+ "step 350 | loss: 5.3045 | dt: 2678.00ms | tok/sec: 764.75\n",
54
+ "step 360 | loss: 5.2118 | dt: 2666.44ms | tok/sec: 768.06\n",
55
+ "step 370 | loss: 5.4723 | dt: 2639.73ms | tok/sec: 775.84\n",
56
+ "step 380 | loss: 5.4257 | dt: 2653.49ms | tok/sec: 771.81\n",
57
+ "step 390 | loss: 5.0813 | dt: 2623.92ms | tok/sec: 780.51\n",
58
+ "step 400 | loss: 5.0538 | dt: 2637.88ms | tok/sec: 776.38\n",
59
+ "step 410 | loss: 5.0351 | dt: 2708.26ms | tok/sec: 756.20\n",
60
+ "step 420 | loss: 5.0659 | dt: 2661.11ms | tok/sec: 769.60\n",
61
+ "step 430 | loss: 4.9364 | dt: 2685.40ms | tok/sec: 762.64\n",
62
+ "step 440 | loss: 5.3093 | dt: 2632.10ms | tok/sec: 778.08\n",
63
+ "step 450 | loss: 5.0675 | dt: 2653.00ms | tok/sec: 771.96\n",
64
+ "step 460 | loss: 4.9092 | dt: 2668.73ms | tok/sec: 767.41\n",
65
+ "step 470 | loss: 4.5785 | dt: 2700.87ms | tok/sec: 758.28\n",
66
+ "step 480 | loss: 5.1269 | dt: 2705.93ms | tok/sec: 756.86\n",
67
+ "step 490 | loss: 5.4856 | dt: 2653.12ms | tok/sec: 771.92\n",
68
+ "step 500 | loss: 5.2739 | dt: 2679.78ms | tok/sec: 764.24\n",
69
+ "\n",
70
+ "--- Generating text at step 500 ---\n",
71
+ "!ua any cons mocking\n",
72
+ "I they have we be passion so his all\n",
73
+ "Redpt' grip:' the grave,An thy work!\n",
74
+ "For embr, old Capitol, goodout than coming,\n",
75
+ "When thet vir Rome daughter to the chance\n",
76
+ "-----------------------------------\n",
77
+ "\n",
78
+ "step 510 | loss: 4.6435 | dt: 2697.87ms | tok/sec: 759.12\n",
79
+ "step 520 | loss: 4.9901 | dt: 2688.89ms | tok/sec: 761.65\n",
80
+ "step 530 | loss: 4.5844 | dt: 2703.14ms | tok/sec: 757.64\n",
81
+ "step 540 | loss: 4.9779 | dt: 2663.58ms | tok/sec: 768.89\n",
82
+ "step 550 | loss: 5.1966 | dt: 2714.69ms | tok/sec: 754.42\n",
83
+ "step 560 | loss: 4.9851 | dt: 2695.66ms | tok/sec: 759.74\n",
84
+ "step 570 | loss: 5.2129 | dt: 2698.44ms | tok/sec: 758.96\n",
85
+ "step 580 | loss: 4.7397 | dt: 2714.69ms | tok/sec: 754.41\n",
86
+ "step 590 | loss: 4.9552 | dt: 2692.72ms | tok/sec: 760.57\n",
87
+ "step 600 | loss: 5.2714 | dt: 2713.71ms | tok/sec: 754.69\n",
88
+ "step 610 | loss: 5.2913 | dt: 2653.37ms | tok/sec: 771.85\n",
89
+ "step 620 | loss: 5.0200 | dt: 2701.37ms | tok/sec: 758.13\n",
90
+ "step 630 | loss: 4.3609 | dt: 2689.80ms | tok/sec: 761.39\n",
91
+ "step 640 | loss: 4.9107 | dt: 2769.70ms | tok/sec: 739.43\n",
92
+ "step 650 | loss: 4.3624 | dt: 2707.07ms | tok/sec: 756.54\n",
93
+ "step 660 | loss: 5.0022 | dt: 2720.41ms | tok/sec: 752.83\n",
94
+ "step 670 | loss: 4.4913 | dt: 2680.64ms | tok/sec: 764.00\n",
95
+ "step 680 | loss: 4.7648 | dt: 2673.28ms | tok/sec: 766.10\n",
96
+ "step 690 | loss: 4.6267 | dt: 2671.13ms | tok/sec: 766.72\n",
97
+ "step 700 | loss: 4.8468 | dt: 2683.00ms | tok/sec: 763.33\n",
98
+ "step 710 | loss: 4.8544 | dt: 2678.93ms | tok/sec: 764.48\n",
99
+ "step 720 | loss: 4.5148 | dt: 2698.22ms | tok/sec: 759.02\n",
100
+ "step 730 | loss: 4.4280 | dt: 2694.01ms | tok/sec: 760.21\n",
101
+ "step 740 | loss: 4.4265 | dt: 2681.62ms | tok/sec: 763.72\n",
102
+ "step 750 | loss: 4.4757 | dt: 2671.03ms | tok/sec: 766.75\n",
103
+ "step 760 | loss: 4.3867 | dt: 2712.16ms | tok/sec: 755.12\n",
104
+ "step 770 | loss: 4.8252 | dt: 2681.40ms | tok/sec: 763.78\n",
105
+ "step 780 | loss: 4.6916 | dt: 2861.17ms | tok/sec: 715.79\n",
106
+ "step 790 | loss: 4.3555 | dt: 2682.62ms | tok/sec: 763.43\n",
107
+ "step 800 | loss: 4.0581 | dt: 2695.10ms | tok/sec: 759.90\n",
108
+ "step 810 | loss: 4.5024 | dt: 2718.02ms | tok/sec: 753.49\n",
109
+ "step 820 | loss: 4.9491 | dt: 2688.01ms | tok/sec: 761.90\n",
110
+ "step 830 | loss: 4.7404 | dt: 2688.48ms | tok/sec: 761.77\n",
111
+ "step 840 | loss: 4.1571 | dt: 2683.11ms | tok/sec: 763.29\n",
112
+ "step 850 | loss: 4.2970 | dt: 2673.24ms | tok/sec: 766.11\n",
113
+ "step 860 | loss: 4.1351 | dt: 2673.66ms | tok/sec: 765.99\n",
114
+ "step 870 | loss: 4.5339 | dt: 2723.79ms | tok/sec: 751.89\n",
115
+ "step 880 | loss: 4.7270 | dt: 2655.06ms | tok/sec: 771.36\n",
116
+ "step 890 | loss: 4.5174 | dt: 2654.60ms | tok/sec: 771.49\n",
117
+ "step 900 | loss: 4.7254 | dt: 2671.42ms | tok/sec: 766.63\n",
118
+ "step 910 | loss: 4.2173 | dt: 2706.20ms | tok/sec: 756.78\n",
119
+ "step 920 | loss: 4.4660 | dt: 2687.33ms | tok/sec: 762.09\n",
120
+ "step 930 | loss: 4.8292 | dt: 2668.57ms | tok/sec: 767.45\n",
121
+ "step 940 | loss: 4.8714 | dt: 2681.79ms | tok/sec: 763.67\n",
122
+ "step 950 | loss: 4.5977 | dt: 2689.44ms | tok/sec: 761.50\n",
123
+ "step 960 | loss: 3.9432 | dt: 2693.87ms | tok/sec: 760.24\n",
124
+ "step 970 | loss: 4.4277 | dt: 2685.85ms | tok/sec: 762.51\n",
125
+ "step 980 | loss: 3.9359 | dt: 2665.88ms | tok/sec: 768.23\n",
126
+ "step 990 | loss: 4.6091 | dt: 2693.13ms | tok/sec: 760.45\n",
127
+ "step 1000 | loss: 4.0868 | dt: 2702.22ms | tok/sec: 757.90\n",
128
+ "\n",
129
+ "--- Generating text at step 1000 ---\n",
130
+ "! then he did you\n",
131
+ "And pray them have be devour carries meet\n",
132
+ "And dare man, general's true.\n",
133
+ "\n",
134
+ "LidINIUS:\n",
135
+ "'I can I hope with you were access with call me.\n",
136
+ "\n",
137
+ "HONAN\n",
138
+ "-----------------------------------\n",
139
+ "\n",
140
+ "step 1010 | loss: 4.4185 | dt: 2691.11ms | tok/sec: 761.03\n",
141
+ "step 1020 | loss: 4.2212 | dt: 2687.50ms | tok/sec: 762.05\n",
142
+ "step 1030 | loss: 4.4616 | dt: 2670.47ms | tok/sec: 766.91\n",
143
+ "step 1040 | loss: 4.4004 | dt: 2662.88ms | tok/sec: 769.09\n",
144
+ "step 1050 | loss: 4.1682 | dt: 2705.36ms | tok/sec: 757.02\n",
145
+ "step 1060 | loss: 4.0242 | dt: 2722.96ms | tok/sec: 752.12\n",
146
+ "step 1070 | loss: 4.0509 | dt: 2717.44ms | tok/sec: 753.65\n",
147
+ "step 1080 | loss: 4.0544 | dt: 2667.92ms | tok/sec: 767.64\n",
148
+ "step 1090 | loss: 4.0126 | dt: 2740.84ms | tok/sec: 747.22\n",
149
+ "step 1100 | loss: 4.4749 | dt: 2762.05ms | tok/sec: 741.48\n",
150
+ "step 1110 | loss: 4.3578 | dt: 2634.24ms | tok/sec: 777.45\n",
151
+ "step 1120 | loss: 4.0779 | dt: 2677.47ms | tok/sec: 764.90\n",
152
+ "step 1130 | loss: 3.7411 | dt: 2685.44ms | tok/sec: 762.63\n",
153
+ "step 1140 | loss: 4.1268 | dt: 2690.60ms | tok/sec: 761.17\n",
154
+ "step 1150 | loss: 4.5661 | dt: 2719.96ms | tok/sec: 752.95\n",
155
+ "step 1160 | loss: 4.3289 | dt: 2666.21ms | tok/sec: 768.13\n",
156
+ "step 1170 | loss: 3.8129 | dt: 2671.28ms | tok/sec: 766.67\n",
157
+ "step 1180 | loss: 3.8706 | dt: 2717.71ms | tok/sec: 753.58\n",
158
+ "step 1190 | loss: 3.8226 | dt: 2669.59ms | tok/sec: 767.16\n",
159
+ "step 1200 | loss: 4.1762 | dt: 2666.23ms | tok/sec: 768.13\n",
160
+ "step 1210 | loss: 4.3757 | dt: 2708.58ms | tok/sec: 756.12\n",
161
+ "step 1220 | loss: 4.1351 | dt: 2685.38ms | tok/sec: 762.65\n",
162
+ "step 1230 | loss: 4.3564 | dt: 2687.16ms | tok/sec: 762.14\n",
163
+ "step 1240 | loss: 3.7988 | dt: 2675.93ms | tok/sec: 765.34\n",
164
+ "step 1250 | loss: 4.1403 | dt: 2686.86ms | tok/sec: 762.23\n",
165
+ "step 1260 | loss: 4.4226 | dt: 2667.90ms | tok/sec: 767.64\n",
166
+ "step 1270 | loss: 4.4887 | dt: 2739.30ms | tok/sec: 747.64\n",
167
+ "step 1280 | loss: 4.2127 | dt: 2691.33ms | tok/sec: 760.96\n",
168
+ "step 1290 | loss: 3.6873 | dt: 2672.71ms | tok/sec: 766.26\n",
169
+ "step 1300 | loss: 4.0008 | dt: 2696.75ms | tok/sec: 759.43\n",
170
+ "step 1310 | loss: 3.6128 | dt: 2662.96ms | tok/sec: 769.07\n",
171
+ "step 1320 | loss: 4.2900 | dt: 2656.27ms | tok/sec: 771.01\n",
172
+ "step 1330 | loss: 3.8012 | dt: 2660.07ms | tok/sec: 769.91\n",
173
+ "step 1340 | loss: 4.1523 | dt: 2642.03ms | tok/sec: 775.16\n",
174
+ "step 1350 | loss: 3.9300 | dt: 2683.45ms | tok/sec: 763.20\n",
175
+ "step 1360 | loss: 4.1542 | dt: 2639.64ms | tok/sec: 775.86\n",
176
+ "step 1370 | loss: 4.0666 | dt: 2660.70ms | tok/sec: 769.72\n",
177
+ "step 1380 | loss: 3.8962 | dt: 2655.70ms | tok/sec: 771.17\n",
178
+ "step 1390 | loss: 3.7736 | dt: 2637.29ms | tok/sec: 776.55\n",
179
+ "step 1400 | loss: 3.7997 | dt: 2647.38ms | tok/sec: 773.59\n",
180
+ "step 1410 | loss: 3.7525 | dt: 2629.39ms | tok/sec: 778.89\n",
181
+ "step 1420 | loss: 3.6936 | dt: 2688.40ms | tok/sec: 761.79\n",
182
+ "step 1430 | loss: 4.1674 | dt: 2638.92ms | tok/sec: 776.07\n",
183
+ "step 1440 | loss: 4.1046 | dt: 2648.60ms | tok/sec: 773.24\n",
184
+ "step 1450 | loss: 3.8131 | dt: 2636.89ms | tok/sec: 776.67\n",
185
+ "step 1460 | loss: 3.4866 | dt: 2659.35ms | tok/sec: 770.11\n",
186
+ "step 1470 | loss: 3.7860 | dt: 2638.77ms | tok/sec: 776.12\n",
187
+ "step 1480 | loss: 4.1924 | dt: 2632.46ms | tok/sec: 777.98\n",
188
+ "step 1490 | loss: 4.0126 | dt: 2635.47ms | tok/sec: 777.09\n",
189
+ "step 1500 | loss: 3.4828 | dt: 2633.86ms | tok/sec: 777.57\n",
190
+ "\n",
191
+ "--- Generating text at step 1500 ---\n",
192
+ "! we rid,\n",
193
+ "my thing was so done.\n",
194
+ "\n",
195
+ "GLUUS:\n",
196
+ "F mercy in mocked itUS;\n",
197
+ "But rest at Rome,iances for him: but\n",
198
+ "Masters are worse, proud to the people,\n",
199
+ "And for your\n",
200
+ "-----------------------------------\n",
201
+ "\n",
202
+ "step 1510 | loss: 3.6085 | dt: 2699.83ms | tok/sec: 758.57\n",
203
+ "step 1520 | loss: 3.5473 | dt: 2703.50ms | tok/sec: 757.54\n",
204
+ "step 1530 | loss: 3.8369 | dt: 2704.69ms | tok/sec: 757.20\n",
205
+ "step 1540 | loss: 4.0431 | dt: 2699.75ms | tok/sec: 758.59\n",
206
+ "step 1550 | loss: 3.8126 | dt: 2697.98ms | tok/sec: 759.09\n",
207
+ "step 1560 | loss: 4.0581 | dt: 2708.62ms | tok/sec: 756.10\n",
208
+ "step 1570 | loss: 3.4292 | dt: 2701.01ms | tok/sec: 758.23\n",
209
+ "step 1580 | loss: 3.8859 | dt: 2701.13ms | tok/sec: 758.20\n",
210
+ "step 1590 | loss: 4.0481 | dt: 2708.01ms | tok/sec: 756.27\n",
211
+ "step 1600 | loss: 4.1384 | dt: 2723.99ms | tok/sec: 751.84\n",
212
+ "step 1610 | loss: 3.9161 | dt: 2718.54ms | tok/sec: 753.35\n",
213
+ "step 1620 | loss: 3.3759 | dt: 2728.75ms | tok/sec: 750.53\n",
214
+ "step 1630 | loss: 3.6518 | dt: 2700.32ms | tok/sec: 758.43\n",
215
+ "step 1640 | loss: 3.3208 | dt: 2696.71ms | tok/sec: 759.44\n",
216
+ "step 1650 | loss: 3.9156 | dt: 2707.93ms | tok/sec: 756.30\n",
217
+ "step 1660 | loss: 3.5364 | dt: 2725.21ms | tok/sec: 751.50\n",
218
+ "step 1670 | loss: 3.8675 | dt: 2708.14ms | tok/sec: 756.24\n",
219
+ "step 1680 | loss: 3.6225 | dt: 2702.13ms | tok/sec: 757.92\n",
220
+ "step 1690 | loss: 3.8710 | dt: 2694.93ms | tok/sec: 759.95\n",
221
+ "step 1700 | loss: 3.7588 | dt: 2696.90ms | tok/sec: 759.39\n",
222
+ "step 1710 | loss: 3.6354 | dt: 2727.15ms | tok/sec: 750.97\n",
223
+ "step 1720 | loss: 3.5004 | dt: 2710.80ms | tok/sec: 755.50\n",
224
+ "step 1730 | loss: 3.5569 | dt: 2736.50ms | tok/sec: 748.40\n",
225
+ "step 1740 | loss: 3.4937 | dt: 2700.68ms | tok/sec: 758.33\n",
226
+ "step 1750 | loss: 3.4585 | dt: 2704.23ms | tok/sec: 757.33\n",
227
+ "step 1760 | loss: 3.8912 | dt: 2718.99ms | tok/sec: 753.22\n",
228
+ "step 1770 | loss: 3.9121 | dt: 2759.54ms | tok/sec: 742.15\n",
229
+ "step 1780 | loss: 3.5978 | dt: 2701.65ms | tok/sec: 758.06\n",
230
+ "step 1790 | loss: 3.2438 | dt: 2705.15ms | tok/sec: 757.08\n",
231
+ "step 1800 | loss: 3.4297 | dt: 2697.91ms | tok/sec: 759.11\n",
232
+ "step 1810 | loss: 3.8631 | dt: 2723.87ms | tok/sec: 751.87\n",
233
+ "step 1820 | loss: 3.7029 | dt: 2717.25ms | tok/sec: 753.70\n",
234
+ "step 1830 | loss: 3.2425 | dt: 2727.89ms | tok/sec: 750.76\n",
235
+ "step 1840 | loss: 3.3572 | dt: 2722.87ms | tok/sec: 752.15\n",
236
+ "step 1850 | loss: 3.2716 | dt: 2707.57ms | tok/sec: 756.40\n",
237
+ "step 1860 | loss: 3.5134 | dt: 2699.27ms | tok/sec: 758.73\n",
238
+ "step 1870 | loss: 3.7097 | dt: 2723.70ms | tok/sec: 751.92\n",
239
+ "step 1880 | loss: 3.5368 | dt: 2721.65ms | tok/sec: 752.48\n",
240
+ "step 1890 | loss: 3.7578 | dt: 2698.47ms | tok/sec: 758.95\n",
241
+ "step 1900 | loss: 3.2360 | dt: 2703.24ms | tok/sec: 757.61\n",
242
+ "step 1910 | loss: 3.5294 | dt: 2695.89ms | tok/sec: 759.68\n",
243
+ "step 1920 | loss: 3.7359 | dt: 2723.54ms | tok/sec: 751.96\n",
244
+ "step 1930 | loss: 3.8138 | dt: 2696.43ms | tok/sec: 759.52\n",
245
+ "step 1940 | loss: 3.6550 | dt: 2704.04ms | tok/sec: 757.38\n",
246
+ "step 1950 | loss: 3.0794 | dt: 2704.37ms | tok/sec: 757.29\n",
247
+ "step 1960 | loss: 3.3290 | dt: 2702.23ms | tok/sec: 757.89\n",
248
+ "step 1970 | loss: 3.0204 | dt: 2734.40ms | tok/sec: 748.98\n",
249
+ "step 1980 | loss: 3.5516 | dt: 2703.91ms | tok/sec: 757.42\n",
250
+ "step 1990 | loss: 3.3070 | dt: 2706.01ms | tok/sec: 756.83\n",
251
+ "step 2000 | loss: 3.5758 | dt: 2726.00ms | tok/sec: 751.28\n",
252
+ "\n",
253
+ "--- Generating text at step 2000 ---\n",
254
+ "!,\n",
255
+ "if I put thee due by! see I do not weep\n",
256
+ "Thy thing in this heels hits them,\n",
257
+ "That neverYour valour should be in pride,\n",
258
+ "Call him more deity tears;Signiorio\n",
259
+ "In time of tears\n",
260
+ "-----------------------------------\n",
261
+ "\n",
262
+ "step 2010 | loss: 3.3916 | dt: 2700.63ms | tok/sec: 758.34\n",
263
+ "step 2020 | loss: 3.5364 | dt: 2709.72ms | tok/sec: 755.80\n",
264
+ "step 2030 | loss: 3.4518 | dt: 2708.71ms | tok/sec: 756.08\n",
265
+ "step 2040 | loss: 3.2968 | dt: 2699.77ms | tok/sec: 758.58\n",
266
+ "step 2050 | loss: 3.1897 | dt: 2705.14ms | tok/sec: 757.08\n",
267
+ "step 2060 | loss: 3.2544 | dt: 2695.59ms | tok/sec: 759.76\n",
268
+ "step 2070 | loss: 3.1843 | dt: 2699.45ms | tok/sec: 758.67\n",
269
+ "step 2080 | loss: 3.1527 | dt: 2697.34ms | tok/sec: 759.27\n",
270
+ "step 2090 | loss: 3.5361 | dt: 2698.70ms | tok/sec: 758.88\n",
271
+ "step 2100 | loss: 3.6072 | dt: 2698.23ms | tok/sec: 759.02\n",
272
+ "step 2110 | loss: 3.2871 | dt: 2697.04ms | tok/sec: 759.35\n",
273
+ "step 2120 | loss: 2.9393 | dt: 2715.87ms | tok/sec: 754.09\n",
274
+ "step 2130 | loss: 3.0962 | dt: 2706.66ms | tok/sec: 756.65\n",
275
+ "step 2140 | loss: 3.4452 | dt: 2705.67ms | tok/sec: 756.93\n",
276
+ "step 2150 | loss: 3.2900 | dt: 2707.68ms | tok/sec: 756.37\n",
277
+ "step 2160 | loss: 2.9626 | dt: 2707.08ms | tok/sec: 756.54\n",
278
+ "step 2170 | loss: 3.0681 | dt: 2702.33ms | tok/sec: 757.86\n",
279
+ "step 2180 | loss: 2.9747 | dt: 2702.98ms | tok/sec: 757.68\n",
280
+ "step 2190 | loss: 3.1063 | dt: 2716.67ms | tok/sec: 753.87\n",
281
+ "step 2200 | loss: 3.3207 | dt: 2703.48ms | tok/sec: 757.54\n",
282
+ "step 2210 | loss: 3.1918 | dt: 2701.67ms | tok/sec: 758.05\n",
283
+ "step 2220 | loss: 3.3162 | dt: 2703.80ms | tok/sec: 757.45\n",
284
+ "step 2230 | loss: 2.8865 | dt: 2705.55ms | tok/sec: 756.96\n",
285
+ "step 2240 | loss: 3.1759 | dt: 2717.99ms | tok/sec: 753.50\n",
286
+ "step 2250 | loss: 3.3846 | dt: 2703.64ms | tok/sec: 757.50\n",
287
+ "step 2260 | loss: 3.3697 | dt: 2732.82ms | tok/sec: 749.41\n",
288
+ "step 2270 | loss: 3.2227 | dt: 2703.69ms | tok/sec: 757.48\n",
289
+ "step 2280 | loss: 2.8169 | dt: 2706.62ms | tok/sec: 756.66\n",
290
+ "step 2290 | loss: 2.9860 | dt: 2700.74ms | tok/sec: 758.31\n",
291
+ "step 2300 | loss: 2.7406 | dt: 2711.47ms | tok/sec: 755.31\n",
292
+ "step 2310 | loss: 3.1365 | dt: 2706.79ms | tok/sec: 756.62\n",
293
+ "step 2320 | loss: 2.9496 | dt: 2699.63ms | tok/sec: 758.62\n",
294
+ "step 2330 | loss: 3.2225 | dt: 2702.71ms | tok/sec: 757.76\n",
295
+ "step 2340 | loss: 3.0330 | dt: 2710.74ms | tok/sec: 755.51\n",
296
+ "step 2350 | loss: 3.1792 | dt: 2704.27ms | tok/sec: 757.32\n",
297
+ "step 2360 | loss: 3.0794 | dt: 2709.94ms | tok/sec: 755.74\n",
298
+ "step 2370 | loss: 2.9420 | dt: 2697.96ms | tok/sec: 759.09\n",
299
+ "step 2380 | loss: 2.8587 | dt: 2695.43ms | tok/sec: 759.80\n",
300
+ "step 2390 | loss: 2.8747 | dt: 2695.93ms | tok/sec: 759.66\n",
301
+ "step 2400 | loss: 2.8319 | dt: 2711.99ms | tok/sec: 755.16\n",
302
+ "step 2410 | loss: 2.8368 | dt: 2736.55ms | tok/sec: 748.39\n",
303
+ "step 2420 | loss: 3.1382 | dt: 2693.63ms | tok/sec: 760.31\n",
304
+ "step 2430 | loss: 3.2540 | dt: 2703.97ms | tok/sec: 757.41\n",
305
+ "step 2440 | loss: 2.8659 | dt: 2702.61ms | tok/sec: 757.79\n",
306
+ "step 2450 | loss: 2.6254 | dt: 2700.30ms | tok/sec: 758.43\n",
307
+ "step 2460 | loss: 2.7329 | dt: 2706.12ms | tok/sec: 756.80\n",
308
+ "step 2470 | loss: 3.0561 | dt: 2704.39ms | tok/sec: 757.29\n",
309
+ "step 2480 | loss: 2.8807 | dt: 2705.39ms | tok/sec: 757.01\n",
310
+ "step 2490 | loss: 2.6715 | dt: 2699.49ms | tok/sec: 758.66\n",
311
+ "step 2500 | loss: 2.6661 | dt: 2701.31ms | tok/sec: 758.15\n",
312
+ "\n",
313
+ "--- Generating text at step 2500 ---\n",
314
+ "! wasBoy\n",
315
+ "'s one heinous to purchase with us and fold'd\n",
316
+ "Than is of lips for your time his son,\n",
317
+ "Which my conscience did forsworn,\n",
318
+ "It like a friendly that owe'd their go\n",
319
+ "But, afterced his\n",
320
+ "-----------------------------------\n",
321
+ "\n",
322
+ "step 2510 | loss: 2.6807 | dt: 2698.79ms | tok/sec: 758.86\n",
323
+ "step 2520 | loss: 2.8045 | dt: 2696.69ms | tok/sec: 759.45\n",
324
+ "step 2530 | loss: 2.9238 | dt: 2701.84ms | tok/sec: 758.00\n",
325
+ "step 2540 | loss: 2.7829 | dt: 2696.82ms | tok/sec: 759.41\n",
326
+ "step 2550 | loss: 2.8688 | dt: 2713.57ms | tok/sec: 754.72\n",
327
+ "step 2560 | loss: 2.5317 | dt: 2700.27ms | tok/sec: 758.44\n",
328
+ "step 2570 | loss: 2.8389 | dt: 2694.92ms | tok/sec: 759.95\n",
329
+ "step 2580 | loss: 2.8973 | dt: 2700.89ms | tok/sec: 758.27\n",
330
+ "step 2590 | loss: 2.9376 | dt: 2701.52ms | tok/sec: 758.09\n",
331
+ "step 2600 | loss: 2.8337 | dt: 2734.45ms | tok/sec: 748.96\n",
332
+ "step 2610 | loss: 2.3944 | dt: 2696.95ms | tok/sec: 759.38\n",
333
+ "step 2620 | loss: 2.5720 | dt: 2705.58ms | tok/sec: 756.95\n",
334
+ "step 2630 | loss: 2.3286 | dt: 2726.76ms | tok/sec: 751.07\n",
335
+ "step 2640 | loss: 2.6583 | dt: 2712.49ms | tok/sec: 755.03\n",
336
+ "step 2650 | loss: 2.5653 | dt: 2706.73ms | tok/sec: 756.63\n",
337
+ "step 2660 | loss: 2.8053 | dt: 2709.10ms | tok/sec: 755.97\n",
338
+ "step 2670 | loss: 2.6431 | dt: 2696.00ms | tok/sec: 759.64\n",
339
+ "step 2680 | loss: 2.6953 | dt: 2746.24ms | tok/sec: 745.75\n",
340
+ "step 2690 | loss: 2.6361 | dt: 2716.29ms | tok/sec: 753.97\n",
341
+ "step 2700 | loss: 2.4590 | dt: 2774.52ms | tok/sec: 738.15\n",
342
+ "step 2710 | loss: 2.3923 | dt: 2749.37ms | tok/sec: 744.90\n",
343
+ "step 2720 | loss: 2.4924 | dt: 2701.07ms | tok/sec: 758.22\n",
344
+ "step 2730 | loss: 2.4631 | dt: 2711.96ms | tok/sec: 755.17\n",
345
+ "step 2740 | loss: 2.4050 | dt: 2718.98ms | tok/sec: 753.22\n",
346
+ "step 2750 | loss: 2.6518 | dt: 2697.21ms | tok/sec: 759.30\n",
347
+ "step 2760 | loss: 2.7260 | dt: 2876.82ms | tok/sec: 711.90\n",
348
+ "step 2770 | loss: 2.3956 | dt: 2746.34ms | tok/sec: 745.72\n",
349
+ "step 2780 | loss: 2.1536 | dt: 2684.97ms | tok/sec: 762.76\n",
350
+ "step 2790 | loss: 2.2741 | dt: 2954.08ms | tok/sec: 693.28\n",
351
+ "step 2800 | loss: 2.5598 | dt: 2906.71ms | tok/sec: 704.58\n",
352
+ "step 2810 | loss: 2.2911 | dt: 2680.05ms | tok/sec: 764.16\n",
353
+ "step 2820 | loss: 2.1884 | dt: 2682.93ms | tok/sec: 763.34\n",
354
+ "step 2830 | loss: 2.2343 | dt: 2675.76ms | tok/sec: 765.39\n",
355
+ "step 2840 | loss: 2.2411 | dt: 2695.51ms | tok/sec: 759.78\n",
356
+ "step 2850 | loss: 2.2290 | dt: 2864.33ms | tok/sec: 715.00\n",
357
+ "step 2860 | loss: 2.2985 | dt: 2659.26ms | tok/sec: 770.14\n",
358
+ "step 2870 | loss: 2.2090 | dt: 2667.58ms | tok/sec: 767.74\n",
359
+ "step 2880 | loss: 2.2442 | dt: 2657.60ms | tok/sec: 770.62\n",
360
+ "step 2890 | loss: 2.0802 | dt: 2673.51ms | tok/sec: 766.03\n",
361
+ "step 2900 | loss: 2.2653 | dt: 2668.15ms | tok/sec: 767.57\n",
362
+ "step 2910 | loss: 2.2835 | dt: 2662.19ms | tok/sec: 769.29\n",
363
+ "step 2920 | loss: 2.1685 | dt: 2660.40ms | tok/sec: 769.81\n",
364
+ "step 2930 | loss: 2.1162 | dt: 2660.61ms | tok/sec: 769.75\n",
365
+ "step 2940 | loss: 1.8274 | dt: 2661.90ms | tok/sec: 769.38\n",
366
+ "step 2950 | loss: 1.8564 | dt: 2703.16ms | tok/sec: 757.63\n",
367
+ "step 2960 | loss: 1.7792 | dt: 2658.62ms | tok/sec: 770.32\n",
368
+ "step 2970 | loss: 2.0663 | dt: 2662.07ms | tok/sec: 769.33\n",
369
+ "step 2980 | loss: 1.9452 | dt: 2660.27ms | tok/sec: 769.85\n",
370
+ "step 2990 | loss: 2.1067 | dt: 2664.71ms | tok/sec: 768.56\n",
371
+ "step 3000 | loss: 1.9999 | dt: 2669.69ms | tok/sec: 767.13\n",
372
+ "\n",
373
+ "--- Generating text at step 3000 ---\n",
374
+ "! thrive lives, the rage d\n",
375
+ "Endiciansed in the wind O very danger!\n",
376
+ "\n",
377
+ "First Murderer:\n",
378
+ "What commend thee well that are this!\n",
379
+ "\n",
380
+ "First Murderer:\n",
381
+ "On my poor lord,\n",
382
+ "I'll follow thee better\n",
383
+ "-----------------------------------\n",
384
+ "\n",
385
+ "step 3010 | loss: 1.9782 | dt: 2669.18ms | tok/sec: 767.28\n",
386
+ "step 3020 | loss: 1.9162 | dt: 2816.11ms | tok/sec: 727.24\n",
387
+ "step 3030 | loss: 1.7600 | dt: 2748.60ms | tok/sec: 745.11\n",
388
+ "step 3040 | loss: 1.7559 | dt: 3330.35ms | tok/sec: 614.95\n",
389
+ "step 3050 | loss: 1.8405 | dt: 2766.42ms | tok/sec: 740.31\n",
390
+ "step 3060 | loss: 1.8245 | dt: 2675.78ms | tok/sec: 765.38\n",
391
+ "step 3070 | loss: 1.7706 | dt: 2834.19ms | tok/sec: 722.61\n",
392
+ "step 3080 | loss: 1.8807 | dt: 2701.07ms | tok/sec: 758.22\n",
393
+ "step 3090 | loss: 1.9564 | dt: 2774.91ms | tok/sec: 738.04\n",
394
+ "step 3100 | loss: 1.6477 | dt: 2787.13ms | tok/sec: 734.81\n",
395
+ "step 3110 | loss: 1.5303 | dt: 2768.86ms | tok/sec: 739.65\n",
396
+ "step 3120 | loss: 1.5144 | dt: 2668.25ms | tok/sec: 767.54\n",
397
+ "step 3130 | loss: 1.6874 | dt: 2684.58ms | tok/sec: 762.87\n",
398
+ "step 3140 | loss: 1.5967 | dt: 2672.88ms | tok/sec: 766.22\n",
399
+ "step 3150 | loss: 1.4894 | dt: 2685.02ms | tok/sec: 762.75\n",
400
+ "step 3160 | loss: 1.5369 | dt: 2780.46ms | tok/sec: 736.57\n",
401
+ "step 3170 | loss: 1.5521 | dt: 2766.64ms | tok/sec: 740.25\n",
402
+ "step 3180 | loss: 1.5144 | dt: 2710.92ms | tok/sec: 755.46\n",
403
+ "step 3190 | loss: 1.5083 | dt: 2749.40ms | tok/sec: 744.89\n",
404
+ "step 3200 | loss: 1.4799 | dt: 2678.04ms | tok/sec: 764.74\n",
405
+ "step 3210 | loss: 1.5245 | dt: 2707.56ms | tok/sec: 756.40\n",
406
+ "step 3220 | loss: 1.3773 | dt: 2688.40ms | tok/sec: 761.79\n",
407
+ "step 3230 | loss: 1.5153 | dt: 2684.65ms | tok/sec: 762.86\n",
408
+ "step 3240 | loss: 1.4807 | dt: 2753.05ms | tok/sec: 743.90\n",
409
+ "step 3250 | loss: 1.4735 | dt: 2755.05ms | tok/sec: 743.36\n",
410
+ "step 3260 | loss: 1.3735 | dt: 2719.84ms | tok/sec: 752.98\n",
411
+ "step 3270 | loss: 1.2075 | dt: 2843.99ms | tok/sec: 720.12\n",
412
+ "step 3280 | loss: 1.2305 | dt: 2703.84ms | tok/sec: 757.44\n",
413
+ "step 3290 | loss: 1.1408 | dt: 2739.76ms | tok/sec: 747.51\n",
414
+ "step 3300 | loss: 1.2520 | dt: 2892.35ms | tok/sec: 708.07\n",
415
+ "step 3310 | loss: 1.3056 | dt: 2716.91ms | tok/sec: 753.80\n",
416
+ "step 3320 | loss: 1.2803 | dt: 2681.74ms | tok/sec: 763.68\n",
417
+ "step 3330 | loss: 1.2850 | dt: 2725.97ms | tok/sec: 751.29\n",
418
+ "step 3340 | loss: 1.2028 | dt: 2692.16ms | tok/sec: 760.73\n",
419
+ "step 3350 | loss: 1.1584 | dt: 2698.27ms | tok/sec: 759.00\n",
420
+ "step 3360 | loss: 1.0396 | dt: 2998.41ms | tok/sec: 683.03\n",
421
+ "step 3370 | loss: 1.0607 | dt: 2903.31ms | tok/sec: 705.40\n",
422
+ "step 3380 | loss: 1.1164 | dt: 3033.53ms | tok/sec: 675.12\n",
423
+ "step 3390 | loss: 1.1108 | dt: 2886.95ms | tok/sec: 709.40\n",
424
+ "step 3400 | loss: 1.0315 | dt: 2752.63ms | tok/sec: 744.02\n",
425
+ "step 3410 | loss: 1.1562 | dt: 2726.22ms | tok/sec: 751.22\n",
426
+ "step 3420 | loss: 1.1321 | dt: 2708.28ms | tok/sec: 756.20\n",
427
+ "step 3430 | loss: 0.9760 | dt: 2681.26ms | tok/sec: 763.82\n",
428
+ "step 3440 | loss: 0.8877 | dt: 2717.40ms | tok/sec: 753.66\n",
429
+ "step 3450 | loss: 0.9401 | dt: 2842.15ms | tok/sec: 720.58\n",
430
+ "step 3460 | loss: 0.9696 | dt: 2802.88ms | tok/sec: 730.68\n",
431
+ "step 3470 | loss: 0.8294 | dt: 2881.42ms | tok/sec: 710.76\n",
432
+ "step 3480 | loss: 0.9166 | dt: 2741.99ms | tok/sec: 746.90\n",
433
+ "step 3490 | loss: 0.8187 | dt: 2768.42ms | tok/sec: 739.77\n",
434
+ "step 3500 | loss: 0.8012 | dt: 2875.77ms | tok/sec: 712.16\n",
435
+ "\n",
436
+ "--- Generating text at step 3500 ---\n",
437
+ "! who child is Rivers; all,saving thy father knows promise wish' moan'occnard's ease that Alban comes blaced boys.\n",
438
+ "\n",
439
+ "BUCKINGHAM:\n",
440
+ "A greater prince, a royal cousin doth make us\n",
441
+ "down; and\n",
442
+ "-----------------------------------\n",
443
+ "\n",
444
+ "step 3510 | loss: 0.8895 | dt: 2911.00ms | tok/sec: 703.54\n",
445
+ "step 3520 | loss: 0.8510 | dt: 2742.43ms | tok/sec: 746.78\n",
446
+ "step 3530 | loss: 0.7573 | dt: 2752.54ms | tok/sec: 744.04\n",
447
+ "step 3540 | loss: 0.7299 | dt: 2756.80ms | tok/sec: 742.89\n",
448
+ "step 3550 | loss: 0.7541 | dt: 2748.92ms | tok/sec: 745.02\n",
449
+ "step 3560 | loss: 0.7650 | dt: 2749.40ms | tok/sec: 744.89\n",
450
+ "step 3570 | loss: 0.6896 | dt: 2862.06ms | tok/sec: 715.57\n",
451
+ "step 3580 | loss: 0.6993 | dt: 2753.84ms | tok/sec: 743.69\n",
452
+ "step 3590 | loss: 0.6519 | dt: 2998.55ms | tok/sec: 683.00\n",
453
+ "step 3600 | loss: 0.6217 | dt: 2740.47ms | tok/sec: 747.32\n",
454
+ "step 3610 | loss: 0.6385 | dt: 2790.61ms | tok/sec: 733.89\n",
455
+ "step 3620 | loss: 0.5826 | dt: 2770.31ms | tok/sec: 739.27\n",
456
+ "step 3630 | loss: 0.6594 | dt: 2783.17ms | tok/sec: 735.85\n",
457
+ "step 3640 | loss: 0.5835 | dt: 2924.64ms | tok/sec: 700.26\n",
458
+ "step 3650 | loss: 0.6141 | dt: 2800.75ms | tok/sec: 731.23\n",
459
+ "step 3660 | loss: 0.6132 | dt: 2762.03ms | tok/sec: 741.48\n",
460
+ "step 3670 | loss: 0.5363 | dt: 2849.28ms | tok/sec: 718.78\n",
461
+ "step 3680 | loss: 0.5985 | dt: 2842.88ms | tok/sec: 720.40\n",
462
+ "step 3690 | loss: 0.4882 | dt: 2755.73ms | tok/sec: 743.18\n",
463
+ "step 3700 | loss: 0.4431 | dt: 2730.10ms | tok/sec: 750.16\n",
464
+ "step 3710 | loss: 0.4325 | dt: 2700.16ms | tok/sec: 758.47\n",
465
+ "step 3720 | loss: 0.4599 | dt: 2738.14ms | tok/sec: 747.95\n",
466
+ "step 3730 | loss: 0.4503 | dt: 2748.69ms | tok/sec: 745.08\n",
467
+ "step 3740 | loss: 0.4781 | dt: 2740.93ms | tok/sec: 747.19\n",
468
+ "step 3750 | loss: 0.5382 | dt: 2849.36ms | tok/sec: 718.76\n",
469
+ "step 3760 | loss: 0.3955 | dt: 2729.07ms | tok/sec: 750.44\n",
470
+ "step 3770 | loss: 0.3948 | dt: 2822.13ms | tok/sec: 725.69\n",
471
+ "step 3780 | loss: 0.4099 | dt: 2917.25ms | tok/sec: 702.03\n",
472
+ "step 3790 | loss: 0.4257 | dt: 2775.18ms | tok/sec: 737.97\n",
473
+ "step 3800 | loss: 0.3690 | dt: 2777.39ms | tok/sec: 737.38\n",
474
+ "step 3810 | loss: 0.3799 | dt: 2850.00ms | tok/sec: 718.60\n",
475
+ "step 3820 | loss: 0.3161 | dt: 2853.62ms | tok/sec: 717.69\n",
476
+ "step 3830 | loss: 0.3599 | dt: 2750.53ms | tok/sec: 744.58\n",
477
+ "step 3840 | loss: 0.3355 | dt: 2805.98ms | tok/sec: 729.87\n",
478
+ "step 3850 | loss: 0.3302 | dt: 3000.64ms | tok/sec: 682.52\n",
479
+ "step 3860 | loss: 0.3285 | dt: 2791.06ms | tok/sec: 733.77\n",
480
+ "step 3870 | loss: 0.2618 | dt: 2751.07ms | tok/sec: 744.44\n",
481
+ "step 3880 | loss: 0.3333 | dt: 2798.18ms | tok/sec: 731.90\n",
482
+ "step 3890 | loss: 0.2814 | dt: 2870.91ms | tok/sec: 713.36\n",
483
+ "step 3900 | loss: 0.2948 | dt: 2761.24ms | tok/sec: 741.70\n",
484
+ "step 3910 | loss: 0.2399 | dt: 2869.73ms | tok/sec: 713.66\n",
485
+ "step 3920 | loss: 0.2674 | dt: 2980.56ms | tok/sec: 687.12\n",
486
+ "step 3930 | loss: 0.2109 | dt: 2954.62ms | tok/sec: 693.15\n",
487
+ "step 3940 | loss: 0.2220 | dt: 2755.38ms | tok/sec: 743.27\n",
488
+ "step 3950 | loss: 0.2295 | dt: 2780.64ms | tok/sec: 736.52\n",
489
+ "step 3960 | loss: 0.2387 | dt: 2795.49ms | tok/sec: 732.61\n",
490
+ "step 3970 | loss: 0.2579 | dt: 3017.64ms | tok/sec: 678.68\n",
491
+ "step 3980 | loss: 0.2046 | dt: 2761.85ms | tok/sec: 741.53\n",
492
+ "step 3990 | loss: 0.2144 | dt: 2744.11ms | tok/sec: 746.33\n",
493
+ "step 4000 | loss: 0.2070 | dt: 2722.34ms | tok/sec: 752.29\n",
494
+ "\n",
495
+ "--- Generating text at step 4000 ---\n",
496
+ "!\n",
497
+ "If presently you will take the one\n",
498
+ "To what we will o\n",
499
+ "\n",
500
+ "'ll theretove him to known your lord.\n",
501
+ "\n",
502
+ "PRINCEE:\n",
503
+ "Indeed, my lord, safe still well we deserve\n",
504
+ "In your will.\n",
505
+ "\n",
506
+ "\n",
507
+ "-----------------------------------\n",
508
+ "\n",
509
+ "step 4010 | loss: 0.2144 | dt: 2902.75ms | tok/sec: 705.54\n",
510
+ "step 4020 | loss: 0.2045 | dt: 2945.05ms | tok/sec: 695.40\n",
511
+ "step 4030 | loss: 0.1790 | dt: 3056.03ms | tok/sec: 670.15\n",
512
+ "step 4040 | loss: 0.1516 | dt: 2747.85ms | tok/sec: 745.31\n",
513
+ "step 4050 | loss: 0.1615 | dt: 2805.18ms | tok/sec: 730.08\n",
514
+ "step 4060 | loss: 0.1421 | dt: 2749.10ms | tok/sec: 744.97\n",
515
+ "step 4070 | loss: 0.1674 | dt: 2859.66ms | tok/sec: 716.17\n",
516
+ "step 4080 | loss: 0.1633 | dt: 2847.74ms | tok/sec: 719.17\n",
517
+ "step 4090 | loss: 0.1570 | dt: 2916.81ms | tok/sec: 702.14\n",
518
+ "step 4100 | loss: 0.1592 | dt: 3016.92ms | tok/sec: 678.84\n",
519
+ "step 4110 | loss: 0.1363 | dt: 2727.48ms | tok/sec: 750.88\n",
520
+ "step 4120 | loss: 0.1797 | dt: 2730.33ms | tok/sec: 750.09\n",
521
+ "step 4130 | loss: 0.1210 | dt: 2699.00ms | tok/sec: 758.80\n",
522
+ "step 4140 | loss: 0.1253 | dt: 2867.20ms | tok/sec: 714.29\n",
523
+ "step 4150 | loss: 0.1016 | dt: 2856.48ms | tok/sec: 716.97\n",
524
+ "step 4160 | loss: 0.1162 | dt: 2750.59ms | tok/sec: 744.57\n",
525
+ "step 4170 | loss: 0.1599 | dt: 2886.92ms | tok/sec: 709.41\n",
526
+ "step 4180 | loss: 0.1116 | dt: 2701.22ms | tok/sec: 758.18\n",
527
+ "step 4190 | loss: 0.1093 | dt: 2674.94ms | tok/sec: 765.63\n",
528
+ "step 4200 | loss: 0.1099 | dt: 2875.45ms | tok/sec: 712.24\n",
529
+ "step 4210 | loss: 0.1319 | dt: 2775.36ms | tok/sec: 737.92\n",
530
+ "step 4220 | loss: 0.0932 | dt: 2740.37ms | tok/sec: 747.34\n",
531
+ "step 4230 | loss: 0.1047 | dt: 2697.50ms | tok/sec: 759.22\n",
532
+ "step 4240 | loss: 0.1159 | dt: 2656.54ms | tok/sec: 770.93\n",
533
+ "step 4250 | loss: 0.0943 | dt: 2632.27ms | tok/sec: 778.04\n",
534
+ "step 4260 | loss: 0.0935 | dt: 2733.61ms | tok/sec: 749.19\n",
535
+ "step 4270 | loss: 0.1044 | dt: 2700.76ms | tok/sec: 758.31\n",
536
+ "step 4280 | loss: 0.0909 | dt: 2690.66ms | tok/sec: 761.15\n",
537
+ "step 4290 | loss: 0.0993 | dt: 2689.09ms | tok/sec: 761.60\n",
538
+ "step 4300 | loss: 0.0931 | dt: 2710.26ms | tok/sec: 755.65\n",
539
+ "step 4310 | loss: 0.0841 | dt: 2795.26ms | tok/sec: 732.67\n",
540
+ "step 4320 | loss: 0.0729 | dt: 2704.03ms | tok/sec: 757.39\n",
541
+ "step 4330 | loss: 0.0633 | dt: 2676.90ms | tok/sec: 765.07\n",
542
+ "step 4340 | loss: 0.0859 | dt: 2761.68ms | tok/sec: 741.58\n",
543
+ "step 4350 | loss: 0.0625 | dt: 2778.00ms | tok/sec: 737.22\n",
544
+ "step 4360 | loss: 0.0638 | dt: 3028.06ms | tok/sec: 676.34\n",
545
+ "step 4370 | loss: 0.0841 | dt: 2895.60ms | tok/sec: 707.28\n",
546
+ "step 4380 | loss: 0.0731 | dt: 2678.41ms | tok/sec: 764.63\n",
547
+ "step 4390 | loss: 0.0676 | dt: 2632.73ms | tok/sec: 777.90\n",
548
+ "step 4400 | loss: 0.0789 | dt: 2633.57ms | tok/sec: 777.65\n",
549
+ "step 4410 | loss: 0.0870 | dt: 2633.24ms | tok/sec: 777.75\n",
550
+ "step 4420 | loss: 0.0564 | dt: 2666.58ms | tok/sec: 768.03\n",
551
+ "step 4430 | loss: 0.0510 | dt: 2663.78ms | tok/sec: 768.83\n",
552
+ "step 4440 | loss: 0.0844 | dt: 2637.76ms | tok/sec: 776.42\n",
553
+ "step 4450 | loss: 0.0574 | dt: 2649.37ms | tok/sec: 773.02\n",
554
+ "step 4460 | loss: 0.0670 | dt: 2642.94ms | tok/sec: 774.89\n",
555
+ "step 4470 | loss: 0.0818 | dt: 2635.25ms | tok/sec: 777.15\n",
556
+ "step 4480 | loss: 0.0611 | dt: 2641.49ms | tok/sec: 775.32\n",
557
+ "step 4490 | loss: 0.0571 | dt: 2718.30ms | tok/sec: 753.41\n",
558
+ "step 4500 | loss: 0.0524 | dt: 2714.53ms | tok/sec: 754.46\n",
559
+ "\n",
560
+ "--- Generating text at step 4500 ---\n",
561
+ "! to Tyenting in their hate.\n",
562
+ "\n",
563
+ "BUCKINGHAM:\n",
564
+ "My lord, I'll bear unto your head\n",
565
+ "Which best lives hath had but your breath or\n",
566
+ "Plured with the hawnes with yields that\n",
567
+ "There did touch his\n",
568
+ "-----------------------------------\n",
569
+ "\n",
570
+ "step 4510 | loss: 0.0513 | dt: 2759.85ms | tok/sec: 742.07\n",
571
+ "step 4520 | loss: 0.0582 | dt: 2780.63ms | tok/sec: 736.52\n",
572
+ "step 4530 | loss: 0.0607 | dt: 2788.91ms | tok/sec: 734.34\n",
573
+ "step 4540 | loss: 0.0537 | dt: 2705.52ms | tok/sec: 756.97\n",
574
+ "step 4550 | loss: 0.0469 | dt: 2797.45ms | tok/sec: 732.10\n",
575
+ "step 4560 | loss: 0.0513 | dt: 2715.12ms | tok/sec: 754.29\n",
576
+ "step 4570 | loss: 0.0537 | dt: 2701.81ms | tok/sec: 758.01\n",
577
+ "step 4580 | loss: 0.0555 | dt: 2703.04ms | tok/sec: 757.66\n",
578
+ "step 4590 | loss: 0.0617 | dt: 2702.77ms | tok/sec: 757.74\n",
579
+ "step 4600 | loss: 0.0601 | dt: 2703.01ms | tok/sec: 757.67\n",
580
+ "step 4610 | loss: 0.0386 | dt: 2713.98ms | tok/sec: 754.61\n",
581
+ "step 4620 | loss: 0.0469 | dt: 2702.42ms | tok/sec: 757.84\n",
582
+ "step 4630 | loss: 0.0429 | dt: 2701.25ms | tok/sec: 758.17\n",
583
+ "step 4640 | loss: 0.0436 | dt: 2700.08ms | tok/sec: 758.50\n",
584
+ "step 4650 | loss: 0.0552 | dt: 2707.12ms | tok/sec: 756.52\n",
585
+ "step 4660 | loss: 0.0478 | dt: 2704.64ms | tok/sec: 757.22\n",
586
+ "step 4670 | loss: 0.0503 | dt: 2700.75ms | tok/sec: 758.31\n",
587
+ "step 4680 | loss: 0.0370 | dt: 2703.38ms | tok/sec: 757.57\n",
588
+ "step 4690 | loss: 0.0488 | dt: 2690.58ms | tok/sec: 761.18\n",
589
+ "step 4700 | loss: 0.0395 | dt: 2711.10ms | tok/sec: 755.41\n",
590
+ "step 4710 | loss: 0.0384 | dt: 2695.39ms | tok/sec: 759.82\n",
591
+ "step 4720 | loss: 0.0320 | dt: 2698.22ms | tok/sec: 759.02\n",
592
+ "step 4730 | loss: 0.0389 | dt: 2693.92ms | tok/sec: 760.23\n",
593
+ "step 4740 | loss: 0.0417 | dt: 2715.99ms | tok/sec: 754.05\n",
594
+ "step 4750 | loss: 0.0327 | dt: 2707.00ms | tok/sec: 756.56\n",
595
+ "step 4760 | loss: 0.0383 | dt: 2702.06ms | tok/sec: 757.94\n",
596
+ "step 4770 | loss: 0.0482 | dt: 2720.02ms | tok/sec: 752.93\n",
597
+ "step 4780 | loss: 0.0285 | dt: 2695.32ms | tok/sec: 759.84\n",
598
+ "step 4790 | loss: 0.0466 | dt: 2699.53ms | tok/sec: 758.65\n",
599
+ "step 4800 | loss: 0.0293 | dt: 2747.82ms | tok/sec: 745.32\n",
600
+ "step 4810 | loss: 0.0475 | dt: 2694.82ms | tok/sec: 759.98\n",
601
+ "step 4820 | loss: 0.0331 | dt: 2697.06ms | tok/sec: 759.34\n",
602
+ "step 4830 | loss: 0.0424 | dt: 2695.18ms | tok/sec: 759.88\n",
603
+ "step 4840 | loss: 0.0388 | dt: 2699.61ms | tok/sec: 758.63\n",
604
+ "step 4850 | loss: 0.0278 | dt: 2697.17ms | tok/sec: 759.31\n",
605
+ "step 4860 | loss: 0.0352 | dt: 2712.98ms | tok/sec: 754.89\n",
606
+ "step 4870 | loss: 0.0231 | dt: 2708.67ms | tok/sec: 756.09\n",
607
+ "step 4880 | loss: 0.0355 | dt: 2697.11ms | tok/sec: 759.33\n",
608
+ "step 4890 | loss: 0.0379 | dt: 2696.41ms | tok/sec: 759.53\n",
609
+ "step 4900 | loss: 0.0251 | dt: 2692.32ms | tok/sec: 760.68\n",
610
+ "step 4910 | loss: 0.0263 | dt: 2695.45ms | tok/sec: 759.80\n",
611
+ "step 4920 | loss: 0.0279 | dt: 2703.49ms | tok/sec: 757.54\n",
612
+ "step 4930 | loss: 0.0253 | dt: 2696.46ms | tok/sec: 759.51\n",
613
+ "step 4940 | loss: 0.0279 | dt: 2719.05ms | tok/sec: 753.21\n",
614
+ "step 4950 | loss: 0.0303 | dt: 2698.82ms | tok/sec: 758.85\n",
615
+ "step 4960 | loss: 0.0277 | dt: 2702.86ms | tok/sec: 757.72\n",
616
+ "step 4970 | loss: 0.0230 | dt: 2708.92ms | tok/sec: 756.02\n",
617
+ "step 4980 | loss: 0.0397 | dt: 2693.32ms | tok/sec: 760.40\n",
618
+ "step 4990 | loss: 0.0244 | dt: 2699.65ms | tok/sec: 758.62\n",
619
+ "Saving model to smollm_135_checkpoint.pth\n",
620
+ "\n",
621
+ "--- Resuming training from checkpoint ---\n",
622
+ "Checkpoint loaded successfully.\n",
623
+ "Resume step 0 | loss: 0.0218\n",
624
+ "Resume step 10 | loss: 0.0335\n",
625
+ "Resume step 20 | loss: 0.0360\n",
626
+ "Resume step 30 | loss: 0.0450\n",
627
+ "Resume step 40 | loss: 0.0396\n",
628
+ "Resumed training completed.\n"
629
+ ]
630
+ }
631
+ ],
632
+ "source": [
633
+ "# SmolLM-135M Implementation (Llama Architecture)\n",
634
+ "# Based on: https://huggingface.co/HuggingFaceTB/SmolLM-135M\n",
635
+ "\n",
636
+ "import math\n",
637
+ "import inspect\n",
638
+ "from dataclasses import dataclass\n",
639
+ "from typing import Optional, Tuple\n",
640
+ "\n",
641
+ "import torch\n",
642
+ "import torch.nn as nn\n",
643
+ "from torch.nn import functional as F\n",
644
+ "\n",
645
+ "# Configuration for SmolLM-135M\n",
646
+ "@dataclass\n",
647
+ "class SmolLMConfig:\n",
648
+ " block_size: int = 512 # Reduced to 512 for 4GB GPU training\n",
649
+ " vocab_size: int = 50304 # Aligned to 50304 for tiktoken compatibility (SmolLM native is 49152)\n",
650
+ " n_layer: int = 30\n",
651
+ " n_head: int = 9\n",
652
+ " n_kv_head: int = 3 # Grouped Query Attention (GQA)\n",
653
+ " n_embd: int = 576\n",
654
+ " intermediate_size: int = 1536 # SwiGLU intermediate size\n",
655
+ " rms_norm_eps: float = 1e-5\n",
656
+ " rope_theta: float = 10000.0\n",
657
+ " dropout: float = 0.0\n",
658
+ " bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster\n",
659
+ "\n",
660
+ "class RMSNorm(nn.Module):\n",
661
+ " def __init__(self, dim: int, eps: float = 1e-6):\n",
662
+ " super().__init__()\n",
663
+ " self.eps = eps\n",
664
+ " self.weight = nn.Parameter(torch.ones(dim))\n",
665
+ "\n",
666
+ " def _norm(self, x):\n",
667
+ " return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n",
668
+ "\n",
669
+ " def forward(self, x):\n",
670
+ " output = self._norm(x.float()).type_as(x)\n",
671
+ " return output * self.weight\n",
672
+ "\n",
673
+ "def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):\n",
674
+ " freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))\n",
675
+ " t = torch.arange(end, device=freqs.device, dtype=torch.float32)\n",
676
+ " freqs = torch.outer(t, freqs)\n",
677
+ " freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64\n",
678
+ " return freqs_cis\n",
679
+ "\n",
680
+ "def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):\n",
681
+ " ndim = x.ndim\n",
682
+ " assert 0 <= 1 < ndim\n",
683
+ " assert freqs_cis.shape == (x.shape[1], x.shape[-1])\n",
684
+ " shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]\n",
685
+ " return freqs_cis.view(*shape)\n",
686
+ "\n",
687
+ "def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):\n",
688
+ " xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))\n",
689
+ " xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))\n",
690
+ " freqs_cis = reshape_for_broadcast(freqs_cis, xq_)\n",
691
+ " xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)\n",
692
+ " xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)\n",
693
+ " return xq_out.type_as(xq), xk_out.type_as(xk)\n",
694
+ "\n",
695
+ "class CausalSelfAttention(nn.Module):\n",
696
+ " def __init__(self, config: SmolLMConfig):\n",
697
+ " super().__init__()\n",
698
+ " self.n_head = config.n_head\n",
699
+ " self.n_kv_head = config.n_kv_head\n",
700
+ " self.n_embd = config.n_embd\n",
701
+ " self.head_dim = config.n_embd // config.n_head\n",
702
+ " self.n_rep = self.n_head // self.n_kv_head\n",
703
+ "\n",
704
+ " self.wq = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=config.bias)\n",
705
+ " self.wk = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=config.bias)\n",
706
+ " self.wv = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=config.bias)\n",
707
+ " self.wo = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=config.bias)\n",
708
+ "\n",
709
+ " self.dropout = config.dropout\n",
710
+ " self.resid_dropout = nn.Dropout(config.dropout)\n",
711
+ "\n",
712
+ " def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):\n",
713
+ " B, T, C = x.shape\n",
714
+ "\n",
715
+ " xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)\n",
716
+ " xq = xq.view(B, T, self.n_head, self.head_dim)\n",
717
+ " xk = xk.view(B, T, self.n_kv_head, self.head_dim)\n",
718
+ " xv = xv.view(B, T, self.n_kv_head, self.head_dim)\n",
719
+ "\n",
720
+ " xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)\n",
721
+ "\n",
722
+ " # Grouped Query Attention: repeat k/v heads to match q heads\n",
723
+ " xk = torch.repeat_interleave(xk, dim=2, repeats=self.n_rep)\n",
724
+ " xv = torch.repeat_interleave(xv, dim=2, repeats=self.n_rep)\n",
725
+ "\n",
726
+ " # Make heads batch dimension\n",
727
+ " xq = xq.transpose(1, 2) # (B, n_head, T, head_dim)\n",
728
+ " xk = xk.transpose(1, 2) # (B, n_head, T, head_dim)\n",
729
+ " xv = xv.transpose(1, 2) # (B, n_head, T, head_dim)\n",
730
+ "\n",
731
+ " # Flash Attention\n",
732
+ " output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)\n",
733
+ "\n",
734
+ " output = output.transpose(1, 2).contiguous().view(B, T, C)\n",
735
+ " return self.resid_dropout(self.wo(output))\n",
736
+ "\n",
737
+ "class SwiGLU(nn.Module):\n",
738
+ " def __init__(self, config: SmolLMConfig):\n",
739
+ " super().__init__()\n",
740
+ " self.w1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) # Gate\n",
741
+ " self.w3 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) # Value\n",
742
+ " self.w2 = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) # Output\n",
743
+ " self.dropout = nn.Dropout(config.dropout)\n",
744
+ "\n",
745
+ " def forward(self, x):\n",
746
+ " return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))\n",
747
+ "\n",
748
+ "class Block(nn.Module):\n",
749
+ " def __init__(self, config: SmolLMConfig):\n",
750
+ " super().__init__()\n",
751
+ " self.attention_norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)\n",
752
+ " self.attention = CausalSelfAttention(config)\n",
753
+ " self.ffn_norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)\n",
754
+ " self.feed_forward = SwiGLU(config)\n",
755
+ "\n",
756
+ " def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor):\n",
757
+ " h = x + self.attention(self.attention_norm(x), freqs_cis)\n",
758
+ " out = h + self.feed_forward(self.ffn_norm(h))\n",
759
+ " return out\n",
760
+ "\n",
761
+ "class SmolLM(nn.Module):\n",
762
+ " def __init__(self, config: SmolLMConfig):\n",
763
+ " super().__init__()\n",
764
+ " self.config = config\n",
765
+ " self.tok_embeddings = nn.Embedding(config.vocab_size, config.n_embd)\n",
766
+ " self.layers = nn.ModuleList([Block(config) for _ in range(config.n_layer)])\n",
767
+ " self.norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)\n",
768
+ " self.output = nn.Linear(config.n_embd, config.vocab_size, bias=False)\n",
769
+ "\n",
770
+ " # Weight sharing\n",
771
+ " self.tok_embeddings.weight = self.output.weight\n",
772
+ "\n",
773
+ " # Precompute RoPE frequencies\n",
774
+ " self.freqs_cis = precompute_freqs_cis(config.n_embd // config.n_head, config.block_size * 2, config.rope_theta)\n",
775
+ "\n",
776
+ " self.apply(self._init_weights)\n",
777
+ "\n",
778
+ " def _init_weights(self, module):\n",
779
+ " if isinstance(module, nn.Linear):\n",
780
+ " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
781
+ " if module.bias is not None:\n",
782
+ " torch.nn.init.zeros_(module.bias)\n",
783
+ " elif isinstance(module, nn.Embedding):\n",
784
+ " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
785
+ "\n",
786
+ " def forward(self, idx, targets=None):\n",
787
+ " B, T = idx.shape\n",
788
+ " x = self.tok_embeddings(idx)\n",
789
+ " \n",
790
+ " # Ensure freqs_cis is on the correct device\n",
791
+ " if self.freqs_cis.device != x.device:\n",
792
+ " self.freqs_cis = self.freqs_cis.to(x.device)\n",
793
+ " freqs_cis = self.freqs_cis[:T]\n",
794
+ "\n",
795
+ " for layer in self.layers:\n",
796
+ " x = layer(x, freqs_cis)\n",
797
+ " \n",
798
+ " x = self.norm(x)\n",
799
+ " logits = self.output(x)\n",
800
+ "\n",
801
+ " loss = None\n",
802
+ " if targets is not None:\n",
803
+ " loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))\n",
804
+ " \n",
805
+ " return logits, loss\n",
806
+ "\n",
807
+ "# Device selection\n",
808
+ "device = 'cpu'\n",
809
+ "if torch.cuda.is_available():\n",
810
+ " device = 'cuda'\n",
811
+ "elif hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n",
812
+ " device = \"mps\"\n",
813
+ "print(f\"using device: {device}\")\n",
814
+ "\n",
815
+ "# Data Loading\n",
816
+ "import tiktoken\n",
817
+ "\n",
818
+ "class DataLoaderLite:\n",
819
+ " def __init__(self, B, T):\n",
820
+ " self.B = B\n",
821
+ " self.T = T\n",
822
+ "\n",
823
+ " # Load tokens from disk\n",
824
+ " try:\n",
825
+ " with open('input.txt', 'r', encoding='utf-8') as f:\n",
826
+ " text = f.read()\n",
827
+ " except FileNotFoundError:\n",
828
+ " print(\"Error: input.txt not found. Please ensure the file exists.\")\n",
829
+ " text = \"Hello world \" * 1000 # Fallback for testing if file missing\n",
830
+ " \n",
831
+ " enc = tiktoken.get_encoding('gpt2') \n",
832
+ " tokens = enc.encode(text)\n",
833
+ " self.tokens = torch.tensor(tokens)\n",
834
+ " print(f'loaded {len(self.tokens)} tokens')\n",
835
+ " print(f'1 epoch = {len(self.tokens) // (B * T)} batches')\n",
836
+ "\n",
837
+ " self.current_position = 0\n",
838
+ " \n",
839
+ " def next_batch(self):\n",
840
+ " B, T = self.B, self.T\n",
841
+ " buf = self.tokens[self.current_position: self.current_position + B * T + 1]\n",
842
+ " x = (buf[:-1]).view(B, T) # inputs\n",
843
+ " y = (buf[1:]).view(B, T) # targets\n",
844
+ " self.current_position += B*T\n",
845
+ " if self.current_position + (B * T + 1) > len(self.tokens):\n",
846
+ " self.current_position = 0\n",
847
+ " return x, y\n",
848
+ "\n",
849
+ "# Training Setup\n",
850
+ "torch.manual_seed(1337)\n",
851
+ "if torch.cuda.is_available():\n",
852
+ " torch.cuda.manual_seed(1337)\n",
853
+ "\n",
854
+ "torch.set_float32_matmul_precision('high')\n",
855
+ "\n",
856
+ "config = SmolLMConfig()\n",
857
+ "model = SmolLM(config)\n",
858
+ "model.to(device)\n",
859
+ "\n",
860
+ "print(f\"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M\")\n",
861
+ "\n",
862
+ "# Generation Function\n",
863
+ "@torch.no_grad()\n",
864
+ "def generate(model, idx, max_new_tokens, temperature=1.0, top_k=None):\n",
865
+ " \"\"\"\n",
866
+ " Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete\n",
867
+ " the sequence max_new_tokens times, feeding the predictions back into the model each time.\n",
868
+ " \"\"\"\n",
869
+ " for _ in range(max_new_tokens):\n",
870
+ " # if the sequence context is growing too long we must crop it at block_size\n",
871
+ " idx_cond = idx if idx.size(1) <= model.config.block_size else idx[:, -model.config.block_size:]\n",
872
+ " # forward the model to get the logits for the index in the sequence\n",
873
+ " logits, _ = model(idx_cond)\n",
874
+ " # pluck the logits at the final step and scale by desired temperature\n",
875
+ " logits = logits[:, -1, :] / temperature\n",
876
+ " # optionally crop the logits to only the top k options\n",
877
+ " if top_k is not None:\n",
878
+ " v, _ = torch.topk(logits, min(top_k, logits.size(-1)))\n",
879
+ " logits[logits < v[:, [-1]]] = -float('Inf')\n",
880
+ " # apply softmax to convert logits to (normalized) probabilities\n",
881
+ " probs = F.softmax(logits, dim=-1)\n",
882
+ " # sample from the distribution\n",
883
+ " idx_next = torch.multinomial(probs, num_samples=1)\n",
884
+ " # append sampled index to the running sequence and continue\n",
885
+ " idx = torch.cat((idx, idx_next), dim=1)\n",
886
+ "\n",
887
+ " return idx\n",
888
+ "\n",
889
+ "# Training Loop\n",
890
+ "train_loader = DataLoaderLite(B = 4, T = 512) # Reduced batch size and context for 4GB GPU\n",
891
+ "optimizer = torch.optim.AdamW(model.parameters(), lr = 3e-4)\n",
892
+ "\n",
893
+ "import time\n",
894
+ "import os\n",
895
+ "\n",
896
+ "max_steps = 5000\n",
897
+ "eval_interval = 500\n",
898
+ "save_path = \"smollm_135_checkpoint.pth\"\n",
899
+ "\n",
900
+ "print(\"Starting training...\")\n",
901
+ "for i in range(max_steps):\n",
902
+ " t0 = time.time()\n",
903
+ " x, y = train_loader.next_batch()\n",
904
+ " x, y = x.to(device), y.to(device)\n",
905
+ " optimizer.zero_grad()\n",
906
+ " \n",
907
+ " # Mixed precision training\n",
908
+ " with torch.autocast(device_type=device, dtype=torch.bfloat16 if device=='cuda' else torch.float32):\n",
909
+ " logits, loss = model(x, y) \n",
910
+ " \n",
911
+ " loss.backward()\n",
912
+ " optimizer.step()\n",
913
+ " \n",
914
+ " if device == 'cuda':\n",
915
+ " torch.cuda.synchronize() \n",
916
+ " \n",
917
+ " t1 = time.time()\n",
918
+ " dt = (t1 - t0) * 1000\n",
919
+ " tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)\n",
920
+ " \n",
921
+ " if i % 10 == 0:\n",
922
+ " print(f'step {i} | loss: {loss.item():.4f} | dt: {dt:.2f}ms | tok/sec: {tokens_per_sec:.2f}')\n",
923
+ " \n",
924
+ " # Generate output every 500 steps\n",
925
+ " if i > 0 and i % eval_interval == 0:\n",
926
+ " print(f\"\\n--- Generating text at step {i} ---\")\n",
927
+ " context = torch.zeros((1, 1), dtype=torch.long, device=device) # Start with token 0 (usually valid)\n",
928
+ " generated = generate(model, context, max_new_tokens=50)\n",
929
+ " # Decode using tiktoken (gpt2 encoding as used in DataLoader)\n",
930
+ " enc = tiktoken.get_encoding('gpt2')\n",
931
+ " decoded = enc.decode(generated[0].tolist())\n",
932
+ " # Force ASCII for Windows console compatibility\n",
933
+ " print(decoded.encode('ascii', errors='ignore').decode('ascii'))\n",
934
+ " print(\"-----------------------------------\\n\")\n",
935
+ "\n",
936
+ "# Save checkpoint\n",
937
+ "print(f\"Saving model to {save_path}\")\n",
938
+ "torch.save(model.state_dict(), save_path)\n",
939
+ "\n",
940
+ "# Resume training demonstration\n",
941
+ "print(\"\\n--- Resuming training from checkpoint ---\")\n",
942
+ "# Re-initialize model to prove loading works\n",
943
+ "model_new = SmolLM(config)\n",
944
+ "model_new.to(device)\n",
945
+ "model_new.load_state_dict(torch.load(save_path))\n",
946
+ "print(\"Checkpoint loaded successfully.\")\n",
947
+ "\n",
948
+ "optimizer_new = torch.optim.AdamW(model_new.parameters(), lr = 3e-4)\n",
949
+ "\n",
950
+ "# Train for another 50 steps\n",
951
+ "for i in range(50):\n",
952
+ " t0 = time.time()\n",
953
+ " x, y = train_loader.next_batch()\n",
954
+ " x, y = x.to(device), y.to(device)\n",
955
+ " optimizer_new.zero_grad()\n",
956
+ " \n",
957
+ " with torch.autocast(device_type=device, dtype=torch.bfloat16 if device=='cuda' else torch.float32):\n",
958
+ " logits, loss = model_new(x, y) \n",
959
+ " \n",
960
+ " loss.backward()\n",
961
+ " optimizer_new.step()\n",
962
+ " \n",
963
+ " if device == 'cuda':\n",
964
+ " torch.cuda.synchronize()\n",
965
+ " \n",
966
+ " if i % 10 == 0:\n",
967
+ " print(f'Resume step {i} | loss: {loss.item():.4f}')\n",
968
+ "\n",
969
+ "print(\"Resumed training completed.\")\n"
970
+ ]
971
+ },
972
+ {
973
+ "cell_type": "code",
974
+ "execution_count": null,
975
+ "id": "a720457f",
976
+ "metadata": {},
977
+ "outputs": [],
978
+ "source": []
979
+ }
980
+ ],
981
+ "metadata": {
982
+ "kernelspec": {
983
+ "display_name": "Python 3",
984
+ "language": "python",
985
+ "name": "python3"
986
+ },
987
+ "language_info": {
988
+ "codemirror_mode": {
989
+ "name": "ipython",
990
+ "version": 3
991
+ },
992
+ "file_extension": ".py",
993
+ "mimetype": "text/x-python",
994
+ "name": "python",
995
+ "nbconvert_exporter": "python",
996
+ "pygments_lexer": "ipython3",
997
+ "version": "3.13.7"
998
+ }
999
+ },
1000
+ "nbformat": 4,
1001
+ "nbformat_minor": 5
1002
+ }
app.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio App for Sentence Completion
3
+ Main entry point for Hugging Face Spaces
4
+ """
5
+
6
+ import gradio as gr
7
+ import torch
8
+ from inference import load_model, generate_text, get_device
9
+
10
+
11
+ # Global model variable
12
+ model = None
13
+ device = None
14
+
15
+
16
+ def initialize_model(model_path=None):
17
+ """Initialize the model on startup"""
18
+ global model, device
19
+ try:
20
+ model, device = load_model(model_path=model_path)
21
+ #model.eval() # Set to eval mode once
22
+ return f"Model loaded successfully on device: {device}"
23
+ except Exception as e:
24
+ return f"Error loading model: {str(e)}"
25
+
26
+
27
+ def complete_sentence(prompt, max_tokens, top_k, temperature):
28
+ """Generate sentence completion based on prompt"""
29
+ global model, device
30
+
31
+ if model is None:
32
+ return "Error: Model not loaded. Please restart the app."
33
+
34
+ if not prompt.strip():
35
+ return "Please enter a prompt to complete."
36
+
37
+ try:
38
+ # Generate completion
39
+ print(prompt)
40
+ generated_text = generate_text(
41
+ prompt=prompt,
42
+ model=model,
43
+ max_tokens=int(max_tokens),
44
+ device=device
45
+ )
46
+ print(device)
47
+
48
+ return generated_text
49
+ except Exception as e:
50
+ return f"Error generating text: {str(e)}"
51
+
52
+
53
+ def create_interface():
54
+ """Create and return the Gradio interface"""
55
+
56
+ # Initialize model on startup
57
+ # Try to load from common checkpoint paths
58
+ checkpoint_paths = [
59
+ './model/smollm_135_checkpoint.pth',
60
+ './model/model.pth',
61
+ 'model.pt',
62
+ 'checkpoint.pth',
63
+ ]
64
+
65
+ model_path = None
66
+ for path in checkpoint_paths:
67
+ import os
68
+ if os.path.exists(path):
69
+ model_path = path
70
+ print(f"Model found at {path}")
71
+ break
72
+ else:
73
+ print(f"Model not found at {path}")
74
+
75
+ status = initialize_model(model_path=model_path)
76
+ print(status)
77
+
78
+ # Create Gradio interface
79
+ with gr.Blocks(title="Sentence Completion with SmolLM-135M") as demo:
80
+ gr.Markdown(
81
+ """
82
+ # Sentence Completion with SmolLM-135M
83
+
84
+ Enter a prompt and the model will complete the sentence for you.
85
+ Adjust the parameters to control the generation behavior.
86
+ """
87
+ )
88
+
89
+ with gr.Row():
90
+ with gr.Column(scale=2):
91
+ prompt_input = gr.Textbox(
92
+ label="Prompt",
93
+ placeholder="Enter your prompt here...",
94
+ lines=3,
95
+ value="The future of artificial intelligence is"
96
+ )
97
+
98
+ with gr.Row():
99
+ max_tokens_slider = gr.Slider(
100
+ minimum=10,
101
+ maximum=200,
102
+ value=50,
103
+ step=10,
104
+ label="Max Tokens"
105
+ )
106
+
107
+ top_k_slider = gr.Slider(
108
+ minimum=1,
109
+ maximum=100,
110
+ value=50,
111
+ step=1,
112
+ label="Top-K"
113
+ )
114
+
115
+ temperature_slider = gr.Slider(
116
+ minimum=0.1,
117
+ maximum=2.0,
118
+ value=1.0,
119
+ step=0.1,
120
+ label="Temperature"
121
+ )
122
+
123
+ generate_btn = gr.Button("Generate", variant="primary")
124
+
125
+ with gr.Column(scale=2):
126
+ output_text = gr.Textbox(
127
+ label="Generated Text",
128
+ lines=10,
129
+ interactive=False
130
+ )
131
+
132
+ gr.Markdown(
133
+ """
134
+ ### Parameters:
135
+ - **Max Tokens**: Maximum number of tokens to generate
136
+ - **Top-K**: Sample from top K most likely tokens (lower = more focused)
137
+ - **Temperature**: Controls randomness (lower = more deterministic, higher = more creative)
138
+ """
139
+ )
140
+
141
+ # Set up the generate function
142
+ generate_btn.click(
143
+ fn=complete_sentence,
144
+ inputs=[prompt_input, max_tokens_slider, top_k_slider, temperature_slider],
145
+ outputs=output_text
146
+ )
147
+
148
+ # Also generate on Enter key press
149
+ prompt_input.submit(
150
+ fn=complete_sentence,
151
+ inputs=[prompt_input, max_tokens_slider, top_k_slider, temperature_slider],
152
+ outputs=output_text
153
+ )
154
+
155
+ return demo
156
+
157
+
158
+ if __name__ == "__main__":
159
+ demo = create_interface()
160
+ demo.launch(share=False)
check_cuda.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ print(f"CUDA available: {torch.cuda.is_available()}")
4
+ print(f"CUDA device count: {torch.cuda.device_count()}")
5
+ if torch.cuda.is_available():
6
+ print(f"Current device: {torch.cuda.current_device()}")
7
+ print(f"Device name: {torch.cuda.get_device_name(0)}")
8
+ else:
9
+ print("CUDA not available")
checkpoint_info.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Checkpoint type: <class 'collections.OrderedDict'>
2
+ Keys: ['tok_embeddings.weight', 'layers.0.attention_norm.weight', 'layers.0.attention.wq.weight', 'layers.0.attention.wk.weight', 'layers.0.attention.wv.weight', 'layers.0.attention.wo.weight', 'layers.0.ffn_norm.weight', 'layers.0.feed_forward.w1.weight', 'layers.0.feed_forward.w3.weight', 'layers.0.feed_forward.w2.weight', 'layers.1.attention_norm.weight', 'layers.1.attention.wq.weight', 'layers.1.attention.wk.weight', 'layers.1.attention.wv.weight', 'layers.1.attention.wo.weight', 'layers.1.ffn_norm.weight', 'layers.1.feed_forward.w1.weight', 'layers.1.feed_forward.w3.weight', 'layers.1.feed_forward.w2.weight', 'layers.2.attention_norm.weight', 'layers.2.attention.wq.weight', 'layers.2.attention.wk.weight', 'layers.2.attention.wv.weight', 'layers.2.attention.wo.weight', 'layers.2.ffn_norm.weight', 'layers.2.feed_forward.w1.weight', 'layers.2.feed_forward.w3.weight', 'layers.2.feed_forward.w2.weight', 'layers.3.attention_norm.weight', 'layers.3.attention.wq.weight', 'layers.3.attention.wk.weight', 'layers.3.attention.wv.weight', 'layers.3.attention.wo.weight', 'layers.3.ffn_norm.weight', 'layers.3.feed_forward.w1.weight', 'layers.3.feed_forward.w3.weight', 'layers.3.feed_forward.w2.weight', 'layers.4.attention_norm.weight', 'layers.4.attention.wq.weight', 'layers.4.attention.wk.weight', 'layers.4.attention.wv.weight', 'layers.4.attention.wo.weight', 'layers.4.ffn_norm.weight', 'layers.4.feed_forward.w1.weight', 'layers.4.feed_forward.w3.weight', 'layers.4.feed_forward.w2.weight', 'layers.5.attention_norm.weight', 'layers.5.attention.wq.weight', 'layers.5.attention.wk.weight', 'layers.5.attention.wv.weight', 'layers.5.attention.wo.weight', 'layers.5.ffn_norm.weight', 'layers.5.feed_forward.w1.weight', 'layers.5.feed_forward.w3.weight', 'layers.5.feed_forward.w2.weight', 'layers.6.attention_norm.weight', 'layers.6.attention.wq.weight', 'layers.6.attention.wk.weight', 'layers.6.attention.wv.weight', 'layers.6.attention.wo.weight', 'layers.6.ffn_norm.weight', 'layers.6.feed_forward.w1.weight', 'layers.6.feed_forward.w3.weight', 'layers.6.feed_forward.w2.weight', 'layers.7.attention_norm.weight', 'layers.7.attention.wq.weight', 'layers.7.attention.wk.weight', 'layers.7.attention.wv.weight', 'layers.7.attention.wo.weight', 'layers.7.ffn_norm.weight', 'layers.7.feed_forward.w1.weight', 'layers.7.feed_forward.w3.weight', 'layers.7.feed_forward.w2.weight', 'layers.8.attention_norm.weight', 'layers.8.attention.wq.weight', 'layers.8.attention.wk.weight', 'layers.8.attention.wv.weight', 'layers.8.attention.wo.weight', 'layers.8.ffn_norm.weight', 'layers.8.feed_forward.w1.weight', 'layers.8.feed_forward.w3.weight', 'layers.8.feed_forward.w2.weight', 'layers.9.attention_norm.weight', 'layers.9.attention.wq.weight', 'layers.9.attention.wk.weight', 'layers.9.attention.wv.weight', 'layers.9.attention.wo.weight', 'layers.9.ffn_norm.weight', 'layers.9.feed_forward.w1.weight', 'layers.9.feed_forward.w3.weight', 'layers.9.feed_forward.w2.weight', 'layers.10.attention_norm.weight', 'layers.10.attention.wq.weight', 'layers.10.attention.wk.weight', 'layers.10.attention.wv.weight', 'layers.10.attention.wo.weight', 'layers.10.ffn_norm.weight', 'layers.10.feed_forward.w1.weight', 'layers.10.feed_forward.w3.weight', 'layers.10.feed_forward.w2.weight', 'layers.11.attention_norm.weight', 'layers.11.attention.wq.weight', 'layers.11.attention.wk.weight', 'layers.11.attention.wv.weight', 'layers.11.attention.wo.weight', 'layers.11.ffn_norm.weight', 'layers.11.feed_forward.w1.weight', 'layers.11.feed_forward.w3.weight', 'layers.11.feed_forward.w2.weight', 'layers.12.attention_norm.weight', 'layers.12.attention.wq.weight', 'layers.12.attention.wk.weight', 'layers.12.attention.wv.weight', 'layers.12.attention.wo.weight', 'layers.12.ffn_norm.weight', 'layers.12.feed_forward.w1.weight', 'layers.12.feed_forward.w3.weight', 'layers.12.feed_forward.w2.weight', 'layers.13.attention_norm.weight', 'layers.13.attention.wq.weight', 'layers.13.attention.wk.weight', 'layers.13.attention.wv.weight', 'layers.13.attention.wo.weight', 'layers.13.ffn_norm.weight', 'layers.13.feed_forward.w1.weight', 'layers.13.feed_forward.w3.weight', 'layers.13.feed_forward.w2.weight', 'layers.14.attention_norm.weight', 'layers.14.attention.wq.weight', 'layers.14.attention.wk.weight', 'layers.14.attention.wv.weight', 'layers.14.attention.wo.weight', 'layers.14.ffn_norm.weight', 'layers.14.feed_forward.w1.weight', 'layers.14.feed_forward.w3.weight', 'layers.14.feed_forward.w2.weight', 'layers.15.attention_norm.weight', 'layers.15.attention.wq.weight', 'layers.15.attention.wk.weight', 'layers.15.attention.wv.weight', 'layers.15.attention.wo.weight', 'layers.15.ffn_norm.weight', 'layers.15.feed_forward.w1.weight', 'layers.15.feed_forward.w3.weight', 'layers.15.feed_forward.w2.weight', 'layers.16.attention_norm.weight', 'layers.16.attention.wq.weight', 'layers.16.attention.wk.weight', 'layers.16.attention.wv.weight', 'layers.16.attention.wo.weight', 'layers.16.ffn_norm.weight', 'layers.16.feed_forward.w1.weight', 'layers.16.feed_forward.w3.weight', 'layers.16.feed_forward.w2.weight', 'layers.17.attention_norm.weight', 'layers.17.attention.wq.weight', 'layers.17.attention.wk.weight', 'layers.17.attention.wv.weight', 'layers.17.attention.wo.weight', 'layers.17.ffn_norm.weight', 'layers.17.feed_forward.w1.weight', 'layers.17.feed_forward.w3.weight', 'layers.17.feed_forward.w2.weight', 'layers.18.attention_norm.weight', 'layers.18.attention.wq.weight', 'layers.18.attention.wk.weight', 'layers.18.attention.wv.weight', 'layers.18.attention.wo.weight', 'layers.18.ffn_norm.weight', 'layers.18.feed_forward.w1.weight', 'layers.18.feed_forward.w3.weight', 'layers.18.feed_forward.w2.weight', 'layers.19.attention_norm.weight', 'layers.19.attention.wq.weight', 'layers.19.attention.wk.weight', 'layers.19.attention.wv.weight', 'layers.19.attention.wo.weight', 'layers.19.ffn_norm.weight', 'layers.19.feed_forward.w1.weight', 'layers.19.feed_forward.w3.weight', 'layers.19.feed_forward.w2.weight', 'layers.20.attention_norm.weight', 'layers.20.attention.wq.weight', 'layers.20.attention.wk.weight', 'layers.20.attention.wv.weight', 'layers.20.attention.wo.weight', 'layers.20.ffn_norm.weight', 'layers.20.feed_forward.w1.weight', 'layers.20.feed_forward.w3.weight', 'layers.20.feed_forward.w2.weight', 'layers.21.attention_norm.weight', 'layers.21.attention.wq.weight', 'layers.21.attention.wk.weight', 'layers.21.attention.wv.weight', 'layers.21.attention.wo.weight', 'layers.21.ffn_norm.weight', 'layers.21.feed_forward.w1.weight', 'layers.21.feed_forward.w3.weight', 'layers.21.feed_forward.w2.weight', 'layers.22.attention_norm.weight', 'layers.22.attention.wq.weight', 'layers.22.attention.wk.weight', 'layers.22.attention.wv.weight', 'layers.22.attention.wo.weight', 'layers.22.ffn_norm.weight', 'layers.22.feed_forward.w1.weight', 'layers.22.feed_forward.w3.weight', 'layers.22.feed_forward.w2.weight', 'layers.23.attention_norm.weight', 'layers.23.attention.wq.weight', 'layers.23.attention.wk.weight', 'layers.23.attention.wv.weight', 'layers.23.attention.wo.weight', 'layers.23.ffn_norm.weight', 'layers.23.feed_forward.w1.weight', 'layers.23.feed_forward.w3.weight', 'layers.23.feed_forward.w2.weight', 'layers.24.attention_norm.weight', 'layers.24.attention.wq.weight', 'layers.24.attention.wk.weight', 'layers.24.attention.wv.weight', 'layers.24.attention.wo.weight', 'layers.24.ffn_norm.weight', 'layers.24.feed_forward.w1.weight', 'layers.24.feed_forward.w3.weight', 'layers.24.feed_forward.w2.weight', 'layers.25.attention_norm.weight', 'layers.25.attention.wq.weight', 'layers.25.attention.wk.weight', 'layers.25.attention.wv.weight', 'layers.25.attention.wo.weight', 'layers.25.ffn_norm.weight', 'layers.25.feed_forward.w1.weight', 'layers.25.feed_forward.w3.weight', 'layers.25.feed_forward.w2.weight', 'layers.26.attention_norm.weight', 'layers.26.attention.wq.weight', 'layers.26.attention.wk.weight', 'layers.26.attention.wv.weight', 'layers.26.attention.wo.weight', 'layers.26.ffn_norm.weight', 'layers.26.feed_forward.w1.weight', 'layers.26.feed_forward.w3.weight', 'layers.26.feed_forward.w2.weight', 'layers.27.attention_norm.weight', 'layers.27.attention.wq.weight', 'layers.27.attention.wk.weight', 'layers.27.attention.wv.weight', 'layers.27.attention.wo.weight', 'layers.27.ffn_norm.weight', 'layers.27.feed_forward.w1.weight', 'layers.27.feed_forward.w3.weight', 'layers.27.feed_forward.w2.weight', 'layers.28.attention_norm.weight', 'layers.28.attention.wq.weight', 'layers.28.attention.wk.weight', 'layers.28.attention.wv.weight', 'layers.28.attention.wo.weight', 'layers.28.ffn_norm.weight', 'layers.28.feed_forward.w1.weight', 'layers.28.feed_forward.w3.weight', 'layers.28.feed_forward.w2.weight', 'layers.29.attention_norm.weight', 'layers.29.attention.wq.weight', 'layers.29.attention.wk.weight', 'layers.29.attention.wv.weight', 'layers.29.attention.wo.weight', 'layers.29.ffn_norm.weight', 'layers.29.feed_forward.w1.weight', 'layers.29.feed_forward.w3.weight', 'layers.29.feed_forward.w2.weight', 'norm.weight', 'output.weight']
inference.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference and Model Loading Utilities
3
+ """
4
+
5
+ import os
6
+ import torch
7
+ from torch.nn import functional as F
8
+ import tiktoken
9
+ from model import SmolLM, SmolLMConfig
10
+
11
+
12
+ def get_device():
13
+ """Auto-detect and return the best available device"""
14
+ print(f"[DEBUG] Checking device availability...")
15
+ print(f"[DEBUG] torch.cuda.is_available(): {torch.cuda.is_available()}")
16
+ if torch.cuda.is_available():
17
+ print(f"[DEBUG] CUDA device: {torch.cuda.get_device_name(0)}")
18
+ return 'cuda'
19
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
20
+ print(f"[DEBUG] Using MPS device")
21
+ return "mps"
22
+ else:
23
+ print(f"[DEBUG] Falling back to CPU")
24
+ return 'cpu'
25
+
26
+
27
+ def load_model(model_path=None, device=None):
28
+ """
29
+ Load SmolLM model from checkpoint.
30
+
31
+ Args:
32
+ model_path: Path to saved model checkpoint (.pth or .pt file)
33
+ device: Device to load model on (auto-detected if None)
34
+
35
+ Returns:
36
+ Loaded model and device
37
+ """
38
+ if device is None:
39
+ device = get_device()
40
+
41
+ # Try to load saved checkpoint first
42
+ if model_path and os.path.exists(model_path):
43
+ try:
44
+ print(f"Loading saved model from {model_path}...")
45
+ model = SmolLM.load_checkpoint(model_path, device=device)
46
+ return model, device
47
+ except Exception as e:
48
+ print(f"Failed to load saved model: {e}")
49
+
50
+ # Fallback to untrained model
51
+ print("Creating model with default config (untrained)...")
52
+ config = SmolLMConfig()
53
+ model = SmolLM(config)
54
+ model.to(device)
55
+ return model, device
56
+
57
+
58
+ def generate_text(prompt, model, max_tokens=50, top_k=50, temperature=1.0, device="cpu"):
59
+ """
60
+ Generate text completion for a given prompt using the SmolLM model.
61
+
62
+ Args:
63
+ prompt: Input text prompt
64
+ model: SmolLM model instance
65
+ max_tokens: Maximum number of tokens to generate
66
+ top_k: Top-k sampling parameter (None for no top-k filtering)
67
+ temperature: Temperature for sampling (higher = more random)
68
+ device: Device to run inference on
69
+
70
+ Returns:
71
+ Generated text string (including original prompt)
72
+ """
73
+ # Global tokenizer cache
74
+ _TOKENIZER = None
75
+
76
+ def _get_tokenizer():
77
+ global _TOKENIZER
78
+ if _TOKENIZER is None:
79
+ _TOKENIZER = tiktoken.get_encoding("gpt2")
80
+ return _TOKENIZER
81
+
82
+ def generate_text(prompt, model, max_tokens=50, top_k=50, temperature=1.0, device="cpu"):
83
+ """
84
+ Generate text completion for a given prompt using the SmolLM model.
85
+
86
+ Args:
87
+ prompt: Input text prompt
88
+ model: SmolLM model instance
89
+ max_tokens: Maximum number of tokens to generate
90
+ top_k: Top-k sampling parameter (None for no top-k filtering)
91
+ temperature: Temperature for sampling (higher = more random)
92
+ device: Device to run inference on
93
+
94
+ Returns:
95
+ Generated text string (including original prompt)
96
+ """
97
+ enc = _get_tokenizer()
98
+ model.eval()
99
+
100
+ with torch.no_grad():
101
+ # tokenize prompt
102
+ input_ids = enc.encode(prompt)
103
+ x = torch.tensor(input_ids, dtype=torch.long, device=device).unsqueeze(0)
104
+ print(max_tokens)
105
+ past_kv = None
106
+ generated_ids = list(input_ids)
107
+
108
+ for _ in range(max_tokens):
109
+ print(x.shape)
110
+ logits, _ = model(x)
111
+ print(logits.shape)
112
+ logits = logits[:, -1, :] / temperature
113
+ print(logits.shape)
114
+ if top_k is not None:
115
+ topk = torch.topk(logits, top_k, dim=-1)
116
+ mask = logits < topk.values[:, -1].unsqueeze(-1)
117
+ logits = logits.masked_fill(mask, -float("inf"))
118
+
119
+ probs = F.softmax(logits, dim=-1)
120
+ next_token = torch.multinomial(probs, num_samples=1)
121
+ x = torch.cat((x, next_token), dim=1)
122
+
123
+ generated_ids = x[0].tolist()
124
+
125
+ # generated_ids already contains the prompt and generated tokens
126
+ return enc.decode(generated_ids)
inspect_checkpoint.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import sys
3
+
4
+ def inspect_checkpoint():
5
+ path = './model/smollm_135_checkpoint.pth'
6
+ output_file = 'checkpoint_info.txt'
7
+ with open(output_file, 'w') as f:
8
+ try:
9
+ checkpoint = torch.load(path, map_location='cpu')
10
+ f.write(f"Checkpoint type: {type(checkpoint)}\n")
11
+ if isinstance(checkpoint, dict):
12
+ f.write(f"Keys: {list(checkpoint.keys())}\n")
13
+ else:
14
+ f.write("Checkpoint is not a dictionary.\n")
15
+ except Exception as e:
16
+ f.write(f"Error loading checkpoint: {e}\n")
17
+
18
+ if __name__ == "__main__":
19
+ inspect_checkpoint()
main.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ def main():
2
+ print("Hello from smollm-135!")
3
+
4
+
5
+ if __name__ == "__main__":
6
+ main()
model.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ SmolLM-135M Implementation (Llama Architecture)
3
+ Based on: https://huggingface.co/HuggingFaceTB/SmolLM-135M
4
+ """
5
+
6
+ import math
7
+ import inspect
8
+ from dataclasses import dataclass
9
+ from typing import Optional, Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+
15
+ # Configuration for SmolLM-135M
16
+ @dataclass
17
+ class SmolLMConfig:
18
+ block_size: int = 512 # Reduced to 512 for 4GB GPU training
19
+ vocab_size: int = 50304 # Aligned to 50304 for tiktoken compatibility (SmolLM native is 49152)
20
+ n_layer: int = 30
21
+ n_head: int = 9
22
+ n_kv_head: int = 3 # Grouped Query Attention (GQA)
23
+ n_embd: int = 576
24
+ intermediate_size: int = 1536 # SwiGLU intermediate size
25
+ rms_norm_eps: float = 1e-5
26
+ rope_theta: float = 10000.0
27
+ dropout: float = 0.0
28
+ bias: bool = False # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
29
+
30
+ class RMSNorm(nn.Module):
31
+ def __init__(self, dim: int, eps: float = 1e-6):
32
+ super().__init__()
33
+ self.eps = eps
34
+ self.weight = nn.Parameter(torch.ones(dim))
35
+
36
+ def _norm(self, x):
37
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
38
+
39
+ def forward(self, x):
40
+ output = self._norm(x.float()).type_as(x)
41
+ return output * self.weight
42
+
43
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
44
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
45
+ t = torch.arange(end, device=freqs.device, dtype=torch.float32)
46
+ freqs = torch.outer(t, freqs)
47
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
48
+ return freqs_cis
49
+
50
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
51
+ ndim = x.ndim
52
+ assert 0 <= 1 < ndim
53
+ assert freqs_cis.shape == (x.shape[1], x.shape[-1])
54
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
55
+ return freqs_cis.view(*shape)
56
+
57
+ def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
58
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
59
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
60
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
61
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
62
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
63
+ return xq_out.type_as(xq), xk_out.type_as(xk)
64
+
65
+ class CausalSelfAttention(nn.Module):
66
+ def __init__(self, config: SmolLMConfig):
67
+ super().__init__()
68
+ self.n_head = config.n_head
69
+ self.n_kv_head = config.n_kv_head
70
+ self.n_embd = config.n_embd
71
+ self.head_dim = config.n_embd // config.n_head
72
+ self.n_rep = self.n_head // self.n_kv_head
73
+
74
+ self.wq = nn.Linear(config.n_embd, config.n_head * self.head_dim, bias=config.bias)
75
+ self.wk = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=config.bias)
76
+ self.wv = nn.Linear(config.n_embd, config.n_kv_head * self.head_dim, bias=config.bias)
77
+ self.wo = nn.Linear(config.n_head * self.head_dim, config.n_embd, bias=config.bias)
78
+
79
+ self.dropout = config.dropout
80
+ self.resid_dropout = nn.Dropout(config.dropout)
81
+
82
+ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
83
+ B, T, C = x.shape
84
+
85
+ xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
86
+ xq = xq.view(B, T, self.n_head, self.head_dim)
87
+ xk = xk.view(B, T, self.n_kv_head, self.head_dim)
88
+ xv = xv.view(B, T, self.n_kv_head, self.head_dim)
89
+
90
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
91
+
92
+ # Grouped Query Attention: repeat k/v heads to match q heads
93
+ xk = torch.repeat_interleave(xk, dim=2, repeats=self.n_rep)
94
+ xv = torch.repeat_interleave(xv, dim=2, repeats=self.n_rep)
95
+
96
+ # Make heads batch dimension
97
+ xq = xq.transpose(1, 2) # (B, n_head, T, head_dim)
98
+ xk = xk.transpose(1, 2) # (B, n_head, T, head_dim)
99
+ xv = xv.transpose(1, 2) # (B, n_head, T, head_dim)
100
+
101
+ if past_kv is not None:
102
+ k_cache, v_cache = past_kv
103
+ xk = torch.cat([k_cache, xk], dim=2)
104
+ xv = torch.cat([v_cache, xv], dim=2)
105
+
106
+ current_kv = (xk, xv)
107
+
108
+ # Flash Attention
109
+ if past_kv is not None and T == 1:
110
+ # Optimization: no causal mask needed for the last token attending to all previous
111
+ output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=False)
112
+ else:
113
+ output = F.scaled_dot_product_attention(xq, xk, xv, is_causal=True)
114
+
115
+ output = output.transpose(1, 2).contiguous().view(B, T, C)
116
+ return self.resid_dropout(self.wo(output)), current_kv
117
+
118
+ class SwiGLU(nn.Module):
119
+ def __init__(self, config: SmolLMConfig):
120
+ super().__init__()
121
+ self.w1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) # Gate
122
+ self.w3 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) # Value
123
+ self.w2 = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) # Output
124
+ self.dropout = nn.Dropout(config.dropout)
125
+
126
+ def forward(self, x):
127
+ return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
128
+
129
+ class Block(nn.Module):
130
+ def __init__(self, config: SmolLMConfig):
131
+ super().__init__()
132
+ self.attention_norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
133
+ self.attention = CausalSelfAttention(config)
134
+ self.ffn_norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
135
+ self.feed_forward = SwiGLU(config)
136
+
137
+ def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None):
138
+ attn_out, layer_kv = self.attention(self.attention_norm(x), freqs_cis, past_kv)
139
+ h = x + attn_out
140
+ out = h + self.feed_forward(self.ffn_norm(h))
141
+ return out, layer_kv
142
+
143
+ class SmolLM(nn.Module):
144
+ def __init__(self, config: SmolLMConfig):
145
+ super().__init__()
146
+ self.config = config
147
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.n_embd)
148
+ self.layers = nn.ModuleList([Block(config) for _ in range(config.n_layer)])
149
+ self.norm = RMSNorm(config.n_embd, eps=config.rms_norm_eps)
150
+ self.output = nn.Linear(config.n_embd, config.vocab_size, bias=False)
151
+
152
+ # Weight sharing
153
+ self.tok_embeddings.weight = self.output.weight
154
+
155
+ # Precompute RoPE frequencies
156
+ self.freqs_cis = precompute_freqs_cis(config.n_embd // config.n_head, config.block_size * 2, config.rope_theta)
157
+
158
+ self.apply(self._init_weights)
159
+
160
+ def _init_weights(self, module):
161
+ if isinstance(module, nn.Linear):
162
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
163
+ if module.bias is not None:
164
+ torch.nn.init.zeros_(module.bias)
165
+ elif isinstance(module, nn.Embedding):
166
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
167
+
168
+ def forward(self, idx, targets=None, past_kv=None):
169
+ B, T = idx.shape
170
+ x = self.tok_embeddings(idx)
171
+
172
+ # Determine starting position for RoPE
173
+ start_pos = 0
174
+
175
+
176
+ # Ensure freqs_cis is on the correct device
177
+ if self.freqs_cis.device != x.device:
178
+ self.freqs_cis = self.freqs_cis.to(x.device)
179
+
180
+ # Select freqs_cis for the current positions
181
+ freqs_cis = self.freqs_cis[start_pos : start_pos + T]
182
+
183
+ new_past_kv = []
184
+ for i, layer in enumerate(self.layers):
185
+ layer_past_kv = past_kv[i] if past_kv is not None else None
186
+ x, layer_kv = layer(x, freqs_cis, past_kv=layer_past_kv)
187
+ new_past_kv.append(layer_kv)
188
+
189
+ x = self.norm(x)
190
+ logits = self.output(x)
191
+
192
+ loss = None
193
+ if targets is not None:
194
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
195
+
196
+ return logits, loss
197
+
198
+ def save_checkpoint(self, filepath):
199
+ """Save model checkpoint with config"""
200
+ checkpoint = {
201
+ 'model_state_dict': self.state_dict(),
202
+ 'config': {
203
+ 'block_size': self.config.block_size,
204
+ 'vocab_size': self.config.vocab_size,
205
+ 'n_layer': self.config.n_layer,
206
+ 'n_head': self.config.n_head,
207
+ 'n_kv_head': self.config.n_kv_head,
208
+ 'n_embd': self.config.n_embd,
209
+ 'intermediate_size': self.config.intermediate_size,
210
+ 'rms_norm_eps': self.config.rms_norm_eps,
211
+ 'rope_theta': self.config.rope_theta,
212
+ 'dropout': self.config.dropout,
213
+ 'bias': self.config.bias
214
+ }
215
+ }
216
+ torch.save(checkpoint, filepath)
217
+ print(f"Model saved to {filepath}")
218
+
219
+ @classmethod
220
+ def load_checkpoint(cls, filepath, device='cpu'):
221
+ """Load model from checkpoint file"""
222
+ checkpoint = torch.load(filepath, map_location=device)
223
+
224
+ if isinstance(checkpoint, dict) and 'config' in checkpoint:
225
+ # Checkpoint contains config and state dict
226
+ config_dict = checkpoint['config']
227
+ config = SmolLMConfig(**config_dict)
228
+ state_dict = checkpoint['model_state_dict']
229
+ else:
230
+ # Checkpoint is likely just the state dict
231
+ print("Warning: Checkpoint does not contain config. Using default SmolLMConfig.")
232
+ config = SmolLMConfig()
233
+ state_dict = checkpoint
234
+
235
+ model = cls(config)
236
+ model.load_state_dict(state_dict)
237
+ model.to(device)
238
+ print(f"Model loaded from {filepath}")
239
+ return model
model/smollm_135_checkpoint.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:691cbc7dafd54bde47470447eb33f27ab2f6ac5e77cd45b71942aceeac678b85
3
+ size 540826259
profile_app.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import torch
3
+ from inference import load_model, generate_text, get_device
4
+
5
+ def profile_app():
6
+ print("Profiling app performance...")
7
+
8
+ # 1. Measure Model Loading
9
+ start_load = time.time()
10
+ checkpoint_path = './model/smollm_135_checkpoint.pth'
11
+ model, device = load_model(model_path=checkpoint_path)
12
+ print(device)
13
+ end_load = time.time()
14
+ print(f"Model loading time: {end_load - start_load:.4f}s")
15
+ print(f"Device: {device}")
16
+
17
+ # 2. Measure Generation
18
+ prompt = "The future of AI is"
19
+ max_tokens = 50
20
+
21
+ start_gen = time.time()
22
+ output = generate_text(prompt, model, max_tokens=max_tokens, device=device)
23
+ end_gen = time.time()
24
+
25
+ duration = end_gen - start_gen
26
+ tokens_per_sec = max_tokens / duration
27
+
28
+ print(f"Generation time: {duration:.4f}s")
29
+ print(f"Tokens per second: {tokens_per_sec:.2f}")
30
+ print(f"Output: {output}")
31
+
32
+ if __name__ == "__main__":
33
+ profile_app()
pyproject.toml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "smollm-135"
3
+ version = "0.1.0"
4
+ description = "Add your description here"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "gradio>=6.0.1",
9
+ "tiktoken>=0.12.0",
10
+ "torch>=2.9.1",
11
+ "transformers>=4.57.3",
12
+ ]
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch==2.0.0
3
+ transformers>=4.30.0
4
+ tiktoken>=0.5.0
test_inference.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from inference import load_model, generate_text
2
+ import os
3
+
4
+ def test_inference():
5
+ print("Testing inference...")
6
+
7
+ checkpoint_path = './model/smollm_135_checkpoint.pth'
8
+ if not os.path.exists(checkpoint_path):
9
+ print(f"Warning: Checkpoint not found at {checkpoint_path}")
10
+ checkpoint_path = None
11
+
12
+ # Load model
13
+ model, device = load_model(model_path=checkpoint_path)
14
+ print(f"Model loaded on {device}")
15
+
16
+ # Generate text
17
+ prompt = "Hello, world"
18
+ print(f"Generating text for prompt: '{prompt}'")
19
+ generated = generate_text(prompt, model, max_tokens=10, device=device)
20
+ print(f"Generated text: {generated}")
21
+
22
+ print("Inference test passed.")
23
+
24
+ if __name__ == "__main__":
25
+ test_inference()
test_kv_cache.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ from model import SmolLM, SmolLMConfig
4
+ from inference import load_model
5
+
6
+ def test_kv_performance():
7
+ print("Testing KV cache performance...")
8
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
+ print(f"Device: {device}")
10
+
11
+ # Load model (untrained is fine for performance test)
12
+ config = SmolLMConfig()
13
+ model = SmolLM(config).to(device)
14
+ model.eval()
15
+
16
+ prompt_len = 10
17
+ gen_len = 50
18
+ input_ids = torch.randint(0, config.vocab_size, (1, prompt_len)).to(device)
19
+
20
+ # Measure generation with KV cache
21
+ start_time = time.time()
22
+ past_kv = None
23
+ x = input_ids
24
+
25
+ with torch.no_grad():
26
+ # Prefill
27
+ _, _, past_kv = model(x)
28
+
29
+ # Generate
30
+ for _ in range(gen_len):
31
+ model_input = x[:, -1:]
32
+ _, _, past_kv = model(model_input, past_kv=past_kv)
33
+ # Dummy token selection
34
+ next_token = torch.tensor([[0]], device=device)
35
+ x = torch.cat([x, next_token], dim=1)
36
+
37
+ end_time = time.time()
38
+ duration = end_time - start_time
39
+ tokens_per_sec = gen_len / duration
40
+
41
+ print(f"Generated {gen_len} tokens in {duration:.4f}s")
42
+ print(f"Speed: {tokens_per_sec:.2f} tokens/sec")
43
+
44
+ if __name__ == "__main__":
45
+ test_kv_performance()
test_model.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model import SmolLM, SmolLMConfig
3
+
4
+ def test_model():
5
+ print("Testing SmolLM model...")
6
+ config = SmolLMConfig()
7
+ model = SmolLM(config)
8
+ print("Model instantiated successfully.")
9
+
10
+ # Create dummy input
11
+ idx = torch.randint(0, config.vocab_size, (1, 128))
12
+ print(f"Input shape: {idx.shape}")
13
+
14
+ # Forward pass
15
+ logits, loss, _ = model(idx)
16
+ print(f"Logits shape: {logits.shape}")
17
+
18
+ assert logits.shape == (1, 128, config.vocab_size)
19
+ print("Forward pass successful.")
20
+
21
+ if __name__ == "__main__":
22
+ test_model()
test_tiktoken.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import tiktoken
3
+
4
+ def test_tiktoken():
5
+ start = time.time()
6
+ enc = tiktoken.get_encoding("gpt2")
7
+ end = time.time()
8
+ print(f"tiktoken load time: {end - start:.4f}s")
9
+
10
+ start = time.time()
11
+ enc.encode("Hello world")
12
+ end = time.time()
13
+ print(f"encoding time: {end - start:.4f}s")
14
+
15
+ if __name__ == "__main__":
16
+ test_tiktoken()