Deepu1965 commited on
Commit
3ecb23d
·
verified ·
1 Parent(s): 9dbdc5e

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +9 -3
  2. history.csv +4 -4
  3. metrics.json +183 -86
  4. model.pt +2 -2
  5. tokenizer/tokenizer.json +1 -10
README.md CHANGED
@@ -1,7 +1,13 @@
1
- # Week 2 MoE (hash routing)
2
 
3
- * Best validation accuracy: 0.8050
4
  * Top-k: 1
5
  * Aux loss coef: 0.0
6
 
7
- Artifacts include the trained state dict (`model.pt`), metrics (`metrics.json`), per-epoch history (`history.csv`), and tokenizer files.
 
 
 
 
 
 
 
1
+ # Week 2 MoE Seq2Seq (hash routing)
2
 
3
+ * Best validation loss: 5.6076
4
  * Top-k: 1
5
  * Aux loss coef: 0.0
6
 
7
+ Artifacts include the trained state dict (`model.pt`), metrics (`metrics.json`), per-epoch history (`history.csv`), and tokenizer files.
8
+
9
+ ## Architecture
10
+ - Encoder-Decoder Transformer with Sparse MoE layers
11
+ - Hash-based routing (deterministic) or Token-choice top-k routing (learned)
12
+ - Load balancing auxiliary loss for top-k routing
13
+ - Trained from scratch on XSum for abstractive summarization
history.csv CHANGED
@@ -1,4 +1,4 @@
1
- epoch,train_loss,train_aux_loss,train_accuracy,val_loss,val_aux_loss,val_accuracy
2
- 1,0.9436405922412873,0.0,0.6104,0.6622284393310547,0.0,0.754
3
- 2,0.5059148602962494,0.0,0.8156,0.5826493395864963,0.0,0.805
4
- 3,0.3382958103120327,0.0,0.8844,0.6135995112359524,0.0,0.803
 
1
+ epoch,train_loss,train_aux_loss,train_perplexity,val_loss,val_aux_loss,val_perplexity
2
+ 1,6.781526548633617,0.0,881.413217772467,6.13985468708606,0.0,463.9861427818742
3
+ 2,5.734937044167949,0.0,309.49348572595704,5.784560862128246,0.0,325.23918390511113
4
+ 3,5.265651165721379,0.0,193.57231541377683,5.607645124527238,0.0,272.5017740962985
metrics.json CHANGED
@@ -2,128 +2,225 @@
2
  "history": [
3
  {
4
  "epoch": 1,
5
- "train_loss": 0.9436405922412873,
6
  "train_aux_loss": 0.0,
7
- "train_accuracy": 0.6104,
8
- "val_loss": 0.6622284393310547,
9
  "val_aux_loss": 0.0,
10
- "val_accuracy": 0.754
11
  },
12
  {
13
  "epoch": 2,
14
- "train_loss": 0.5059148602962494,
15
  "train_aux_loss": 0.0,
16
- "train_accuracy": 0.8156,
17
- "val_loss": 0.5826493395864963,
18
  "val_aux_loss": 0.0,
19
- "val_accuracy": 0.805
20
  },
21
  {
22
  "epoch": 3,
23
- "train_loss": 0.3382958103120327,
24
  "train_aux_loss": 0.0,
25
- "train_accuracy": 0.8844,
26
- "val_loss": 0.6135995112359524,
27
  "val_aux_loss": 0.0,
28
- "val_accuracy": 0.803
29
  }
30
  ],
31
  "train_expert_usage": [
32
- [
33
- [
34
- 22.28339958190918,
35
- 40.70140075683594,
36
- 14.662799835205078,
37
- 50.35239791870117
 
 
 
 
 
 
 
 
38
  ],
39
- [
40
- 58.98899841308594,
41
- 38.73780059814453,
42
- 12.463199615478516,
43
- 17.809999465942383
 
 
 
 
 
 
 
 
44
  ]
45
- ],
46
- [
47
- [
48
- 23.274999618530273,
49
- 41.05299758911133,
50
- 14.551199913024902,
51
- 49.12080001831055
 
 
 
 
 
 
 
 
52
  ],
53
- [
54
- 57.284000396728516,
55
- 32.5275993347168,
56
- 11.510799407958984,
57
- 26.67759895324707
 
 
 
 
 
 
 
 
58
  ]
59
- ],
60
- [
61
- [
62
- 22.86240005493164,
63
- 43.082000732421875,
64
- 14.503599166870117,
65
- 47.551998138427734
 
 
 
 
 
 
 
 
66
  ],
67
- [
68
- 55.36159896850586,
69
- 30.51259994506836,
70
- 11.39799976348877,
71
- 30.727798461914062
 
 
 
 
 
 
 
 
72
  ]
73
- ]
74
  ],
