Shengkun commited on
Commit
6ecb966
·
verified ·
1 Parent(s): 31d9d26

Upload LlamaForCausalLM

Browse files
Files changed (2) hide show
  1. config.json +630 -32
  2. modeling_darwinlm.py +4 -4
config.json CHANGED
@@ -46,38 +46,636 @@
46
  "eos_token_id": 128001,
47
  "head_dim": 128,
48
  "heads_each_attn": {
49
- "0.self_attn.o_proj": 11,
50
- "1.self_attn.o_proj": 14,
51
- "10.self_attn.o_proj": 14,
52
- "11.self_attn.o_proj": 14,
53
- "12.self_attn.o_proj": 7,
54
- "13.self_attn.o_proj": 14,
55
- "14.self_attn.o_proj": 26,
56
- "15.self_attn.o_proj": 14,
57
- "16.self_attn.o_proj": 7,
58
- "17.self_attn.o_proj": 18,
59
- "18.self_attn.o_proj": 18,
60
- "19.self_attn.o_proj": 18,
61
- "2.self_attn.o_proj": 22,
62
- "20.self_attn.o_proj": 31,
63
- "21.self_attn.o_proj": 28,
64
- "22.self_attn.o_proj": 18,
65
- "23.self_attn.o_proj": 18,
66
- "24.self_attn.o_proj": 14,
67
- "25.self_attn.o_proj": 14,
68
- "26.self_attn.o_proj": 32,
69
- "27.self_attn.o_proj": 22,
70
- "28.self_attn.o_proj": 22,
71
- "29.self_attn.o_proj": 18,
72
- "3.self_attn.o_proj": 7,
73
- "30.self_attn.o_proj": 31,
74
- "31.self_attn.o_proj": 28,
75
- "4.self_attn.o_proj": 14,
76
- "5.self_attn.o_proj": 22,
77
- "6.self_attn.o_proj": 14,
78
- "7.self_attn.o_proj": 14,
79
- "8.self_attn.o_proj": 18,
80
- "9.self_attn.o_proj": 4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  },
82
  "hidden_act": "silu",
83
  "hidden_size": 4096,
 
46
  "eos_token_id": 128001,
47
  "head_dim": 128,
