pragadeeshv23 commited on
Commit
ffc0c0c
·
verified ·
1 Parent(s): 5ad5c9b

Upload folder using huggingface_hub

Browse files
Files changed (13) hide show
  1. .gitignore +6 -0
  2. LICENSE +661 -0
  3. README.md +279 -0
  4. data/meta.txt +6 -0
  5. deepspeed.md +199 -0
  6. ds_config.active.json +58 -0
  7. ds_config.json +65 -0
  8. main.py +677 -0
  9. main_deepspeed.py +696 -0
  10. prepare_data.py +153 -0
  11. push_to_hf.py +95 -0
  12. run.py +590 -0
  13. train_deepspeed.sh +160 -0
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.bin
2
+ *.pt
3
+ nvme_offload/
4
+ __pycache__/
5
+ .ruff_cache/
6
+ checkpoints/
LICENSE ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU AFFERO GENERAL PUBLIC LICENSE
2
+ Version 3, 19 November 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU Affero General Public License is a free, copyleft license for
11
+ software and other kinds of works, specifically designed to ensure
12
+ cooperation with the community in the case of network server software.
13
+
14
+ The licenses for most software and other practical works are designed
15
+ to take away your freedom to share and change the works. By contrast,
16
+ our General Public Licenses are intended to guarantee your freedom to
17
+ share and change all versions of a program--to make sure it remains free
18
+ software for all its users.
19
+
20
+ When we speak of free software, we are referring to freedom, not
21
+ price. Our General Public Licenses are designed to make sure that you
22
+ have the freedom to distribute copies of free software (and charge for
23
+ them if you wish), that you receive source code or can get it if you
24
+ want it, that you can change the software or use pieces of it in new
25
+ free programs, and that you know you can do these things.
26
+
27
+ Developers that use our General Public Licenses protect your rights
28
+ with two steps: (1) assert copyright on the software, and (2) offer
29
+ you this License which gives you legal permission to copy, distribute
30
+ and/or modify the software.
31
+
32
+ A secondary benefit of defending all users' freedom is that
33
+ improvements made in alternate versions of the program, if they
34
+ receive widespread use, become available for other developers to
35
+ incorporate. Many developers of free software are heartened and
36
+ encouraged by the resulting cooperation. However, in the case of
37
+ software used on network servers, this result may fail to come about.
38
+ The GNU General Public License permits making a modified version and
39
+ letting the public access it on a server without ever releasing its
40
+ source code to the public.
41
+
42
+ The GNU Affero General Public License is designed specifically to
43
+ ensure that, in such cases, the modified source code becomes available
44
+ to the community. It requires the operator of a network server to
45
+ provide the source code of the modified version running there to the
46
+ users of that server. Therefore, public use of a modified version, on
47
+ a publicly accessible server, gives the public access to the source
48
+ code of the modified version.
49
+
50
+ An older license, called the Affero General Public License and
51
+ published by Affero, was designed to accomplish similar goals. This is
52
+ a different license, not a version of the Affero GPL, but Affero has
53
+ released a new version of the Affero GPL which permits relicensing under
54
+ this license.
55
+
56
+ The precise terms and conditions for copying, distribution and
57
+ modification follow.
58
+
59
+ TERMS AND CONDITIONS
60
+
61
+ 0. Definitions.
62
+
63
+ "This License" refers to version 3 of the GNU Affero General Public License.
64
+
65
+ "Copyright" also means copyright-like laws that apply to other kinds of
66
+ works, such as semiconductor masks.
67
+
68
+ "The Program" refers to any copyrightable work licensed under this
69
+ License. Each licensee is addressed as "you". "Licensees" and
70
+ "recipients" may be individuals or organizations.
71
+
72
+ To "modify" a work means to copy from or adapt all or part of the work
73
+ in a fashion requiring copyright permission, other than the making of an
74
+ exact copy. The resulting work is called a "modified version" of the
75
+ earlier work or a work "based on" the earlier work.
76
+
77
+ A "covered work" means either the unmodified Program or a work based
78
+ on the Program.
79
+
80
+ To "propagate" a work means to do anything with it that, without
81
+ permission, would make you directly or secondarily liable for
82
+ infringement under applicable copyright law, except executing it on a
83
+ computer or modifying a private copy. Propagation includes copying,
84
+ distribution (with or without modification), making available to the
85
+ public, and in some countries other activities as well.
86
+
87
+ To "convey" a work means any kind of propagation that enables other
88
+ parties to make or receive copies. Mere interaction with a user through
89
+ a computer network, with no transfer of a copy, is not conveying.
90
+
91
+ An interactive user interface displays "Appropriate Legal Notices"
92
+ to the extent that it includes a convenient and prominently visible
93
+ feature that (1) displays an appropriate copyright notice, and (2)
94
+ tells the user that there is no warranty for the work (except to the
95
+ extent that warranties are provided), that licensees may convey the
96
+ work under this License, and how to view a copy of this License. If
97
+ the interface presents a list of user commands or options, such as a
98
+ menu, a prominent item in the list meets this criterion.
99
+
100
+ 1. Source Code.
101
+
102
+ The "source code" for a work means the preferred form of the work
103
+ for making modifications to it. "Object code" means any non-source
104
+ form of a work.
105
+
106
+ A "Standard Interface" means an interface that either is an official
107
+ standard defined by a recognized standards body, or, in the case of
108
+ interfaces specified for a particular programming language, one that
109
+ is widely used among developers working in that language.
110
+
111
+ The "System Libraries" of an executable work include anything, other
112
+ than the work as a whole, that (a) is included in the normal form of
113
+ packaging a Major Component, but which is not part of that Major
114
+ Component, and (b) serves only to enable use of the work with that
115
+ Major Component, or to implement a Standard Interface for which an
116
+ implementation is available to the public in source code form. A
117
+ "Major Component", in this context, means a major essential component
118
+ (kernel, window system, and so on) of the specific operating system
119
+ (if any) on which the executable work runs, or a compiler used to
120
+ produce the work, or an object code interpreter used to run it.
121
+
122
+ The "Corresponding Source" for a work in object code form means all
123
+ the source code needed to generate, install, and (for an executable
124
+ work) run the object code and to modify the work, including scripts to
125
+ control those activities. However, it does not include the work's
126
+ System Libraries, or general-purpose tools or generally available free
127
+ programs which are used unmodified in performing those activities but
128
+ which are not part of the work. For example, Corresponding Source
129
+ includes interface definition files associated with source files for
130
+ the work, and the source code for shared libraries and dynamically
131
+ linked subprograms that the work is specifically designed to require,
132
+ such as by intimate data communication or control flow between those
133
+ subprograms and other parts of the work.
134
+
135
+ The Corresponding Source need not include anything that users
136
+ can regenerate automatically from other parts of the Corresponding
137
+ Source.
138
+
139
+ The Corresponding Source for a work in source code form is that
140
+ same work.
141
+
142
+ 2. Basic Permissions.
143
+
144
+ All rights granted under this License are granted for the term of
145
+ copyright on the Program, and are irrevocable provided the stated
146
+ conditions are met. This License explicitly affirms your unlimited
147
+ permission to run the unmodified Program. The output from running a
148
+ covered work is covered by this License only if the output, given its
149
+ content, constitutes a covered work. This License acknowledges your
150
+ rights of fair use or other equivalent, as provided by copyright law.
151
+
152
+ You may make, run and propagate covered works that you do not
153
+ convey, without conditions so long as your license otherwise remains
154
+ in force. You may convey covered works to others for the sole purpose
155
+ of having them make modifications exclusively for you, or provide you
156
+ with facilities for running those works, provided that you comply with
157
+ the terms of this License in conveying all material for which you do
158
+ not control copyright. Those thus making or running the covered works
159
+ for you must do so exclusively on your behalf, under your direction
160
+ and control, on terms that prohibit them from making any copies of
161
+ your copyrighted material outside their relationship with you.
162
+
163
+ Conveying under any other circumstances is permitted solely under
164
+ the conditions stated below. Sublicensing is not allowed; section 10
165
+ makes it unnecessary.
166
+
167
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
168
+
169
+ No covered work shall be deemed part of an effective technological
170
+ measure under any applicable law fulfilling obligations under article
171
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
172
+ similar laws prohibiting or restricting circumvention of such
173
+ measures.
174
+
175
+ When you convey a covered work, you waive any legal power to forbid
176
+ circumvention of technological measures to the extent such circumvention
177
+ is effected by exercising rights under this License with respect to
178
+ the covered work, and you disclaim any intention to limit operation or
179
+ modification of the work as a means of enforcing, against the work's
180
+ users, your or third parties' legal rights to forbid circumvention of
181
+ technological measures.
182
+
183
+ 4. Conveying Verbatim Copies.
184
+
185
+ You may convey verbatim copies of the Program's source code as you
186
+ receive it, in any medium, provided that you conspicuously and
187
+ appropriately publish on each copy an appropriate copyright notice;
188
+ keep intact all notices stating that this License and any
189
+ non-permissive terms added in accord with section 7 apply to the code;
190
+ keep intact all notices of the absence of any warranty; and give all
191
+ recipients a copy of this License along with the Program.
192
+
193
+ You may charge any price or no price for each copy that you convey,
194
+ and you may offer support or warranty protection for a fee.
195
+
196
+ 5. Conveying Modified Source Versions.
197
+
198
+ You may convey a work based on the Program, or the modifications to
199
+ produce it from the Program, in the form of source code under the
200
+ terms of section 4, provided that you also meet all of these conditions:
201
+
202
+ a) The work must carry prominent notices stating that you modified
203
+ it, and giving a relevant date.
204
+
205
+ b) The work must carry prominent notices stating that it is
206
+ released under this License and any conditions added under section
207
+ 7. This requirement modifies the requirement in section 4 to
208
+ "keep intact all notices".
209
+
210
+ c) You must license the entire work, as a whole, under this
211
+ License to anyone who comes into possession of a copy. This
212
+ License will therefore apply, along with any applicable section 7
213
+ additional terms, to the whole of the work, and all its parts,
214
+ regardless of how they are packaged. This License gives no
215
+ permission to license the work in any other way, but it does not
216
+ invalidate such permission if you have separately received it.
217
+
218
+ d) If the work has interactive user interfaces, each must display
219
+ Appropriate Legal Notices; however, if the Program has interactive
220
+ interfaces that do not display Appropriate Legal Notices, your
221
+ work need not make them do so.
222
+
223
+ A compilation of a covered work with other separate and independent
224
+ works, which are not by their nature extensions of the covered work,
225
+ and which are not combined with it such as to form a larger program,
226
+ in or on a volume of a storage or distribution medium, is called an
227
+ "aggregate" if the compilation and its resulting copyright are not
228
+ used to limit the access or legal rights of the compilation's users
229
+ beyond what the individual works permit. Inclusion of a covered work
230
+ in an aggregate does not cause this License to apply to the other
231
+ parts of the aggregate.
232
+
233
+ 6. Conveying Non-Source Forms.
234
+
235
+ You may convey a covered work in object code form under the terms
236
+ of sections 4 and 5, provided that you also convey the
237
+ machine-readable Corresponding Source under the terms of this License,
238
+ in one of these ways:
239
+
240
+ a) Convey the object code in, or embodied in, a physical product
241
+ (including a physical distribution medium), accompanied by the
242
+ Corresponding Source fixed on a durable physical medium
243
+ customarily used for software interchange.
244
+
245
+ b) Convey the object code in, or embodied in, a physical product
246
+ (including a physical distribution medium), accompanied by a
247
+ written offer, valid for at least three years and valid for as
248
+ long as you offer spare parts or customer support for that product
249
+ model, to give anyone who possesses the object code either (1) a
250
+ copy of the Corresponding Source for all the software in the
251
+ product that is covered by this License, on a durable physical
252
+ medium customarily used for software interchange, for a price no
253
+ more than your reasonable cost of physically performing this
254
+ conveying of source, or (2) access to copy the
255
+ Corresponding Source from a network server at no charge.
256
+
257
+ c) Convey individual copies of the object code with a copy of the
258
+ written offer to provide the Corresponding Source. This
259
+ alternative is allowed only occasionally and noncommercially, and
260
+ only if you received the object code with such an offer, in accord
261
+ with subsection 6b.
262
+
263
+ d) Convey the object code by offering access from a designated
264
+ place (gratis or for a charge), and offer equivalent access to the
265
+ Corresponding Source in the same way through the same place at no
266
+ further charge. You need not require recipients to copy the
267
+ Corresponding Source along with the object code. If the place to
268
+ copy the object code is a network server, the Corresponding Source
269
+ may be on a different server (operated by you or a third party)
270
+ that supports equivalent copying facilities, provided you maintain
271
+ clear directions next to the object code saying where to find the
272
+ Corresponding Source. Regardless of what server hosts the
273
+ Corresponding Source, you remain obligated to ensure that it is
274
+ available for as long as needed to satisfy these requirements.
275
+
276
+ e) Convey the object code using peer-to-peer transmission, provided
277
+ you inform other peers where the object code and Corresponding
278
+ Source of the work are being offered to the general public at no
279
+ charge under subsection 6d.
280
+
281
+ A separable portion of the object code, whose source code is excluded
282
+ from the Corresponding Source as a System Library, need not be
283
+ included in conveying the object code work.
284
+
285
+ A "User Product" is either (1) a "consumer product", which means any
286
+ tangible personal property which is normally used for personal, family,
287
+ or household purposes, or (2) anything designed or sold for incorporation
288
+ into a dwelling. In determining whether a product is a consumer product,
289
+ doubtful cases shall be resolved in favor of coverage. For a particular
290
+ product received by a particular user, "normally used" refers to a
291
+ typical or common use of that class of product, regardless of the status
292
+ of the particular user or of the way in which the particular user
293
+ actually uses, or expects or is expected to use, the product. A product
294
+ is a consumer product regardless of whether the product has substantial
295
+ commercial, industrial or non-consumer uses, unless such uses represent
296
+ the only significant mode of use of the product.
297
+
298
+ "Installation Information" for a User Product means any methods,
299
+ procedures, authorization keys, or other information required to install
300
+ and execute modified versions of a covered work in that User Product from
301
+ a modified version of its Corresponding Source. The information must
302
+ suffice to ensure that the continued functioning of the modified object
303
+ code is in no case prevented or interfered with solely because
304
+ modification has been made.
305
+
306
+ If you convey an object code work under this section in, or with, or
307
+ specifically for use in, a User Product, and the conveying occurs as
308
+ part of a transaction in which the right of possession and use of the
309
+ User Product is transferred to the recipient in perpetuity or for a
310
+ fixed term (regardless of how the transaction is characterized), the
311
+ Corresponding Source conveyed under this section must be accompanied
312
+ by the Installation Information. But this requirement does not apply
313
+ if neither you nor any third party retains the ability to install
314
+ modified object code on the User Product (for example, the work has
315
+ been installed in ROM).
316
+
317
+ The requirement to provide Installation Information does not include a
318
+ requirement to continue to provide support service, warranty, or updates
319
+ for a work that has been modified or installed by the recipient, or for
320
+ the User Product in which it has been modified or installed. Access to a
321
+ network may be denied when the modification itself materially and
322
+ adversely affects the operation of the network or violates the rules and
323
+ protocols for communication across the network.
324
+
325
+ Corresponding Source conveyed, and Installation Information provided,
326
+ in accord with this section must be in a format that is publicly
327
+ documented (and with an implementation available to the public in
328
+ source code form), and must require no special password or key for
329
+ unpacking, reading or copying.
330
+
331
+ 7. Additional Terms.
332
+
333
+ "Additional permissions" are terms that supplement the terms of this
334
+ License by making exceptions from one or more of its conditions.
335
+ Additional permissions that are applicable to the entire Program shall
336
+ be treated as though they were included in this License, to the extent
337
+ that they are valid under applicable law. If additional permissions
338
+ apply only to part of the Program, that part may be used separately
339
+ under those permissions, but the entire Program remains governed by
340
+ this License without regard to the additional permissions.
341
+
342
+ When you convey a copy of a covered work, you may at your option
343
+ remove any additional permissions from that copy, or from any part of
344
+ it. (Additional permissions may be written to require their own
345
+ removal in certain cases when you modify the work.) You may place
346
+ additional permissions on material, added by you to a covered work,
347
+ for which you have or can give appropriate copyright permission.
348
+
349
+ Notwithstanding any other provision of this License, for material you
350
+ add to a covered work, you may (if authorized by the copyright holders of
351
+ that material) supplement the terms of this License with terms:
352
+
353
+ a) Disclaiming warranty or limiting liability differently from the
354
+ terms of sections 15 and 16 of this License; or
355
+
356
+ b) Requiring preservation of specified reasonable legal notices or
357
+ author attributions in that material or in the Appropriate Legal
358
+ Notices displayed by works containing it; or
359
+
360
+ c) Prohibiting misrepresentation of the origin of that material, or
361
+ requiring that modified versions of such material be marked in
362
+ reasonable ways as different from the original version; or
363
+
364
+ d) Limiting the use for publicity purposes of names of licensors or
365
+ authors of the material; or
366
+
367
+ e) Declining to grant rights under trademark law for use of some
368
+ trade names, trademarks, or service marks; or
369
+
370
+ f) Requiring indemnification of licensors and authors of that
371
+ material by anyone who conveys the material (or modified versions of
372
+ it) with contractual assumptions of liability to the recipient, for
373
+ any liability that these contractual assumptions directly impose on
374
+ those licensors and authors.
375
+
376
+ All other non-permissive additional terms are considered "further
377
+ restrictions" within the meaning of section 10. If the Program as you
378
+ received it, or any part of it, contains a notice stating that it is
379
+ governed by this License along with a term that is a further
380
+ restriction, you may remove that term. If a license document contains
381
+ a further restriction but permits relicensing or conveying under this
382
+ License, you may add to a covered work material governed by the terms
383
+ of that license document, provided that the further restriction does
384
+ not survive such relicensing or conveying.
385
+
386
+ If you add terms to a covered work in accord with this section, you
387
+ must place, in the relevant source files, a statement of the
388
+ additional terms that apply to those files, or a notice indicating
389
+ where to find the applicable terms.
390
+
391
+ Additional terms, permissive or non-permissive, may be stated in the
392
+ form of a separately written license, or stated as exceptions;
393
+ the above requirements apply either way.
394
+
395
+ 8. Termination.
396
+
397
+ You may not propagate or modify a covered work except as expressly
398
+ provided under this License. Any attempt otherwise to propagate or
399
+ modify it is void, and will automatically terminate your rights under
400
+ this License (including any patent licenses granted under the third
401
+ paragraph of section 11).
402
+
403
+ However, if you cease all violation of this License, then your
404
+ license from a particular copyright holder is reinstated (a)
405
+ provisionally, unless and until the copyright holder explicitly and
406
+ finally terminates your license, and (b) permanently, if the copyright
407
+ holder fails to notify you of the violation by some reasonable means
408
+ prior to 60 days after the cessation.
409
+
410
+ Moreover, your license from a particular copyright holder is
411
+ reinstated permanently if the copyright holder notifies you of the
412
+ violation by some reasonable means, this is the first time you have
413
+ received notice of violation of this License (for any work) from that
414
+ copyright holder, and you cure the violation prior to 30 days after
415
+ your receipt of the notice.
416
+
417
+ Termination of your rights under this section does not terminate the
418
+ licenses of parties who have received copies or rights from you under
419
+ this License. If your rights have been terminated and not permanently
420
+ reinstated, you do not qualify to receive new licenses for the same
421
+ material under section 10.
422
+
423
+ 9. Acceptance Not Required for Having Copies.
424
+
425
+ You are not required to accept this License in order to receive or
426
+ run a copy of the Program. Ancillary propagation of a covered work
427
+ occurring solely as a consequence of using peer-to-peer transmission
428
+ to receive a copy likewise does not require acceptance. However,
429
+ nothing other than this License grants you permission to propagate or
430
+ modify any covered work. These actions infringe copyright if you do
431
+ not accept this License. Therefore, by modifying or propagating a
432
+ covered work, you indicate your acceptance of this License to do so.
433
+
434
+ 10. Automatic Licensing of Downstream Recipients.
435
+
436
+ Each time you convey a covered work, the recipient automatically
437
+ receives a license from the original licensors, to run, modify and
438
+ propagate that work, subject to this License. You are not responsible
439
+ for enforcing compliance by third parties with this License.
440
+
441
+ An "entity transaction" is a transaction transferring control of an
442
+ organization, or substantially all assets of one, or subdividing an
443
+ organization, or merging organizations. If propagation of a covered
444
+ work results from an entity transaction, each party to that
445
+ transaction who receives a copy of the work also receives whatever
446
+ licenses to the work the party's predecessor in interest had or could
447
+ give under the previous paragraph, plus a right to possession of the
448
+ Corresponding Source of the work from the predecessor in interest, if
449
+ the predecessor has it or can get it with reasonable efforts.
450
+
451
+ You may not impose any further restrictions on the exercise of the
452
+ rights granted or affirmed under this License. For example, you may
453
+ not impose a license fee, royalty, or other charge for exercise of
454
+ rights granted under this License, and you may not initiate litigation
455
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
456
+ any patent claim is infringed by making, using, selling, offering for
457
+ sale, or importing the Program or any portion of it.
458
+
459
+ 11. Patents.
460
+
461
+ A "contributor" is a copyright holder who authorizes use under this
462
+ License of the Program or a work on which the Program is based. The
463
+ work thus licensed is called the contributor's "contributor version".
464
+
465
+ A contributor's "essential patent claims" are all patent claims
466
+ owned or controlled by the contributor, whether already acquired or
467
+ hereafter acquired, that would be infringed by some manner, permitted
468
+ by this License, of making, using, or selling its contributor version,
469
+ but do not include claims that would be infringed only as a
470
+ consequence of further modification of the contributor version. For
471
+ purposes of this definition, "control" includes the right to grant
472
+ patent sublicenses in a manner consistent with the requirements of
473
+ this License.
474
+
475
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
476
+ patent license under the contributor's essential patent claims, to
477
+ make, use, sell, offer for sale, import and otherwise run, modify and
478
+ propagate the contents of its contributor version.
479
+
480
+ In the following three paragraphs, a "patent license" is any express
481
+ agreement or commitment, however denominated, not to enforce a patent
482
+ (such as an express permission to practice a patent or covenant not to
483
+ sue for patent infringement). To "grant" such a patent license to a
484
+ party means to make such an agreement or commitment not to enforce a
485
+ patent against the party.
486
+
487
+ If you convey a covered work, knowingly relying on a patent license,
488
+ and the Corresponding Source of the work is not available for anyone
489
+ to copy, free of charge and under the terms of this License, through a
490
+ publicly available network server or other readily accessible means,
491
+ then you must either (1) cause the Corresponding Source to be so
492
+ available, or (2) arrange to deprive yourself of the benefit of the
493
+ patent license for this particular work, or (3) arrange, in a manner
494
+ consistent with the requirements of this License, to extend the patent
495
+ license to downstream recipients. "Knowingly relying" means you have
496
+ actual knowledge that, but for the patent license, your conveying the
497
+ covered work in a country, or your recipient's use of the covered work
498
+ in a country, would infringe one or more identifiable patents in that
499
+ country that you have reason to believe are valid.
500
+
501
+ If, pursuant to or in connection with a single transaction or
502
+ arrangement, you convey, or propagate by procuring conveyance of, a
503
+ covered work, and grant a patent license to some of the parties
504
+ receiving the covered work authorizing them to use, propagate, modify
505
+ or convey a specific copy of the covered work, then the patent license
506
+ you grant is automatically extended to all recipients of the covered
507
+ work and works based on it.
508
+
509
+ A patent license is "discriminatory" if it does not include within
510
+ the scope of its coverage, prohibits the exercise of, or is
511
+ conditioned on the non-exercise of one or more of the rights that are
512
+ specifically granted under this License. You may not convey a covered
513
+ work if you are a party to an arrangement with a third party that is
514
+ in the business of distributing software, under which you make payment
515
+ to the third party based on the extent of your activity of conveying
516
+ the work, and under which the third party grants, to any of the
517
+ parties who would receive the covered work from you, a discriminatory
518
+ patent license (a) in connection with copies of the covered work
519
+ conveyed by you (or copies made from those copies), or (b) primarily
520
+ for and in connection with specific products or compilations that
521
+ contain the covered work, unless you entered into that arrangement,
522
+ or that patent license was granted, prior to 28 March 2007.
523
+
524
+ Nothing in this License shall be construed as excluding or limiting
525
+ any implied license or other defenses to infringement that may
526
+ otherwise be available to you under applicable patent law.
527
+
528
+ 12. No Surrender of Others' Freedom.
529
+
530
+ If conditions are imposed on you (whether by court order, agreement or
531
+ otherwise) that contradict the conditions of this License, they do not
532
+ excuse you from the conditions of this License. If you cannot convey a
533
+ covered work so as to satisfy simultaneously your obligations under this
534
+ License and any other pertinent obligations, then as a consequence you may
535
+ not convey it at all. For example, if you agree to terms that obligate you
536
+ to collect a royalty for further conveying from those to whom you convey
537
+ the Program, the only way you could satisfy both those terms and this
538
+ License would be to refrain entirely from conveying the Program.
539
+
540
+ 13. Remote Network Interaction; Use with the GNU General Public License.
541
+
542
+ Notwithstanding any other provision of this License, if you modify the
543
+ Program, your modified version must prominently offer all users
544
+ interacting with it remotely through a computer network (if your version
545
+ supports such interaction) an opportunity to receive the Corresponding
546
+ Source of your version by providing access to the Corresponding Source
547
+ from a network server at no charge, through some standard or customary
548
+ means of facilitating copying of software. This Corresponding Source
549
+ shall include the Corresponding Source for any work covered by version 3
550
+ of the GNU General Public License that is incorporated pursuant to the
551
+ following paragraph.
552
+
553
+ Notwithstanding any other provision of this License, you have
554
+ permission to link or combine any covered work with a work licensed
555
+ under version 3 of the GNU General Public License into a single
556
+ combined work, and to convey the resulting work. The terms of this
557
+ License will continue to apply to the part which is the covered work,
558
+ but the work with which it is combined will remain governed by version
559
+ 3 of the GNU General Public License.
560
+
561
+ 14. Revised Versions of this License.
562
+
563
+ The Free Software Foundation may publish revised and/or new versions of
564
+ the GNU Affero General Public License from time to time. Such new versions
565
+ will be similar in spirit to the present version, but may differ in detail to
566
+ address new problems or concerns.
567
+
568
+ Each version is given a distinguishing version number. If the
569
+ Program specifies that a certain numbered version of the GNU Affero General
570
+ Public License "or any later version" applies to it, you have the
571
+ option of following the terms and conditions either of that numbered
572
+ version or of any later version published by the Free Software
573
+ Foundation. If the Program does not specify a version number of the
574
+ GNU Affero General Public License, you may choose any version ever published
575
+ by the Free Software Foundation.
576
+
577
+ If the Program specifies that a proxy can decide which future
578
+ versions of the GNU Affero General Public License can be used, that proxy's
579
+ public statement of acceptance of a version permanently authorizes you
580
+ to choose that version for the Program.
581
+
582
+ Later license versions may give you additional or different
583
+ permissions. However, no additional obligations are imposed on any
584
+ author or copyright holder as a result of your choosing to follow a
585
+ later version.
586
+
587
+ 15. Disclaimer of Warranty.
588
+
589
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
590
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
591
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
592
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
593
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
594
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
595
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
596
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
597
+
598
+ 16. Limitation of Liability.
599
+
600
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
601
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
602
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
603
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
604
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
605
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
606
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
607
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
608
+ SUCH DAMAGES.
609
+
610
+ 17. Interpretation of Sections 15 and 16.
611
+
612
+ If the disclaimer of warranty and limitation of liability provided
613
+ above cannot be given local legal effect according to their terms,
614
+ reviewing courts shall apply local law that most closely approximates
615
+ an absolute waiver of all civil liability in connection with the
616
+ Program, unless a warranty or assumption of liability accompanies a
617
+ copy of the Program in return for a fee.
618
+
619
+ END OF TERMS AND CONDITIONS
620
+
621
+ How to Apply These Terms to Your New Programs
622
+
623
+ If you develop a new program, and you want it to be of the greatest
624
+ possible use to the public, the best way to achieve this is to make it
625
+ free software which everyone can redistribute and change under these terms.
626
+
627
+ To do so, attach the following notices to the program. It is safest
628
+ to attach them to the start of each source file to most effectively
629
+ state the exclusion of warranty; and each file should have at least
630
+ the "copyright" line and a pointer to where the full notice is found.
631
+
632
+ <one line to give the program's name and a brief idea of what it does.>
633
+ Copyright (C) <year> <name of author>
634
+
635
+ This program is free software: you can redistribute it and/or modify
636
+ it under the terms of the GNU Affero General Public License as published
637
+ by the Free Software Foundation, either version 3 of the License, or
638
+ (at your option) any later version.
639
+
640
+ This program is distributed in the hope that it will be useful,
641
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
642
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
643
+ GNU Affero General Public License for more details.
644
+
645
+ You should have received a copy of the GNU Affero General Public License
646
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
647
+
648
+ Also add information on how to contact you by electronic and paper mail.
649
+
650
+ If your software can interact with users remotely through a computer
651
+ network, you should also make sure that it provides a way for users to
652
+ get its source. For example, if your program is a web application, its
653
+ interface could display a "Source" link that leads users to an archive
654
+ of the code. There are many ways you could offer source, and different
655
+ solutions will be better for different programs; see section 13 for the
656
+ specific requirements.
657
+
658
+ You should also get your employer (if you work as a programmer) or school,
659
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
660
+ For more information on this, and how to apply and follow the GNU AGPL, see
661
+ <https://www.gnu.org/licenses/>.
README.md ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tiny-GPT: 0.5B MoE Language Model
2
+
3
+ A clean, efficient implementation of a **Mixture-of-Experts GPT** that fits on modest GPUs (4GB VRAM) while training on large datasets.
4
+
5
+ ## 🎯 Main Goal
6
+ **Generate proper English text** - not gibberish!
7
+
8
+ ## 📊 Quick Stats
9
+
10
+ | Metric | Value |
11
+ |--------|-------|
12
+ | **Model Size** | 0.5B parameters (520M) |
13
+ | **Active per Token** | 180M parameters (via MoE routing) |
14
+ | **Architecture** | 12 Transformer layers, 8 experts/layer, top-2 routing |
15
+ | **Training Data** | WikiText-103 (103M tokens, ~500MB) |
16
+ | **GPU Memory** | 0.97 GiB (model weights only) |
17
+ | **Training Time** | ~10-20 hours on RTX 2050 (10k steps) |
18
+ | **Tokenizer** | GPT-2 BPE (50,257 vocab via tiktoken) |
19
+
20
+ ## 🚀 Quick Start
21
+
22
+ ### 1. Prepare Dataset
23
+ ```bash
24
+ python prepare_data.py
25
+ ```
26
+ Downloads WikiText-103 and tokenizes to memory-mapped binary files (~500MB).
27
+ This is a one-time operation that takes **10-30 minutes**.
28
+
29
+ ### 2. Train
30
+ ```bash
31
+ python main.py
32
+ ```
33
+ Starts training from scratch with:
34
+ - **Learning rate**: 1.5e-4 (lowered for stability)
35
+ - **Warmup**: 500 steps (better convergence)
36
+ - **Total steps**: 10,000 (more thorough training)
37
+ - **Batch size**: 16 (gradient accumulation of 2x8)
38
+
39
+ Training progress shows in real-time via rich progress bar.
40
+
41
+ ### 3. Generate Text
42
+ ```bash
43
+ python run.py
44
+ ```
45
+
46
+ ## 🤗 Use Hugging Face Hub (instead of local/GitHub checkpoints)
47
+
48
+ ### 1. Upload checkpoints to HF Hub
49
+ ```bash
50
+ pip install huggingface_hub
51
+ export HF_TOKEN=your_hf_token
52
+ python push_to_hf.py --repo-id yourname/Tiny-GPT
53
+ ```
54
+
55
+ This uploads:
56
+ - `checkpoints/best.pt` → `best.pt`
57
+ - `checkpoints/latest.pt` → `latest.pt` (if present)
58
+
59
+ ### 2. Run inference directly from HF Hub
60
+ ```bash
61
+ python run.py --hf-repo yourname/Tiny-GPT --prompt "The future of AI is"
62
+ ```
63
+
64
+ Optional flags:
65
+ - `--hf-filename best.pt`
66
+ - `--hf-revision main`
67
+ - `--hf-token <token>` (or use `HF_TOKEN` env var)
68
+
69
+ ## 📁 File Structure
70
+
71
+ ```
72
+ Tiny-GPT/
73
+ ├── main.py # Training script
74
+ ├── run.py # Inference script (NEW)
75
+ ├── prepare_data.py # Dataset preparation
76
+ ├── mini_gpt.py # Deprecated v1 (reference only)
77
+ ├── reset_training.sh # Clean old checkpoints
78
+ ├── wait_for_dataset.sh # Monitor data preparation
79
+
80
+ ├── data/
81
+ │ ├── train.bin # ~1.8M examples → ~80M tokens
82
+ │ ├── val.bin # ~3.7k examples → ~1.7M tokens
83
+ │ ├── test.bin # ~4.3k examples → ~2.0M tokens
84
+ │ └── meta.txt # Metadata
85
+
86
+ └── checkpoints/
87
+ ├── latest.pt # Most recent checkpoint
88
+ └── best.pt # Best validation loss checkpoint
89
+ ```
90
+
91
+ ## 🔧 Configuration
92
+
93
+ All hyperparameters are defined in `main.py`:
94
+
95
+ ```python
96
+ BLOCK_SIZE = 128 # Context window
97
+ EMBED_DIM = 768 # Model width
98
+ NUM_LAYERS = 12 # Transformer blocks
99
+ NUM_EXPERTS = 8 # Experts per MoE layer
100
+ TOP_K = 2 # Experts used per token
101
+ LR = 1.5e-4 # Learning rate (adjusted)
102
+ WARMUP_STEPS = 500 # Warmup schedule
103
+ MAX_ITERS = 10000 # Total training steps
104
+ GRAD_CLIP = 1.0 # Gradient clipping
105
+ ```
106
+
107
+ ## 📈 Expected Training Progress
108
+
109
+ **With fixed hyperparameters (new):**
110
+ - **Step 1**: Loss ~8.0
111
+ - **Step 500**: Loss ~6.5-7.0
112
+ - **Step 2500**: Loss ~4.5-5.0
113
+ - **Step 5000**: Loss ~3.8-4.2
114
+ - **Step 10000**: Loss ~3.5-3.8
115
+
116
+ **Quality indicator:** Model starts generating coherent English by step 2000+
117
+
118
+ ## 💡 What Changed?
119
+
120
+ ### Before (Broken)
121
+ ```
122
+ Learning Rate: 3e-4 (too high)
123
+ Warmup: 200 steps (insufficient)
124
+ Auto-resume: Enabled (got stuck in NaN)
125
+ Trainer Loss: DIVERGES TO NAN
126
+ Output: "hi defencesaternal Thirty shows allowanceBad Leh..." ❌
127
+ ```
128
+
129
+ ### After (Fixed)
130
+ ```
131
+ Learning Rate: 1.5e-4 (stable)
132
+ Warmup: 500 steps (better convergence)
133
+ Auto-resume: Disabled (start fresh)
134
+ Training Loss: SMOOTH CONVERGENCE
135
+ Output: "The history of the universe began with the Big Bang..." ✓
136
+ ```
137
+
138
+ ## 🧠 Model Architecture
139
+
140
+ ```
141
+ Input Tokens
142
+
143
+ Embedding + Positional Encoding (768-dim)
144
+
145
+ [x12 Transformer Blocks]
146
+ ├─ Multi-Head Attention (12 heads)
147
+ │ └─ Output: 768-dim
148
+ └─ Mixture-of-Experts Layer
149
+ ├─ 8 Expert FFNs (768→3072→768)
150
+ ├─ Router: Selects top-2 experts per token
151
+ └─ Load-balancing auxiliary loss
152
+
153
+ Layer Norm
154
+
155
+ Output Linear → Logits (50,257)
156
+
157
+ Cross-Entropy Loss
158
+ ```
159
+
160
+ **Memory Trick:** The CPUOffloadAdamW optimizer keeps fp32 master weights + momentum/variance on CPU RAM to save GPU VRAM:
161
+ - GPU: fp16 model weights + fp16 gradients (~1 GB)
162
+ - CPU: fp32 master weights + fp32 m/v (~4 GB)
163
+
164
+ ## 🎮 Using `run.py`
165
+
166
+ ### Interactive Mode (Default)
167
+ ```bash
168
+ python run.py
169
+ ```
170
+ Type prompts and press Enter. Commands:
171
+ - `/temp 0.8` - Set temperature (higher = more random)
172
+ - `/len 150` - Set max tokens
173
+ - `/topk 40` - Enable top-k sampling
174
+ - `/topp 0.9` - Set nucleus sampling threshold
175
+ - `quit` - Exit
176
+
177
+ ### Single Prompt
178
+ ```bash
179
+ python run.py --prompt "The future of AI is"
180
+ ```
181
+
182
+ ### Batch from File
183
+ ```bash
184
+ python run.py --prompts prompts.txt # One prompt per line
185
+ ```
186
+
187
+ ### Custom Checkpoint
188
+ ```bash
189
+ python run.py --checkpoint checkpoints/best.pt
190
+ ```
191
+
192
+ ### Full Options
193
+ ```bash
194
+ python run.py --help
195
+ ```
196
+
197
+ ## 🔍 Monitoring Training
198
+
199
+ The training loop shows:
200
+ ```
201
+ Step 5000 │ Train 4.23 │ Val 4.45 │ LR 0.000097
202
+ ```
203
+
204
+ **Healthy indicators:**
205
+ - ✓ Train loss smoothly decreases
206
+ - ✓ Val loss follows trend
207
+ - ✓ No NaN values
208
+ - ✓ Learning rate schedule works
209
+ - ✓ No gradient clipping (or occasional, < 10% of steps)
210
+
211
+ **Red flags:**
212
+ - ❌ Loss jumps/oscillates wildly
213
+ - ❌ NaN values appear
214
+ - ❌ Val loss stops improving (need more data or different HP)
215
+ - ❌ Constant gradient clipping (reduce LR)
216
+
217
+ ## 📊 Checkpointing
218
+
219
+ Saved automatically every 500 steps:
220
+ - **`latest.pt`**: Most recent checkpoint (always usable)
221
+ - **`best.pt`**: Best validation loss (for inference)
222
+
223
+ Load in Python:
224
+ ```python
225
+ checkpoint = torch.load("checkpoints/best.pt", map_location="cpu")
226
+ model.load_state_dict(checkpoint["model"])
227
+ optimizer.load_state_dict(checkpoint["optimizer"])
228
+ step = checkpoint["step"]
229
+ ```
230
+
231
+ ## 🛑 Troubleshooting
232
+
233
+ ### Dataset not preparing
234
+ ```bash
235
+ # Monitor progress
236
+ ./wait_for_dataset.sh
237
+
238
+ # Check manually
239
+ ls -lh data/
240
+ ```
241
+
242
+ ### Training produces NaN
243
+ ✓ **Fixed**: Lowered learning rate to 1.5e-4 and increased warmup
244
+
245
+ ### Model outputs gibberish
246
+ ✓ **Fixed**: Trained on larger dataset (WikiText-103 vs WikiText-2)
247
+
248
+ ### Out of memory
249
+ - Reduce `MICRO_BATCH` to 1 (slower but less VRAM)
250
+ - Reduce `BLOCK_SIZE` to 64
251
+ - Remove gradient checkpointing
252
+
253
+ ### GPU not detected
254
+ ```python
255
+ # Check in Python
256
+ import torch
257
+ print(torch.cuda.is_available()) # Should be True
258
+ print(torch.cuda.get_device_name(0)) # GPU name
259
+ ```
260
+
261
+ ## 📚 References
262
+
263
+ - **Mixture of Experts**: [Switch Transformers](https://arxiv.org/abs/2101.03961)
264
+ - **GPT Architecture**: [Language Models are Unsupervised Multitask Learners](https://d4mucfpkswtq.cloudfront.net/better-language-models/language-models.pdf)
265
+ - **Memory Optimization**: [Reducing Activation Recomputation in Large Transformer Models](https://arxiv.org/abs/2205.05198)
266
+ - **Tokenization**: [tiktoken](https://github.com/openai/tiktoken)
267
+
268
+ ## 📝 License
269
+
270
+ MIT License - See LICENSE file
271
+
272
+ ---
273
+
274
+ **Status**: ✅ Ready for training!
275
+
276
+ Next steps:
277
+ 1. ⏳ Wait for dataset preparation (`prepare_data.py`)
278
+ 2. ▶️ Run training (`python main.py`)
279
+ 3. 🎉 Generate text (`python run.py`)
data/meta.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ dataset : fineweb
2
+ tokeniser : gpt2
3
+ vocab_size : 50257
4
+ train_tokens : 9753076758
5
+ val_tokens : 100388902
6
+ test_tokens : 100523684
deepspeed.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Alright, let’s get practical. DeepSpeed is basically a **booster pack for training large models with PyTorch**. It handles memory, distributes work across GPUs, and keeps things from melting your hardware. 🧠⚡
3
+
4
+ We’ll walk through the **minimal working setup** so the idea clicks.
5
+
6
+ ---
7
+
8
+ ## 1. Install DeepSpeed
9
+
10
+ First step: install it.
11
+
12
+ ```bash
13
+ pip install deepspeed
14
+ ```
15
+
16
+ If CUDA and PyTorch are already installed, DeepSpeed compiles its optimizations automatically.
17
+
18
+ Check it works:
19
+
20
+ ```bash
21
+ deepspeed --version
22
+ ```
23
+
24
+ ---
25
+
26
+ ## 2. Simple PyTorch Training Script
27
+
28
+ A normal PyTorch training script might look like this:
29
+
30
+ ```python
31
+ import torch
32
+ import torch.nn as nn
33
+
34
+ model = nn.Linear(10, 1)
35
+
36
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
37
+
38
+ for step in range(100):
39
+ x = torch.randn(32, 10)
40
+ y = torch.randn(32, 1)
41
+
42
+ output = model(x)
43
+ loss = ((output - y) ** 2).mean()
44
+
45
+ optimizer.zero_grad()
46
+ loss.backward()
47
+ optimizer.step()
48
+ ```
49
+
50
+ DeepSpeed wraps this training loop so it can **handle distributed training and memory optimization**.
51
+
52
+ ---
53
+
54
+ ## 3. Add DeepSpeed to the Script
55
+
56
+ Modify the script like this:
57
+
58
+ ```python
59
+ import torch
60
+ import torch.nn as nn
61
+ import deepspeed
62
+
63
+ model = nn.Linear(10, 1)
64
+
65
+ parameters = filter(lambda p: p.requires_grad, model.parameters())
66
+
67
+ model_engine, optimizer, _, _ = deepspeed.initialize(
68
+ model=model,
69
+ model_parameters=parameters,
70
+ config="ds_config.json"
71
+ )
72
+
73
+ for step in range(100):
74
+ x = torch.randn(32, 10).to(model_engine.local_rank)
75
+ y = torch.randn(32, 1).to(model_engine.local_rank)
76
+
77
+ output = model_engine(x)
78
+ loss = ((output - y) ** 2).mean()
79
+
80
+ model_engine.backward(loss)
81
+ model_engine.step()
82
+ ```
83
+
84
+ Notice the difference:
85
+
86
+ Instead of
87
+ `loss.backward()`
88
+ you use
89
+
90
+ ```
91
+ model_engine.backward(loss)
92
+ ```
93
+
94
+ DeepSpeed now manages **gradient sync, memory, and distributed GPUs**.
95
+
96
+ ---
97
+
98
+ ## 4. Create the DeepSpeed Config
99
+
100
+ DeepSpeed uses a JSON config file.
101
+
102
+ `ds_config.json`
103
+
104
+ ```json
105
+ {
106
+ "train_batch_size": 32,
107
+ "fp16": {
108
+ "enabled": true
109
+ },
110
+ "zero_optimization": {
111
+ "stage": 2
112
+ }
113
+ }
114
+ ```
115
+
116
+ Key parts:
117
+
118
+ **fp16**
119
+ Uses half precision to save memory.
120
+
121
+ **zero_optimization**
122
+
123
+ * stage 1 → optimizer states split
124
+ * stage 2 → gradients split
125
+ * stage 3 → full model partitioned
126
+
127
+ Stage 3 is the **big guns for massive models**.
128
+
129
+ ---
130
+
131
+ ## 5. Run the Training
132
+
133
+ Instead of running Python directly, you launch with DeepSpeed:
134
+
135
+ ```bash
136
+ deepspeed train.py
137
+ ```
138
+
139
+ Multi-GPU example:
140
+
141
+ ```bash
142
+ deepspeed --num_gpus=4 train.py
143
+ ```
144
+
145
+ Now your model trains **distributed across GPUs automatically**.
146
+
147
+ ---
148
+
149
+ ## 6. Example with Hugging Face Transformers
150
+
151
+ DeepSpeed is heavily used with **Hugging Face Transformers**.
152
+
153
+ Example training command:
154
+
155
+ ```bash
156
+ deepspeed run_clm.py \
157
+ --model_name_or_path gpt2 \
158
+ --deepspeed ds_config.json
159
+ ```
160
+
161
+ This is how people train **LLMs efficiently**.
162
+
163
+ ---
164
+
165
+ ## What DeepSpeed Is Really Doing Behind the Scenes
166
+
167
+ When you enable ZeRO optimization, it:
168
+
169
+ * splits model parameters across GPUs
170
+ * shards gradients
171
+ * shards optimizer states
172
+ * optionally offloads memory to CPU or NVMe
173
+
174
+ So your GPU doesn’t have to hold **the entire model at once**.
175
+
176
+ That’s how researchers train **100B+ parameter models** without needing a ridiculous cluster.
177
+
178
+ ---
179
+
180
+ ## A Mental Model
181
+
182
+ Think of training a huge neural network like moving a giant couch up a staircase.
183
+
184
+ Without DeepSpeed:
185
+
186
+ one person tries to carry the couch alone. 💀
187
+
188
+ With DeepSpeed:
189
+
190
+ four people lift different corners and move together. 🛋️
191
+
192
+ Same couch.
193
+ Way less suffering.
194
+
195
+ ---
196
+
197
+ One spicy insight: DeepSpeed becomes **really powerful when models hit GPU memory limits**. That’s why it’s everywhere in LLM training.
198
+
199
+ The next level trick is combining it with **LoRA fine-tuning**, which lets you train huge models on **a single GPU or even a laptop**. That combo is what a lot of modern AI hackers use. 🔬
ds_config.active.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_batch_size": 16,
3
+ "train_micro_batch_size_per_gpu": 2,
4
+ "gradient_accumulation_steps": 8,
5
+ "optimizer": {
6
+ "type": "AdamW",
7
+ "params": {
8
+ "lr": 0.00015,
9
+ "betas": [
10
+ 0.9,
11
+ 0.999
12
+ ],
13
+ "eps": 1e-08,
14
+ "weight_decay": 0.01,
15
+ "torch_adam": true
16
+ }
17
+ },
18
+ "scheduler": {
19
+ "type": "WarmupLR",
20
+ "params": {
21
+ "warmup_min_lr": 0,
22
+ "warmup_max_lr": 0.00015,
23
+ "warmup_num_steps": 500
24
+ }
25
+ },
26
+ "zero_optimization": {
27
+ "stage": 2,
28
+ "offload_optimizer": {
29
+ "device": "cpu",
30
+ "pin_memory": true
31
+ },
32
+ "overlap_comm": true,
33
+ "contiguous_gradients": true,
34
+ "reduce_bucket_size": 2000000.0,
35
+ "gather_16bit_weights_on_model_save": false
36
+ },
37
+ "bf16": {
38
+ "enabled": true
39
+ },
40
+ "gradient_clipping": 1.0,
41
+ "activation_checkpointing": {
42
+ "partition_activations": true,
43
+ "contiguous_memory_optimization": true,
44
+ "number_checkpoints": 12,
45
+ "synchronize_checkpoint_boundary": false,
46
+ "cpu_checkpointing": false
47
+ },
48
+ "wall_clock_breakdown": false,
49
+ "steps_per_print": 100,
50
+ "fp16": {
51
+ "enabled": false
52
+ },
53
+ "amp": {
54
+ "enabled": false,
55
+ "amp_master_weights": false,
56
+ "loss_scale_window": 1000
57
+ }
58
+ }
ds_config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train_batch_size": 4,
3
+ "train_micro_batch_size_per_gpu": 1,
4
+ "gradient_accumulation_steps": 4,
5
+
6
+ "optimizer": {
7
+ "type": "AdamW",
8
+ "params": {
9
+ "lr": 1.5e-4,
10
+ "betas": [0.9, 0.999],
11
+ "eps": 1e-8,
12
+ "weight_decay": 0.01,
13
+ "torch_adam": true
14
+ }
15
+ },
16
+
17
+ "scheduler": {
18
+ "type": "WarmupLR",
19
+ "params": {
20
+ "warmup_min_lr": 0,
21
+ "warmup_max_lr": 1.5e-4,
22
+ "warmup_num_steps": 500
23
+ }
24
+ },
25
+
26
+ "zero_optimization": {
27
+ "stage": 2,
28
+ "offload_optimizer": {
29
+ "device": "cpu",
30
+ "pin_memory": false
31
+ },
32
+ "overlap_comm": true,
33
+ "contiguous_gradients": true,
34
+ "reduce_bucket_size": 1e6,
35
+ "gather_16bit_weights_on_model_save": false
36
+ },
37
+
38
+ "bf16": {
39
+ "enabled": true
40
+ },
41
+
42
+ "gradient_clipping": 1.0,
43
+
44
+ "activation_checkpointing": {
45
+ "partition_activations": true,
46
+ "contiguous_memory_optimization": true,
47
+ "number_checkpoints": 12,
48
+ "synchronize_checkpoint_boundary": false,
49
+ "cpu_checkpointing": true
50
+ },
51
+
52
+ "wall_clock_breakdown": false,
53
+
54
+ "steps_per_print": 100,
55
+
56
+ "fp16": {
57
+ "enabled": false
58
+ },
59
+
60
+ "amp": {
61
+ "enabled": false,
62
+ "amp_master_weights": false,
63
+ "loss_scale_window": 1000
64
+ }
65
+ }
main.py ADDED
@@ -0,0 +1,677 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MoE GPT – 0.5 Billion Parameter Language Model
3
+ ================================================
4
+ Mixture-of-Experts GPT trained on WikiText-2.
5
+ Fits on GPUs with as little as 4 GB VRAM via:
6
+ - FP16 model weights on GPU (~1 GB)
7
+ - CPU-offloaded AdamW (optimizer states on RAM, not VRAM)
8
+ - Gradient checkpointing (recompute activations to save memory)
9
+
10
+ Architecture
11
+ 12 Transformer layers × (12-head attention + MoE FFN)
12
+ 8 expert FFNs per layer, top-2 routing
13
+ Total params ≈ 521 M | Active per token ≈ 180 M
14
+
15
+ Run order:
16
+ pip install torch tiktoken numpy datasets
17
+ python prepare_data.py # once — downloads WikiText-2
18
+ python main.py # train + generate
19
+ """
20
+
21
+ import os
22
+ import math
23
+ import gc
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from torch.utils.checkpoint import checkpoint as grad_checkpoint
29
+ import tiktoken
30
+ from rich.progress import (
31
+ Progress, BarColumn, TextColumn, TimeRemainingColumn, TimeElapsedColumn,
32
+ SpinnerColumn, MofNCompleteColumn,
33
+ )
34
+ from rich.console import Console
35
+ from rich.table import Table
36
+ from rich import print as rprint
37
+
38
+ console = Console()
39
+
40
+ # ═════════════════════════════════════════════════════════════════════════════
41
+ # 1. LOAD DATA (memory-mapped .bin files from prepare_data.py)
42
+ # ═════════════════════════════════════════════════════════════════════════════
43
+
44
+ DATA_DIR = "data"
45
+ for split in ("train", "val", "test"):
46
+ path = os.path.join(DATA_DIR, f"{split}.bin")
47
+ if not os.path.exists(path):
48
+ raise FileNotFoundError(
49
+ f"\n[ERROR] '{path}' not found.\n"
50
+ "Run python prepare_data.py first."
51
+ )
52
+
53
+ train_data = np.memmap(os.path.join(DATA_DIR, "train.bin"), dtype=np.uint16, mode="r")
54
+ val_data = np.memmap(os.path.join(DATA_DIR, "val.bin"), dtype=np.uint16, mode="r")
55
+ test_data = np.memmap(os.path.join(DATA_DIR, "test.bin"), dtype=np.uint16, mode="r")
56
+
57
+ print("Dataset loaded (memory-mapped)")
58
+ print(f" Train : {len(train_data):>12,} tokens")
59
+ print(f" Val : {len(val_data):>12,} tokens")
60
+ print(f" Test : {len(test_data):>12,} tokens")
61
+ print()
62
+
63
+ # ═════════════════════════════════════════════════════════════════════════════
64
+ # 2. TOKENISER – GPT-2 BPE (matches prepare_data.py)
65
+ # ═════════════════════════════════════════════════════════════════════════════
66
+
67
+ enc = tiktoken.get_encoding("gpt2")
68
+ vocab_size = enc.n_vocab # 50 257
69
+
70
+ def encode(text: str) -> list:
71
+ return enc.encode_ordinary(text)
72
+
73
+ def decode(ids: list) -> str:
74
+ return enc.decode(ids)
75
+
76
+ print(f"Tokeniser : GPT-2 BPE (vocab {vocab_size:,})")
77
+ print()
78
+
79
+ # ═════════════════════════════════════════════════════════════════════════════
80
+ # 3. HYPERPARAMETERS
81
+ # ═════════════════════════════════════════════════════════════════════════════
82
+
83
+ BLOCK_SIZE = 128 # context window (tokens)
84
+ MICRO_BATCH = 2 # samples per GPU forward pass (tiny for VRAM)
85
+ GRAD_ACCUM = 8 # accumulate before optimizer step → eff. batch 16
86
+ EMBED_DIM = 768 # model width
87
+ NUM_HEADS = 12 # attention heads
88
+ NUM_LAYERS = 12 # transformer blocks
89
+ NUM_EXPERTS = 8 # expert FFNs per MoE layer
90
+ TOP_K = 2 # experts activated per token
91
+ FFN_DIM = EMBED_DIM * 4 # 3 072 (expert hidden dim)
92
+ DROPOUT = 0.1
93
+ LR = 1.5e-4 # peak learning rate (reduced from 3e-4 to prevent NaN)
94
+ WARMUP_STEPS = 500 # increased warmup for stability
95
+ MAX_ITERS = 10000 # extended training to 10k steps
96
+ EVAL_EVERY = 500
97
+ EVAL_ITERS = 50
98
+ AUX_LOSS_W = 0.01 # load-balancing auxiliary loss weight
99
+ GRAD_CLIP = 1.0
100
+ CHECKPOINT_DIR = "checkpoints" # directory for saving checkpoints
101
+
102
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
103
+ DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
104
+ # bfloat16: same exponent range as fp32 — no overflow/NaN, no GradScaler needed.
105
+ # float16 caused NaN because it overflows at 65504.
106
+
107
+ print(f"Device : {DEVICE.upper()}")
108
+ print(f"Precision : {'BF16 + CPU-offload optimizer' if DTYPE == torch.bfloat16 else 'FP32'}")
109
+ print(f"Effective batch : {MICRO_BATCH * GRAD_ACCUM}")
110
+ print()
111
+
112
+ # ═════════════════════════════════════════════════════════════════════════════
113
+ # 4. DATA LOADER
114
+ # ═════════════════════════════════════════════════════════════════════════════
115
+
116
+ def get_batch(split="train"):
117
+ data = {"train": train_data, "val": val_data, "test": test_data}[split]
118
+ ix = np.random.randint(0, len(data) - BLOCK_SIZE, size=(MICRO_BATCH,))
119
+ x = np.stack([data[i : i + BLOCK_SIZE ].astype(np.int64) for i in ix])
120
+ y = np.stack([data[i+1 : i + BLOCK_SIZE + 1].astype(np.int64) for i in ix])
121
+ return torch.from_numpy(x).to(DEVICE), torch.from_numpy(y).to(DEVICE)
122
+
123
+ # ═════════════════════════════════════════════════════════════════════════════
124
+ # 5. MODEL — Mixture-of-Experts GPT (~0.5 B params)
125
+ # ═════════════════════════════════════════════════════════════════════════════
126
+
127
+ class CausalSelfAttention(nn.Module):
128
+ """Multi-head causal self-attention with fused QKV projection."""
129
+
130
+ def __init__(self):
131
+ super().__init__()
132
+ self.n_heads = NUM_HEADS
133
+ self.head_dim = EMBED_DIM // NUM_HEADS
134
+ self.qkv = nn.Linear(EMBED_DIM, 3 * EMBED_DIM, bias=False)
135
+ self.proj = nn.Linear(EMBED_DIM, EMBED_DIM, bias=False)
136
+ self.attn_drop = nn.Dropout(DROPOUT)
137
+ self.proj_drop = nn.Dropout(DROPOUT)
138
+ self.register_buffer(
139
+ "mask",
140
+ torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE))
141
+ .view(1, 1, BLOCK_SIZE, BLOCK_SIZE),
142
+ )
143
+
144
+ def forward(self, x):
145
+ B, T, C = x.shape
146
+ qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
147
+ q, k, v = qkv.permute(2, 0, 3, 1, 4) # each (B, H, T, D)
148
+
149
+ att = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
150
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
151
+ att = F.softmax(att.float(), dim=-1).to(x.dtype) # softmax in fp32
152
+ att = self.attn_drop(att)
153
+
154
+ out = (att @ v).transpose(1, 2).reshape(B, T, C)
155
+ return self.proj_drop(self.proj(out))
156
+
157
+
158
+ class ExpertFFN(nn.Module):
159
+ """Single expert: two-layer FFN with GELU."""
160
+
161
+ def __init__(self):
162
+ super().__init__()
163
+ self.w1 = nn.Linear(EMBED_DIM, FFN_DIM)
164
+ self.w2 = nn.Linear(FFN_DIM, EMBED_DIM)
165
+ self.act = nn.GELU()
166
+ self.drop = nn.Dropout(DROPOUT)
167
+
168
+ def forward(self, x):
169
+ return self.drop(self.w2(self.act(self.w1(x))))
170
+
171
+
172
+ class MoELayer(nn.Module):
173
+ """
174
+ Mixture-of-Experts: routes each token to TOP_K of NUM_EXPERTS FFNs.
175
+ Includes Switch-Transformer-style load-balancing auxiliary loss.
176
+ """
177
+
178
+ def __init__(self):
179
+ super().__init__()
180
+ self.router = nn.Linear(EMBED_DIM, NUM_EXPERTS, bias=False)
181
+ self.experts = nn.ModuleList([ExpertFFN() for _ in range(NUM_EXPERTS)])
182
+
183
+ def forward(self, x):
184
+ B, T, C = x.shape
185
+ flat = x.reshape(-1, C) # (N, C)
186
+ N = flat.shape[0]
187
+
188
+ # ── routing ──
189
+ logits = self.router(flat) # (N, E)
190
+ probs = F.softmax(logits.float(), dim=-1) # fp32 for stability
191
+
192
+ top_w, top_i = torch.topk(probs, TOP_K, dim=-1) # (N, K)
193
+ top_w = (top_w / top_w.sum(dim=-1, keepdim=True)).to(x.dtype)
194
+
195
+ # ── load-balancing loss ──
196
+ one_hot = F.one_hot(top_i, NUM_EXPERTS).float().sum(dim=1) # (N, E)
197
+ f = one_hot.mean(dim=0)
198
+ P = probs.mean(dim=0)
199
+ aux_loss = NUM_EXPERTS * (f * P).sum()
200
+
201
+ # ── dispatch to experts ──
202
+ out = torch.zeros_like(flat)
203
+ for i, expert in enumerate(self.experts):
204
+ mask = (top_i == i).any(dim=-1) # (N,)
205
+ if not mask.any():
206
+ continue
207
+ tokens = flat[mask] # (n_i, C)
208
+ e_out = expert(tokens) # (n_i, C)
209
+ match = (top_i[mask] == i).to(x.dtype) # (n_i, K)
210
+ weights = (top_w[mask] * match).sum(-1, keepdim=True)
211
+ out[mask] += weights * e_out
212
+
213
+ return out.reshape(B, T, C), aux_loss
214
+
215
+
216
+ class TransformerBlock(nn.Module):
217
+ """Pre-norm Transformer block: Attention + MoE, with residuals."""
218
+
219
+ def __init__(self):
220
+ super().__init__()
221
+ self.ln1 = nn.LayerNorm(EMBED_DIM)
222
+ self.attn = CausalSelfAttention()
223
+ self.ln2 = nn.LayerNorm(EMBED_DIM)
224
+ self.moe = MoELayer()
225
+
226
+ def forward(self, x):
227
+ x = x + self.attn(self.ln1(x))
228
+ moe_out, aux = self.moe(self.ln2(x))
229
+ x = x + moe_out
230
+ return x, aux
231
+
232
+
233
+ class MoEGPT(nn.Module):
234
+ """
235
+ Full MoE-GPT model (~521 M parameters, ~180 M active per token).
236
+
237
+ 1. Token + positional embeddings
238
+ 2. 12 × Transformer blocks (self-attention + MoE FFN)
239
+ 3. Final layer-norm → linear head (weight-tied with token embedding)
240
+ """
241
+
242
+ def __init__(self):
243
+ super().__init__()
244
+ self.tok_emb = nn.Embedding(vocab_size, EMBED_DIM)
245
+ self.pos_emb = nn.Embedding(BLOCK_SIZE, EMBED_DIM)
246
+ self.drop = nn.Dropout(DROPOUT)
247
+ self.blocks = nn.ModuleList([TransformerBlock() for _ in range(NUM_LAYERS)])
248
+ self.ln_f = nn.LayerNorm(EMBED_DIM)
249
+ self.head = nn.Linear(EMBED_DIM, vocab_size, bias=False)
250
+
251
+ # Weight tying saves ~38 M params and improves training
252
+ self.head.weight = self.tok_emb.weight
253
+ self._init_weights()
254
+
255
+ def _init_weights(self):
256
+ """GPT-2-style init with scaled residual projections."""
257
+ for name, p in self.named_parameters():
258
+ if p.dim() >= 2:
259
+ nn.init.normal_(p, mean=0.0, std=0.02)
260
+ elif "bias" in name:
261
+ nn.init.zeros_(p)
262
+ scale = (2 * NUM_LAYERS) ** -0.5
263
+ for block in self.blocks:
264
+ nn.init.normal_(block.attn.proj.weight, mean=0.0, std=0.02 * scale)
265
+ for expert in block.moe.experts:
266
+ nn.init.normal_(expert.w2.weight, mean=0.0, std=0.02 * scale)
267
+
268
+ def forward(self, idx, targets=None):
269
+ B, T = idx.shape
270
+ x = self.drop(
271
+ self.tok_emb(idx) + self.pos_emb(torch.arange(T, device=idx.device))
272
+ )
273
+
274
+ total_aux = 0.0
275
+ for block in self.blocks:
276
+ if self.training:
277
+ x, aux = grad_checkpoint(block, x, use_reentrant=False)
278
+ else:
279
+ x, aux = block(x)
280
+ total_aux = total_aux + aux
281
+
282
+ logits = self.head(self.ln_f(x))
283
+
284
+ loss = None
285
+ if targets is not None:
286
+ ce = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
287
+ loss = ce + AUX_LOSS_W * total_aux
288
+ return logits, loss
289
+
290
+ @torch.no_grad()
291
+ def generate(self, prompt: str, max_new_tokens=200, temperature=0.8):
292
+ self.eval()
293
+ ids = encode(prompt)
294
+ idx = torch.tensor([ids], dtype=torch.long, device=DEVICE)
295
+
296
+ for _ in range(max_new_tokens):
297
+ ctx = idx[:, -BLOCK_SIZE:]
298
+ logits, _ = self(ctx)
299
+ logits = logits[:, -1, :].float() / temperature
300
+ probs = F.softmax(logits, dim=-1)
301
+ nxt = torch.multinomial(probs, 1)
302
+ idx = torch.cat([idx, nxt], dim=1)
303
+
304
+ self.train()
305
+ return decode(idx[0].tolist())
306
+
307
+ # ═════════════════════════════════════════════════════════════════════════════
308
+ # 6. CPU-OFFLOAD OPTIMIZER
309
+ # Hand-rolled AdamW with fp32 master weights but fp16 momentum/variance.
310
+ # Saves ~2 GB CPU RAM compared to using torch.optim.AdamW (all fp32).
311
+ # GPU VRAM cost ≈ 1 GB (only fp16 model weights + grads).
312
+ #
313
+ # Memory breakdown for 520 M params:
314
+ # fp32 master weights : ~2.0 GB
315
+ # fp16 momentum : ~1.0 GB
316
+ # fp16 variance : ~1.0 GB
317
+ # Total CPU RAM : ~4.1 GB (was ~6.2 GB with all-fp32 AdamW)
318
+ # ═════════════════════════════════════════════════════════════════════════════
319
+
320
+ class CPUOffloadAdamW:
321
+ """
322
+ AdamW with ALL state (master weights, momentum, variance) in fp32 on CPU.
323
+ fp16 m/v was the culprit for NaN — Adam variance accumulates squared
324
+ gradients that easily exceed fp16 max (65504) → overflow → NaN.
325
+ GPU holds only fp16 model weights + fp16 gradients (~1 GB VRAM).
326
+ CPU RAM: fp32 master(2 GB) + fp32 m(2 GB) + fp32 v(2 GB) ≈ 6.2 GB.
327
+ Expose param_groups so torch.amp.GradScaler.unscale_() works correctly.
328
+ """
329
+
330
+ def __init__(self, gpu_params, lr=3e-4, betas=(0.9, 0.999),
331
+ eps=1e-8, weight_decay=0.01):
332
+ self.gpu_params = list(gpu_params)
333
+ self.lr = lr
334
+ self.beta1, self.beta2 = betas
335
+ self.eps = eps
336
+ self.wd = weight_decay
337
+ self.t = 0
338
+
339
+ # fp32 master copies + fp32 momentum/variance on CPU
340
+ self.master = [p.data.float().cpu() for p in self.gpu_params]
341
+ self.m = [torch.zeros_like(mp) for mp in self.master] # fp32
342
+ self.v = [torch.zeros_like(mp) for mp in self.master] # fp32
343
+
344
+ # GradScaler compatibility: unscale_() iterates param_groups
345
+ self.param_groups = [{"params": self.gpu_params}]
346
+
347
+ def step(self):
348
+ self.t += 1
349
+ bc1 = 1.0 - self.beta1 ** self.t
350
+ bc2 = 1.0 - self.beta2 ** self.t
351
+
352
+ for i, gp in enumerate(self.gpu_params):
353
+ if gp.grad is None:
354
+ continue
355
+ g = gp.grad.data.float().cpu() # fp16 grad → fp32
356
+
357
+ # Decoupled weight decay
358
+ self.master[i].mul_(1.0 - self.lr * self.wd)
359
+
360
+ # Adam moments (all fp32 — no overflow risk)
361
+ self.m[i].mul_(self.beta1).add_(g, alpha=1.0 - self.beta1)
362
+ self.v[i].mul_(self.beta2).addcmul_(g, g, value=1.0 - self.beta2)
363
+
364
+ # Bias-corrected parameter update
365
+ self.master[i].addcdiv_(
366
+ self.m[i] / bc1,
367
+ (self.v[i] / bc2).sqrt_().add_(self.eps),
368
+ value=-self.lr,
369
+ )
370
+
371
+ # Push updated fp32 weights → GPU fp16
372
+ gp.data.copy_(self.master[i])
373
+
374
+ def zero_grad(self):
375
+ for gp in self.gpu_params:
376
+ gp.grad = None
377
+
378
+ def set_lr(self, lr):
379
+ self.lr = lr
380
+
381
+ def state_dict(self):
382
+ return {"t": self.t, "master": self.master, "m": self.m, "v": self.v}
383
+
384
+ def load_state_dict(self, sd):
385
+ self.t = sd["t"]
386
+ self.master = sd["master"]
387
+ self.m = sd["m"]
388
+ self.v = sd["v"]
389
+ for gp, mp in zip(self.gpu_params, self.master):
390
+ gp.data.copy_(mp.data)
391
+
392
+ # ═════════════════════════════════════════════════════════════════════════════
393
+ # 7. CHECKPOINT HELPERS
394
+ # ═════════════════════════════════════════════════════════════════════════════
395
+
396
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
397
+
398
+ def save_checkpoint(step, model, optimizer, train_loss, val_loss, path):
399
+ """Save model + optimizer + training state to disk."""
400
+ torch.save({
401
+ "step": step,
402
+ "model": model.state_dict(),
403
+ "optimizer": optimizer.state_dict(),
404
+ "train_loss": train_loss,
405
+ "val_loss": val_loss,
406
+ }, path)
407
+
408
+ def load_checkpoint(path, model, optimizer):
409
+ """Load checkpoint and return the step to resume from."""
410
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
411
+ model.load_state_dict(ckpt["model"])
412
+ optimizer.load_state_dict(ckpt["optimizer"])
413
+ print(f" Resumed from step {ckpt['step']} "
414
+ f"(train {ckpt['train_loss']:.4f}, val {ckpt['val_loss']:.4f})")
415
+ return ckpt["step"], ckpt["val_loss"]
416
+
417
+ # ═════════════════════════════════════════════════════════════════════════════
418
+ # 8. LEARNING-RATE SCHEDULE (linear warmup → cosine decay to 10 %)
419
+ # ═════════════════════════════════════════════════════════════════════════════
420
+
421
+ def get_lr(step):
422
+ if step < WARMUP_STEPS:
423
+ return LR * step / WARMUP_STEPS
424
+ progress = (step - WARMUP_STEPS) / max(1, MAX_ITERS - WARMUP_STEPS)
425
+ return LR * 0.1 + 0.5 * LR * 0.9 * (1 + math.cos(math.pi * progress))
426
+
427
+ # ═════════════════════════════════════════════════════════════════════════════
428
+ # 9. LOSS ESTIMATION
429
+ # ═════════════════════════════════════════════════════════════════════════════
430
+
431
+ @torch.no_grad()
432
+ def estimate_loss():
433
+ model.eval()
434
+ out = {}
435
+ for split in ("train", "val"):
436
+ losses = []
437
+ for _ in range(EVAL_ITERS):
438
+ x, y = get_batch(split)
439
+ _, loss = model(x, y)
440
+ losses.append(loss.item())
441
+ out[split] = sum(losses) / len(losses)
442
+ model.train()
443
+ return out
444
+
445
+ # ═════════════════════════════════════════════════════════════════════════════
446
+ # 10. INSTANTIATE MODEL + OPTIMIZER
447
+ # ═════════════════════════════════════════════════════���═══════════════════════
448
+
449
+ if DEVICE == "cuda":
450
+ torch.cuda.empty_cache()
451
+
452
+ # ── Delete any NaN-poisoned checkpoints before loading ──
453
+ _nan_guard = os.path.join(CHECKPOINT_DIR, "latest.pt")
454
+ if os.path.exists(_nan_guard):
455
+ try:
456
+ _c = torch.load(_nan_guard, map_location="cpu", weights_only=False)
457
+ if _c.get("val_loss") != _c.get("val_loss"): # nan != nan
458
+ os.remove(_nan_guard)
459
+ _best = os.path.join(CHECKPOINT_DIR, "best.pt")
460
+ if os.path.exists(_best):
461
+ os.remove(_best)
462
+ print("[yellow]NaN checkpoint detected and removed — starting fresh.[/yellow]")
463
+ except Exception:
464
+ pass
465
+
466
+ model = MoEGPT()
467
+ n_total = sum(p.numel() for p in model.parameters())
468
+ _expert1 = sum(p.numel() for p in model.blocks[0].moe.experts[0].parameters())
469
+ n_active = n_total - _expert1 * (NUM_EXPERTS - TOP_K) * NUM_LAYERS
470
+
471
+ # Move to GPU in fp16 (or stay fp32 on CPU)
472
+ model = model.to(dtype=DTYPE, device=DEVICE)
473
+ gc.collect()
474
+ if DEVICE == "cuda":
475
+ torch.cuda.empty_cache()
476
+ vram_used = torch.cuda.memory_allocated() / 1024**3
477
+ print(f"GPU VRAM used : {vram_used:.2f} GiB (model weights)")
478
+
479
+ if DEVICE == "cuda":
480
+ # Initialize optimizer AFTER config changes so it uses the new LR
481
+ optimizer = CPUOffloadAdamW(model.parameters(), lr=LR)
482
+ gc.collect()
483
+ opt_gb = n_total * 4 * 3 / 1024**3 # fp32 master + fp32 m + fp32 v
484
+ print(f"CPU RAM for opt : ~{opt_gb:.1f} GiB (fp32 master + fp32 m + fp32 v)")
485
+ else:
486
+ _inner = torch.optim.AdamW(model.parameters(), lr=LR)
487
+ class _Wrap:
488
+ def __init__(self, o): self.opt = o
489
+ def step(self): self.opt.step()
490
+ def zero_grad(self): self.opt.zero_grad(set_to_none=True)
491
+ def set_lr(self, lr):
492
+ for pg in self.opt.param_groups: pg["lr"] = lr
493
+ def state_dict(self): return self.opt.state_dict()
494
+ def load_state_dict(self, sd): self.opt.load_state_dict(sd)
495
+ optimizer = _Wrap(_inner)
496
+
497
+ print(f"Total parameters : {n_total:>14,}")
498
+ print(f"Active per token : {n_active:>14,}")
499
+ print()
500
+
501
+ # ── Auto-resume from latest checkpoint ──
502
+ RESUME = True # Automatically resume from latest.pt if it exists
503
+ start_step = 0
504
+ best_val = float("inf")
505
+ latest_ckpt = os.path.join(CHECKPOINT_DIR, "latest.pt")
506
+ if RESUME and os.path.exists(latest_ckpt):
507
+ try:
508
+ _c = torch.load(latest_ckpt, map_location="cpu", weights_only=False)
509
+ # Skip NaN-poisoned checkpoints
510
+ if _c.get("val_loss") != _c.get("val_loss") or _c.get("train_loss") != _c.get("train_loss"):
511
+ print("Checkpoint has NaN losses — deleting and starting fresh")
512
+ os.remove(latest_ckpt)
513
+ else:
514
+ print("Checkpoint found — resuming …")
515
+ start_step, best_val = load_checkpoint(latest_ckpt, model, optimizer)
516
+ print()
517
+ except Exception as e:
518
+ print(f"Checkpoint corrupted ({e}) — starting fresh")
519
+ else:
520
+ if RESUME:
521
+ print("No checkpoint found — starting fresh training")
522
+ print()
523
+
524
+ # ═════════════════════════════════════════════════════════════════════════════
525
+ # 11. TRAINING LOOP
526
+ # ═════════════════════════════════════════════════════════════════════════════
527
+
528
+ console.rule("[bold green]Training started")
529
+ print()
530
+
531
+ with Progress(
532
+ SpinnerColumn(),
533
+ TextColumn("[bold blue]{task.description}"),
534
+ BarColumn(bar_width=30),
535
+ MofNCompleteColumn(),
536
+ TextColumn("•"),
537
+ TimeElapsedColumn(),
538
+ TextColumn("•"),
539
+ TimeRemainingColumn(),
540
+ TextColumn("•"),
541
+ TextColumn("[yellow]loss {task.fields[train_loss]}"),
542
+ TextColumn("[cyan]val {task.fields[val_loss]}"),
543
+ TextColumn("[magenta]lr {task.fields[lr]}"),
544
+ console=console,
545
+ refresh_per_second=4,
546
+ ) as progress:
547
+ total_steps = MAX_ITERS - start_step
548
+ task = progress.add_task(
549
+ "Training", total=total_steps,
550
+ train_loss="--.----", val_loss="--.----", lr="--.------",
551
+ )
552
+
553
+ for step in range(start_step + 1, MAX_ITERS + 1):
554
+
555
+ lr = get_lr(step)
556
+ optimizer.set_lr(lr)
557
+
558
+ optimizer.zero_grad()
559
+ accum_loss = 0.0
560
+
561
+ for _ in range(GRAD_ACCUM):
562
+ x, y = get_batch("train")
563
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16,
564
+ enabled=(DTYPE == torch.bfloat16)):
565
+ _, loss = model(x, y)
566
+ (loss / GRAD_ACCUM).backward()
567
+ accum_loss += loss.item() / GRAD_ACCUM
568
+
569
+ # Gradient clipping with stricter threshold to prevent explosion
570
+ norm_before = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
571
+ if norm_before > GRAD_CLIP:
572
+ progress.console.print(f" [yellow]Gradient norm clipped: {norm_before:.2f} → {GRAD_CLIP}[/]", style="dim")
573
+ optimizer.step()
574
+
575
+ progress.update(
576
+ task, advance=1,
577
+ train_loss=f"{accum_loss:.4f}", lr=f"{lr:.6f}",
578
+ )
579
+
580
+ if step % EVAL_EVERY == 0 or step == 1:
581
+ losses = estimate_loss()
582
+ progress.update(
583
+ task,
584
+ train_loss=f"{losses['train']:.4f}",
585
+ val_loss=f"{losses['val']:.4f}",
586
+ lr=f"{lr:.6f}",
587
+ )
588
+ progress.console.print(
589
+ f" [bold]Step {step:>5}[/] │ "
590
+ f"[yellow]Train {losses['train']:.4f}[/] │ "
591
+ f"[cyan]Val {losses['val']:.4f}[/] │ "
592
+ f"[magenta]LR {lr:.6f}[/]"
593
+ )
594
+
595
+ # ── Save checkpoints ──
596
+ save_checkpoint(
597
+ step, model, optimizer,
598
+ losses["train"], losses["val"],
599
+ os.path.join(CHECKPOINT_DIR, "latest.pt"),
600
+ )
601
+ if losses["val"] < best_val:
602
+ best_val = losses["val"]
603
+ save_checkpoint(
604
+ step, model, optimizer,
605
+ losses["train"], losses["val"],
606
+ os.path.join(CHECKPOINT_DIR, "best.pt"),
607
+ )
608
+ progress.console.print(
609
+ f" [bold green]★ New best val loss: {best_val:.4f} (saved best.pt)[/]"
610
+ )
611
+
612
+ print()
613
+ console.rule("[bold green]Training complete")
614
+ print()
615
+
616
+ # ── Load best checkpoint for final evaluation ──
617
+ best_ckpt = os.path.join(CHECKPOINT_DIR, "best.pt")
618
+ if os.path.exists(best_ckpt):
619
+ print("Loading best checkpoint for evaluation …")
620
+ load_checkpoint(best_ckpt, model, optimizer)
621
+ print()
622
+
623
+ # ═════════════════════════════════════════════════════════════════════════════
624
+ # 12. TEST EVALUATION
625
+ # ═════════════════════════════════════════════════════════════════════════════
626
+
627
+ model.eval()
628
+ test_losses = []
629
+ with torch.no_grad():
630
+ for _ in range(EVAL_ITERS):
631
+ x, y = get_batch("test")
632
+ _, loss = model(x, y)
633
+ test_losses.append(loss.item())
634
+ test_loss = sum(test_losses) / len(test_losses)
635
+ print(f"Test loss : {test_loss:.4f}")
636
+ print()
637
+ model.train()
638
+
639
+ # ═════════════════════════════════════════════════════════════════════════════
640
+ # 13. TEXT GENERATION SAMPLES
641
+ # ═════════════════════════════════════════════════════════════════════════════
642
+
643
+ prompts = [
644
+ "The history of",
645
+ "Scientists have discovered",
646
+ "In the early twentieth century",
647
+ ]
648
+
649
+ print("=" * 60)
650
+ print("Generated Text Samples")
651
+ print("=" * 60)
652
+
653
+ for prompt in prompts:
654
+ output = model.generate(prompt, max_new_tokens=120, temperature=0.7)
655
+ print(f"\nPrompt : \"{prompt}\"")
656
+ print(f"Output : {output.strip()}")
657
+ print()
658
+
659
+ # ═════════════════════════════════════════════════════════════════════════════
660
+ # 14. INTERACTIVE MODE
661
+ # ═════════════════════════════════════════════════════════════════════════════
662
+
663
+ print("=" * 60)
664
+ print("Interactive Mode (type 'quit' to exit)")
665
+ print("=" * 60)
666
+
667
+ while True:
668
+ try:
669
+ prompt = input("\nEnter a prompt: ").strip()
670
+ except (EOFError, KeyboardInterrupt):
671
+ break
672
+ if not prompt or prompt.lower() == "quit":
673
+ break
674
+ output = model.generate(prompt, max_new_tokens=150, temperature=0.8)
675
+ print(f"\n{output.strip()}")
676
+
677
+ print("\nGoodbye!")
main_deepspeed.py ADDED
@@ -0,0 +1,696 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MoE GPT – 0.5 Billion Parameter Language Model (DeepSpeed ZeRO-3)
3
+ ==================================================================
4
+ Mixture-of-Experts GPT trained on FineWeb-Edu with DeepSpeed ZeRO-Infinity:
5
+ - ZeRO Stage 3: All states partitioned across GPUs + CPU RAM offload
6
+ - CPU Offloading: Parameters & optimizer states in CPU RAM
7
+ - Memory efficient: Fits massive models on limited VRAM
8
+ - Automatic gradient checkpointing & mixed precision (bfloat16)
9
+
10
+ Architecture
11
+ 12 Transformer layers × (12-head attention + MoE FFN)
12
+ 8 expert FFNs per layer, top-2 routing
13
+ Total params ≈ 521 M | Active per token ≈ 180 M
14
+
15
+ Run order:
16
+ pip install torch tiktoken numpy datasets deepspeed
17
+ python prepare_data.py # once — downloads FineWeb-Edu
18
+ deepspeed --num_gpus 1 main.py # train with DeepSpeed
19
+ python run.py # generate
20
+ """
21
+
22
+ import os
23
+ import sys
24
+ import math
25
+ import gc
26
+ import json
27
+ from datetime import datetime, timedelta
28
+ from zoneinfo import ZoneInfo
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn as nn
32
+ import torch.nn.functional as F
33
+ from torch.utils.checkpoint import checkpoint as grad_checkpoint
34
+ import tiktoken
35
+ import deepspeed
36
+ from rich.progress import (
37
+ Progress, BarColumn, TextColumn, TimeRemainingColumn, TimeElapsedColumn,
38
+ SpinnerColumn, MofNCompleteColumn,
39
+ )
40
+ from rich.console import Console
41
+
42
+ console = Console()
43
+ IST = ZoneInfo("Asia/Kolkata")
44
+
45
+ # ═════════════════════════════════════════════════════════════════════════════
46
+ # 1. LOAD DATA (memory-mapped .bin files from prepare_data.py)
47
+ # ═════════════════════════════════════════════════════════════════════════════
48
+
49
+ DATA_DIR = "data"
50
+ for split in ("train", "val", "test"):
51
+ path = os.path.join(DATA_DIR, f"{split}.bin")
52
+ if not os.path.exists(path):
53
+ raise FileNotFoundError(
54
+ f"\n[ERROR] '{path}' not found.\n"
55
+ "Run python prepare_data.py first."
56
+ )
57
+
58
+ train_data = np.memmap(os.path.join(DATA_DIR, "train.bin"), dtype=np.uint16, mode="r")
59
+ val_data = np.memmap(os.path.join(DATA_DIR, "val.bin"), dtype=np.uint16, mode="r")
60
+ test_data = np.memmap(os.path.join(DATA_DIR, "test.bin"), dtype=np.uint16, mode="r")
61
+
62
+ print("Dataset loaded (memory-mapped)")
63
+ print(f" Train : {len(train_data):>12,} tokens")
64
+ print(f" Val : {len(val_data):>12,} tokens")
65
+ print(f" Test : {len(test_data):>12,} tokens")
66
+ print()
67
+
68
+ # ═════════════════════════════════════════════════════════════════════════════
69
+ # 2. TOKENISER – GPT-2 BPE (matches prepare_data.py)
70
+ # ═════════════════════════════════════════════════════════════════════════════
71
+
72
+ enc = tiktoken.get_encoding("gpt2")
73
+ vocab_size = enc.n_vocab # 50 257
74
+
75
+ def encode(text: str) -> list:
76
+ return enc.encode_ordinary(text)
77
+
78
+ def decode(ids: list) -> str:
79
+ return enc.decode(ids)
80
+
81
+ print(f"Tokeniser : GPT-2 BPE (vocab {vocab_size:,})")
82
+ print()
83
+
84
+ # ═════════════════════════════════════════════════════════════════════════════
85
+ # 3. HYPERPARAMETERS
86
+ # ═════════════════════════════════════════════════════════════════════════════
87
+
88
+ BLOCK_SIZE = 64 # context window (tokens)
89
+ MICRO_BATCH = 2 # samples per GPU forward pass (managed by DeepSpeed)
90
+ GRAD_ACCUM = 8 # accumulate before optimizer step → eff. batch 16
91
+ EMBED_DIM = 512 # model width
92
+ NUM_HEADS = 8 # attention heads
93
+ NUM_LAYERS = 8 # transformer blocks
94
+ NUM_EXPERTS = 4 # expert FFNs per MoE layer
95
+ TOP_K = 2 # experts activated per token
96
+ FFN_DIM = EMBED_DIM * 4 # 2 048 (expert hidden dim)
97
+ DROPOUT = 0.1
98
+ LR = 1.5e-4 # peak learning rate
99
+ WARMUP_STEPS = 500 # linear warmup
100
+ MAX_ITERS = 50000 # total optimiser steps
101
+ EVAL_EVERY = 100
102
+ EVAL_ITERS = 50
103
+ AUX_LOSS_W = 0.01 # load-balancing auxiliary loss weight
104
+ GRAD_CLIP = 1.0
105
+ CHECKPOINT_DIR = "checkpoints" # directory for saving checkpoints
106
+
107
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
108
+ DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
109
+
110
+ ALLOW_TF32 = os.environ.get("ALLOW_TF32", "1") == "1"
111
+ USE_TORCH_COMPILE = os.environ.get("USE_TORCH_COMPILE", "0") == "1"
112
+ USE_ACTIVATION_CHECKPOINT = os.environ.get("USE_ACTIVATION_CHECKPOINT", "1") == "1"
113
+
114
+ if DEVICE == "cuda":
115
+ # Throughput-oriented CUDA settings.
116
+ torch.backends.cudnn.benchmark = True
117
+ torch.backends.cuda.matmul.allow_tf32 = ALLOW_TF32
118
+ torch.backends.cudnn.allow_tf32 = ALLOW_TF32
119
+ torch.set_float32_matmul_precision("high")
120
+
121
+ print(f"Device : {DEVICE.upper()}")
122
+ print(f"Precision : {'BF16 (DeepSpeed)' if DTYPE == torch.bfloat16 else 'FP32'}")
123
+ print(f"Effective batch (default) : {MICRO_BATCH * GRAD_ACCUM}")
124
+ print(f"TF32 : {'ON' if (DEVICE == 'cuda' and ALLOW_TF32) else 'OFF'}")
125
+ print(f"Torch compile : {'ON' if USE_TORCH_COMPILE else 'OFF'}")
126
+ print(f"Act checkpoint : {'ON' if USE_ACTIVATION_CHECKPOINT else 'OFF'}")
127
+ print()
128
+
129
+ # ═════════════════════════════════════════════════════════════════════════════
130
+ # 4. DATA LOADER
131
+ # ═════════════════════════════════════════════════════════════════════════════
132
+
133
+ def get_batch(split="train"):
134
+ data = {"train": train_data, "val": val_data, "test": test_data}[split]
135
+ ix = np.random.randint(0, len(data) - BLOCK_SIZE, size=(MICRO_BATCH,))
136
+ x = np.stack([data[i : i + BLOCK_SIZE ].astype(np.int64) for i in ix])
137
+ y = np.stack([data[i+1 : i + BLOCK_SIZE + 1].astype(np.int64) for i in ix])
138
+ return torch.from_numpy(x).to(DEVICE), torch.from_numpy(y).to(DEVICE)
139
+
140
+ # ═════════════════════════════════════════════════════════════════════════════
141
+ # 5. MODEL — Mixture-of-Experts GPT (~0.5 B params)
142
+ # ═════════════════════════════════════════════════════════════════════════════
143
+
144
+ class CausalSelfAttention(nn.Module):
145
+ """Multi-head causal self-attention with fused QKV projection."""
146
+
147
+ def __init__(self):
148
+ super().__init__()
149
+ self.n_heads = NUM_HEADS
150
+ self.head_dim = EMBED_DIM // NUM_HEADS
151
+ self.qkv = nn.Linear(EMBED_DIM, 3 * EMBED_DIM, bias=False)
152
+ self.proj = nn.Linear(EMBED_DIM, EMBED_DIM, bias=False)
153
+ self.attn_drop = nn.Dropout(DROPOUT)
154
+ self.proj_drop = nn.Dropout(DROPOUT)
155
+ self.register_buffer(
156
+ "mask",
157
+ torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE))
158
+ .view(1, 1, BLOCK_SIZE, BLOCK_SIZE),
159
+ )
160
+
161
+ def forward(self, x):
162
+ B, T, C = x.shape
163
+ qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
164
+ q, k, v = qkv.permute(2, 0, 3, 1, 4) # each (B, H, T, D)
165
+
166
+ att = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
167
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
168
+ att = F.softmax(att.float(), dim=-1).to(x.dtype) # softmax in fp32
169
+ att = self.attn_drop(att)
170
+
171
+ out = (att @ v).transpose(1, 2).reshape(B, T, C)
172
+ return self.proj_drop(self.proj(out))
173
+
174
+
175
+ class ExpertFFN(nn.Module):
176
+ """Single expert: two-layer FFN with GELU."""
177
+
178
+ def __init__(self):
179
+ super().__init__()
180
+ self.w1 = nn.Linear(EMBED_DIM, FFN_DIM)
181
+ self.w2 = nn.Linear(FFN_DIM, EMBED_DIM)
182
+ self.act = nn.GELU()
183
+ self.drop = nn.Dropout(DROPOUT)
184
+
185
+ def forward(self, x):
186
+ return self.drop(self.w2(self.act(self.w1(x))))
187
+
188
+
189
+ class MoELayer(nn.Module):
190
+ """
191
+ Mixture-of-Experts: routes each token to TOP_K of NUM_EXPERTS FFNs.
192
+ Includes Switch-Transformer-style load-balancing auxiliary loss.
193
+ """
194
+
195
+ def __init__(self):
196
+ super().__init__()
197
+ self.router = nn.Linear(EMBED_DIM, NUM_EXPERTS, bias=False)
198
+ self.experts = nn.ModuleList([ExpertFFN() for _ in range(NUM_EXPERTS)])
199
+
200
+ def forward(self, x):
201
+ B, T, C = x.shape
202
+ flat = x.reshape(-1, C) # (N, C)
203
+ N = flat.shape[0]
204
+
205
+ # ── routing ──
206
+ logits = self.router(flat) # (N, E)
207
+ probs = F.softmax(logits.float(), dim=-1) # fp32 for stability
208
+
209
+ top_w, top_i = torch.topk(probs, TOP_K, dim=-1) # (N, K)
210
+ top_w = (top_w / top_w.sum(dim=-1, keepdim=True)).to(x.dtype)
211
+
212
+ # ── load-balancing loss ──
213
+ one_hot = F.one_hot(top_i, NUM_EXPERTS).float().sum(dim=1) # (N, E)
214
+ f = one_hot.mean(dim=0)
215
+ P = probs.mean(dim=0)
216
+ aux_loss = NUM_EXPERTS * (f * P).sum()
217
+
218
+ # ── dispatch to experts ──
219
+ out = torch.zeros_like(flat)
220
+ for i, expert in enumerate(self.experts):
221
+ mask = (top_i == i).any(dim=-1) # (N,)
222
+ if not mask.any():
223
+ continue
224
+ tokens = flat[mask] # (n_i, C)
225
+ e_out = expert(tokens) # (n_i, C)
226
+ match = (top_i[mask] == i).to(x.dtype) # (n_i, K)
227
+ weights = (top_w[mask] * match).sum(-1, keepdim=True)
228
+ out[mask] += weights * e_out
229
+
230
+ return out.reshape(B, T, C), aux_loss
231
+
232
+
233
+ class TransformerBlock(nn.Module):
234
+ """Pre-norm Transformer block: Attention + MoE, with residuals."""
235
+
236
+ def __init__(self):
237
+ super().__init__()
238
+ self.ln1 = nn.LayerNorm(EMBED_DIM)
239
+ self.attn = CausalSelfAttention()
240
+ self.ln2 = nn.LayerNorm(EMBED_DIM)
241
+ self.moe = MoELayer()
242
+
243
+ def forward(self, x):
244
+ x = x + self.attn(self.ln1(x))
245
+ moe_out, aux = self.moe(self.ln2(x))
246
+ x = x + moe_out
247
+ return x, aux
248
+
249
+
250
+ class MoEGPT(nn.Module):
251
+ """
252
+ Full MoE-GPT model (~521 M parameters, ~180 M active per token).
253
+
254
+ 1. Token + positional embeddings
255
+ 2. 12 × Transformer blocks (self-attention + MoE FFN)
256
+ 3. Final layer-norm → linear head (weight-tied with token embedding)
257
+ """
258
+
259
+ def __init__(self):
260
+ super().__init__()
261
+ self.tok_emb = nn.Embedding(vocab_size, EMBED_DIM)
262
+ self.pos_emb = nn.Embedding(BLOCK_SIZE, EMBED_DIM)
263
+ self.drop = nn.Dropout(DROPOUT)
264
+ self.blocks = nn.ModuleList([TransformerBlock() for _ in range(NUM_LAYERS)])
265
+ self.ln_f = nn.LayerNorm(EMBED_DIM)
266
+ self.head = nn.Linear(EMBED_DIM, vocab_size, bias=False)
267
+
268
+ # Weight tying saves ~38 M params and improves training
269
+ self.head.weight = self.tok_emb.weight
270
+ self._init_weights()
271
+
272
+ def _init_weights(self):
273
+ """GPT-2-style init with scaled residual projections."""
274
+ for name, p in self.named_parameters():
275
+ if p.dim() >= 2:
276
+ nn.init.normal_(p, mean=0.0, std=0.02)
277
+ elif "bias" in name:
278
+ nn.init.zeros_(p)
279
+ scale = (2 * NUM_LAYERS) ** -0.5
280
+ for block in self.blocks:
281
+ nn.init.normal_(block.attn.proj.weight, mean=0.0, std=0.02 * scale)
282
+ for expert in block.moe.experts:
283
+ nn.init.normal_(expert.w2.weight, mean=0.0, std=0.02 * scale)
284
+
285
+ def forward(self, idx, targets=None):
286
+ B, T = idx.shape
287
+ x = self.drop(
288
+ self.tok_emb(idx) + self.pos_emb(torch.arange(T, device=idx.device))
289
+ )
290
+
291
+ total_aux = 0.0
292
+ for block in self.blocks:
293
+ if self.training and USE_ACTIVATION_CHECKPOINT:
294
+ x, aux = grad_checkpoint(block, x, use_reentrant=False)
295
+ else:
296
+ x, aux = block(x)
297
+ total_aux = total_aux + aux
298
+
299
+ logits = self.head(self.ln_f(x))
300
+
301
+ loss = None
302
+ if targets is not None:
303
+ ce = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
304
+ loss = ce + AUX_LOSS_W * total_aux
305
+ return logits, loss
306
+
307
+ @torch.no_grad()
308
+ def generate(self, prompt: str, max_new_tokens=200, temperature=0.8):
309
+ self.eval()
310
+ ids = encode(prompt)
311
+ idx = torch.tensor([ids], dtype=torch.long, device=DEVICE)
312
+
313
+ for _ in range(max_new_tokens):
314
+ ctx = idx[:, -BLOCK_SIZE:]
315
+ logits, _ = self(ctx)
316
+ logits = logits[:, -1, :].float() / temperature
317
+ probs = F.softmax(logits, dim=-1)
318
+ nxt = torch.multinomial(probs, 1)
319
+ idx = torch.cat([idx, nxt], dim=1)
320
+
321
+ self.train()
322
+ return decode(idx[0].tolist())
323
+
324
+ # ═════════════════════════════════════════════════════════════════════════════
325
+ # 6. CHECKPOINT HELPERS
326
+ # ═════════════════════════════════════════════════════════════════════════════
327
+
328
+ os.makedirs(CHECKPOINT_DIR, exist_ok=True)
329
+
330
+
331
+ def _strip_orig_mod_prefix(state_dict):
332
+ out = {}
333
+ for k, v in state_dict.items():
334
+ if k.startswith("_orig_mod."):
335
+ out[k[len("_orig_mod."):]] = v
336
+ else:
337
+ out[k] = v
338
+ return out
339
+
340
+
341
+ def _add_orig_mod_prefix(state_dict):
342
+ out = {}
343
+ for k, v in state_dict.items():
344
+ if k.startswith("_orig_mod."):
345
+ out[k] = v
346
+ else:
347
+ out[f"_orig_mod.{k}"] = v
348
+ return out
349
+
350
+
351
+ def _align_state_dict_for_model(state_dict, model):
352
+ """Align checkpoint keys with model keys (compiled vs non-compiled)."""
353
+ model_keys = list(model.state_dict().keys())
354
+ if not model_keys:
355
+ return state_dict
356
+
357
+ model_has_orig = model_keys[0].startswith("_orig_mod.")
358
+ ckpt_keys = list(state_dict.keys())
359
+ ckpt_has_orig = bool(ckpt_keys) and ckpt_keys[0].startswith("_orig_mod.")
360
+
361
+ if model_has_orig and not ckpt_has_orig:
362
+ return _add_orig_mod_prefix(state_dict)
363
+ if not model_has_orig and ckpt_has_orig:
364
+ return _strip_orig_mod_prefix(state_dict)
365
+ return state_dict
366
+
367
+ def save_checkpoint(step, model, train_loss, val_loss, path):
368
+ """Save model and training state to disk."""
369
+ # DeepSpeed handles checkpointing, but we also save basic metadata
370
+ model_state = model.state_dict() if hasattr(model, "state_dict") else None
371
+ if model_state is not None:
372
+ # Store canonical keys so checkpoints are reusable across compile modes.
373
+ model_state = _strip_orig_mod_prefix(model_state)
374
+
375
+ checkpoint = {
376
+ "step": step,
377
+ "train_loss": train_loss,
378
+ "val_loss": val_loss,
379
+ "model_state": model_state,
380
+ }
381
+ torch.save(checkpoint, path)
382
+
383
+ def load_checkpoint(path, model):
384
+ """Load checkpoint and return the step to resume from."""
385
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
386
+ if ckpt.get("model_state"):
387
+ model_state = _align_state_dict_for_model(ckpt["model_state"], model)
388
+ model.load_state_dict(model_state)
389
+ print(f" Resumed from step {ckpt['step']} "
390
+ f"(train {ckpt['train_loss']:.4f}, val {ckpt['val_loss']:.4f})")
391
+ return ckpt["step"], ckpt["val_loss"]
392
+
393
+ # ═════════════════════════════════════════════════════════════════════════════
394
+ # 7. LEARNING-RATE SCHEDULE (linear warmup → cosine decay to 10 %)
395
+ # ═════════════════════════════════════════════════════════════════════════════
396
+
397
+ def get_lr(step):
398
+ if step < WARMUP_STEPS:
399
+ return LR * step / WARMUP_STEPS
400
+ progress = (step - WARMUP_STEPS) / max(1, MAX_ITERS - WARMUP_STEPS)
401
+ return LR * 0.1 + 0.5 * LR * 0.9 * (1 + math.cos(math.pi * progress))
402
+
403
+
404
+ def get_eta_clock(progress, task_id):
405
+ """Return estimated finish time in IST as HH:MM."""
406
+ remaining = progress.tasks[task_id].time_remaining
407
+ if remaining is None:
408
+ return "--:--"
409
+ end_at = datetime.now(IST) + timedelta(seconds=max(0.0, remaining))
410
+ return end_at.strftime("%H:%M")
411
+
412
+ # ═════════════════════════════════════════════════════════════════════════════
413
+ # 8. LOSS ESTIMATION
414
+ # ═════════════════════════════════════════════════════════════════════════════
415
+
416
+ @torch.no_grad()
417
+ def estimate_loss(model):
418
+ model.eval()
419
+ out = {}
420
+ for split in ("train", "val"):
421
+ losses = []
422
+ for _ in range(EVAL_ITERS):
423
+ x, y = get_batch(split)
424
+ with torch.amp.autocast("cuda", dtype=torch.bfloat16, enabled=(DEVICE == "cuda" and DTYPE == torch.bfloat16)):
425
+ _, loss = model(x, y)
426
+ losses.append(loss.item())
427
+ out[split] = sum(losses) / len(losses)
428
+ model.train()
429
+ return out
430
+
431
+ # ═════════════════════════════════════════════════════════════════════════════
432
+ # 9. DEEPSPEED INITIALIZATION & TRAINING
433
+ # ═════════════════════════════════════════════════════════════════════════════
434
+
435
+ if __name__ == "__main__":
436
+ # Clear cache
437
+ if DEVICE == "cuda":
438
+ torch.cuda.empty_cache()
439
+
440
+ # Initialize model
441
+ model = MoEGPT()
442
+ if DEVICE == "cuda" and USE_TORCH_COMPILE:
443
+ # Full-graph mode is too brittle for dynamic shapes; max-autotune is a good speed/compat compromise.
444
+ model = torch.compile(model, mode="max-autotune", fullgraph=False)
445
+
446
+ n_total = sum(p.numel() for p in model.parameters())
447
+ _expert1 = sum(p.numel() for p in model.blocks[0].moe.experts[0].parameters())
448
+ n_active = n_total - _expert1 * (NUM_EXPERTS - TOP_K) * NUM_LAYERS
449
+
450
+ print(f"Total parameters : {n_total:>14,}")
451
+ print(f"Active per token : {n_active:>14,}")
452
+ print()
453
+
454
+ # Ensure LOCAL_RANK is set so DeepSpeed's sanity checks pass when running
455
+ # this script directly with `python main_deepspeed.py` (single-GPU).
456
+ if "LOCAL_RANK" not in os.environ:
457
+ os.environ["LOCAL_RANK"] = "0"
458
+
459
+ # Initialize torch.distributed if not already initialized (single-process setup).
460
+ if not torch.distributed.is_initialized():
461
+ init_kwargs = {
462
+ "backend": "nccl" if DEVICE == "cuda" else "gloo",
463
+ "init_method": "tcp://127.0.0.1:29500",
464
+ "rank": 0,
465
+ "world_size": 1,
466
+ }
467
+ if DEVICE == "cuda":
468
+ init_kwargs["device_id"] = torch.device("cuda", int(os.environ.get("LOCAL_RANK", 0)))
469
+ torch.distributed.init_process_group(**init_kwargs)
470
+
471
+ # Load DeepSpeed config (launcher may provide a mode-specific config path).
472
+ ds_config_path = os.environ.get("DS_CONFIG_PATH", "ds_config.json")
473
+ with open(ds_config_path) as f:
474
+ ds_config = json.load(f)
475
+
476
+ # Keep runtime batch settings aligned with DeepSpeed config.
477
+ MICRO_BATCH = int(ds_config.get("train_micro_batch_size_per_gpu", MICRO_BATCH))
478
+ GRAD_ACCUM = int(ds_config.get("gradient_accumulation_steps", GRAD_ACCUM))
479
+
480
+ print(f"DeepSpeed micro-batch : {MICRO_BATCH}")
481
+ print(f"DeepSpeed grad accum : {GRAD_ACCUM}")
482
+ print(f"DeepSpeed eff batch : {MICRO_BATCH * GRAD_ACCUM}")
483
+ print()
484
+
485
+ # Initialize DeepSpeed engine
486
+ model_engine, optimizer, _, lr_scheduler = deepspeed.initialize(
487
+ args=type("args", (), {"local_rank": int(os.environ.get("LOCAL_RANK", 0))})(),
488
+ model=model,
489
+ model_parameters=model.parameters(),
490
+ config=ds_config,
491
+ dist_init_required=False,
492
+ )
493
+
494
+ print(f"[DeepSpeed] Initialized with ZeRO Stage {ds_config['zero_optimization']['stage']}")
495
+ print(f"[DeepSpeed] Device: {model_engine.device}")
496
+ print()
497
+
498
+ # Training state
499
+ start_step = 0
500
+ best_val = float("inf")
501
+ prev_val = None
502
+
503
+ # Auto-resume from latest checkpoint (with NaN guard)
504
+ latest_ckpt = os.path.join(CHECKPOINT_DIR, "latest.pt")
505
+ if os.path.exists(latest_ckpt):
506
+ try:
507
+ _c = torch.load(latest_ckpt, map_location="cpu", weights_only=False)
508
+ if _c.get("val_loss") != _c.get("val_loss") or _c.get("train_loss") != _c.get("train_loss"):
509
+ print("Checkpoint has NaN losses — deleting and starting fresh")
510
+ os.remove(latest_ckpt)
511
+ else:
512
+ print("Checkpoint found — resuming …")
513
+ start_step, best_val = load_checkpoint(latest_ckpt, model)
514
+ print()
515
+ except Exception as e:
516
+ print(f"Checkpoint corrupted ({e}) — starting fresh")
517
+ else:
518
+ print("No checkpoint found — starting fresh training")
519
+ print()
520
+
521
+ # ─────────────────────────────────────────────────────────────────────────
522
+ # TRAINING LOOP
523
+ # ─────────────────────────────────────────────────────────────────────────
524
+
525
+ console.rule("[bold green]Training started (DeepSpeed)")
526
+ print()
527
+
528
+ with Progress(
529
+ SpinnerColumn(),
530
+ TextColumn("[bold blue]{task.description}"),
531
+ BarColumn(bar_width=30),
532
+ MofNCompleteColumn(),
533
+ TextColumn("•"),
534
+ TimeElapsedColumn(),
535
+ TextColumn("•"),
536
+ TimeRemainingColumn(),
537
+ TextColumn("•"),
538
+ TextColumn("[green]ETA {task.fields[end_clock]} IST"),
539
+ TextColumn("•"),
540
+ TextColumn("[yellow]loss {task.fields[train_loss]}"),
541
+ TextColumn("[cyan]val {task.fields[val_loss]}"),
542
+ TextColumn("[magenta]lr {task.fields[lr]}"),
543
+ console=console,
544
+ refresh_per_second=4,
545
+ ) as progress:
546
+ total_steps = MAX_ITERS - start_step
547
+ task = progress.add_task(
548
+ "Training", total=total_steps,
549
+ train_loss="--.----", val_loss="--.----", lr="--.------", end_clock="--:--",
550
+ )
551
+
552
+ step = start_step
553
+ micro_loss_sum = 0.0
554
+ micro_loss_count = 0
555
+ total_micro_steps = MAX_ITERS * GRAD_ACCUM
556
+
557
+ for micro_step in range(start_step * GRAD_ACCUM + 1, total_micro_steps + 1):
558
+
559
+ lr = get_lr(step + 1)
560
+ for param_group in optimizer.param_groups:
561
+ param_group["lr"] = lr
562
+
563
+ x, y = get_batch("train")
564
+ _, loss = model_engine(x, y)
565
+
566
+ model_engine.backward(loss)
567
+ is_boundary = model_engine.is_gradient_accumulation_boundary()
568
+ model_engine.step()
569
+
570
+ micro_loss_sum += loss.item()
571
+ micro_loss_count += 1
572
+
573
+ if not is_boundary:
574
+ continue
575
+
576
+ step += 1
577
+ accum_loss = micro_loss_sum / max(1, micro_loss_count)
578
+ micro_loss_sum = 0.0
579
+ micro_loss_count = 0
580
+
581
+ progress.update(
582
+ task,
583
+ advance=1,
584
+ train_loss=f"{accum_loss:.4f}",
585
+ lr=f"{lr:.6f}",
586
+ end_clock=get_eta_clock(progress, task),
587
+ )
588
+
589
+ if step % EVAL_EVERY == 0:
590
+ losses = estimate_loss(model_engine.module if hasattr(model_engine, "module") else model_engine)
591
+ if prev_val is None:
592
+ trend = "init"
593
+ delta = 0.0
594
+ else:
595
+ delta = losses["val"] - prev_val
596
+ if delta < -1e-6:
597
+ trend = "improving"
598
+ elif delta > 1e-6:
599
+ trend = "worse"
600
+ else:
601
+ trend = "flat"
602
+ prev_val = losses["val"]
603
+
604
+ progress.update(
605
+ task,
606
+ train_loss=f"{losses['train']:.4f}",
607
+ val_loss=f"{losses['val']:.4f}",
608
+ lr=f"{lr:.6f}",
609
+ end_clock=get_eta_clock(progress, task),
610
+ )
611
+ progress.console.print(
612
+ f" [bold]Step {step:>5}[/] │ "
613
+ f"[yellow]Train {losses['train']:.4f}[/] │ "
614
+ f"[cyan]Val {losses['val']:.4f} ({trend}, Δ {delta:+.4f})[/] │ "
615
+ f"[magenta]LR {lr:.6f}[/]"
616
+ )
617
+
618
+ # Save checkpoints
619
+ save_checkpoint(
620
+ step, model_engine.module if hasattr(model_engine, 'module') else model_engine,
621
+ losses["train"], losses["val"],
622
+ os.path.join(CHECKPOINT_DIR, "latest.pt"),
623
+ )
624
+ if losses["val"] < best_val:
625
+ best_val = losses["val"]
626
+ save_checkpoint(
627
+ step, model_engine.module if hasattr(model_engine, 'module') else model_engine,
628
+ losses["train"], losses["val"],
629
+ os.path.join(CHECKPOINT_DIR, "best.pt"),
630
+ )
631
+ progress.console.print(
632
+ f" [bold green]★ New best val loss: {best_val:.4f} (saved best.pt)[/]"
633
+ )
634
+
635
+ if step >= MAX_ITERS:
636
+ break
637
+
638
+ print()
639
+ console.rule("[bold green]Training complete")
640
+ print()
641
+
642
+ # ─────────────────────────────────────────────────────────────────────────
643
+ # TEST EVALUATION
644
+ # ─────────────────────────────────────────────────────────────────────────
645
+
646
+ model_eval = model_engine.module if hasattr(model_engine, 'module') else model_engine
647
+ model_eval.eval()
648
+ test_losses = []
649
+ with torch.no_grad():
650
+ for _ in range(EVAL_ITERS):
651
+ x, y = get_batch("test")
652
+ _, loss = model_eval(x, y)
653
+ test_losses.append(loss.item())
654
+ test_loss = sum(test_losses) / len(test_losses)
655
+ print(f"Test loss : {test_loss:.4f}")
656
+ print()
657
+
658
+ # ─────────────────────────────────────────────────────────────────────────
659
+ # TEXT GENERATION
660
+ # ─────────────────────────────────────────────────────────────────────────
661
+
662
+ prompts = [
663
+ "The history of",
664
+ "Scientists have discovered",
665
+ "In the early twentieth century",
666
+ ]
667
+
668
+ print("=" * 60)
669
+ print("Generated Text Samples")
670
+ print("=" * 60)
671
+
672
+ for prompt in prompts:
673
+ output = model_eval.generate(prompt, max_new_tokens=120, temperature=0.7)
674
+ print(f"\nPrompt : \"{prompt}\"")
675
+ print(f"Output : {output.strip()}")
676
+ print()
677
+
678
+ # ─────────────────────────────────────────────────────────────────────────
679
+ # INTERACTIVE MODE
680
+ # ─────────────────────────────────────────────────────────────────────────
681
+
682
+ print("=" * 60)
683
+ print("Interactive Mode (type 'quit' to exit)")
684
+ print("=" * 60)
685
+
686
+ while True:
687
+ try:
688
+ prompt = input("\nEnter a prompt: ").strip()
689
+ except (EOFError, KeyboardInterrupt):
690
+ break
691
+ if not prompt or prompt.lower() == "quit":
692
+ break
693
+ output = model_eval.generate(prompt, max_new_tokens=150, temperature=0.8)
694
+ print(f"\n{output.strip()}")
695
+
696
+ print("\nGoodbye!")
prepare_data.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ prepare_data.py
4
+ ===============
5
+ Build tokenized binary files for training from Cosmopedia using streaming.
6
+
7
+ Outputs:
8
+ data/train.bin
9
+ data/val.bin
10
+ data/test.bin
11
+
12
+ Dataset:
13
+ HuggingFaceTB/cosmopedia (streaming)
14
+ """
15
+
16
+ import os
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ import tiktoken
21
+ from datasets import load_dataset
22
+ from tqdm.auto import tqdm
23
+
24
+ # Local project cache for reproducibility and resume behavior.
25
+ os.environ.setdefault("HF_HOME", "./hf_cache")
26
+ os.environ.setdefault("HF_DATASETS_CACHE", "./hf_cache/datasets")
27
+ os.environ.setdefault("HF_HUB_CACHE", "./hf_cache/hub")
28
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
29
+
30
+ DATA_DIR = Path("data")
31
+ DATA_DIR.mkdir(parents=True, exist_ok=True)
32
+
33
+ CACHE_DIR = "./hf_cache"
34
+ DATASET_NAME = "HuggingFaceTB/cosmopedia"
35
+ DATASET_CONFIG = os.environ.get("DATASET_CONFIG", "stories")
36
+
37
+ # Stream only first N rows by default, matching your requested pattern.
38
+ MAX_EXAMPLES = int(os.environ.get("MAX_EXAMPLES", "1000000"))
39
+
40
+ # Deterministic split from one stream: 98% train, 1% val, 1% test.
41
+ TRAIN_FRAC = float(os.environ.get("TRAIN_FRAC", "0.98"))
42
+ VAL_FRAC = float(os.environ.get("VAL_FRAC", "0.01"))
43
+
44
+ # Flush chunks to disk to keep RAM bounded.
45
+ FLUSH_TOKENS = int(os.environ.get("FLUSH_TOKENS", "2000000"))
46
+
47
+ enc = tiktoken.get_encoding("gpt2")
48
+ EOT = enc.eot_token
49
+
50
+
51
+ def extract_text(row: dict) -> str:
52
+ """Extract a usable text field across possible Cosmopedia schemas."""
53
+ if "text" in row and isinstance(row["text"], str):
54
+ return row["text"].strip()
55
+ if "content" in row and isinstance(row["content"], str):
56
+ return row["content"].strip()
57
+
58
+ parts = []
59
+ for key in ("prompt", "question", "instruction", "input", "answer", "response", "output"):
60
+ val = row.get(key)
61
+ if isinstance(val, str) and val.strip():
62
+ parts.append(val.strip())
63
+
64
+ return "\n\n".join(parts).strip()
65
+
66
+
67
+ def encode_text(text: str):
68
+ ids = enc.encode_ordinary(text)
69
+ ids.append(EOT)
70
+ return ids
71
+
72
+
73
+ def flush_tokens(fp, buffer_tokens):
74
+ if not buffer_tokens:
75
+ return 0
76
+ arr = np.asarray(buffer_tokens, dtype=np.uint16)
77
+ arr.tofile(fp)
78
+ n = int(arr.size)
79
+ buffer_tokens.clear()
80
+ return n
81
+
82
+
83
+ def pick_split(i: int, total: int) -> str:
84
+ train_cut = int(total * TRAIN_FRAC)
85
+ val_cut = train_cut + int(total * VAL_FRAC)
86
+ if i < train_cut:
87
+ return "train"
88
+ if i < val_cut:
89
+ return "val"
90
+ return "test"
91
+
92
+
93
+ if __name__ == "__main__":
94
+ print("Loading Cosmopedia (streaming)...")
95
+
96
+ # This follows your requested style while allowing MAX_EXAMPLES override.
97
+ dataset = load_dataset(
98
+ DATASET_NAME,
99
+ DATASET_CONFIG,
100
+ split="train",
101
+ streaming=True,
102
+ cache_dir=CACHE_DIR,
103
+ )
104
+
105
+ out_paths = {
106
+ "train": DATA_DIR / "train.bin",
107
+ "val": DATA_DIR / "val.bin",
108
+ "test": DATA_DIR / "test.bin",
109
+ }
110
+
111
+ for p in out_paths.values():
112
+ if p.exists():
113
+ p.unlink()
114
+
115
+ buffers = {"train": [], "val": [], "test": []}
116
+ counts_examples = {"train": 0, "val": 0, "test": 0}
117
+ counts_tokens = {"train": 0, "val": 0, "test": 0}
118
+
119
+ with open(out_paths["train"], "ab") as f_train, open(out_paths["val"], "ab") as f_val, open(out_paths["test"], "ab") as f_test:
120
+ fps = {"train": f_train, "val": f_val, "test": f_test}
121
+
122
+ progress = tqdm(total=MAX_EXAMPLES, desc="Streaming+Encoding", unit="doc")
123
+ for i, row in enumerate(dataset):
124
+ if i >= MAX_EXAMPLES:
125
+ break
126
+
127
+ text = extract_text(row)
128
+ if not text:
129
+ progress.update(1)
130
+ continue
131
+
132
+ split = pick_split(i, MAX_EXAMPLES)
133
+ toks = encode_text(text)
134
+ buffers[split].extend(toks)
135
+ counts_examples[split] += 1
136
+
137
+ # Flush all splits together so val/test are written even if their
138
+ # individual buffers never reach FLUSH_TOKENS (they're only 1% each).
139
+ if len(buffers["train"]) >= FLUSH_TOKENS:
140
+ for s in ("train", "val", "test"):
141
+ counts_tokens[s] += flush_tokens(fps[s], buffers[s])
142
+
143
+ progress.update(1)
144
+
145
+ progress.close()
146
+
147
+ for split in ("train", "val", "test"):
148
+ counts_tokens[split] += flush_tokens(fps[split], buffers[split])
149
+
150
+ print("\nDone.")
151
+ for split in ("train", "val", "test"):
152
+ print(f"{split:>5}: {counts_examples[split]:>10,} docs -> {counts_tokens[split]:>12,} tokens")
153
+ print(f"Saved files in: {DATA_DIR.resolve()}")
push_to_hf.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Upload Tiny-GPT checkpoints to Hugging Face Hub.
4
+
5
+ Usage:
6
+ python push_to_hf.py --repo-id yourname/Tiny-GPT
7
+ python push_to_hf.py --repo-id yourname/Tiny-GPT --checkpoint checkpoints/best.pt
8
+
9
+ Auth:
10
+ Set HF_TOKEN env var or run: huggingface-cli login
11
+ """
12
+
13
+ import argparse
14
+ import os
15
+ import sys
16
+ from pathlib import Path
17
+
18
+
19
+ def main():
20
+ parser = argparse.ArgumentParser(description="Upload Tiny-GPT checkpoints to HF Hub")
21
+ parser.add_argument("--repo-id", required=True, help="HF repo id, e.g. yourname/Tiny-GPT")
22
+ parser.add_argument(
23
+ "--checkpoint",
24
+ default="checkpoints/best.pt",
25
+ help="Primary checkpoint path to upload (default: checkpoints/best.pt)",
26
+ )
27
+ parser.add_argument(
28
+ "--latest-checkpoint",
29
+ default="checkpoints/latest.pt",
30
+ help="Optional latest checkpoint path to upload (default: checkpoints/latest.pt)",
31
+ )
32
+ parser.add_argument(
33
+ "--private",
34
+ action="store_true",
35
+ help="Create private repo instead of public",
36
+ )
37
+ parser.add_argument(
38
+ "--message",
39
+ default="Upload Tiny-GPT checkpoints",
40
+ help="Commit message for HF Hub",
41
+ )
42
+ parser.add_argument(
43
+ "--token",
44
+ default=None,
45
+ help="HF token (or set HF_TOKEN env var)",
46
+ )
47
+ args = parser.parse_args()
48
+
49
+ token = args.token or os.environ.get("HF_TOKEN")
50
+
51
+ try:
52
+ from huggingface_hub import HfApi, upload_file
53
+ except ImportError:
54
+ print("[ERROR] Missing dependency: huggingface_hub")
55
+ print("[ERROR] Install with: pip install huggingface_hub")
56
+ sys.exit(1)
57
+
58
+ checkpoint = Path(args.checkpoint)
59
+ latest_checkpoint = Path(args.latest_checkpoint)
60
+
61
+ if not checkpoint.exists():
62
+ print(f"[ERROR] Checkpoint not found: {checkpoint}")
63
+ sys.exit(1)
64
+
65
+ api = HfApi(token=token)
66
+
67
+ # Create repo if it does not exist yet.
68
+ api.create_repo(repo_id=args.repo_id, repo_type="model", private=args.private, exist_ok=True)
69
+
70
+ print(f"Uploading {checkpoint} -> {args.repo_id}/best.pt")
71
+ upload_file(
72
+ path_or_fileobj=str(checkpoint),
73
+ path_in_repo="best.pt",
74
+ repo_id=args.repo_id,
75
+ repo_type="model",
76
+ token=token,
77
+ commit_message=args.message,
78
+ )
79
+
80
+ if latest_checkpoint.exists():
81
+ print(f"Uploading {latest_checkpoint} -> {args.repo_id}/latest.pt")
82
+ upload_file(
83
+ path_or_fileobj=str(latest_checkpoint),
84
+ path_in_repo="latest.pt",
85
+ repo_id=args.repo_id,
86
+ repo_type="model",
87
+ token=token,
88
+ commit_message=args.message,
89
+ )
90
+
91
+ print("Done. Model checkpoints are now on Hugging Face Hub.")
92
+
93
+
94
+ if __name__ == "__main__":
95
+ main()
run.py ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ run.py – Inference script for MoE-GPT
3
+ ========================================
4
+ Run the trained model anytime to generate text.
5
+
6
+ Usage:
7
+ python run.py # Interactive mode
8
+ python run.py --prompt "text" # Generate from prompt
9
+ python run.py --file data.txt # Generate continuations from file
10
+
11
+ No training — just inference from the best checkpoint.
12
+ """
13
+
14
+ import os
15
+ import sys
16
+ import argparse
17
+ from pathlib import Path
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import tiktoken
23
+
24
+ # ═════════════════════════════════════════════════════════════════════════════
25
+ # CONFIGURATION (must match main.py)
26
+ # ═════════════════════════════════════════════════════════════════════════════
27
+
28
+ BLOCK_SIZE = 128
29
+ EMBED_DIM = 768
30
+ NUM_HEADS = 12
31
+ NUM_LAYERS = 12
32
+ NUM_EXPERTS = 8
33
+ TOP_K = 2
34
+ FFN_DIM = EMBED_DIM * 4
35
+ DROPOUT = 0.1
36
+ CHECKPOINT_DIR = "checkpoints"
37
+
38
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
39
+ DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
40
+
41
+ # ═════════════════════════════════════════════════════════════════════════════
42
+ # 1. TOKENISER – GPT-2 BPE
43
+ # ═════════════════════════════════════════════════════════════════════════════
44
+
45
+ enc = tiktoken.get_encoding("gpt2")
46
+ vocab_size = enc.n_vocab # 50,257
47
+
48
+
49
+ def encode(text: str) -> list:
50
+ return enc.encode_ordinary(text)
51
+
52
+
53
+ def decode(ids: list) -> str:
54
+ return enc.decode(ids)
55
+
56
+
57
+ def _infer_num_heads(embed_dim: int) -> int:
58
+ """Infer a reasonable attention head count from embedding size."""
59
+ for h in (16, 12, 8, 6, 4, 2, 1):
60
+ if embed_dim % h == 0:
61
+ return h
62
+ return 1
63
+
64
+
65
+ def apply_model_config_from_state_dict(state_dict: dict):
66
+ """Update global model hyperparameters to match checkpoint tensors."""
67
+ global BLOCK_SIZE, EMBED_DIM, NUM_HEADS, NUM_LAYERS, NUM_EXPERTS, FFN_DIM, vocab_size
68
+
69
+ if "tok_emb.weight" not in state_dict or "pos_emb.weight" not in state_dict:
70
+ return
71
+
72
+ vocab_size = state_dict["tok_emb.weight"].shape[0]
73
+ EMBED_DIM = state_dict["tok_emb.weight"].shape[1]
74
+ BLOCK_SIZE = state_dict["pos_emb.weight"].shape[0]
75
+
76
+ layer_ids = []
77
+ for k in state_dict.keys():
78
+ if k.startswith("blocks."):
79
+ parts = k.split(".")
80
+ if len(parts) > 1 and parts[1].isdigit():
81
+ layer_ids.append(int(parts[1]))
82
+ if layer_ids:
83
+ NUM_LAYERS = max(layer_ids) + 1
84
+
85
+ router_key = "blocks.0.moe.router.weight"
86
+ if router_key in state_dict:
87
+ NUM_EXPERTS = state_dict[router_key].shape[0]
88
+
89
+ ffn_key = "blocks.0.moe.experts.0.w1.weight"
90
+ if ffn_key in state_dict:
91
+ FFN_DIM = state_dict[ffn_key].shape[0]
92
+ else:
93
+ FFN_DIM = EMBED_DIM * 4
94
+
95
+ NUM_HEADS = _infer_num_heads(EMBED_DIM)
96
+
97
+
98
+ def _get_model_state_from_checkpoint(ckpt: dict) -> dict:
99
+ """Support both training checkpoint formats used in this repo."""
100
+ if "model_state" in ckpt:
101
+ return ckpt["model_state"]
102
+ if "model" in ckpt:
103
+ return ckpt["model"]
104
+ raise KeyError("Checkpoint does not contain 'model_state' or 'model'")
105
+
106
+
107
+ def resolve_checkpoint_path(
108
+ checkpoint_path=None,
109
+ hf_repo=None,
110
+ hf_filename="best.pt",
111
+ hf_revision=None,
112
+ hf_token=None,
113
+ ):
114
+ """Resolve a local checkpoint path, optionally downloading from HF Hub."""
115
+ if hf_repo:
116
+ try:
117
+ from huggingface_hub import hf_hub_download
118
+ except ImportError:
119
+ print("[ERROR] huggingface_hub is required for --hf-repo")
120
+ print("[ERROR] Install it with: pip install huggingface_hub")
121
+ sys.exit(1)
122
+
123
+ cache_dir = Path("hf_cache") / "hub"
124
+ cache_dir.mkdir(parents=True, exist_ok=True)
125
+ return hf_hub_download(
126
+ repo_id=hf_repo,
127
+ filename=hf_filename,
128
+ revision=hf_revision,
129
+ token=hf_token,
130
+ cache_dir=str(cache_dir),
131
+ )
132
+
133
+ if checkpoint_path is None:
134
+ checkpoint_path = os.path.join(CHECKPOINT_DIR, "best.pt")
135
+ return checkpoint_path
136
+
137
+
138
+ # ═════════════════════════════════════════════════════════════════════════════
139
+ # 2. MODEL ARCHITECTURE (minimal — see main.py for full details)
140
+ # ══════════════��══════════════════════════════════════════════════════════════
141
+
142
+
143
+ class CausalSelfAttention(nn.Module):
144
+ def __init__(self):
145
+ super().__init__()
146
+ self.n_heads = NUM_HEADS
147
+ self.head_dim = EMBED_DIM // NUM_HEADS
148
+ self.qkv = nn.Linear(EMBED_DIM, 3 * EMBED_DIM, bias=False)
149
+ self.proj = nn.Linear(EMBED_DIM, EMBED_DIM, bias=False)
150
+ self.attn_drop = nn.Dropout(DROPOUT)
151
+ self.proj_drop = nn.Dropout(DROPOUT)
152
+ self.register_buffer(
153
+ "mask",
154
+ torch.tril(torch.ones(BLOCK_SIZE, BLOCK_SIZE)).view(
155
+ 1, 1, BLOCK_SIZE, BLOCK_SIZE
156
+ ),
157
+ )
158
+
159
+ def forward(self, x):
160
+ B, T, C = x.shape
161
+ qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim)
162
+ q, k, v = qkv.permute(2, 0, 3, 1, 4)
163
+
164
+ att = (q @ k.transpose(-2, -1)) * (self.head_dim**-0.5)
165
+ att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float("-inf"))
166
+ att = F.softmax(att.float(), dim=-1).to(x.dtype)
167
+ att = self.attn_drop(att)
168
+
169
+ out = (att @ v).transpose(1, 2).reshape(B, T, C)
170
+ return self.proj_drop(self.proj(out))
171
+
172
+
173
+ class ExpertFFN(nn.Module):
174
+ def __init__(self):
175
+ super().__init__()
176
+ self.w1 = nn.Linear(EMBED_DIM, FFN_DIM)
177
+ self.w2 = nn.Linear(FFN_DIM, EMBED_DIM)
178
+ self.act = nn.GELU()
179
+ self.drop = nn.Dropout(DROPOUT)
180
+
181
+ def forward(self, x):
182
+ return self.drop(self.w2(self.act(self.w1(x))))
183
+
184
+
185
+ class MoELayer(nn.Module):
186
+ def __init__(self):
187
+ super().__init__()
188
+ self.router = nn.Linear(EMBED_DIM, NUM_EXPERTS, bias=False)
189
+ self.experts = nn.ModuleList([ExpertFFN() for _ in range(NUM_EXPERTS)])
190
+
191
+ def forward(self, x):
192
+ B, T, C = x.shape
193
+ flat = x.reshape(-1, C)
194
+ N = flat.shape[0]
195
+
196
+ logits = self.router(flat)
197
+ probs = F.softmax(logits.float(), dim=-1)
198
+
199
+ top_w, top_i = torch.topk(probs, TOP_K, dim=-1)
200
+ top_w = (top_w / top_w.sum(dim=-1, keepdim=True)).to(x.dtype)
201
+
202
+ out = torch.zeros_like(flat)
203
+ for i, expert in enumerate(self.experts):
204
+ mask = (top_i == i).any(dim=-1)
205
+ if not mask.any():
206
+ continue
207
+ tokens = flat[mask]
208
+ e_out = expert(tokens)
209
+ match = (top_i[mask] == i).to(x.dtype)
210
+ weights = (top_w[mask] * match).sum(-1, keepdim=True)
211
+ out[mask] += weights * e_out
212
+
213
+ return out.reshape(B, T, C)
214
+
215
+
216
+ class TransformerBlock(nn.Module):
217
+ def __init__(self):
218
+ super().__init__()
219
+ self.ln1 = nn.LayerNorm(EMBED_DIM)
220
+ self.attn = CausalSelfAttention()
221
+ self.ln2 = nn.LayerNorm(EMBED_DIM)
222
+ self.moe = MoELayer()
223
+
224
+ def forward(self, x):
225
+ x = x + self.attn(self.ln1(x))
226
+ x = x + self.moe(self.ln2(x))
227
+ return x
228
+
229
+
230
+ class MoEGPT(nn.Module):
231
+ def __init__(self):
232
+ super().__init__()
233
+ self.tok_emb = nn.Embedding(vocab_size, EMBED_DIM)
234
+ self.pos_emb = nn.Embedding(BLOCK_SIZE, EMBED_DIM)
235
+ self.drop = nn.Dropout(DROPOUT)
236
+ self.blocks = nn.ModuleList([TransformerBlock() for _ in range(NUM_LAYERS)])
237
+ self.ln_f = nn.LayerNorm(EMBED_DIM)
238
+ self.head = nn.Linear(EMBED_DIM, vocab_size, bias=False)
239
+ self.head.weight = self.tok_emb.weight
240
+ self._init_weights()
241
+
242
+ def _init_weights(self):
243
+ for name, p in self.named_parameters():
244
+ if p.dim() >= 2:
245
+ nn.init.normal_(p, mean=0.0, std=0.02)
246
+ elif "bias" in name:
247
+ nn.init.zeros_(p)
248
+ scale = (2 * NUM_LAYERS) ** -0.5
249
+ for block in self.blocks:
250
+ nn.init.normal_(block.attn.proj.weight, mean=0.0, std=0.02 * scale)
251
+ for expert in block.moe.experts:
252
+ nn.init.normal_(expert.w2.weight, mean=0.0, std=0.02 * scale)
253
+
254
+ def forward(self, idx, targets=None):
255
+ B, T = idx.shape
256
+ x = self.drop(
257
+ self.tok_emb(idx) + self.pos_emb(torch.arange(T, device=idx.device))
258
+ )
259
+
260
+ for block in self.blocks:
261
+ x = block(x)
262
+
263
+ logits = self.head(self.ln_f(x))
264
+
265
+ loss = None
266
+ if targets is not None:
267
+ loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
268
+ return logits, loss
269
+
270
+ @torch.no_grad()
271
+ def generate(
272
+ self,
273
+ prompt: str,
274
+ max_new_tokens=200,
275
+ temperature=0.8,
276
+ top_k=None,
277
+ top_p=0.9,
278
+ ):
279
+ """
280
+ Generate text from a prompt.
281
+
282
+ Args:
283
+ prompt: Starting text
284
+ max_new_tokens: How many tokens to generate
285
+ temperature: Higher = more random (0.5-1.5 typical)
286
+ top_k: Keep only top-k most likely tokens (None = disabled)
287
+ top_p: Nucleus sampling threshold (0.9 typical)
288
+ """
289
+ self.eval()
290
+ ids = torch.tensor([encode(prompt)], dtype=torch.long, device=DEVICE)
291
+
292
+ for _ in range(max_new_tokens):
293
+ ctx = ids[:, -BLOCK_SIZE:]
294
+ with torch.amp.autocast(
295
+ "cuda", dtype=torch.bfloat16, enabled=(DTYPE == torch.bfloat16)
296
+ ):
297
+ logits, _ = self(ctx)
298
+ logits = logits[:, -1, :].float() / temperature
299
+
300
+ # Top-K filtering
301
+ if top_k is not None:
302
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
303
+ logits[indices_to_remove] = float("-inf")
304
+
305
+ # Top-P (nucleus) filtering
306
+ if top_p < 1.0:
307
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
308
+ cumsum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
309
+ sorted_indices_to_remove = cumsum_probs > top_p
310
+ sorted_indices_to_remove[..., 0] = False
311
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
312
+ logits[:, indices_to_remove] = float("-inf")
313
+
314
+ probs = F.softmax(logits, dim=-1)
315
+ nxt = torch.multinomial(probs, 1)
316
+ ids = torch.cat([ids, nxt], dim=1)
317
+
318
+ self.train()
319
+ return decode(ids[0].tolist())
320
+
321
+
322
+ # ═════════════════════════════════════════════════════════════════════════════
323
+ # 3. LOAD MODEL FROM CHECKPOINT
324
+ # ═════════════════════════════════════════════════════════════════════════════
325
+
326
+
327
+ def load_model(
328
+ checkpoint_path=None,
329
+ hf_repo=None,
330
+ hf_filename="best.pt",
331
+ hf_revision=None,
332
+ hf_token=None,
333
+ ):
334
+ """Load the trained model from checkpoint."""
335
+ checkpoint_path = resolve_checkpoint_path(
336
+ checkpoint_path=checkpoint_path,
337
+ hf_repo=hf_repo,
338
+ hf_filename=hf_filename,
339
+ hf_revision=hf_revision,
340
+ hf_token=hf_token,
341
+ )
342
+
343
+ if not os.path.exists(checkpoint_path):
344
+ print(f"[ERROR] Checkpoint not found at: {checkpoint_path}")
345
+ print(f"[ERROR] Have you run 'python main.py' yet?")
346
+ sys.exit(1)
347
+
348
+ print(f"Loading model from {checkpoint_path} ...", end=" ", flush=True)
349
+ ckpt = torch.load(checkpoint_path, map_location=DEVICE, weights_only=False)
350
+ model_state = _get_model_state_from_checkpoint(ckpt)
351
+ apply_model_config_from_state_dict(model_state)
352
+
353
+ model = MoEGPT()
354
+ model = model.to(dtype=DTYPE, device=DEVICE)
355
+ model.load_state_dict(model_state)
356
+ model.eval()
357
+
358
+ print("✓")
359
+ print(f" Device: {DEVICE.upper()}")
360
+ print(f" Dtype: {DTYPE}")
361
+ print(
362
+ f" Model: block={BLOCK_SIZE}, emb={EMBED_DIM}, heads={NUM_HEADS}, "
363
+ f"layers={NUM_LAYERS}, experts={NUM_EXPERTS}, ffn={FFN_DIM}"
364
+ )
365
+ print()
366
+
367
+ return model
368
+
369
+
370
+ # ═════════════════════════════════════════════════════════════════════════════
371
+ # 4. INTERACTIVE & BATCH INFERENCE
372
+ # ═════════════════════════════════════════════════════════════════════════════
373
+
374
+
375
+ def interactive_mode(model):
376
+ """Interactive text generation."""
377
+ print("=" * 70)
378
+ print("Interactive Mode – Type 'quit' to exit")
379
+ print("=" * 70)
380
+ print()
381
+ print("Commands:")
382
+ print(" quit – Exit")
383
+ print(" /temp 0.7 – Set temperature (default 0.8)")
384
+ print(" /len 100 – Set max tokens (default 200)")
385
+ print(" /topk 40 – Set top-k (default None = disabled)")
386
+ print(" /topp 0.9 – Set top-p (default 0.9)")
387
+ print()
388
+
389
+ temperature = 0.8
390
+ max_tokens = 200
391
+ top_k = None
392
+ top_p = 0.9
393
+
394
+ while True:
395
+ try:
396
+ user_input = input("Prompt > ").strip()
397
+ except (EOFError, KeyboardInterrupt):
398
+ break
399
+
400
+ if not user_input:
401
+ continue
402
+
403
+ if user_input.lower() == "quit":
404
+ break
405
+
406
+ # Handle commands
407
+ if user_input.startswith("/"):
408
+ parts = user_input.split()
409
+ if len(parts) == 2:
410
+ cmd, val = parts[0][1:], parts[1]
411
+ try:
412
+ if cmd == "temp":
413
+ temperature = float(val)
414
+ print(f"Temperature set to {temperature}")
415
+ elif cmd == "len":
416
+ max_tokens = int(val)
417
+ print(f"Max tokens set to {max_tokens}")
418
+ elif cmd == "topk":
419
+ top_k = int(val)
420
+ print(f"Top-k set to {top_k}")
421
+ elif cmd == "topp":
422
+ top_p = float(val)
423
+ print(f"Top-p set to {top_p}")
424
+ except ValueError:
425
+ print(f"Invalid value for {cmd}")
426
+ continue
427
+
428
+ print()
429
+ with torch.no_grad():
430
+ output = model.generate(
431
+ user_input,
432
+ max_new_tokens=max_tokens,
433
+ temperature=temperature,
434
+ top_k=top_k,
435
+ top_p=top_p,
436
+ )
437
+ print(output)
438
+ print()
439
+
440
+ print("\nGoodbye!")
441
+
442
+
443
+ def batch_generation(model, prompts, max_tokens=200, temperature=0.8):
444
+ """Generate from a list of prompts."""
445
+ print("=" * 70)
446
+ print("Batch Generation")
447
+ print("=" * 70)
448
+ print()
449
+
450
+ with torch.no_grad():
451
+ for i, prompt in enumerate(prompts, 1):
452
+ print(f"[{i}/{len(prompts)}] Prompt: {prompt}")
453
+ output = model.generate(
454
+ prompt,
455
+ max_new_tokens=max_tokens,
456
+ temperature=temperature,
457
+ )
458
+ print(f"Output: {output}\n")
459
+
460
+
461
+ # ═════════════════════════════════════════════════════════════════════════════
462
+ # 5. MAIN
463
+ # ═════════════════════════════════════════════════════════════════════════════
464
+
465
+
466
+ def main():
467
+ parser = argparse.ArgumentParser(
468
+ description="Generate text using trained MoE-GPT model",
469
+ formatter_class=argparse.RawDescriptionHelpFormatter,
470
+ epilog="""
471
+ Examples:
472
+ python run.py # Interactive mode
473
+ python run.py --prompt "Hello world" # Generate from prompt
474
+ python run.py --prompts file.txt # Batch from file (one per line)
475
+ python run.py --checkpoint custom.pt # Use custom checkpoint
476
+ python run.py --hf-repo user/Tiny-GPT # Load from Hugging Face Hub
477
+ """,
478
+ )
479
+ parser.add_argument(
480
+ "--prompt",
481
+ type=str,
482
+ help="Single prompt to generate from",
483
+ )
484
+ parser.add_argument(
485
+ "--prompts",
486
+ type=str,
487
+ help="File with prompts (one per line) for batch generation",
488
+ )
489
+ parser.add_argument(
490
+ "--checkpoint",
491
+ type=str,
492
+ default=None,
493
+ help="Path to checkpoint (default: checkpoints/best.pt)",
494
+ )
495
+ parser.add_argument(
496
+ "--hf-repo",
497
+ type=str,
498
+ default=None,
499
+ help="Hugging Face repo id (e.g. user/Tiny-GPT). If set, download checkpoint from HF Hub.",
500
+ )
501
+ parser.add_argument(
502
+ "--hf-filename",
503
+ type=str,
504
+ default="best.pt",
505
+ help="Filename inside HF repo (default: best.pt)",
506
+ )
507
+ parser.add_argument(
508
+ "--hf-revision",
509
+ type=str,
510
+ default=None,
511
+ help="HF branch/tag/commit to download from",
512
+ )
513
+ parser.add_argument(
514
+ "--hf-token",
515
+ type=str,
516
+ default=None,
517
+ help="HF token for private repos (or use HF_TOKEN env var)",
518
+ )
519
+ parser.add_argument(
520
+ "--max-tokens",
521
+ type=int,
522
+ default=200,
523
+ help="Max tokens to generate (default: 200)",
524
+ )
525
+ parser.add_argument(
526
+ "--temperature",
527
+ type=float,
528
+ default=0.8,
529
+ help="Sampling temperature (default: 0.8)",
530
+ )
531
+ parser.add_argument(
532
+ "--top-k",
533
+ type=int,
534
+ default=None,
535
+ help="Top-k sampling (default: disabled)",
536
+ )
537
+ parser.add_argument(
538
+ "--top-p",
539
+ type=float,
540
+ default=0.9,
541
+ help="Top-p/nucleus sampling (default: 0.9)",
542
+ )
543
+
544
+ args = parser.parse_args()
545
+
546
+ if args.hf_repo and args.checkpoint:
547
+ print("[ERROR] Use either --checkpoint or --hf-repo, not both.")
548
+ sys.exit(1)
549
+
550
+ hf_token = args.hf_token or os.environ.get("HF_TOKEN")
551
+
552
+ # Load model
553
+ model = load_model(
554
+ checkpoint_path=args.checkpoint,
555
+ hf_repo=args.hf_repo,
556
+ hf_filename=args.hf_filename,
557
+ hf_revision=args.hf_revision,
558
+ hf_token=hf_token,
559
+ )
560
+
561
+ # Dispatch to appropriate mode
562
+ if args.prompt:
563
+ # Single prompt
564
+ print(f"Prompt: {args.prompt}\n")
565
+ with torch.no_grad():
566
+ output = model.generate(
567
+ args.prompt,
568
+ max_new_tokens=args.max_tokens,
569
+ temperature=args.temperature,
570
+ top_k=args.top_k,
571
+ top_p=args.top_p,
572
+ )
573
+ print(output)
574
+
575
+ elif args.prompts:
576
+ # Batch from file
577
+ if not os.path.exists(args.prompts):
578
+ print(f"[ERROR] File not found: {args.prompts}")
579
+ sys.exit(1)
580
+ with open(args.prompts) as f:
581
+ prompts = [line.strip() for line in f if line.strip()]
582
+ batch_generation(model, prompts, args.max_tokens, args.temperature)
583
+
584
+ else:
585
+ # Interactive mode
586
+ interactive_mode(model)
587
+
588
+
589
+ if __name__ == "__main__":
590
+ main()
train_deepspeed.sh ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # train_deepspeed.sh - Launch training with DeepSpeed ZeRO-Infinity
3
+
4
+ set -e
5
+
6
+ echo "╔════════════════════════════════════════════════════════════════╗"
7
+ echo "║ Tiny-GPT: DeepSpeed ZeRO-3 Training (CPU/NVMe Offloading) ║"
8
+ echo "╚════════════════════════════════════════════════════════════════╝"
9
+ echo
10
+
11
+ # Colors
12
+ GREEN='\033[0;32m'
13
+ BLUE='\033[0;34m'
14
+ YELLOW='\033[1;33m'
15
+ NC='\033[0m'
16
+
17
+ # Training mode: passive (default) or aggressive
18
+ TRAIN_MODE=$(echo "${TRAIN_MODE:-passive}" | tr '[:upper:]' '[:lower:]')
19
+ export TRAIN_MODE
20
+ if [ "$TRAIN_MODE" != "passive" ] && [ "$TRAIN_MODE" != "aggressive" ]; then
21
+ echo -e "${YELLOW}!${NC} Invalid TRAIN_MODE='$TRAIN_MODE'. Use 'passive' or 'aggressive'."
22
+ exit 1
23
+ fi
24
+
25
+ # Check DeepSpeed installation
26
+ echo -e "${BLUE}[1/4]${NC} Checking dependencies..."
27
+ python -c "import deepspeed" 2>/dev/null && echo -e " ${GREEN}✓${NC} DeepSpeed installed" || {
28
+ echo -e " Installing DeepSpeed..."
29
+ pip install deepspeed -q
30
+ echo -e " ${GREEN}✓${NC} DeepSpeed installed"
31
+ }
32
+ echo -e " ${GREEN}✓${NC} All dependencies ready"
33
+ echo
34
+
35
+ # Checkpoint handling (keep by default for auto-resume)
36
+ echo -e "${BLUE}[2/4]${NC} Checkpoint handling..."
37
+ mkdir -p checkpoints
38
+ if [ "${RESET_CHECKPOINTS:-0}" = "1" ]; then
39
+ rm -f checkpoints/*.pt
40
+ echo -e " ${GREEN}✓${NC} Checkpoints cleared (RESET_CHECKPOINTS=1)"
41
+ else
42
+ if ls checkpoints/*.pt >/dev/null 2>&1; then
43
+ echo -e " ${GREEN}✓${NC} Existing checkpoints found (auto-resume enabled)"
44
+ else
45
+ echo -e " ${GREEN}✓${NC} No existing checkpoints (fresh run)"
46
+ fi
47
+ fi
48
+ echo
49
+
50
+ # Verify dataset
51
+ echo -e "${BLUE}[3/4]${NC} Verifying dataset..."
52
+ if [ -f "data/train.bin" ] && [ -f "data/val.bin" ] && [ -f "data/test.bin" ]; then
53
+ echo -e " ${GREEN}✓${NC} Dataset ready"
54
+ else
55
+ echo -e " ${YELLOW}!${NC} Dataset not found. Run: python prepare_data.py"
56
+ exit 1
57
+ fi
58
+ echo
59
+
60
+ # Get number of GPUs
61
+ NUM_GPUS=$(nvidia-smi --list-gpus 2>/dev/null | wc -l)
62
+ if [ -z "$NUM_GPUS" ] || [ "$NUM_GPUS" -eq 0 ]; then
63
+ NUM_GPUS=1
64
+ fi
65
+ export NUM_GPUS
66
+
67
+ # Build active DeepSpeed config based on TRAIN_MODE
68
+ python - <<'PY'
69
+ import json
70
+ import os
71
+ import multiprocessing
72
+
73
+ mode = os.environ.get("TRAIN_MODE", "passive").lower()
74
+ num_gpus = int(os.environ.get("NUM_GPUS", "1"))
75
+
76
+ with open("ds_config.json") as f:
77
+ cfg = json.load(f)
78
+
79
+ zero = cfg.setdefault("zero_optimization", {})
80
+ off_opt = zero.setdefault("offload_optimizer", {"device": "cpu"})
81
+ act_ckpt = cfg.setdefault("activation_checkpointing", {})
82
+
83
+ if mode == "aggressive":
84
+ # Higher-throughput profile: larger batches and no CPU checkpointing.
85
+ cfg["train_micro_batch_size_per_gpu"] = 2
86
+ cfg["gradient_accumulation_steps"] = 8
87
+ cfg["train_batch_size"] = cfg["train_micro_batch_size_per_gpu"] * cfg["gradient_accumulation_steps"] * max(1, num_gpus)
88
+ off_opt["pin_memory"] = True
89
+ zero["reduce_bucket_size"] = 2e6
90
+ act_ckpt["cpu_checkpointing"] = False
91
+ else:
92
+ # Low-resource profile (current stable baseline).
93
+ cfg["train_micro_batch_size_per_gpu"] = 1
94
+ cfg["gradient_accumulation_steps"] = 4
95
+ cfg["train_batch_size"] = cfg["train_micro_batch_size_per_gpu"] * cfg["gradient_accumulation_steps"] * max(1, num_gpus)
96
+ off_opt["pin_memory"] = False
97
+ zero["reduce_bucket_size"] = 1e6
98
+ act_ckpt["cpu_checkpointing"] = True
99
+
100
+ with open("ds_config.active.json", "w") as f:
101
+ json.dump(cfg, f, indent=2)
102
+ PY
103
+
104
+ # Show configuration
105
+ echo -e "${BLUE}[4/4]${NC} Launching DeepSpeed training..."
106
+ python -c "
107
+ import json
108
+ with open('ds_config.active.json') as f:
109
+ cfg = json.load(f)
110
+ print(' DeepSpeed Configuration:')
111
+ print(f\" • Mode: ${TRAIN_MODE}\")
112
+ print(f\" • ZeRO Stage: {cfg['zero_optimization']['stage']}\")
113
+ print(f\" • Optimizer Offload: {cfg['zero_optimization']['offload_optimizer']['device']}\")
114
+ param_offload = cfg['zero_optimization'].get('offload_param', {}).get('device', 'none')
115
+ print(f\" • Parameter Offload: {param_offload}\")
116
+ print(f\" • Mixed Precision: {'bfloat16' if cfg.get('bf16', {}).get('enabled') else 'float32'}\")
117
+ print(f\" • Micro Batch: {cfg['train_micro_batch_size_per_gpu']}\")
118
+ print(f\" • Grad Accum: {cfg['gradient_accumulation_steps']}\")
119
+ print(f\" • Batch Size: {cfg['train_batch_size']}\")
120
+ print()
121
+ "
122
+
123
+ # Launch training with DeepSpeed
124
+ echo -e "${YELLOW}Starting DeepSpeed training...${NC}"
125
+ echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
126
+ echo
127
+
128
+ # Skip CUDA version mismatch check (system CUDA >= PyTorch CUDA is fine)
129
+ export DS_SKIP_CUDA_CHECK=1
130
+ if [ "$TRAIN_MODE" = "aggressive" ]; then
131
+ CPU_THREADS=$(nproc)
132
+ export MAX_JOBS=$CPU_THREADS
133
+ export OMP_NUM_THREADS=$CPU_THREADS
134
+ export MKL_NUM_THREADS=$CPU_THREADS
135
+ export ALLOW_TF32=1
136
+ export USE_TORCH_COMPILE=${USE_TORCH_COMPILE:-0}
137
+ export USE_ACTIVATION_CHECKPOINT=0
138
+ else
139
+ export MAX_JOBS=1
140
+ export OMP_NUM_THREADS=1
141
+ export MKL_NUM_THREADS=1
142
+ export ALLOW_TF32=1
143
+ export USE_TORCH_COMPILE=${USE_TORCH_COMPILE:-0}
144
+ export USE_ACTIVATION_CHECKPOINT=1
145
+ fi
146
+
147
+ # Main script reads this active config path.
148
+ export DS_CONFIG_PATH="ds_config.active.json"
149
+
150
+ # Launch with deepspeed
151
+ deepspeed --num_gpus $NUM_GPUS main_deepspeed.py
152
+
153
+ echo
154
+ echo "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━"
155
+ echo -e "${GREEN}Training complete!${NC}"
156
+ echo
157
+ echo "Check results:"
158
+ echo " • Checkpoints: ls -lh checkpoints/"
159
+ echo " • Generate: python run.py"
160
+ echo " • Best model: checkpoints/best.pt"