75
  "val_expert_usage": [
76
- [
77
- [
78
- 13.781001091003418,
79
- 31.424001693725586,
80
- 14.86400032043457,
81
- 67.93099975585938
 
 
 
 
 
 
 
 
82
  ],
83
- [
84
- 71.9800033569336,
85
- 18.58700180053711,
86
- 12.019001007080078,
87
- 25.41400146484375
 
 
 
 
 
 
 
 
88
  ]
89
- ],
90
- [
91
- [
92
- 14.10200023651123,
93
- 34.134002685546875,
94
- 14.655000686645508,
95
- 65.10900115966797
 
 
 
 
 
 
 
 
96
  ],
97
- [
98
- 57.595001220703125,
99
- 23.998001098632812,
100
- 11.643000602722168,
101
- 34.763999938964844
 
 
 
 
 
 
 
 
102
  ]
103
- ],
104
- [
105
- [
106
- 13.87600040435791,
107
- 41.534000396728516,
108
- 15.024001121520996,
109
- 57.566001892089844
 
 
 
 
 
 
 
 
110
  ],
111
- [
112
- 43.11000061035156,
113
- 32.54600143432617,
114
- 10.670000076293945,
115
- 41.67400360107422
 
 
 
 
 
 
 
 
116
  ]
117
- ]
118
  ],