48
  "heads_each_attn": {
49
+ "0.self_attn.o_proj": [
50
+ 13,
51
+ 15,
52
+ 18,
53
+ 19,
54
+ 20,
55
+ 22,
56
+ 25,
57
+ 27,
58
+ 28,
59
+ 30,
60
+ 31
61
+ ],
62
+ "1.self_attn.o_proj": [
63
+ 0,
64
+ 2,
65
+ 8,
66
+ 9,
67
+ 11,
68
+ 12,
69
+ 13,
70
+ 14,
71
+ 15,
72
+ 19,
73
+ 22,
74
+ 25,
75
+ 27,
76
+ 29
77
+ ],
78
+ "10.self_attn.o_proj": [
79
+ 2,
80
+ 3,
81
+ 4,
82
+ 5,
83
+ 6,
84
+ 7,
85
+ 8,
86
+ 12,
87
+ 13,
88
+ 15,
89
+ 20,
90
+ 24,
91
+ 25,
92
+ 31
93
+ ],
94
+ "11.self_attn.o_proj": [
95
+ 1,
96
+ 3,
97
+ 7,
98
+ 11,
99
+ 13,
100
+ 14,
101
+ 15,
102
+ 18,
103
+ 19,
104
+ 20,
105
+ 21,
106
+ 22,
107
+ 28,
108
+ 29
109
+ ],
110
+ "12.self_attn.o_proj": [
111
+ 2,
112
+ 7,
113
+ 10,
114
+ 11,
115
+ 13,
116
+ 14,
117
+ 27
118
+ ],
119
+ "13.self_attn.o_proj": [
120
+ 0,
121
+ 1,
122
+ 2,
123
+ 5,
124
+ 6,
125
+ 10,
126
+ 12,
127
+ 13,
128
+ 16,
129
+ 17,
130
+ 20,
131
+ 23,
132
+ 26,
133
+ 29
134
+ ],
135
+ "14.self_attn.o_proj": [
136
+ 0,
137
+ 2,
138
+ 3,
139
+ 4,
140
+ 5,
141
+ 6,
142
+ 7,
143
+ 9,
144
+ 10,
145
+ 11,
146
+ 12,
147
+ 13,
148
+ 14,
149
+ 16,
150
+ 17,
151
+ 18,
152
+ 20,
153
+ 21,
154
+ 23,
155
+ 24,
156
+ 25,
157
+ 27,
158
+ 28,
159
+ 29,
160
+ 30,
161
+ 31
162
+ ],
163
+ "15.self_attn.o_proj": [
164
+ 1,
165
+ 2,
166
+ 3,
167
+ 9,
168
+ 10,
169
+ 11,
170
+ 13,
171
+ 15,
172
+ 17,
173
+ 18,
174
+ 22,
175
+ 23,
176
+ 28,
177
+ 29
178
+ ],
179
+ "16.self_attn.o_proj": [
180
+ 1,
181
+ 2,
182
+ 5,
183
+ 7,
184
+ 21,
185
+ 23,
186
+ 28
187
+ ],
188
+ "17.self_attn.o_proj": [
189
+ 1,
190
+ 2,
191
+ 5,
192
+ 6,
193
+ 7,
194
+ 9,
195
+ 10,
196
+ 14,
197
+ 15,
198
+ 16,
199
+ 17,
200
+ 19,
201
+ 21,
202
+ 22,
203
+ 23,
204
+ 26,
205
+ 28,
206
+ 31
207
+ ],
208
+ "18.self_attn.o_proj": [
209
+ 1,
210
+ 2,
211
+ 3,
212
+ 4,
213
+ 5,
214
+ 6,
215
+ 9,
216
+ 10,
217
+ 11,
218
+ 12,
219
+ 13,
220
+ 15,
221
+ 16,
222
+ 18,
223
+ 19,
224
+ 21,
225
+ 24,
226
+ 31
227
+ ],
228
+ "19.self_attn.o_proj": [
229
+ 0,
230
+ 1,
231
+ 4,
232
+ 5,
233
+ 6,
234
+ 12,
235
+ 13,
236
+ 14,
237
+ 15,
238
+ 16,
239
+ 19,
240
+ 21,
241
+ 22,
242
+ 24,
243
+ 25,
244
+ 28,
245
+ 29,
246
+ 31
247
+ ],
248
+ "2.self_attn.o_proj": [
249
+ 1,
250
+ 2,
251
+ 3,
252
+ 4,
253
+ 6,
254
+ 8,
255
+ 9,
256
+ 10,
257
+ 11,
258
+ 12,
259
+ 13,
260
+ 14,
261
+ 15,
262
+ 16,
263
+ 17,
264
+ 18,
265
+ 20,
266
+ 21,
267
+ 23,
268
+ 24,
269
+ 25,
270
+ 27
271
+ ],
272
+ "20.self_attn.o_proj": [
273
+ 0,
274
+ 1,
275
+ 2,
276
+ 3,
277
+ 4,
278
+ 5,
279
+ 6,
280
+ 7,
281
+ 8,
282
+ 9,
283
+ 10,
284
+ 11,
285
+ 12,
286
+ 13,
287
+ 14,
288
+ 15,
289
+ 16,
290
+ 17,
291
+ 18,
292
+ 19,
293
+ 20,
294
+ 21,
295
+ 22,
296
+ 23,
297
+ 24,
298
+ 25,
299
+ 26,
300
+ 27,
301
+ 28,
302
+ 30,
303
+ 31
304
+ ],
305
+ "21.self_attn.o_proj": [
306
+ 0,
307
+ 1,
308
+ 2,
309
+ 3,
310
+ 4,
311
+ 5,
312
+ 6,
313
+ 8,
314
+ 9,
315
+ 11,
316
+ 12,
317
+ 13,
318
+ 15,
319
+ 16,
320
+ 17,
321
+ 18,
322
+ 20,
323
+ 21,
324
+ 22,
325
+ 23,
326
+ 24,
327
+ 25,
328
+ 26,
329
+ 27,
330
+ 28,
331
+ 29,
332
+ 30,
333
+ 31
334
+ ],
335
+ "22.self_attn.o_proj": [
336
+ 0,
337
+ 1,
338
+ 2,
339
+ 3,
340
+ 4,
341
+ 5,
342
+ 6,
343
+ 12,
344
+ 13,
345
+ 15,
346
+ 16,
347
+ 18,
348
+ 20,
349
+ 21,
350
+ 23,
351
+ 25,
352
+ 26,
353
+ 31
354
+ ],
355
+ "23.self_attn.o_proj": [
356
+ 0,
357
+ 2,
358
+ 3,
359
+ 5,
360
+ 7,
361
+ 10,
362
+ 11,
363
+ 12,
364
+ 13,
365
+ 14,
366
+ 15,
367
+ 18,
368
+ 19,
369
+ 20,
370
+ 21,
371
+ 24,
372
+ 26,
373
+ 27
374
+ ],
375
+ "24.self_attn.o_proj": [
376
+ 0,
377
+ 4,
378
+ 5,
379
+ 6,
380
+ 7,
381
+ 8,
382
+ 9,
383
+ 11,
384
+ 13,
385
+ 14,
386
+ 23,
387
+ 25,
388
+ 26,
389
+ 31
390
+ ],
391
+ "25.self_attn.o_proj": [
392
+ 0,
393
+ 3,
394
+ 8,
395
+ 9,
396
+ 10,
397
+ 16,
398
+ 17,
399
+ 19,
400
+ 25,
401
+ 26,
402
+ 27,
403
+ 28,
404
+ 29,
405
+ 30
406
+ ],
407
+ "26.self_attn.o_proj": [
408
+ 0,
409
+ 1,
410
+ 2,
411
+ 3,
412
+ 4,
413
+ 5,
414
+ 6,
415
+ 7,
416
+ 8,
417
+ 9,
418
+ 10,
419
+ 11,
420
+ 12,
421
+ 13,
422
+ 14,
423
+ 15,
424
+ 16,
425
+ 17,
426
+ 18,
427
+ 19,
428
+ 20,
429
+ 21,
430
+ 22,
431
+ 23,
432
+ 24,
433
+ 25,
434
+ 26,
435
+ 27,
436
+ 28,
437
+ 29,
438
+ 30,
439
+ 31
440
+ ],
441
+ "27.self_attn.o_proj": [
442
+ 0,
443
+ 2,
444
+ 3,
445
+ 4,
446
+ 5,
447
+ 10,
448
+ 11,
449
+ 13,
450
+ 15,
451
+ 17,
452
+ 18,
453
+ 19,
454
+ 21,
455
+ 22,
456
+ 23,
457
+ 24,
458
+ 25,
459
+ 26,
460
+ 27,
461
+ 29,
462
+ 30,
463
+ 31
464
+ ],
465
+ "28.self_attn.o_proj": [
466
+ 1,
467
+ 2,
468
+ 3,
469
+ 4,
470
+ 5,
471
+ 6,
472
+ 7,
473
+ 8,
474
+ 11,
475
+ 13,
476
+ 14,
477
+ 15,
478
+ 16,
479
+ 17,
480
+ 21,
481
+ 22,
482
+ 24,
483
+ 25,
484
+ 27,
485
+ 28,
486
+ 29,
487
+ 31
488
+ ],
489
+ "29.self_attn.o_proj": [
490
+ 0,
491
+ 1,
492
+ 2,
493
+ 3,
494
+ 4,
495
+ 5,
496
+ 6,
497
+ 7,
498
+ 12,
499
+ 13,
500
+ 14,
501
+ 15,
502
+ 17,
503
+ 18,
504
+ 19,
505
+ 23,
506
+ 29,
507
+ 30
508
+ ],
509
+ "3.self_attn.o_proj": [
510
+ 1,
511
+ 4,
512
+ 9,
513
+ 11,
514
+ 13,
515
+ 17,
516
+ 27
517
+ ],
518
+ "30.self_attn.o_proj": [
519
+ 0,
520
+ 1,
521
+ 2,
522
+ 3,
523
+ 4,
524
+ 5,
525
+ 6,
526
+ 7,
527
+ 8,
528
+ 9,
529
+ 10,
530
+ 11,
531
+ 12,
532
+ 13,
533
+ 14,
534
+ 15,
535
+ 16,
536
+ 17,
537
+ 19,
538
+ 20,
539
+ 21,
540
+ 22,
541
+ 23,
542
+ 24,
543
+ 25,
544
+ 26,
545
+ 27,
546
+ 28,
547
+ 29,
548
+ 30,
549
+ 31
550
+ ],
551
+ "31.self_attn.o_proj": [
552
+ 1,
553
+ 2,
554
+ 3,
555
+ 4,
556
+ 5,
557
+ 6,
558
+ 8,
559
+ 9,
560
+ 10,
561
+ 11,
562
+ 12,
563
+ 13,
564
+ 15,
565
+ 16,
566
+ 17,
567
+ 18,
568
+ 19,
569
+ 20,
570
+ 21,
571
+ 22,
572
+ 24,
573
+ 25,
574
+ 26,
575
+ 27,
576
+ 28,
577
+ 29,
578
+ 30,
579
+ 31
580
+ ],
581
+ "4.self_attn.o_proj": [
582
+ 0,
583
+ 3,
584
+ 9,
585
+ 10,
586
+ 11,
587
+ 14,
588
+ 15,
589
+ 18,
590
+ 19,
591
+ 20,
592
+ 21,
593
+ 28,
594
+ 29,
595
+ 30
596
+ ],
597
+ "5.self_attn.o_proj": [
598
+ 0,
599
+ 1,
600
+ 2,
601
+ 5,
602
+ 6,
603
+ 7,
604
+ 9,
605
+ 10,
606
+ 11,
607
+ 12,
608
+ 14,
609
+ 15,
610
+ 18,
611
+ 20,
612
+ 21,
613
+ 22,
614
+ 23,
615
+ 24,
616
+ 26,
617
+ 27,
618
+ 28,
619
+ 30
620
+ ],
621
+ "6.self_attn.o_proj": [
622
+ 2,
623
+ 3,
624
+ 8,
625
+ 9,
626
+ 11,
627
+ 14,
628
+ 15,
629
+ 17,
630
+ 23,
631
+ 24,
632
+ 26,
633
+ 27,
634
+ 28,
635
+ 31
636
+ ],
637
+ "7.self_attn.o_proj": [
638
+ 3,
639
+ 6,
640
+ 7,
641
+ 8,
642
+ 11,
643
+ 13,
644
+ 16,
645
+ 18,
646
+ 20,
647
+ 22,
648
+ 24,
649
+ 25,
650
+ 26,
651
+ 28
652
+ ],
653
+ "8.self_attn.o_proj": [
654
+ 0,
655
+ 1,
656
+ 2,
657
+ 3,
658
+ 6,
659
+ 7,
660
+ 8,
661
+ 9,
662
+ 10,
663
+ 15,
664
+ 17,
665
+ 20,
666
+ 22,
667
+ 24,
668
+ 26,
669
+ 27,
670
+ 28,
671
+ 30
672
+ ],
673
+ "9.self_attn.o_proj": [
674
+ 12,
675
+ 20,
676
+ 22,
677
+ 27
678
+ ]
679
  },
