Safetensors
English
qwen2
xiao23451 commited on
Commit
c6a012a
·
verified ·
1 Parent(s): 2319f92

trainer_state.json and zero_to_fp32.py

Browse files
Files changed (2) hide show
  1. trainer_state.json +784 -0
  2. zero_to_fp32.py +674 -0
trainer_state.json ADDED
@@ -0,0 +1,784 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": null,
3
+ "best_metric": null,
4
+ "best_model_checkpoint": null,
5
+ "epoch": 0.1718213058419244,
6
+ "eval_steps": 500,
7
+ "global_step": 50,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "clip_ratio": 0.0,
14
+ "completion_length": 3435.541748046875,
15
+ "epoch": 0.003436426116838488,
16
+ "grad_norm": 0.07630682736635208,
17
+ "kl": 0.0,
18
+ "learning_rate": 2e-08,
19
+ "loss": -0.0067,
20
+ "num_tokens": 1402750.0,
21
+ "reward": -0.28836290910840034,
22
+ "reward_std": 0.5228589028120041,
23
+ "rewards/cosine_scaled_reward": -0.19886894896626472,
24
+ "rewards/format_reward": 0.10937500093132257,
25
+ "step": 1
26
+ },
27
+ {
28
+ "clip_ratio": 0.0,
29
+ "completion_length": 3328.8594360351562,
30
+ "epoch": 0.006872852233676976,
31
+ "grad_norm": 0.0795634388923645,
32
+ "kl": 0.0,
33
+ "learning_rate": 4e-08,
34
+ "loss": -0.0212,
35
+ "num_tokens": 2772778.0,
36
+ "reward": -0.24981184303760529,
37
+ "reward_std": 0.5158329159021378,
38
+ "rewards/cosine_scaled_reward": -0.21214550733566284,
39
+ "rewards/format_reward": 0.1744791679084301,
40
+ "step": 2
41
+ },
42
+ {
43
+ "clip_ratio": 0.0,
44
+ "completion_length": 3388.9662475585938,
45
+ "epoch": 0.010309278350515464,
46
+ "grad_norm": 0.0782911404967308,
47
+ "kl": 0.0006399154663085938,
48
+ "learning_rate": 6e-08,
49
+ "loss": -0.0117,
50
+ "num_tokens": 4160535.0,
51
+ "reward": -0.1925769094377756,
52
+ "reward_std": 0.49376169592142105,
53
+ "rewards/cosine_scaled_reward": -0.18483011424541473,
54
+ "rewards/format_reward": 0.1770833358168602,
55
+ "step": 3
56
+ },
57
+ {
58
+ "clip_ratio": 0.0,
59
+ "completion_length": 3274.7423095703125,
60
+ "epoch": 0.013745704467353952,
61
+ "grad_norm": 0.08107814937829971,
62
+ "kl": 0.0006198883056640625,
63
+ "learning_rate": 8e-08,
64
+ "loss": -0.0191,
65
+ "num_tokens": 5507208.0,
66
+ "reward": -0.11736843098333338,
67
+ "reward_std": 0.5767310410737991,
68
+ "rewards/cosine_scaled_reward": -0.1550383809953928,
69
+ "rewards/format_reward": 0.19270833395421505,
70
+ "step": 4
71
+ },
72
+ {
73
+ "clip_ratio": 0.0,
74
+ "completion_length": 3351.6198120117188,
75
+ "epoch": 0.01718213058419244,
76
+ "grad_norm": 0.0697300136089325,
77
+ "kl": 0.000614166259765625,
78
+ "learning_rate": 1e-07,
79
+ "loss": -0.0236,
80
+ "num_tokens": 6873604.0,
81
+ "reward": -0.17346507962793112,
82
+ "reward_std": 0.5954280346632004,
83
+ "rewards/cosine_scaled_reward": -0.1687637884169817,
84
+ "rewards/format_reward": 0.1640625037252903,
85
+ "step": 5
86
+ },
87
+ {
88
+ "clip_ratio": 0.0,
89
+ "completion_length": 3481.205810546875,
90
+ "epoch": 0.020618556701030927,
91
+ "grad_norm": 0.07721535861492157,
92
+ "kl": 0.0006694793701171875,
93
+ "learning_rate": 1.2e-07,
94
+ "loss": -0.0123,
95
+ "num_tokens": 8293145.0,
96
+ "reward": -0.26027682796120644,
97
+ "reward_std": 0.589268833398819,
98
+ "rewards/cosine_scaled_reward": -0.19784674793481827,
99
+ "rewards/format_reward": 0.13541666977107525,
100
+ "step": 6
101
+ },
102
+ {
103
+ "clip_ratio": 0.0,
104
+ "completion_length": 3353.908935546875,
105
+ "epoch": 0.024054982817869417,
106
+ "grad_norm": 0.08021606504917145,
107
+ "kl": 0.00066375732421875,
108
+ "learning_rate": 1.4e-07,
109
+ "loss": -0.0066,
110
+ "num_tokens": 9663540.0,
111
+ "reward": -0.11294916458427906,
112
+ "reward_std": 0.5502363964915276,
113
+ "rewards/cosine_scaled_reward": -0.15413082763552666,
114
+ "rewards/format_reward": 0.19531250186264515,
115
+ "step": 7
116
+ },
117
+ {
118
+ "clip_ratio": 0.0,
119
+ "completion_length": 3286.4375610351562,
120
+ "epoch": 0.027491408934707903,
121
+ "grad_norm": 0.08099574595689774,
122
+ "kl": 0.0006456375122070312,
123
+ "learning_rate": 1.6e-07,
124
+ "loss": -0.0197,
125
+ "num_tokens": 11007492.0,
126
+ "reward": -0.1736841667443514,
127
+ "reward_std": 0.6177337318658829,
128
+ "rewards/cosine_scaled_reward": -0.19231082685291767,
129
+ "rewards/format_reward": 0.21093750186264515,
130
+ "step": 8
131
+ },
132
+ {
133
+ "clip_ratio": 0.0,
134
+ "completion_length": 3281.7579345703125,
135
+ "epoch": 0.030927835051546393,
136
+ "grad_norm": 0.07414574921131134,
137
+ "kl": 0.0006361007690429688,
138
+ "learning_rate": 1.8e-07,
139
+ "loss": -0.0083,
140
+ "num_tokens": 12354957.0,
141
+ "reward": -0.022597413510084152,
142
+ "reward_std": 0.5933112800121307,
143
+ "rewards/cosine_scaled_reward": -0.12718411907553673,
144
+ "rewards/format_reward": 0.2317708283662796,
145
+ "step": 9
146
+ },
147
+ {
148
+ "clip_ratio": 0.0,
149
+ "completion_length": 3356.3021850585938,
150
+ "epoch": 0.03436426116838488,
151
+ "grad_norm": 0.07732414454221725,
152
+ "kl": 0.0006494522094726562,
153
+ "learning_rate": 2e-07,
154
+ "loss": -0.0166,
155
+ "num_tokens": 13727747.0,
156
+ "reward": -0.15971257165074348,
157
+ "reward_std": 0.6636292636394501,
158
+ "rewards/cosine_scaled_reward": -0.17751253210008144,
159
+ "rewards/format_reward": 0.1953124962747097,
160
+ "step": 10
161
+ },
162
+ {
163
+ "clip_ratio": 0.0,
164
+ "completion_length": 3344.6823120117188,
165
+ "epoch": 0.037800687285223365,
166
+ "grad_norm": 0.07448139786720276,
167
+ "kl": 0.0006856918334960938,
168
+ "learning_rate": 2.1999999999999998e-07,
169
+ "loss": -0.0224,
170
+ "num_tokens": 15100095.0,
171
+ "reward": -0.2629380598664284,
172
+ "reward_std": 0.5202662125229836,
173
+ "rewards/cosine_scaled_reward": -0.2148023582994938,
174
+ "rewards/format_reward": 0.16666666977107525,
175
+ "step": 11
176
+ },
177
+ {
178
+ "clip_ratio": 0.0,
179
+ "completion_length": 3299.7682495117188,
180
+ "epoch": 0.041237113402061855,
181
+ "grad_norm": 0.0820096954703331,
182
+ "kl": 0.0006608963012695312,
183
+ "learning_rate": 2.4e-07,
184
+ "loss": -0.0151,
185
+ "num_tokens": 16449742.0,
186
+ "reward": -0.1303296772239264,
187
+ "reward_std": 0.5848766267299652,
188
+ "rewards/cosine_scaled_reward": -0.1654252614825964,
189
+ "rewards/format_reward": 0.2005208320915699,
190
+ "step": 12
191
+ },
192
+ {
193
+ "clip_ratio": 0.0,
194
+ "completion_length": 3410.3333740234375,
195
+ "epoch": 0.044673539518900345,
196
+ "grad_norm": 0.07228722423315048,
197
+ "kl": 0.0006933212280273438,
198
+ "learning_rate": 2.6e-07,
199
+ "loss": -0.0197,
200
+ "num_tokens": 17843298.0,
201
+ "reward": -0.22578875720500946,
202
+ "reward_std": 0.5726261362433434,
203
+ "rewards/cosine_scaled_reward": -0.18711312860250473,
204
+ "rewards/format_reward": 0.1484375037252903,
205
+ "step": 13
206
+ },
207
+ {
208
+ "clip_ratio": 0.0,
209
+ "completion_length": 3346.8073120117188,
210
+ "epoch": 0.048109965635738834,
211
+ "grad_norm": 0.07395637035369873,
212
+ "kl": 0.0006265640258789062,
213
+ "learning_rate": 2.8e-07,
214
+ "loss": -0.0217,
215
+ "num_tokens": 19217242.0,
216
+ "reward": -0.12510624434798956,
217
+ "reward_std": 0.6174614131450653,
218
+ "rewards/cosine_scaled_reward": -0.16281353868544102,
219
+ "rewards/format_reward": 0.2005208283662796,
220
+ "step": 14
221
+ },
222
+ {
223
+ "clip_ratio": 0.0,
224
+ "completion_length": 3459.5391845703125,
225
+ "epoch": 0.05154639175257732,
226
+ "grad_norm": 0.07217426598072052,
227
+ "kl": 0.0006237030029296875,
228
+ "learning_rate": 3e-07,
229
+ "loss": -0.009,
230
+ "num_tokens": 20625409.0,
231
+ "reward": -0.2986975237727165,
232
+ "reward_std": 0.5028247013688087,
233
+ "rewards/cosine_scaled_reward": -0.20403625443577766,
234
+ "rewards/format_reward": 0.10937500139698386,
235
+ "step": 15
236
+ },
237
+ {
238
+ "clip_ratio": 0.0,
239
+ "completion_length": 3453.9298095703125,
240
+ "epoch": 0.054982817869415807,
241
+ "grad_norm": 0.07679347693920135,
242
+ "kl": 0.0006990432739257812,
243
+ "learning_rate": 3.2e-07,
244
+ "loss": -0.0042,
245
+ "num_tokens": 22037578.0,
246
+ "reward": -0.25092077255249023,
247
+ "reward_std": 0.5149872973561287,
248
+ "rewards/cosine_scaled_reward": -0.18014787510037422,
249
+ "rewards/format_reward": 0.109375,
250
+ "step": 16
251
+ },
252
+ {
253
+ "clip_ratio": 0.0,
254
+ "completion_length": 3302.6407470703125,
255
+ "epoch": 0.058419243986254296,
256
+ "grad_norm": 0.08445706218481064,
257
+ "kl": 0.00072479248046875,
258
+ "learning_rate": 3.4000000000000003e-07,
259
+ "loss": -0.0126,
260
+ "num_tokens": 23384524.0,
261
+ "reward": -0.15077882632613182,
262
+ "reward_std": 0.5705722346901894,
263
+ "rewards/cosine_scaled_reward": -0.17044148780405521,
264
+ "rewards/format_reward": 0.19010416232049465,
265
+ "step": 17
266
+ },
267
+ {
268
+ "clip_ratio": 0.0,
269
+ "completion_length": 3237.401123046875,
270
+ "epoch": 0.061855670103092786,
271
+ "grad_norm": 0.07715122401714325,
272
+ "kl": 0.000637054443359375,
273
+ "learning_rate": 3.6e-07,
274
+ "loss": -0.0111,
275
+ "num_tokens": 24714692.0,
276
+ "reward": 0.02988712175283581,
277
+ "reward_std": 0.6443284898996353,
278
+ "rewards/cosine_scaled_reward": -0.10875435825437307,
279
+ "rewards/format_reward": 0.2473958320915699,
280
+ "step": 18
281
+ },
282
+ {
283
+ "clip_ratio": 0.0,
284
+ "completion_length": 3372.8281860351562,
285
+ "epoch": 0.06529209621993128,
286
+ "grad_norm": 0.08387548476457596,
287
+ "kl": 0.000675201416015625,
288
+ "learning_rate": 3.7999999999999996e-07,
289
+ "loss": -0.0098,
290
+ "num_tokens": 26092178.0,
291
+ "reward": -0.1666894219815731,
292
+ "reward_std": 0.5761949121952057,
293
+ "rewards/cosine_scaled_reward": -0.1588655449450016,
294
+ "rewards/format_reward": 0.15104166697710752,
295
+ "step": 19
296
+ },
297
+ {
298
+ "clip_ratio": 0.0,
299
+ "completion_length": 3322.8881225585938,
300
+ "epoch": 0.06872852233676977,
301
+ "grad_norm": 0.08063298463821411,
302
+ "kl": 0.0006866455078125,
303
+ "learning_rate": 4e-07,
304
+ "loss": -0.0153,
305
+ "num_tokens": 27456091.0,
306
+ "reward": -0.23877026326954365,
307
+ "reward_std": 0.4924458712339401,
308
+ "rewards/cosine_scaled_reward": -0.1896976288408041,
309
+ "rewards/format_reward": 0.14062499906867743,
310
+ "step": 20
311
+ },
312
+ {
313
+ "clip_ratio": 0.0,
314
+ "completion_length": 3381.4323120117188,
315
+ "epoch": 0.07216494845360824,
316
+ "grad_norm": 0.07568925619125366,
317
+ "kl": 0.000701904296875,
318
+ "learning_rate": 4.1999999999999995e-07,
319
+ "loss": -0.0125,
320
+ "num_tokens": 28838129.0,
321
+ "reward": -0.11927792057394981,
322
+ "reward_std": 0.5444926768541336,
323
+ "rewards/cosine_scaled_reward": -0.1585972849279642,
324
+ "rewards/format_reward": 0.19791666604578495,
325
+ "step": 21
326
+ },
327
+ {
328
+ "clip_ratio": 0.0,
329
+ "completion_length": 3260.4401245117188,
330
+ "epoch": 0.07560137457044673,
331
+ "grad_norm": 0.08385315537452698,
332
+ "kl": 0.0006837844848632812,
333
+ "learning_rate": 4.3999999999999997e-07,
334
+ "loss": -0.0108,
335
+ "num_tokens": 30171120.0,
336
+ "reward": -0.10095808655023575,
337
+ "reward_std": 0.587816633284092,
338
+ "rewards/cosine_scaled_reward": -0.15985404793173075,
339
+ "rewards/format_reward": 0.21875,
340
+ "step": 22
341
+ },
342
+ {
343
+ "clip_ratio": 0.0,
344
+ "completion_length": 3373.6173095703125,
345
+ "epoch": 0.07903780068728522,
346
+ "grad_norm": 0.08223242312669754,
347
+ "kl": 0.0006856918334960938,
348
+ "learning_rate": 4.6e-07,
349
+ "loss": -0.0174,
350
+ "num_tokens": 31549449.0,
351
+ "reward": -0.25981973111629486,
352
+ "reward_std": 0.5918650180101395,
353
+ "rewards/cosine_scaled_reward": -0.21194110810756683,
354
+ "rewards/format_reward": 0.16406250186264515,
355
+ "step": 23
356
+ },
357
+ {
358
+ "clip_ratio": 0.0,
359
+ "completion_length": 2977.151123046875,
360
+ "epoch": 0.08247422680412371,
361
+ "grad_norm": 0.09303563088178635,
362
+ "kl": 0.0006265640258789062,
363
+ "learning_rate": 4.8e-07,
364
+ "loss": -0.0219,
365
+ "num_tokens": 32781049.0,
366
+ "reward": 0.14391778409481049,
367
+ "reward_std": 0.6896847039461136,
368
+ "rewards/cosine_scaled_reward": -0.1090306956321001,
369
+ "rewards/format_reward": 0.3619791641831398,
370
+ "step": 24
371
+ },
372
+ {
373
+ "clip_ratio": 0.0,
374
+ "completion_length": 3351.9610595703125,
375
+ "epoch": 0.0859106529209622,
376
+ "grad_norm": 0.08072555065155029,
377
+ "kl": 0.000690460205078125,
378
+ "learning_rate": 5e-07,
379
+ "loss": -0.0183,
380
+ "num_tokens": 34155622.0,
381
+ "reward": -0.14797978568822145,
382
+ "reward_std": 0.5801117792725563,
383
+ "rewards/cosine_scaled_reward": -0.1703440584242344,
384
+ "rewards/format_reward": 0.1927083320915699,
385
+ "step": 25
386
+ },
387
+ {
388
+ "clip_ratio": 0.0,
389
+ "completion_length": 3287.3984985351562,
390
+ "epoch": 0.08934707903780069,
391
+ "grad_norm": 0.0794869065284729,
392
+ "kl": 0.0006580352783203125,
393
+ "learning_rate": 5.2e-07,
394
+ "loss": -0.0171,
395
+ "num_tokens": 35505685.0,
396
+ "reward": -0.08633977361023426,
397
+ "reward_std": 0.5518800467252731,
398
+ "rewards/cosine_scaled_reward": -0.15905530750751495,
399
+ "rewards/format_reward": 0.2317708320915699,
400
+ "step": 26
401
+ },
402
+ {
403
+ "clip_ratio": 0.0,
404
+ "completion_length": 3477.8203125,
405
+ "epoch": 0.09278350515463918,
406
+ "grad_norm": 0.07011042535305023,
407
+ "kl": 0.0006265640258789062,
408
+ "learning_rate": 5.4e-07,
409
+ "loss": -0.0108,
410
+ "num_tokens": 36918598.0,
411
+ "reward": -0.2969396822154522,
412
+ "reward_std": 0.5096745043992996,
413
+ "rewards/cosine_scaled_reward": -0.19534482806921005,
414
+ "rewards/format_reward": 0.09375,
415
+ "step": 27
416
+ },
417
+ {
418
+ "clip_ratio": 0.0,
419
+ "completion_length": 3348.5912475585938,
420
+ "epoch": 0.09621993127147767,
421
+ "grad_norm": 0.0749465823173523,
422
+ "kl": 0.000682830810546875,
423
+ "learning_rate": 5.6e-07,
424
+ "loss": -0.0155,
425
+ "num_tokens": 38287383.0,
426
+ "reward": -0.15528945997357368,
427
+ "reward_std": 0.6058137118816376,
428
+ "rewards/cosine_scaled_reward": -0.1753009781241417,
429
+ "rewards/format_reward": 0.1953125,
430
+ "step": 28
431
+ },
432
+ {
433
+ "clip_ratio": 0.0,
434
+ "completion_length": 3372.822998046875,
435
+ "epoch": 0.09965635738831616,
436
+ "grad_norm": 0.07214676588773727,
437
+ "kl": 0.0006723403930664062,
438
+ "learning_rate": 5.8e-07,
439
+ "loss": -0.0182,
440
+ "num_tokens": 39663997.0,
441
+ "reward": -0.17148404195904732,
442
+ "reward_std": 0.6370180249214172,
443
+ "rewards/cosine_scaled_reward": -0.1847003474831581,
444
+ "rewards/format_reward": 0.1979166679084301,
445
+ "step": 29
446
+ },
447
+ {
448
+ "clip_ratio": 0.0,
449
+ "completion_length": 3380.385498046875,
450
+ "epoch": 0.10309278350515463,
451
+ "grad_norm": 0.07193299382925034,
452
+ "kl": 0.0006971359252929688,
453
+ "learning_rate": 6e-07,
454
+ "loss": -0.0178,
455
+ "num_tokens": 41042879.0,
456
+ "reward": -0.2691922076046467,
457
+ "reward_std": 0.4672084078192711,
458
+ "rewards/cosine_scaled_reward": -0.2114190198481083,
459
+ "rewards/format_reward": 0.15364583767950535,
460
+ "step": 30
461
+ },
462
+ {
463
+ "clip_ratio": 0.0,
464
+ "completion_length": 3371.7891845703125,
465
+ "epoch": 0.10652920962199312,
466
+ "grad_norm": 0.07917183637619019,
467
+ "kl": 0.0007505416870117188,
468
+ "learning_rate": 6.2e-07,
469
+ "loss": -0.0137,
470
+ "num_tokens": 42425156.0,
471
+ "reward": -0.2727040550671518,
472
+ "reward_std": 0.4990428015589714,
473
+ "rewards/cosine_scaled_reward": -0.21057077683508396,
474
+ "rewards/format_reward": 0.1484375,
475
+ "step": 31
476
+ },
477
+ {
478
+ "clip_ratio": 0.0,
479
+ "completion_length": 3415.7943725585938,
480
+ "epoch": 0.10996563573883161,
481
+ "grad_norm": 0.07662034034729004,
482
+ "kl": 0.0006952285766601562,
483
+ "learning_rate": 6.4e-07,
484
+ "loss": -0.019,
485
+ "num_tokens": 43823815.0,
486
+ "reward": -0.28142979741096497,
487
+ "reward_std": 0.5432849302887917,
488
+ "rewards/cosine_scaled_reward": -0.21363156288862228,
489
+ "rewards/format_reward": 0.1458333320915699,
490
+ "step": 32
491
+ },
492
+ {
493
+ "clip_ratio": 0.0,
494
+ "completion_length": 3361.9037475585938,
495
+ "epoch": 0.1134020618556701,
496
+ "grad_norm": 0.07624170184135437,
497
+ "kl": 0.0006971359252929688,
498
+ "learning_rate": 6.6e-07,
499
+ "loss": -0.0144,
500
+ "num_tokens": 45193236.0,
501
+ "reward": -0.1651035211980343,
502
+ "reward_std": 0.6060075983405113,
503
+ "rewards/cosine_scaled_reward": -0.17239550687372684,
504
+ "rewards/format_reward": 0.1796875037252903,
505
+ "step": 33
506
+ },
507
+ {
508
+ "clip_ratio": 0.0,
509
+ "completion_length": 3295.6173095703125,
510
+ "epoch": 0.11683848797250859,
511
+ "grad_norm": 0.07951829582452774,
512
+ "kl": 0.0006999969482421875,
513
+ "learning_rate": 6.800000000000001e-07,
514
+ "loss": -0.0126,
515
+ "num_tokens": 46544973.0,
516
+ "reward": -0.1435023844242096,
517
+ "reward_std": 0.6106936782598495,
518
+ "rewards/cosine_scaled_reward": -0.1889386922121048,
519
+ "rewards/format_reward": 0.23437499813735485,
520
+ "step": 34
521
+ },
522
+ {
523
+ "clip_ratio": 0.0,
524
+ "completion_length": 3305.9193115234375,
525
+ "epoch": 0.12027491408934708,
526
+ "grad_norm": 0.08214527368545532,
527
+ "kl": 0.0006923675537109375,
528
+ "learning_rate": 7e-07,
529
+ "loss": -0.0116,
530
+ "num_tokens": 47892434.0,
531
+ "reward": -0.08280465751886368,
532
+ "reward_std": 0.6393849849700928,
533
+ "rewards/cosine_scaled_reward": -0.1559856589883566,
534
+ "rewards/format_reward": 0.2291666641831398,
535
+ "step": 35
536
+ },
537
+ {
538
+ "clip_ratio": 0.0,
539
+ "completion_length": 3307.6328735351562,
540
+ "epoch": 0.12371134020618557,
541
+ "grad_norm": 0.0808284804224968,
542
+ "kl": 0.000701904296875,
543
+ "learning_rate": 7.2e-07,
544
+ "loss": -0.0228,
545
+ "num_tokens": 49246847.0,
546
+ "reward": -0.16467856615781784,
547
+ "reward_std": 0.639389768242836,
548
+ "rewards/cosine_scaled_reward": -0.18129761889576912,
549
+ "rewards/format_reward": 0.1979166716337204,
550
+ "step": 36
551
+ },
552
+ {
553
+ "clip_ratio": 0.0,
554
+ "completion_length": 3419.182373046875,
555
+ "epoch": 0.12714776632302405,
556
+ "grad_norm": 0.07309407740831375,
557
+ "kl": 0.000705718994140625,
558
+ "learning_rate": 7.4e-07,
559
+ "loss": -0.0113,
560
+ "num_tokens": 50640711.0,
561
+ "reward": -0.16528937965631485,
562
+ "reward_std": 0.6092499941587448,
563
+ "rewards/cosine_scaled_reward": -0.16858218982815742,
564
+ "rewards/format_reward": 0.17187499813735485,
565
+ "step": 37
566
+ },
567
+ {
568
+ "clip_ratio": 0.0,
569
+ "completion_length": 3183.1563720703125,
570
+ "epoch": 0.13058419243986255,
571
+ "grad_norm": 0.0878312960267067,
572
+ "kl": 0.000705718994140625,
573
+ "learning_rate": 7.599999999999999e-07,
574
+ "loss": -0.017,
575
+ "num_tokens": 51947181.0,
576
+ "reward": -0.08579694479703903,
577
+ "reward_std": 0.5650007948279381,
578
+ "rewards/cosine_scaled_reward": -0.17571097612380981,
579
+ "rewards/format_reward": 0.265625,
580
+ "step": 38
581
+ },
582
+ {
583
+ "clip_ratio": 0.0,
584
+ "completion_length": 3337.8516235351562,
585
+ "epoch": 0.13402061855670103,
586
+ "grad_norm": 0.07768604904413223,
587
+ "kl": 0.0007076263427734375,
588
+ "learning_rate": 7.799999999999999e-07,
589
+ "loss": -0.0154,
590
+ "num_tokens": 53317764.0,
591
+ "reward": -0.20015084743499756,
592
+ "reward_std": 0.5532346814870834,
593
+ "rewards/cosine_scaled_reward": -0.18080458976328373,
594
+ "rewards/format_reward": 0.16145833395421505,
595
+ "step": 39
596
+ },
597
+ {
598
+ "clip_ratio": 0.0,
599
+ "completion_length": 3270.3021240234375,
600
+ "epoch": 0.13745704467353953,
601
+ "grad_norm": 0.07652192562818527,
602
+ "kl": 0.0006961822509765625,
603
+ "learning_rate": 8e-07,
604
+ "loss": -0.023,
605
+ "num_tokens": 54659654.0,
606
+ "reward": -0.056294072419404984,
607
+ "reward_std": 0.5959418416023254,
608
+ "rewards/cosine_scaled_reward": -0.1466366145759821,
609
+ "rewards/format_reward": 0.2369791679084301,
610
+ "step": 40
611
+ },
612
+ {
613
+ "clip_ratio": 0.0,
614
+ "completion_length": 3348.822998046875,
615
+ "epoch": 0.140893470790378,
616
+ "grad_norm": 0.07810583710670471,
617
+ "kl": 0.0006856918334960938,
618
+ "learning_rate": 8.199999999999999e-07,
619
+ "loss": -0.0102,
620
+ "num_tokens": 56025612.0,
621
+ "reward": -0.19650722108781338,
622
+ "reward_std": 0.5950964242219925,
623
+ "rewards/cosine_scaled_reward": -0.18419111147522926,
624
+ "rewards/format_reward": 0.17187499813735485,
625
+ "step": 41
626
+ },
627
+ {
628
+ "clip_ratio": 0.0,
629
+ "completion_length": 3268.635498046875,
630
+ "epoch": 0.14432989690721648,
631
+ "grad_norm": 0.07616201788187027,
632
+ "kl": 0.0006771087646484375,
633
+ "learning_rate": 8.399999999999999e-07,
634
+ "loss": -0.0138,
635
+ "num_tokens": 57364348.0,
636
+ "reward": 0.024999878369271755,
637
+ "reward_std": 0.635323241353035,
638
+ "rewards/cosine_scaled_reward": -0.11380214802920818,
639
+ "rewards/format_reward": 0.2526041679084301,
640
+ "step": 42
641
+ },
642
+ {
643
+ "clip_ratio": 0.0,
644
+ "completion_length": 3218.4974975585938,
645
+ "epoch": 0.14776632302405499,
646
+ "grad_norm": 0.0802433118224144,
647
+ "kl": 0.0007371902465820312,
648
+ "learning_rate": 8.599999999999999e-07,
649
+ "loss": -0.021,
650
+ "num_tokens": 58678671.0,
651
+ "reward": -0.09253586642444134,
652
+ "reward_std": 0.595914788544178,
653
+ "rewards/cosine_scaled_reward": -0.1699658501893282,
654
+ "rewards/format_reward": 0.2473958283662796,
655
+ "step": 43
656
+ },
657
+ {
658
+ "clip_ratio": 0.0,
659
+ "completion_length": 3297.1121215820312,
660
+ "epoch": 0.15120274914089346,
661
+ "grad_norm": 0.0863797515630722,
662
+ "kl": 0.0007600784301757812,
663
+ "learning_rate": 8.799999999999999e-07,
664
+ "loss": -0.0196,
665
+ "num_tokens": 60033334.0,
666
+ "reward": -0.04163103736937046,
667
+ "reward_std": 0.5838516503572464,
668
+ "rewards/cosine_scaled_reward": -0.1275863479822874,
669
+ "rewards/format_reward": 0.2135416641831398,
670
+ "step": 44
671
+ },
672
+ {
673
+ "clip_ratio": 0.0,
674
+ "completion_length": 3309.9922485351562,
675
+ "epoch": 0.15463917525773196,
676
+ "grad_norm": 0.08027694374322891,
677
+ "kl": 0.0007181167602539062,
678
+ "learning_rate": 9e-07,
679
+ "loss": -0.0224,
680
+ "num_tokens": 61382983.0,
681
+ "reward": -0.10378132946789265,
682
+ "reward_std": 0.6359454840421677,
683
+ "rewards/cosine_scaled_reward": -0.15996357426047325,
684
+ "rewards/format_reward": 0.2161458320915699,
685
+ "step": 45
686
+ },
687
+ {
688
+ "clip_ratio": 0.0,
689
+ "completion_length": 3202.8724975585938,
690
+ "epoch": 0.15807560137457044,
691
+ "grad_norm": 0.0834212377667427,
692
+ "kl": 0.000774383544921875,
693
+ "learning_rate": 9.2e-07,
694
+ "loss": -0.0124,
695
+ "num_tokens": 62689362.0,
696
+ "reward": -0.0990382581949234,
697
+ "reward_std": 0.5990613698959351,
698
+ "rewards/cosine_scaled_reward": -0.16670663096010685,
699
+ "rewards/format_reward": 0.23437499720603228,
700
+ "step": 46
701
+ },
702
+ {
703
+ "clip_ratio": 0.0,
704
+ "completion_length": 3431.0208740234375,
705
+ "epoch": 0.16151202749140894,
706
+ "grad_norm": 0.07041551172733307,
707
+ "kl": 0.0007162094116210938,
708
+ "learning_rate": 9.399999999999999e-07,
709
+ "loss": -0.0228,
710
+ "num_tokens": 64087406.0,
711
+ "reward": -0.2570202387869358,
712
+ "reward_std": 0.5948278307914734,
713
+ "rewards/cosine_scaled_reward": -0.20142677798867226,
714
+ "rewards/format_reward": 0.1458333320915699,
715
+ "step": 47
716
+ },
717
+ {
718
+ "clip_ratio": 0.0,
719
+ "completion_length": 3309.7761840820312,
720
+ "epoch": 0.16494845360824742,
721
+ "grad_norm": 0.07270783185958862,
722
+ "kl": 0.0007419586181640625,
723
+ "learning_rate": 9.6e-07,
724
+ "loss": -0.0212,
725
+ "num_tokens": 65438982.0,
726
+ "reward": -0.06799130514264107,
727
+ "reward_std": 0.6631468534469604,
728
+ "rewards/cosine_scaled_reward": -0.13165189698338509,
729
+ "rewards/format_reward": 0.1953125037252903,
730
+ "step": 48
731
+ },
732
+ {
733
+ "clip_ratio": 0.0,
734
+ "completion_length": 3239.0000610351562,
735
+ "epoch": 0.16838487972508592,
736
+ "grad_norm": 0.07522521167993546,
737
+ "kl": 0.0007476806640625,
738
+ "learning_rate": 9.8e-07,
739
+ "loss": -0.0196,
740
+ "num_tokens": 66769818.0,
741
+ "reward": 0.0649011842906475,
742
+ "reward_std": 0.6760600805282593,
743
+ "rewards/cosine_scaled_reward": -0.09645565785467625,
744
+ "rewards/format_reward": 0.2578125,
745
+ "step": 49
746
+ },
747
+ {
748
+ "clip_ratio": 0.0,
749
+ "completion_length": 3254.1094360351562,
750
+ "epoch": 0.1718213058419244,
751
+ "grad_norm": 0.07710978388786316,
752
+ "kl": 0.000732421875,
753
+ "learning_rate": 1e-06,
754
+ "loss": -0.0147,
755
+ "num_tokens": 68100300.0,
756
+ "reward": -0.058488317765295506,
757
+ "reward_std": 0.6240686923265457,
758
+ "rewards/cosine_scaled_reward": -0.14512957073748112,
759
+ "rewards/format_reward": 0.2317708358168602,
760
+ "step": 50
761
+ }
762
+ ],
763
+ "logging_steps": 1,
764
+ "max_steps": 500,
765
+ "num_input_tokens_seen": 0,
766
+ "num_train_epochs": 2,
767
+ "save_steps": 50,
768
+ "stateful_callbacks": {
769
+ "TrainerControl": {
770
+ "args": {
771
+ "should_epoch_stop": false,
772
+ "should_evaluate": false,
773
+ "should_log": false,
774
+ "should_save": true,
775
+ "should_training_stop": false
776
+ },
777
+ "attributes": {}
778
+ }
779
+ },
780
+ "total_flos": 0.0,
781
+ "train_batch_size": 6,
782
+ "trial_name": null,
783
+ "trial_params": null
784
+ }
zero_to_fp32.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 1, 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example:
14
+ # python zero_to_fp32.py . output_dir/
15
+ # or
16
+ # python zero_to_fp32.py . output_dir/ --safe_serialization
17
+
18
+ import argparse
19
+ import torch
20
+ import glob
21
+ import math
22
+ import os
23
+ import re
24
+ import json
25
+ from tqdm import tqdm
26
+ from collections import OrderedDict
27
+ from dataclasses import dataclass
28
+
29
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
30
+ # DeepSpeed data structures it has to be available in the current python environment.
31
+ from deepspeed.utils import logger
32
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
33
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
34
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
35
+
36
+
37
+ @dataclass
38
+ class zero_model_state:
39
+ buffers: dict()
40
+ param_shapes: dict()
41
+ shared_params: list
42
+ ds_version: int
43
+ frozen_param_shapes: dict()
44
+ frozen_param_fragments: dict()
45
+
46
+
47
+ debug = 0
48
+
49
+ # load to cpu
50
+ device = torch.device('cpu')
51
+
52
+
53
+ def atoi(text):
54
+ return int(text) if text.isdigit() else text
55
+
56
+
57
+ def natural_keys(text):
58
+ '''
59
+ alist.sort(key=natural_keys) sorts in human order
60
+ http://nedbatchelder.com/blog/200712/human_sorting.html
61
+ (See Toothy's implementation in the comments)
62
+ '''
63
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
64
+
65
+
66
+ def get_model_state_file(checkpoint_dir, zero_stage):
67
+ if not os.path.isdir(checkpoint_dir):
68
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
69
+
70
+ # there should be only one file
71
+ if zero_stage <= 2:
72
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
73
+ elif zero_stage == 3:
74
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
75
+
76
+ if not os.path.exists(file):
77
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
78
+
79
+ return file
80
+
81
+
82
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
83
+ # XXX: need to test that this simple glob rule works for multi-node setup too
84
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
85
+
86
+ if len(ckpt_files) == 0:
87
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
88
+
89
+ return ckpt_files
90
+
91
+
92
+ def get_optim_files(checkpoint_dir):
93
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
94
+
95
+
96
+ def get_model_state_files(checkpoint_dir):
97
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
98
+
99
+
100
+ def parse_model_states(files):
101
+ zero_model_states = []
102
+ for file in files:
103
+ state_dict = torch.load(file, map_location=device)
104
+
105
+ if BUFFER_NAMES not in state_dict:
106
+ raise ValueError(f"{file} is not a model state checkpoint")
107
+ buffer_names = state_dict[BUFFER_NAMES]
108
+ if debug:
109
+ print("Found buffers:", buffer_names)
110
+
111
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
112
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
113
+ param_shapes = state_dict[PARAM_SHAPES]
114
+
115
+ # collect parameters that are included in param_shapes
116
+ param_names = []
117
+ for s in param_shapes:
118
+ for name in s.keys():
119
+ param_names.append(name)
120
+
121
+ # update with frozen parameters
122
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
123
+ if frozen_param_shapes is not None:
124
+ if debug:
125
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
126
+ param_names += list(frozen_param_shapes.keys())
127
+
128
+ # handle shared params
129
+ shared_params = [[k, v] for k, v in state_dict["shared_params"].items()]
130
+
131
+ ds_version = state_dict.get(DS_VERSION, None)
132
+
133
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
134
+
135
+ z_model_state = zero_model_state(buffers=buffers,
136
+ param_shapes=param_shapes,
137
+ shared_params=shared_params,
138
+ ds_version=ds_version,
139
+ frozen_param_shapes=frozen_param_shapes,
140
+ frozen_param_fragments=frozen_param_fragments)
141
+ zero_model_states.append(z_model_state)
142
+
143
+ return zero_model_states
144
+
145
+
146
+ def parse_optim_states(files, ds_checkpoint_dir):
147
+ total_files = len(files)
148
+ state_dicts = []
149
+ for f in files:
150
+ state_dict = torch.load(f, map_location=device)
151
+ # immediately discard the potentially huge 2 optimizer states as we only care for fp32 master weights
152
+ # and also handle the case where it was already removed by another helper script
153
+ state_dict["optimizer_state_dict"].pop("optimizer_state_dict", None)
154
+ state_dicts.append(state_dict)
155
+
156
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
157
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
158
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
159
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
160
+
161
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
162
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
163
+ # use the max of the partition_count to get the dp world_size.
164
+
165
+ if type(world_size) is list:
166
+ world_size = max(world_size)
167
+
168
+ if world_size != total_files:
169
+ raise ValueError(
170
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
171
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
172
+ )
173
+
174
+ # the groups are named differently in each stage
175
+ if zero_stage <= 2:
176
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
177
+ elif zero_stage == 3:
178
+ fp32_groups_key = FP32_FLAT_GROUPS
179
+ else:
180
+ raise ValueError(f"unknown zero stage {zero_stage}")
181
+
182
+ if zero_stage <= 2:
183
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
184
+ elif zero_stage == 3:
185
+ # if there is more than one param group, there will be multiple flattened tensors - one
186
+ # flattened tensor per group - for simplicity merge them into a single tensor
187
+ #
188
+ # XXX: could make the script more memory efficient for when there are multiple groups - it
189
+ # will require matching the sub-lists of param_shapes for each param group flattened tensor
190
+
191
+ fp32_flat_groups = [
192
+ torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
193
+ ]
194
+
195
+ return zero_stage, world_size, fp32_flat_groups
196
+
197
+
198
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters):
199
+ """
200
+ Returns fp32 state_dict reconstructed from ds checkpoint
201
+
202
+ Args:
203
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
204
+
205
+ """
206
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
207
+
208
+ optim_files = get_optim_files(ds_checkpoint_dir)
209
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
210
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
211
+
212
+ model_files = get_model_state_files(ds_checkpoint_dir)
213
+
214
+ zero_model_states = parse_model_states(model_files)
215
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
216
+
217
+ if zero_stage <= 2:
218
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
219
+ exclude_frozen_parameters)
220
+ elif zero_stage == 3:
221
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
222
+ exclude_frozen_parameters)
223
+
224
+
225
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
226
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
227
+ return
228
+
229
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
230
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
231
+
232
+ if debug:
233
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
234
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
235
+
236
+ wanted_params = len(frozen_param_shapes)
237
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
238
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
239
+ print(f'Frozen params: Have {avail_numel} numels to process.')
240
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
241
+
242
+ total_params = 0
243
+ total_numel = 0
244
+ for name, shape in frozen_param_shapes.items():
245
+ total_params += 1
246
+ unpartitioned_numel = shape.numel()
247
+ total_numel += unpartitioned_numel
248
+
249
+ state_dict[name] = frozen_param_fragments[name]
250
+
251
+ if debug:
252
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
253
+
254
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
255
+
256
+
257
+ def _has_callable(obj, fn):
258
+ attr = getattr(obj, fn, None)
259
+ return callable(attr)
260
+
261
+
262
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
263
+ param_shapes = zero_model_states[0].param_shapes
264
+
265
+ # Reconstruction protocol:
266
+ #
267
+ # XXX: document this
268
+
269
+ if debug:
270
+ for i in range(world_size):
271
+ for j in range(len(fp32_flat_groups[0])):
272
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
273
+
274
+ # XXX: memory usage doubles here (zero2)
275
+ num_param_groups = len(fp32_flat_groups[0])
276
+ merged_single_partition_of_fp32_groups = []
277
+ for i in range(num_param_groups):
278
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
279
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
280
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
281
+ avail_numel = sum(
282
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
283
+
284
+ if debug:
285
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
286
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
287
+ # not asserting if there is a mismatch due to possible padding
288
+ print(f"Have {avail_numel} numels to process.")
289
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
290
+
291
+ # params
292
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
293
+ # out-of-core computing solution
294
+ total_numel = 0
295
+ total_params = 0
296
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
297
+ offset = 0
298
+ avail_numel = full_single_fp32_vector.numel()
299
+ for name, shape in shapes.items():
300
+
301
+ unpartitioned_numel = shape.numel() if _has_callable(shape, 'numel') else math.prod(shape)
302
+ total_numel += unpartitioned_numel
303
+ total_params += 1
304
+
305
+ if debug:
306
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
307
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
308
+ offset += unpartitioned_numel
309
+
310
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
311
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
312
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
313
+ # live optimizer object, so we are checking that the numbers are within the right range
314
+ align_to = 2 * world_size
315
+
316
+ def zero2_align(x):
317
+ return align_to * math.ceil(x / align_to)
318
+
319
+ if debug:
320
+ print(f"original offset={offset}, avail_numel={avail_numel}")
321
+
322
+ offset = zero2_align(offset)
323
+ avail_numel = zero2_align(avail_numel)
324
+
325
+ if debug:
326
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
327
+
328
+ # Sanity check
329
+ if offset != avail_numel:
330
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
331
+
332
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
333
+
334
+
335
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states,
336
+ exclude_frozen_parameters):
337
+ state_dict = OrderedDict()
338
+
339
+ # buffers
340
+ buffers = zero_model_states[0].buffers
341
+ state_dict.update(buffers)
342
+ if debug:
343
+ print(f"added {len(buffers)} buffers")
344
+
345
+ if not exclude_frozen_parameters:
346
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
347
+
348
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
349
+
350
+ # recover shared parameters
351
+ for pair in zero_model_states[0].shared_params:
352
+ if pair[1] in state_dict:
353
+ state_dict[pair[0]] = state_dict[pair[1]]
354
+
355
+ return state_dict
356
+
357
+
358
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
359
+ remainder = unpartitioned_numel % world_size
360
+ padding_numel = (world_size - remainder) if remainder else 0
361
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
362
+ return partitioned_numel, padding_numel
363
+
364
+
365
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
366
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
367
+ return
368
+
369
+ if debug:
370
+ for i in range(world_size):
371
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
372
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
373
+
374
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
375
+ wanted_params = len(frozen_param_shapes)
376
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
377
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
378
+ print(f'Frozen params: Have {avail_numel} numels to process.')
379
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
380
+
381
+ total_params = 0
382
+ total_numel = 0
383
+ for name, shape in zero_model_states[0].frozen_param_shapes.items():
384
+ total_params += 1
385
+ unpartitioned_numel = shape.numel()
386
+ total_numel += unpartitioned_numel
387
+
388
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
389
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
390
+
391
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
392
+
393
+ if debug:
394
+ print(
395
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
396
+ )
397
+
398
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
399
+
400
+
401
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
402
+ param_shapes = zero_model_states[0].param_shapes
403
+ avail_numel = fp32_flat_groups[0].numel() * world_size
404
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
405
+ # param, re-consolidating each param, while dealing with padding if any
406
+
407
+ # merge list of dicts, preserving order
408
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
409
+
410
+ if debug:
411
+ for i in range(world_size):
412
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
413
+
414
+ wanted_params = len(param_shapes)
415
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
416
+ # not asserting if there is a mismatch due to possible padding
417
+ avail_numel = fp32_flat_groups[0].numel() * world_size
418
+ print(f"Trainable params: Have {avail_numel} numels to process.")
419
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
420
+
421
+ # params
422
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
423
+ # out-of-core computing solution
424
+ offset = 0
425
+ total_numel = 0
426
+ total_params = 0
427
+ for name, shape in tqdm(param_shapes.items(), desc='Gathering Sharded Weights'):
428
+ unpartitioned_numel = shape.numel()
429
+ total_numel += unpartitioned_numel
430
+ total_params += 1
431
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
432
+
433
+ if debug:
434
+ print(
435
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
436
+ )
437
+
438
+ # XXX: memory usage doubles here
439
+ state_dict[name] = torch.cat(
440
+ tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
441
+ 0).narrow(0, 0, unpartitioned_numel).view(shape)
442
+ offset += partitioned_numel
443
+
444
+ offset *= world_size
445
+
446
+ # Sanity check
447
+ if offset != avail_numel:
448
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
449
+
450
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
451
+
452
+
453
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states,
454
+ exclude_frozen_parameters):
455
+ state_dict = OrderedDict()
456
+
457
+ # buffers
458
+ buffers = zero_model_states[0].buffers
459
+ state_dict.update(buffers)
460
+ if debug:
461
+ print(f"added {len(buffers)} buffers")
462
+
463
+ if not exclude_frozen_parameters:
464
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
465
+
466
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
467
+
468
+ # recover shared parameters
469
+ for pair in zero_model_states[0].shared_params:
470
+ if pair[1] in state_dict:
471
+ state_dict[pair[0]] = state_dict[pair[1]]
472
+
473
+ return state_dict
474
+
475
+
476
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None, exclude_frozen_parameters=False):
477
+ """
478
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
479
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
480
+ via a model hub.
481
+
482
+ Args:
483
+ - ``checkpoint_dir``: path to the desired checkpoint folder
484
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
485
+ - ``exclude_frozen_parameters``: exclude frozen parameters
486
+
487
+ Returns:
488
+ - pytorch ``state_dict``
489
+
490
+ Note: this approach may not work if your application doesn't have sufficient free CPU memory and
491
+ you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
492
+ the checkpoint.
493
+
494
+ A typical usage might be ::
495
+
496
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
497
+ # do the training and checkpoint saving
498
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
499
+ model = model.cpu() # move to cpu
500
+ model.load_state_dict(state_dict)
501
+ # submit to model hub or save the model to share with others
502
+
503
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
504
+ application. i.e. you will need to re-initialize the deepspeed engine, since
505
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
506
+
507
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
508
+
509
+ """
510
+ if tag is None:
511
+ latest_path = os.path.join(checkpoint_dir, 'latest')
512
+ if os.path.isfile(latest_path):
513
+ with open(latest_path, 'r') as fd:
514
+ tag = fd.read().strip()
515
+ else:
516
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
517
+
518
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
519
+
520
+ if not os.path.isdir(ds_checkpoint_dir):
521
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
522
+
523
+ return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir, exclude_frozen_parameters)
524
+
525
+
526
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir,
527
+ output_dir,
528
+ max_shard_size="5GB",
529
+ safe_serialization=False,
530
+ tag=None,
531
+ exclude_frozen_parameters=False):
532
+ """
533
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
534
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
535
+
536
+ Args:
537
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
538
+ - ``output_dir``: directory to the pytorch fp32 state_dict output files
539
+ - ``max_shard_size``: the maximum size for a checkpoint before being sharded, default value is 5GB
540
+ - ``safe_serialization``: whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
541
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
542
+ - ``exclude_frozen_parameters``: exclude frozen parameters
543
+ """
544
+ # Dependency pre-check
545
+ if safe_serialization:
546
+ try:
547
+ from safetensors.torch import save_file
548
+ except ImportError:
549
+ print('If you want to use `safe_serialization`, please `pip install safetensors`')
550
+ raise
551
+ if max_shard_size is not None:
552
+ try:
553
+ from huggingface_hub import split_torch_state_dict_into_shards
554
+ except ImportError:
555
+ print('If you want to use `max_shard_size`, please `pip install huggingface_hub`')
556
+ raise
557
+
558
+ # Convert zero checkpoint to state_dict
559
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag, exclude_frozen_parameters)
560
+
561
+ # Shard the model if it is too big.
562
+ weights_name = "model.safetensors" if safe_serialization else "pytorch_model.bin"
563
+ if max_shard_size is not None:
564
+ filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
565
+ state_dict_split = split_torch_state_dict_into_shards(state_dict,
566
+ filename_pattern=filename_pattern,
567
+ max_shard_size=max_shard_size)
568
+ else:
569
+ from collections import namedtuple
570
+ StateDictSplit = namedtuple("StateDictSplit", ["is_sharded", "filename_to_tensors"])
571
+ state_dict_split = StateDictSplit(is_sharded=False,
572
+ filename_to_tensors={weights_name: list(state_dict.keys())})
573
+
574
+ # Save the model
575
+ filename_to_tensors = state_dict_split.filename_to_tensors.items()
576
+ for shard_file, tensors in tqdm(filename_to_tensors, desc="Saving checkpoint shards"):
577
+ shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
578
+ output_path = os.path.join(output_dir, shard_file)
579
+ if safe_serialization:
580
+ save_file(shard, output_path, metadata={"format": "pt"})
581
+ else:
582
+ torch.save(shard, output_path)
583
+
584
+ # Save index if sharded
585
+ if state_dict_split.is_sharded:
586
+ index = {
587
+ "metadata": state_dict_split.metadata,
588
+ "weight_map": state_dict_split.tensor_to_filename,
589
+ }
590
+ save_index_file = "model.safetensors.index.json" if safe_serialization else "pytorch_model.bin.index.json"
591
+ save_index_file = os.path.join(output_dir, save_index_file)
592
+ with open(save_index_file, "w", encoding="utf-8") as f:
593
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
594
+ f.write(content)
595
+
596
+
597
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
598
+ """
599
+ 1. Put the provided model to cpu
600
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
601
+ 3. Load it into the provided model
602
+
603
+ Args:
604
+ - ``model``: the model object to update
605
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
606
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
607
+
608
+ Returns:
609
+ - ``model`: modified model
610
+
611
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
612
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
613
+ conveniently placed for you in the checkpoint folder.
614
+
615
+ A typical usage might be ::
616
+
617
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
618
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
619
+ # submit to model hub or save the model to share with others
620
+
621
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
622
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
623
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
624
+
625
+ """
626
+ logger.info(f"Extracting fp32 weights")
627
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
628
+
629
+ logger.info(f"Overwriting model with fp32 weights")
630
+ model = model.cpu()
631
+ model.load_state_dict(state_dict, strict=False)
632
+
633
+ return model
634
+
635
+
636
+ if __name__ == "__main__":
637
+ parser = argparse.ArgumentParser()
638
+ parser.add_argument("checkpoint_dir",
639
+ type=str,
640
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
641
+ parser.add_argument("output_dir",
642
+ type=str,
643
+ help="directory to the pytorch fp32 state_dict output files"
644
+ "(e.g. path/checkpoint-12-output/)")
645
+ parser.add_argument(
646
+ "--max_shard_size",
647
+ type=str,
648
+ default="5GB",
649
+ help="The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size"
650
+ "lower than this size. If expressed as a string, needs to be digits followed by a unit (like `5MB`"
651
+ "We default it to 5GB in order for models to be able to run easily on free-tier google colab instances"
652
+ "without CPU OOM issues.")
653
+ parser.add_argument(
654
+ "--safe_serialization",
655
+ default=False,
656
+ action='store_true',
657
+ help="Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).")
658
+ parser.add_argument("-t",
659
+ "--tag",
660
+ type=str,
661
+ default=None,
662
+ help="checkpoint tag used as a unique identifier for checkpoint. e.g., global_step1")
663
+ parser.add_argument("--exclude_frozen_parameters", action='store_true', help="exclude frozen parameters")
664
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
665
+ args = parser.parse_args()
666
+
667
+ debug = args.debug
668
+
669
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir,
670
+ args.output_dir,
671
+ max_shard_size=args.max_shard_size,
672
+ safe_serialization=args.safe_serialization,
673
+ tag=args.tag,
674
+ exclude_frozen_parameters=args.exclude_frozen_parameters)