119
- "best_val_accuracy": 0.805,
120
  "config": {
121
  "tokenizer": "bert-base-uncased",
122
  "max_seq_len": 128,
123
  "hidden_dim": 256,
124
  "ffn_dim": 512,
125
  "num_heads": 4,
126
- "num_layers": 2,
 
127
  "num_experts": 4,
128
  "router_type": "hash",
129
  "top_k": 1,
 
2
  "history": [
3
  {
4
  "epoch": 1,
5
+ "train_loss": 6.781526548633617,
6
  "train_aux_loss": 0.0,
7
+ "train_perplexity": 881.413217772467,
8
+ "val_loss": 6.13985468708606,
9
  "val_aux_loss": 0.0,
10
+ "val_perplexity": 463.9861427818742
11
  },
12
  {
13
  "epoch": 2,
14
+ "train_loss": 5.734937044167949,
15
  "train_aux_loss": 0.0,
16
+ "train_perplexity": 309.49348572595704,
17
+ "val_loss": 5.784560862128246,
18
  "val_aux_loss": 0.0,
19
+ "val_perplexity": 325.23918390511113
20
  },
21
  {
22
  "epoch": 3,
23
+ "train_loss": 5.265651165721379,
24
  "train_aux_loss": 0.0,
25
+ "train_perplexity": 193.57231541377683,
26
+ "val_loss": 5.607645124527238,
27
  "val_aux_loss": 0.0,
28
+ "val_perplexity": 272.5017740962985
29
  }
30
  ],
31
  "train_expert_usage": [
32
+ {
33
+ "encoder": [
34
+ [
35
+ 1.1995620727539062,
36
+ 1.0382475852966309,
37
+ 1.1820646524429321,
38
+ 1.1452935934066772
39
+ ],
40
+ [
41
+ 1.0309076309204102,
42
+ 1.3545637130737305,
43
+ 1.255271315574646,
44
+ 0.9244250655174255
45
+ ]
46
  ],
47
+ "decoder": [
48
+ [
49
+ 0.767426073551178,
50
+ 0.17747089266777039,
51
+ 0.3735591173171997,
52
+ 0.1704733520746231
53
+ ],
54
+ [
55
+ 0.13853856921195984,
56
+ 0.5398667454719543,
57
+ 0.7335297465324402,
58
+ 0.0769944041967392
59
+ ]
60
  ]
61
+ },
62
+ {
63
+ "encoder": [
64
+ [
65
+ 1.220947027206421,
66
+ 1.0275551080703735,
67
+ 1.2033497095108032,
68
+ 1.1133160591125488
69
+ ],
70
+ [
71
+ 1.0806751251220703,
72
+ 1.291507363319397,
73
+ 1.1101988554000854,
74
+ 1.0827864408493042
75
+ ]
76
  ],
77
+ "decoder": [
78
+ [
79
+ 0.5099791884422302,
80
+ 0.2548719048500061,
81
+ 0.5015906691551208,
82
+ 0.2213464379310608
83
+ ],
84
+ [
85
+ 0.2474392205476761,
86
+ 0.3361818194389343,
87
+ 0.8026777505874634,
88
+ 0.10148938745260239
89
+ ]
90
  ]
91
+ },
92
+ {
93
+ "encoder": [
94
+ [
95
+ 1.1927927732467651,
96
+ 1.0216560363769531,
97
+ 1.242702841758728,
98
+ 1.1080161333084106
99
+ ],
100
+ [
101
+ 1.1292370557785034,
102
+ 1.267861247062683,
103
+ 1.047805905342102,
104
+ 1.120263695716858
105
+ ]
106
  ],
107
+ "decoder": [
108
+ [
109
+ 0.48166799545288086,
110
+ 0.27755507826805115,
111
+ 0.4995720088481903,
112
+ 0.2349848747253418
113
+ ],
114
+ [
115
+ 0.246026873588562,
116
+ 0.32668769359588623,
117
+ 0.8019145131111145,
118
+ 0.11915087699890137
119
+ ]
120
  ]
121
+ }
122
  ],
123
  "val_expert_usage": [
124
+ {
125
+ "encoder": [
126
+ [
127
+ 1.3101885318756104,
128
+ 1.0473449230194092,
129
+ 1.1509929895401,
130
+ 1.0964527130126953
131
+ ],
132
+ [
133
+ 0.9937760829925537,
134
+ 1.3266657590866089,
135
+ 1.1915743350982666,
136
+ 1.0929629802703857
137
+ ]
138
  ],
139
+ "decoder": [
140
+ [
141
+ 0.7106418013572693,
142
+ 0.2959778308868408,
143
+ 0.3988703489303589,
144
+ 0.20999424159526825
145
+ ],
146
+ [
147
+ 0.3808821439743042,
148
+ 0.277557909488678,
149
+ 0.8908476233482361,
150
+ 0.06619657576084137
151
+ ]
152
  ]
153
+ },
154
+ {
155
+ "encoder": [
156
+ [
157
+ 1.2734206914901733,
158
+ 1.0350050926208496,
159
+ 1.2253562211990356,
160
+ 1.071197271347046
161
+ ],
162
+ [
163
+ 1.047668695449829,
164
+ 1.4553172588348389,
165
+ 1.0208303928375244,
166
+ 1.0811628103256226
167
+ ]
168
  ],
169
+ "decoder": [
170
+ [
171
+ 0.5223413705825806,
172
+ 0.2563318610191345,
173
+ 0.5943660736083984,
174
+ 0.2424449622631073
175
+ ],
176
+ [
177
+ 0.22053532302379608,
178
+ 0.38325658440589905,
179
+ 0.9159231781959534,
180
+ 0.09576917439699173
181
+ ]
182
  ]
183
+ },
184
+ {
185
+ "encoder": [
186
+ [
187
+ 1.2540652751922607,
188
+ 1.0315872430801392,
189
+ 1.253417730331421,
190
+ 1.0659087896347046
191
+ ],
192
+ [
193
+ 1.1852425336837769,
194
+ 1.2326953411102295,
195
+ 1.0134551525115967,
196
+ 1.173586130142212
197
+ ]
198
  ],
199
+ "decoder": [
200
+ [
201
+ 0.4280831813812256,
202
+ 0.3081738352775574,
203
+ 0.660526692867279,
204
+ 0.21870052814483643
205
+ ],
206
+ [
207
+ 0.16326090693473816,
208
+ 0.34843143820762634,
209
+ 0.9616851210594177,
210
+ 0.1421067714691162
211
+ ]
212
  ]
213
+ }
214
  ],
215
+ "best_val_loss": 5.607645124527238,
216
  "config": {
217
  "tokenizer": "bert-base-uncased",
218
  "max_seq_len": 128,
219
  "hidden_dim": 256,
220
  "ffn_dim": 512,
221
  "num_heads": 4,
222
+ "num_encoder_layers": 2,
223
+ "num_decoder_layers": 2,
224
  "num_experts": 4,
225
  "router_type": "hash",
226
  "top_k": 1,
model.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e5bd57ce382a8aaccd173f2d58c001818432fe316cf0301d70bc367b53394d87
3
- size 41943990
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d8ca2ca01ce3311dfab7a8fe4b42dc60cac000c7d46108f12115879ff15f2b4
3
+ size 85979282
tokenizer/tokenizer.json CHANGED
@@ -6,16 +6,7 @@
6
  "strategy": "LongestFirst",
7
  "stride": 0
8
  },
9
- "padding": {
10
- "strategy": {
11
- "Fixed": 128
12
- },
13
- "direction": "Right",
14
- "pad_to_multiple_of": null,
15
- "pad_id": 0,
16
- "pad_type_id": 0,
17
- "pad_token": "[PAD]"
18
- },
19
  "added_tokens": [
20
  {
21
  "id": 0,
 
6
  "strategy": "LongestFirst",
7
  "stride": 0
8
  },
9
+ "padding": null,
 
 
 
 
 
 
 
 
 
10
  "added_tokens": [
11
  {
12
  "id": 0,