680
  "hidden_act": "silu",
681
  "hidden_size": 4096,
modeling_darwinlm.py CHANGED
@@ -1061,14 +1061,14 @@ class LlamaModel(LlamaPreTrainedModel):
1061
 
1062
 
1063
  def prune_model(self, heads_each_attn, dim_each_mlp, kv_ignore):
1064
- for name, heads_num in heads_each_attn.items():
1065
  layer_idx = int(name.split(".")[0])
1066
  attn = self.layers[layer_idx].self_attn
1067
- if heads_num == 32:
1068
  self.layers[layer_idx].self_attn = NoAttention()
1069
  continue
1070
- heads = [i for i in range(heads_num)]
1071
- attn.prune_heads(heads, kv_ignore)
1072
 
1073
  for name, dim in dim_each_mlp.items():
1074
  layer_idx = int(name.split(".")[0])
 
1061
 
1062
 
1063
  def prune_model(self, heads_each_attn, dim_each_mlp, kv_ignore):
1064
+ for name, heads_idx in heads_each_attn.items():
1065
  layer_idx = int(name.split(".")[0])
1066
  attn = self.layers[layer_idx].self_attn
1067
+ if len(heads_idx) == 32:
1068
  self.layers[layer_idx].self_attn = NoAttention()
1069
  continue
1070
+ # heads = [i for i in range(heads_num)]
1071
+ attn.prune_heads(heads_idx, kv_ignore)
1072
 
1073
  for name, dim in dim_each_mlp.items():
1074
  layer_idx = int(name.split(".")[0])