Token Classification
Safetensors
English
deberta-v2
shawnrushefsky commited on
Commit
58bec3f
·
1 Parent(s): e926f5a

use best model

Browse files
last-checkpoint/added_tokens.json DELETED
@@ -1,3 +0,0 @@
1
- {
2
- "[MASK]": 128000
3
- }
 
 
 
 
last-checkpoint/config.json DELETED
@@ -1,71 +0,0 @@
1
- {
2
- "architectures": [
3
- "DebertaV2ForTokenClassification"
4
- ],
5
- "attention_probs_dropout_prob": 0.14,
6
- "bos_token_id": 1,
7
- "dtype": "float32",
8
- "eos_token_id": 2,
9
- "hidden_act": "gelu",
10
- "hidden_dropout_prob": 0.14,
11
- "hidden_size": 768,
12
- "id2label": {
13
- "0": "O",
14
- "1": "B-CHA",
15
- "2": "I-CHA",
16
- "3": "B-LOC",
17
- "4": "I-LOC",
18
- "5": "B-FAC",
19
- "6": "I-FAC",
20
- "7": "B-OBJ",
21
- "8": "I-OBJ",
22
- "9": "B-EVT",
23
- "10": "I-EVT",
24
- "11": "B-ORG",
25
- "12": "I-ORG",
26
- "13": "B-MISC",
27
- "14": "I-MISC"
28
- },
29
- "initializer_range": 0.02,
30
- "intermediate_size": 3072,
31
- "label2id": {
32
- "B-CHA": 1,
33
- "B-EVT": 9,
34
- "B-FAC": 5,
35
- "B-LOC": 3,
36
- "B-MISC": 13,
37
- "B-OBJ": 7,
38
- "B-ORG": 11,
39
- "I-CHA": 2,
40
- "I-EVT": 10,
41
- "I-FAC": 6,
42
- "I-LOC": 4,
43
- "I-MISC": 14,
44
- "I-OBJ": 8,
45
- "I-ORG": 12,
46
- "O": 0
47
- },
48
- "layer_norm_eps": 1e-07,
49
- "legacy": true,
50
- "max_position_embeddings": 512,
51
- "max_relative_positions": -1,
52
- "model_type": "deberta-v2",
53
- "norm_rel_ebd": "layer_norm",
54
- "num_attention_heads": 12,
55
- "num_hidden_layers": 12,
56
- "pad_token_id": 0,
57
- "pooler_dropout": 0,
58
- "pooler_hidden_act": "gelu",
59
- "pooler_hidden_size": 768,
60
- "pos_att_type": [
61
- "p2c",
62
- "c2p"
63
- ],
64
- "position_biased_input": false,
65
- "position_buckets": 256,
66
- "relative_attention": true,
67
- "share_att_key": true,
68
- "transformers_version": "4.56.0",
69
- "type_vocab_size": 0,
70
- "vocab_size": 128100
71
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
last-checkpoint/model.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c1afa714bdd56bfbbb1efbf628f4c15f0b6ae266654356a88e0048e0cc7982eb
3
- size 735396724
 
 
 
 
last-checkpoint/optimizer.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a5d4e5065ff82276872b5c3d5a9b5a2c4dac0bfedf089f07966a780f6764b8dd
3
- size 1470915147
 
 
 
 
last-checkpoint/rng_state_0.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8eaef9a54b77e4410eb73091cca1813561231aab2270b6ed20afa38e56d957f0
3
- size 16325
 
 
 
 
last-checkpoint/rng_state_1.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:961966abaa0cb4be309c2dbb6bdbb184ae4138fcea22f730fdf31fa9583dd8d9
3
- size 16325
 
 
 
 
last-checkpoint/rng_state_2.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:116bac6566cf0d4caaea0074b84a22a40b2bd48fc5625a8d91c71f896a73b639
3
- size 16325
 
 
 
 
last-checkpoint/rng_state_3.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:897cd381e5d0e4f14794bfbc22ff307365b08a5ee1d7a2f7ca8224735e21d7e4
3
- size 16325
 
 
 
 
last-checkpoint/rng_state_4.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:aa521eed41a9fa3a1caf2b5b93bb2d46a8d473401226946c74853145f0fa0bbc
3
- size 16325
 
 
 
 
last-checkpoint/rng_state_5.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1985ee7e4ad215de1b263bc720f678bd0365e11ad07a2b9a440683734aa5e894
3
- size 16325
 
 
 
 
last-checkpoint/rng_state_6.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:884f53d156a859a4328175084fdf5db1ceaedae01554df966e9595888d0f4139
3
- size 16325
 
 
 
 
last-checkpoint/rng_state_7.pth DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:cfb6ded3154a8a3c9fead3a4146b442c6599a2709d9507537d87653ee9014dc0
3
- size 16325
 
 
 
 
last-checkpoint/scheduler.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f59d63c52755d6dcb913edf584d7b233e8a521a9d79564a1c979c864c86896da
3
- size 1465
 
 
 
 
last-checkpoint/special_tokens_map.json DELETED
@@ -1,15 +0,0 @@
1
- {
2
- "bos_token": "[CLS]",
3
- "cls_token": "[CLS]",
4
- "eos_token": "[SEP]",
5
- "mask_token": "[MASK]",
6
- "pad_token": "[PAD]",
7
- "sep_token": "[SEP]",
8
- "unk_token": {
9
- "content": "[UNK]",
10
- "lstrip": false,
11
- "normalized": true,
12
- "rstrip": false,
13
- "single_word": false
14
- }
15
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
last-checkpoint/spm.model DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd
3
- size 2464616
 
 
 
 
last-checkpoint/tokenizer.json DELETED
The diff for this file is too large to render. See raw diff
 
last-checkpoint/tokenizer_config.json DELETED
@@ -1,59 +0,0 @@
1
- {
2
- "added_tokens_decoder": {
3
- "0": {
4
- "content": "[PAD]",
5
- "lstrip": false,
6
- "normalized": false,
7
- "rstrip": false,
8
- "single_word": false,
9
- "special": true
10
- },
11
- "1": {
12
- "content": "[CLS]",
13
- "lstrip": false,
14
- "normalized": false,
15
- "rstrip": false,
16
- "single_word": false,
17
- "special": true
18
- },
19
- "2": {
20
- "content": "[SEP]",
21
- "lstrip": false,
22
- "normalized": false,
23
- "rstrip": false,
24
- "single_word": false,
25
- "special": true
26
- },
27
- "3": {
28
- "content": "[UNK]",
29
- "lstrip": false,
30
- "normalized": true,
31
- "rstrip": false,
32
- "single_word": false,
33
- "special": true
34
- },
35
- "128000": {
36
- "content": "[MASK]",
37
- "lstrip": false,
38
- "normalized": false,
39
- "rstrip": false,
40
- "single_word": false,
41
- "special": true
42
- }
43
- },
44
- "bos_token": "[CLS]",
45
- "clean_up_tokenization_spaces": false,
46
- "cls_token": "[CLS]",
47
- "do_lower_case": false,
48
- "eos_token": "[SEP]",
49
- "extra_special_tokens": {},
50
- "mask_token": "[MASK]",
51
- "model_max_length": 1000000000000000019884624838656,
52
- "pad_token": "[PAD]",
53
- "sep_token": "[SEP]",
54
- "sp_model_kwargs": {},
55
- "split_by_punct": false,
56
- "tokenizer_class": "DebertaV2Tokenizer",
57
- "unk_token": "[UNK]",
58
- "vocab_type": "spm"
59
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
last-checkpoint/trainer_state.json DELETED
@@ -1,897 +0,0 @@
1
- {
2
- "best_global_step": 4395,
3
- "best_metric": 0.7449346267936984,
4
- "best_model_checkpoint": "model/checkpoint-4395",
5
- "epoch": 1.0,
6
- "eval_steps": 1465,
7
- "global_step": 5860,
8
- "is_hyper_param_search": false,
9
- "is_local_process_zero": true,
10
- "is_world_process_zero": true,
11
- "log_history": [
12
- {
13
- "epoch": 0.008532423208191127,
14
- "grad_norm": 1.5641576051712036,
15
- "learning_rate": 6.960227272727272e-06,
16
- "loss": 1.6624,
17
- "step": 50
18
- },
19
- {
20
- "epoch": 0.017064846416382253,
21
- "grad_norm": 0.4061637222766876,
22
- "learning_rate": 1.4062500000000001e-05,
23
- "loss": 0.3169,
24
- "step": 100
25
- },
26
- {
27
- "epoch": 0.025597269624573378,
28
- "grad_norm": 0.4451664984226227,
29
- "learning_rate": 2.116477272727273e-05,
30
- "loss": 0.1786,
31
- "step": 150
32
- },
33
- {
34
- "epoch": 0.034129692832764506,
35
- "grad_norm": 0.3457334637641907,
36
- "learning_rate": 2.499898999982817e-05,
37
- "loss": 0.1447,
38
- "step": 200
39
- },
40
- {
41
- "epoch": 0.042662116040955635,
42
- "grad_norm": 0.45870351791381836,
43
- "learning_rate": 2.4989826780227188e-05,
44
- "loss": 0.1324,
45
- "step": 250
46
- },
47
- {
48
- "epoch": 0.051194539249146756,
49
- "grad_norm": 0.3399343490600586,
50
- "learning_rate": 2.4971125493142457e-05,
51
- "loss": 0.1289,
52
- "step": 300
53
- },
54
- {
55
- "epoch": 0.059726962457337884,
56
- "grad_norm": 0.47232744097709656,
57
- "learning_rate": 2.4942900420128184e-05,
58
- "loss": 0.1235,
59
- "step": 350
60
- },
61
- {
62
- "epoch": 0.06825938566552901,
63
- "grad_norm": 0.31067341566085815,
64
- "learning_rate": 2.49051731157388e-05,
65
- "loss": 0.1194,
66
- "step": 400
67
- },
68
- {
69
- "epoch": 0.07679180887372014,
70
- "grad_norm": 0.3713296353816986,
71
- "learning_rate": 2.485797239106845e-05,
72
- "loss": 0.121,
73
- "step": 450
74
- },
75
- {
76
- "epoch": 0.08532423208191127,
77
- "grad_norm": 0.3094009459018707,
78
- "learning_rate": 2.4801334291748917e-05,
79
- "loss": 0.1173,
80
- "step": 500
81
- },
82
- {
83
- "epoch": 0.09385665529010238,
84
- "grad_norm": 0.283159464597702,
85
- "learning_rate": 2.473530207042278e-05,
86
- "loss": 0.1208,
87
- "step": 550
88
- },
89
- {
90
- "epoch": 0.10238907849829351,
91
- "grad_norm": 0.30564266443252563,
92
- "learning_rate": 2.4659926153712765e-05,
93
- "loss": 0.1116,
94
- "step": 600
95
- },
96
- {
97
- "epoch": 0.11092150170648464,
98
- "grad_norm": 0.2683025002479553,
99
- "learning_rate": 2.4575264103712642e-05,
100
- "loss": 0.1104,
101
- "step": 650
102
- },
103
- {
104
- "epoch": 0.11945392491467577,
105
- "grad_norm": 0.3415144085884094,
106
- "learning_rate": 2.4481380574028934e-05,
107
- "loss": 0.1119,
108
- "step": 700
109
- },
110
- {
111
- "epoch": 0.12798634812286688,
112
- "grad_norm": 0.250636488199234,
113
- "learning_rate": 2.437834726040711e-05,
114
- "loss": 0.108,
115
- "step": 750
116
- },
117
- {
118
- "epoch": 0.13651877133105803,
119
- "grad_norm": 0.2625179886817932,
120
- "learning_rate": 2.4266242845979902e-05,
121
- "loss": 0.107,
122
- "step": 800
123
- },
124
- {
125
- "epoch": 0.14505119453924914,
126
- "grad_norm": 0.31926393508911133,
127
- "learning_rate": 2.4145152941179615e-05,
128
- "loss": 0.1075,
129
- "step": 850
130
- },
131
- {
132
- "epoch": 0.15358361774744028,
133
- "grad_norm": 0.2925693988800049,
134
- "learning_rate": 2.401517001836026e-05,
135
- "loss": 0.1087,
136
- "step": 900
137
- },
138
- {
139
- "epoch": 0.1621160409556314,
140
- "grad_norm": 0.2825545370578766,
141
- "learning_rate": 2.3876393341179486e-05,
142
- "loss": 0.1044,
143
- "step": 950
144
- },
145
- {
146
- "epoch": 0.17064846416382254,
147
- "grad_norm": 0.33026885986328125,
148
- "learning_rate": 2.3728928888794205e-05,
149
- "loss": 0.1052,
150
- "step": 1000
151
- },
152
- {
153
- "epoch": 0.17918088737201365,
154
- "grad_norm": 0.32080453634262085,
155
- "learning_rate": 2.3572889274927805e-05,
156
- "loss": 0.1072,
157
- "step": 1050
158
- },
159
- {
160
- "epoch": 0.18771331058020477,
161
- "grad_norm": 0.24436074495315552,
162
- "learning_rate": 2.3408393661870808e-05,
163
- "loss": 0.1031,
164
- "step": 1100
165
- },
166
- {
167
- "epoch": 0.1962457337883959,
168
- "grad_norm": 0.22855418920516968,
169
- "learning_rate": 2.3235567669480528e-05,
170
- "loss": 0.104,
171
- "step": 1150
172
- },
173
- {
174
- "epoch": 0.20477815699658702,
175
- "grad_norm": 0.2983805239200592,
176
- "learning_rate": 2.3054543279249373e-05,
177
- "loss": 0.1048,
178
- "step": 1200
179
- },
180
- {
181
- "epoch": 0.21331058020477817,
182
- "grad_norm": 0.22958669066429138,
183
- "learning_rate": 2.286545873351494e-05,
184
- "loss": 0.1042,
185
- "step": 1250
186
- },
187
- {
188
- "epoch": 0.22184300341296928,
189
- "grad_norm": 0.2709587514400482,
190
- "learning_rate": 2.2668458429888906e-05,
191
- "loss": 0.1028,
192
- "step": 1300
193
- },
194
- {
195
- "epoch": 0.23037542662116042,
196
- "grad_norm": 0.3249916136264801,
197
- "learning_rate": 2.2463692810985354e-05,
198
- "loss": 0.104,
199
- "step": 1350
200
- },
201
- {
202
- "epoch": 0.23890784982935154,
203
- "grad_norm": 0.33006545901298523,
204
- "learning_rate": 2.225131824953274e-05,
205
- "loss": 0.1011,
206
- "step": 1400
207
- },
208
- {
209
- "epoch": 0.24744027303754265,
210
- "grad_norm": 0.27488061785697937,
211
- "learning_rate": 2.203149692895718e-05,
212
- "loss": 0.1015,
213
- "step": 1450
214
- },
215
- {
216
- "epoch": 0.25,
217
- "eval_entity_f1": 0.7333329297742454,
218
- "eval_entity_precision": 0.6570234291058094,
219
- "eval_entity_recall": 0.876211987659762,
220
- "eval_loss": 0.09791089594364166,
221
- "eval_runtime": 794.2111,
222
- "eval_samples_per_second": 1510.933,
223
- "eval_steps_per_second": 23.608,
224
- "step": 1465
225
- },
226
- {
227
- "epoch": 0.25597269624573377,
228
- "grad_norm": 0.21379445493221283,
229
- "learning_rate": 2.180439671952838e-05,
230
- "loss": 0.1029,
231
- "step": 1500
232
- },
233
- {
234
- "epoch": 0.2645051194539249,
235
- "grad_norm": 0.24124079942703247,
236
- "learning_rate": 2.157019105016262e-05,
237
- "loss": 0.1021,
238
- "step": 1550
239
- },
240
- {
241
- "epoch": 0.27303754266211605,
242
- "grad_norm": 0.2577572464942932,
243
- "learning_rate": 2.1329058775980853e-05,
244
- "loss": 0.1009,
245
- "step": 1600
246
- },
247
- {
248
- "epoch": 0.2815699658703072,
249
- "grad_norm": 0.25605666637420654,
250
- "learning_rate": 2.1081184041722966e-05,
251
- "loss": 0.0994,
252
- "step": 1650
253
- },
254
- {
255
- "epoch": 0.2901023890784983,
256
- "grad_norm": 0.20153528451919556,
257
- "learning_rate": 2.0826756141122535e-05,
258
- "loss": 0.101,
259
- "step": 1700
260
- },
261
- {
262
- "epoch": 0.2986348122866894,
263
- "grad_norm": 0.24819990992546082,
264
- "learning_rate": 2.0565969372349447e-05,
265
- "loss": 0.1005,
266
- "step": 1750
267
- },
268
- {
269
- "epoch": 0.30716723549488056,
270
- "grad_norm": 0.282105416059494,
271
- "learning_rate": 2.0299022889630834e-05,
272
- "loss": 0.0989,
273
- "step": 1800
274
- },
275
- {
276
- "epoch": 0.31569965870307165,
277
- "grad_norm": 0.24825932085514069,
278
- "learning_rate": 2.0026120551163576e-05,
279
- "loss": 0.1015,
280
- "step": 1850
281
- },
282
- {
283
- "epoch": 0.3242320819112628,
284
- "grad_norm": 0.19880250096321106,
285
- "learning_rate": 1.9747470763434527e-05,
286
- "loss": 0.0981,
287
- "step": 1900
288
- },
289
- {
290
- "epoch": 0.33276450511945393,
291
- "grad_norm": 0.25610214471817017,
292
- "learning_rate": 1.9463286322067397e-05,
293
- "loss": 0.0993,
294
- "step": 1950
295
- },
296
- {
297
- "epoch": 0.3412969283276451,
298
- "grad_norm": 0.28560763597488403,
299
- "learning_rate": 1.9173784249317774e-05,
300
- "loss": 0.097,
301
- "step": 2000
302
- },
303
- {
304
- "epoch": 0.34982935153583616,
305
- "grad_norm": 0.36660489439964294,
306
- "learning_rate": 1.8879185628340366e-05,
307
- "loss": 0.0965,
308
- "step": 2050
309
- },
310
- {
311
- "epoch": 0.3583617747440273,
312
- "grad_norm": 0.26472488045692444,
313
- "learning_rate": 1.8579715434355174e-05,
314
- "loss": 0.0988,
315
- "step": 2100
316
- },
317
- {
318
- "epoch": 0.36689419795221845,
319
- "grad_norm": 0.25718653202056885,
320
- "learning_rate": 1.8275602362841312e-05,
321
- "loss": 0.0989,
322
- "step": 2150
323
- },
324
- {
325
- "epoch": 0.37542662116040953,
326
- "grad_norm": 0.2548709213733673,
327
- "learning_rate": 1.7967078654889858e-05,
328
- "loss": 0.0974,
329
- "step": 2200
330
- },
331
- {
332
- "epoch": 0.3839590443686007,
333
- "grad_norm": 0.235497385263443,
334
- "learning_rate": 1.7654379919849003e-05,
335
- "loss": 0.0943,
336
- "step": 2250
337
- },
338
- {
339
- "epoch": 0.3924914675767918,
340
- "grad_norm": 0.31954771280288696,
341
- "learning_rate": 1.7337744955397012e-05,
342
- "loss": 0.0965,
343
- "step": 2300
344
- },
345
- {
346
- "epoch": 0.40102389078498296,
347
- "grad_norm": 0.2830738425254822,
348
- "learning_rate": 1.7017415565180293e-05,
349
- "loss": 0.0964,
350
- "step": 2350
351
- },
352
- {
353
- "epoch": 0.40955631399317405,
354
- "grad_norm": 0.23557038605213165,
355
- "learning_rate": 1.669363637415601e-05,
356
- "loss": 0.096,
357
- "step": 2400
358
- },
359
- {
360
- "epoch": 0.4180887372013652,
361
- "grad_norm": 0.3127359449863434,
362
- "learning_rate": 1.636665464178004e-05,
363
- "loss": 0.0951,
364
- "step": 2450
365
- },
366
- {
367
- "epoch": 0.42662116040955633,
368
- "grad_norm": 0.2751990556716919,
369
- "learning_rate": 1.603672007318316e-05,
370
- "loss": 0.0962,
371
- "step": 2500
372
- },
373
- {
374
- "epoch": 0.4351535836177474,
375
- "grad_norm": 0.23224587738513947,
376
- "learning_rate": 1.5704084628479443e-05,
377
- "loss": 0.0975,
378
- "step": 2550
379
- },
380
- {
381
- "epoch": 0.44368600682593856,
382
- "grad_norm": 0.2621734142303467,
383
- "learning_rate": 1.536900233035271e-05,
384
- "loss": 0.0947,
385
- "step": 2600
386
- },
387
- {
388
- "epoch": 0.4522184300341297,
389
- "grad_norm": 0.2019677609205246,
390
- "learning_rate": 1.5031729070067773e-05,
391
- "loss": 0.0967,
392
- "step": 2650
393
- },
394
- {
395
- "epoch": 0.46075085324232085,
396
- "grad_norm": 0.2038186639547348,
397
- "learning_rate": 1.4692522412054772e-05,
398
- "loss": 0.095,
399
- "step": 2700
400
- },
401
- {
402
- "epoch": 0.46928327645051193,
403
- "grad_norm": 0.25815144181251526,
404
- "learning_rate": 1.4351641397215703e-05,
405
- "loss": 0.0935,
406
- "step": 2750
407
- },
408
- {
409
- "epoch": 0.4778156996587031,
410
- "grad_norm": 0.2345559149980545,
411
- "learning_rate": 1.4009346345103494e-05,
412
- "loss": 0.0947,
413
- "step": 2800
414
- },
415
- {
416
- "epoch": 0.4863481228668942,
417
- "grad_norm": 0.20084676146507263,
418
- "learning_rate": 1.366589865512454e-05,
419
- "loss": 0.0946,
420
- "step": 2850
421
- },
422
- {
423
- "epoch": 0.4948805460750853,
424
- "grad_norm": 0.29759591817855835,
425
- "learning_rate": 1.3321560606916652e-05,
426
- "loss": 0.0951,
427
- "step": 2900
428
- },
429
- {
430
- "epoch": 0.5,
431
- "eval_entity_f1": 0.7423178801283339,
432
- "eval_entity_precision": 0.6650263167961575,
433
- "eval_entity_recall": 0.8804226585921899,
434
- "eval_loss": 0.08911468833684921,
435
- "eval_runtime": 790.8047,
436
- "eval_samples_per_second": 1517.442,
437
- "eval_steps_per_second": 23.71,
438
- "step": 2930
439
- },
440
- {
441
- "epoch": 0.5034129692832765,
442
- "grad_norm": 0.2829442322254181,
443
- "learning_rate": 1.2976595160054744e-05,
444
- "loss": 0.0956,
445
- "step": 2950
446
- },
447
- {
448
- "epoch": 0.5119453924914675,
449
- "grad_norm": 0.3120681345462799,
450
- "learning_rate": 1.263126575323735e-05,
451
- "loss": 0.0922,
452
- "step": 3000
453
- },
454
- {
455
- "epoch": 0.5204778156996587,
456
- "grad_norm": 0.2506762444972992,
457
- "learning_rate": 1.228583610310716e-05,
458
- "loss": 0.0943,
459
- "step": 3050
460
- },
461
- {
462
- "epoch": 0.5290102389078498,
463
- "grad_norm": 0.19913755357265472,
464
- "learning_rate": 1.1940570002859372e-05,
465
- "loss": 0.0944,
466
- "step": 3100
467
- },
468
- {
469
- "epoch": 0.537542662116041,
470
- "grad_norm": 0.2714909613132477,
471
- "learning_rate": 1.1595731120791551e-05,
472
- "loss": 0.0924,
473
- "step": 3150
474
- },
475
- {
476
- "epoch": 0.5460750853242321,
477
- "grad_norm": 0.22429101169109344,
478
- "learning_rate": 1.1251582798948877e-05,
479
- "loss": 0.0924,
480
- "step": 3200
481
- },
482
- {
483
- "epoch": 0.5546075085324232,
484
- "grad_norm": 0.18535013496875763,
485
- "learning_rate": 1.0908387852018519e-05,
486
- "loss": 0.0943,
487
- "step": 3250
488
- },
489
- {
490
- "epoch": 0.5631399317406144,
491
- "grad_norm": 0.24078741669654846,
492
- "learning_rate": 1.0566408366626783e-05,
493
- "loss": 0.0955,
494
- "step": 3300
495
- },
496
- {
497
- "epoch": 0.5716723549488054,
498
- "grad_norm": 0.21610242128372192,
499
- "learning_rate": 1.0225905501192207e-05,
500
- "loss": 0.0929,
501
- "step": 3350
502
- },
503
- {
504
- "epoch": 0.5802047781569966,
505
- "grad_norm": 0.21978144347667694,
506
- "learning_rate": 9.887139286487521e-06,
507
- "loss": 0.0949,
508
- "step": 3400
509
- },
510
- {
511
- "epoch": 0.5887372013651877,
512
- "grad_norm": 0.25699079036712646,
513
- "learning_rate": 9.550368427062745e-06,
514
- "loss": 0.0914,
515
- "step": 3450
516
- },
517
- {
518
- "epoch": 0.5972696245733788,
519
- "grad_norm": 0.21956154704093933,
520
- "learning_rate": 9.215850103681096e-06,
521
- "loss": 0.0946,
522
- "step": 3500
523
- },
524
- {
525
- "epoch": 0.60580204778157,
526
- "grad_norm": 0.23602405190467834,
527
- "learning_rate": 8.883839776918538e-06,
528
- "loss": 0.0935,
529
- "step": 3550
530
- },
531
- {
532
- "epoch": 0.6143344709897611,
533
- "grad_norm": 0.24697446823120117,
534
- "learning_rate": 8.554590992077e-06,
535
- "loss": 0.092,
536
- "step": 3600
537
- },
538
- {
539
- "epoch": 0.6228668941979523,
540
- "grad_norm": 0.24195240437984467,
541
- "learning_rate": 8.228355185560196e-06,
542
- "loss": 0.0927,
543
- "step": 3650
544
- },
545
- {
546
- "epoch": 0.6313993174061433,
547
- "grad_norm": 0.25726425647735596,
548
- "learning_rate": 7.905381492859997e-06,
549
- "loss": 0.0942,
550
- "step": 3700
551
- },
552
- {
553
- "epoch": 0.6399317406143344,
554
- "grad_norm": 0.230339914560318,
555
- "learning_rate": 7.5859165582998655e-06,
556
- "loss": 0.0947,
557
- "step": 3750
558
- },
559
- {
560
- "epoch": 0.6484641638225256,
561
- "grad_norm": 0.34315890073776245,
562
- "learning_rate": 7.270204346680777e-06,
563
- "loss": 0.0924,
564
- "step": 3800
565
- },
566
- {
567
- "epoch": 0.6569965870307167,
568
- "grad_norm": 0.21445219218730927,
569
- "learning_rate": 6.958485956973332e-06,
570
- "loss": 0.0965,
571
- "step": 3850
572
- },
573
- {
574
- "epoch": 0.6655290102389079,
575
- "grad_norm": 0.2074640691280365,
576
- "learning_rate": 6.650999438198499e-06,
577
- "loss": 0.093,
578
- "step": 3900
579
- },
580
- {
581
- "epoch": 0.674061433447099,
582
- "grad_norm": 0.1953040212392807,
583
- "learning_rate": 6.347979607637408e-06,
584
- "loss": 0.0923,
585
- "step": 3950
586
- },
587
- {
588
- "epoch": 0.6825938566552902,
589
- "grad_norm": 0.19871976971626282,
590
- "learning_rate": 6.049657871509198e-06,
591
- "loss": 0.0925,
592
- "step": 4000
593
- },
594
- {
595
- "epoch": 0.6911262798634812,
596
- "grad_norm": 0.2851618230342865,
597
- "learning_rate": 5.756262048253709e-06,
598
- "loss": 0.0936,
599
- "step": 4050
600
- },
601
- {
602
- "epoch": 0.6996587030716723,
603
- "grad_norm": 0.25152096152305603,
604
- "learning_rate": 5.468016194554112e-06,
605
- "loss": 0.0904,
606
- "step": 4100
607
- },
608
- {
609
- "epoch": 0.7081911262798635,
610
- "grad_norm": 0.29341500997543335,
611
- "learning_rate": 5.185140434232203e-06,
612
- "loss": 0.0915,
613
- "step": 4150
614
- },
615
- {
616
- "epoch": 0.7167235494880546,
617
- "grad_norm": 0.19771753251552582,
618
- "learning_rate": 4.907850790147146e-06,
619
- "loss": 0.0921,
620
- "step": 4200
621
- },
622
- {
623
- "epoch": 0.7252559726962458,
624
- "grad_norm": 0.2772028148174286,
625
- "learning_rate": 4.636359019225947e-06,
626
- "loss": 0.0926,
627
- "step": 4250
628
- },
629
- {
630
- "epoch": 0.7337883959044369,
631
- "grad_norm": 0.27628397941589355,
632
- "learning_rate": 4.370872450751694e-06,
633
- "loss": 0.0896,
634
- "step": 4300
635
- },
636
- {
637
- "epoch": 0.742320819112628,
638
- "grad_norm": 0.23001554608345032,
639
- "learning_rate": 4.111593828033067e-06,
640
- "loss": 0.0877,
641
- "step": 4350
642
- },
643
- {
644
- "epoch": 0.75,
645
- "eval_entity_f1": 0.7449346267936984,
646
- "eval_entity_precision": 0.668926309530101,
647
- "eval_entity_recall": 0.878515580366616,
648
- "eval_loss": 0.08664915710687637,
649
- "eval_runtime": 792.2218,
650
- "eval_samples_per_second": 1514.727,
651
- "eval_steps_per_second": 23.668,
652
- "step": 4395
653
- },
654
- {
655
- "epoch": 0.7508532423208191,
656
- "grad_norm": 0.23986265063285828,
657
- "learning_rate": 3.858721153575945e-06,
658
- "loss": 0.0898,
659
- "step": 4400
660
- },
661
- {
662
- "epoch": 0.7593856655290102,
663
- "grad_norm": 0.17509053647518158,
664
- "learning_rate": 3.6124475378754783e-06,
665
- "loss": 0.0927,
666
- "step": 4450
667
- },
668
- {
669
- "epoch": 0.7679180887372014,
670
- "grad_norm": 0.23464491963386536,
671
- "learning_rate": 3.3729610519439585e-06,
672
- "loss": 0.0911,
673
- "step": 4500
674
- },
675
- {
676
- "epoch": 0.7764505119453925,
677
- "grad_norm": 0.2226947546005249,
678
- "learning_rate": 3.140444583687245e-06,
679
- "loss": 0.0934,
680
- "step": 4550
681
- },
682
- {
683
- "epoch": 0.7849829351535836,
684
- "grad_norm": 0.21338465809822083,
685
- "learning_rate": 2.915075698239285e-06,
686
- "loss": 0.0901,
687
- "step": 4600
688
- },
689
- {
690
- "epoch": 0.7935153583617748,
691
- "grad_norm": 0.1981608271598816,
692
- "learning_rate": 2.6970265023615297e-06,
693
- "loss": 0.0904,
694
- "step": 4650
695
- },
696
- {
697
- "epoch": 0.8020477815699659,
698
- "grad_norm": 0.19228731095790863,
699
- "learning_rate": 2.4864635130106645e-06,
700
- "loss": 0.0893,
701
- "step": 4700
702
- },
703
- {
704
- "epoch": 0.810580204778157,
705
- "grad_norm": 0.309893935918808,
706
- "learning_rate": 2.283547530175148e-06,
707
- "loss": 0.0896,
708
- "step": 4750
709
- },
710
- {
711
- "epoch": 0.8191126279863481,
712
- "grad_norm": 0.20562194287776947,
713
- "learning_rate": 2.0884335140775522e-06,
714
- "loss": 0.0922,
715
- "step": 4800
716
- },
717
- {
718
- "epoch": 0.8276450511945392,
719
- "grad_norm": 0.2413642853498459,
720
- "learning_rate": 1.901270466836584e-06,
721
- "loss": 0.0886,
722
- "step": 4850
723
- },
724
- {
725
- "epoch": 0.8361774744027304,
726
- "grad_norm": 0.20934799313545227,
727
- "learning_rate": 1.7222013186790995e-06,
728
- "loss": 0.0912,
729
- "step": 4900
730
- },
731
- {
732
- "epoch": 0.8447098976109215,
733
- "grad_norm": 0.20717120170593262,
734
- "learning_rate": 1.5513628187890136e-06,
735
- "loss": 0.0929,
736
- "step": 4950
737
- },
738
- {
739
- "epoch": 0.8532423208191127,
740
- "grad_norm": 0.24309539794921875,
741
- "learning_rate": 1.3888854308764631e-06,
742
- "loss": 0.0907,
743
- "step": 5000
744
- },
745
- {
746
- "epoch": 0.8617747440273038,
747
- "grad_norm": 0.23167067766189575,
748
- "learning_rate": 1.2348932335469992e-06,
749
- "loss": 0.092,
750
- "step": 5050
751
- },
752
- {
753
- "epoch": 0.8703071672354948,
754
- "grad_norm": 0.24426911771297455,
755
- "learning_rate": 1.0895038255468643e-06,
756
- "loss": 0.0913,
757
- "step": 5100
758
- },
759
- {
760
- "epoch": 0.878839590443686,
761
- "grad_norm": 0.21875974535942078,
762
- "learning_rate": 9.528282359567153e-07,
763
- "loss": 0.0919,
764
- "step": 5150
765
- },
766
- {
767
- "epoch": 0.8873720136518771,
768
- "grad_norm": 0.15874198079109192,
769
- "learning_rate": 8.249708394023767e-07,
770
- "loss": 0.0911,
771
- "step": 5200
772
- },
773
- {
774
- "epoch": 0.8959044368600683,
775
- "grad_norm": 0.26125800609588623,
776
- "learning_rate": 7.060292763474142e-07,
777
- "loss": 0.0908,
778
- "step": 5250
779
- },
780
- {
781
- "epoch": 0.9044368600682594,
782
- "grad_norm": 0.25548022985458374,
783
- "learning_rate": 5.960943785283293e-07,
784
- "loss": 0.0907,
785
- "step": 5300
786
- },
787
- {
788
- "epoch": 0.9129692832764505,
789
- "grad_norm": 0.18535113334655762,
790
- "learning_rate": 4.9525009958937e-07,
791
- "loss": 0.0883,
792
- "step": 5350
793
- },
794
- {
795
- "epoch": 0.9215017064846417,
796
- "grad_norm": 0.29517892003059387,
797
- "learning_rate": 4.03573450969906e-07,
798
- "loss": 0.0902,
799
- "step": 5400
800
- },
801
- {
802
- "epoch": 0.9300341296928327,
803
- "grad_norm": 0.21196790039539337,
804
- "learning_rate": 3.211344430933516e-07,
805
- "loss": 0.0899,
806
- "step": 5450
807
- },
808
- {
809
- "epoch": 0.9385665529010239,
810
- "grad_norm": 0.21812868118286133,
811
- "learning_rate": 2.479960319025129e-07,
812
- "loss": 0.0903,
813
- "step": 5500
814
- },
815
- {
816
- "epoch": 0.947098976109215,
817
- "grad_norm": 0.21973644196987152,
818
- "learning_rate": 1.8421407078221404e-07,
819
- "loss": 0.0906,
820
- "step": 5550
821
- },
822
- {
823
- "epoch": 0.9556313993174061,
824
- "grad_norm": 0.24090933799743652,
825
- "learning_rate": 1.2983726790592592e-07,
826
- "loss": 0.0896,
827
- "step": 5600
828
- },
829
- {
830
- "epoch": 0.9641638225255973,
831
- "grad_norm": 0.21915237605571747,
832
- "learning_rate": 8.490714903894025e-08,
833
- "loss": 0.091,
834
- "step": 5650
835
- },
836
- {
837
- "epoch": 0.9726962457337884,
838
- "grad_norm": 0.1731707900762558,
839
- "learning_rate": 4.94580258265126e-08,
840
- "loss": 0.0893,
841
- "step": 5700
842
- },
843
- {
844
- "epoch": 0.9812286689419796,
845
- "grad_norm": 0.2428259402513504,
846
- "learning_rate": 2.3516969591198813e-08,
847
- "loss": 0.0926,
848
- "step": 5750
849
- },
850
- {
851
- "epoch": 0.9897610921501706,
852
- "grad_norm": 0.23712676763534546,
853
- "learning_rate": 7.103790659380993e-09,
854
- "loss": 0.0948,
855
- "step": 5800
856
- },
857
- {
858
- "epoch": 0.9982935153583617,
859
- "grad_norm": 0.24450626969337463,
860
- "learning_rate": 2.3102323277596205e-10,
861
- "loss": 0.0888,
862
- "step": 5850
863
- },
864
- {
865
- "epoch": 1.0,
866
- "eval_entity_f1": 0.744815399729913,
867
- "eval_entity_precision": 0.6683184682739546,
868
- "eval_entity_recall": 0.8811592369510689,
869
- "eval_loss": 0.08635299652814865,
870
- "eval_runtime": 790.7469,
871
- "eval_samples_per_second": 1517.552,
872
- "eval_steps_per_second": 23.712,
873
- "step": 5860
874
- }
875
- ],
876
- "logging_steps": 50,
877
- "max_steps": 5860,
878
- "num_input_tokens_seen": 0,
879
- "num_train_epochs": 1,
880
- "save_steps": 1465,
881
- "stateful_callbacks": {
882
- "TrainerControl": {
883
- "args": {
884
- "should_epoch_stop": false,
885
- "should_evaluate": false,
886
- "should_log": false,
887
- "should_save": true,
888
- "should_training_stop": true
889
- },
890
- "attributes": {}
891
- }
892
- },
893
- "total_flos": 7.412859979369021e+17,
894
- "train_batch_size": 256,
895
- "trial_name": null,
896
- "trial_params": null
897
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
last-checkpoint/training_args.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5cd0cadb15b38d5d62eb7ba1c8d8cf6d8ff7a7453651ed77c16de5239c1b3221
3
- size 5905
 
 
 
 
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c1afa714bdd56bfbbb1efbf628f4c15f0b6ae266654356a88e0048e0cc7982eb
3
  size 735396724
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ade3a987bda6b550c43dc13485050e50f44bad5a2c59f156a6c24f534e2b131d
3
  size 735396724
spm.model DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c679fbf93643d19aab7ee10c0b99e460bdbc02fedf34b92b05af343b4af586fd
3
- size 2464616
 
 
 
 
train.ipynb ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "b00e4cd9",
7
+ "metadata": {
8
+ "scrolled": true
9
+ },
10
+ "outputs": [],
11
+ "source": [
12
+ "!hf download SaladTechnologies/fiction-ner-750m --quiet --repo-type=dataset --local-dir .\n",
13
+ "!unzip -q data.zip"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "id": "b1be4895",
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "import string\n",
24
+ "import random\n",
25
+ "\n",
26
+ "def get_random_string(length=8):\n",
27
+ " \"\"\"Generate a random string of fixed length.\"\"\"\n",
28
+ " letters = string.ascii_letters\n",
29
+ " return ''.join(random.choice(letters) for i in range(length))\n",
30
+ "\n",
31
+ "run_name = f\"ner-{get_random_string(8)}\""
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "id": "f21e8995",
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "from accelerate import notebook_launcher\n",
42
+ "import os\n",
43
+ "\n",
44
+ "\n",
45
+ "cuda_visible_devices = os.getenv(\"CUDA_VISIBLE_DEVICES\", \"0\")\n",
46
+ "num_devices = len(cuda_visible_devices.split(\",\"))\n",
47
+ "\n",
48
+ "\n",
49
+ "def train_fn():\n",
50
+ " global num_processes\n",
51
+ " from datasets import Dataset, concatenate_datasets\n",
52
+ " import pandas as pd\n",
53
+ " from pathlib import Path\n",
54
+ " import random\n",
55
+ " from transformers import AutoTokenizer\n",
56
+ " import torch\n",
57
+ " import numpy as np\n",
58
+ " from transformers import AutoModelForTokenClassification\n",
59
+ " from transformers.data.data_collator import DataCollatorForTokenClassification\n",
60
+ " from transformers.training_args import TrainingArguments\n",
61
+ " from transformers.trainer import Trainer\n",
62
+ " from transformers.trainer_callback import TrainerCallback\n",
63
+ " import numpy as np\n",
64
+ " from sklearn.metrics import precision_recall_fscore_support\n",
65
+ " import os\n",
66
+ " import wandb\n",
67
+ "\n",
68
+ " num_epochs = int(os.getenv(\"NUM_EPOCHS\", 3))\n",
69
+ " output_dir = os.getenv(\"OUTPUT_DIR\", \"./model\")\n",
70
+ " seed = int(os.getenv(\"RANDOM_SEED\", 42))\n",
71
+ " model_id = os.getenv(\"MODEL_ID\")\n",
72
+ " hub_token = os.getenv(\"HF_TOKEN\")\n",
73
+ " save_steps = float(os.getenv(\"SAVE_STEPS\", 100))\n",
74
+ " if save_steps.is_integer():\n",
75
+ " save_steps = int(save_steps)\n",
76
+ " train_size = float(os.getenv(\"TRAIN_SIZE\", 4_000_000))\n",
77
+ " test_size = float(os.getenv(\"TEST_SIZE\", 400_000))\n",
78
+ " if train_size.is_integer():\n",
79
+ " train_size = int(train_size)\n",
80
+ " if test_size.is_integer():\n",
81
+ " test_size = int(test_size)\n",
82
+ " hidden_dropout_prob = float(os.getenv(\"HIDDEN_DROPOUT_PROB\", 0.14))\n",
83
+ " attention_probs_dropout_prob = float(os.getenv(\"ATTENTION_PROBS_DROPOUT_PROB\", 0.14))\n",
84
+ " frequency_exponent = float(os.getenv(\"FREQUENCY_EXPONENT\", 0.35))\n",
85
+ " gamma = float(os.getenv(\"GAMMA\", 2.1))\n",
86
+ " learning_rate = float(os.getenv(\"LEARNING_RATE\", 2.5e-5))\n",
87
+ " lr_scheduler_type = os.getenv(\"LR_SCHEDULER_TYPE\", \"cosine\")\n",
88
+ " weight_decay = float(os.getenv(\"WEIGHT_DECAY\", 0.007))\n",
89
+ " warmup_ratio = float(os.getenv(\"WARMUP_RATIO\", 0.03))\n",
90
+ " per_device_train_batch_size = int(os.getenv(\"PER_DEVICE_TRAIN_BATCH_SIZE\", 256))\n",
91
+ " max_saved_checkpoints = int(os.getenv(\"MAX_SAVED_CHECKPOINTS\", 8))\n",
92
+ " patience = max_saved_checkpoints - 1\n",
93
+ " n_eval_samples = int(os.getenv(\"N_EVAL_SAMPLES\", 5)) # Number of samples to show\n",
94
+ " log_predictions_to_wandb = os.getenv(\"LOG_PREDICTIONS_TO_WANDB\", \"true\").lower() == \"true\"\n",
95
+ " log_predictions_to_console = os.getenv(\"LOG_PREDICTIONS_TO_CONSOLE\", \"false\").lower() == \"true\"\n",
96
+ "\n",
97
+ " num_processes = torch.cuda.device_count()\n",
98
+ " \n",
99
+ " tokenizer = AutoTokenizer.from_pretrained(\"microsoft/deberta-v3-base\")\n",
100
+ " \n",
101
+ " data_dir = Path(\"data\")\n",
102
+ " output = Path(output_dir)\n",
103
+ " random.seed(seed)\n",
104
+ " torch.manual_seed(seed)\n",
105
+ " np.random.seed(seed)\n",
106
+ "\n",
107
+ " \n",
108
+ " label_list = [\n",
109
+ " \"O\",\n",
110
+ " \"B-CHA\",\n",
111
+ " \"I-CHA\",\n",
112
+ " \"B-LOC\",\n",
113
+ " \"I-LOC\",\n",
114
+ " \"B-FAC\",\n",
115
+ " \"I-FAC\",\n",
116
+ " \"B-OBJ\",\n",
117
+ " \"I-OBJ\",\n",
118
+ " \"B-EVT\",\n",
119
+ " \"I-EVT\",\n",
120
+ " \"B-ORG\",\n",
121
+ " \"I-ORG\",\n",
122
+ " \"B-MISC\",\n",
123
+ " \"I-MISC\"\n",
124
+ " ]\n",
125
+ " label_to_id = {label: i for i, label in enumerate(label_list)}\n",
126
+ " id_to_label = {i: label for i, label in enumerate(label_list)}\n",
127
+ "\n",
128
+ " datasets = []\n",
129
+ " for parquet_file in sorted(data_dir.glob(\"*.parquet\")):\n",
130
+ " ds = Dataset.from_parquet(str(parquet_file))\n",
131
+ " datasets.append(ds)\n",
132
+ "\n",
133
+ " full_ds = concatenate_datasets(datasets)\n",
134
+ " splits = full_ds.train_test_split(train_size=train_size, test_size=test_size, seed=seed)\n",
135
+ "\n",
136
+ " train_ds = splits['train']\n",
137
+ " eval_ds = splits['test']\n",
138
+ "\n",
139
+ " stats_file = \"label_counts.csv\"\n",
140
+ " stats_df = pd.read_csv(stats_file)\n",
141
+ " stats_df.head()\n",
142
+ "\n",
143
+ " total_count = stats_df[\"total\"].sum()\n",
144
+ " label_frequencies = {\n",
145
+ " label: stats_df[label].sum() / total_count for label in label_list\n",
146
+ " }\n",
147
+ " \n",
148
+ " label_weights = {}\n",
149
+ " for label, freq in label_frequencies.items():\n",
150
+ " label_weights[label] = 1.0 / freq ** frequency_exponent\n",
151
+ "\n",
152
+ " weight_tensor = torch.tensor([label_weights[label] for label in label_list], dtype=torch.float32)\n",
153
+ "\n",
154
+ " model = AutoModelForTokenClassification.from_pretrained(\n",
155
+ " \"microsoft/deberta-v3-base\",\n",
156
+ " num_labels=len(label_list),\n",
157
+ " id2label=id_to_label,\n",
158
+ " label2id=label_to_id,\n",
159
+ " ignore_mismatched_sizes=True,\n",
160
+ " hidden_dropout_prob=hidden_dropout_prob,\n",
161
+ " attention_probs_dropout_prob=attention_probs_dropout_prob\n",
162
+ " )\n",
163
+ " \n",
164
+ " data_collator = DataCollatorForTokenClassification(\n",
165
+ " tokenizer=tokenizer,\n",
166
+ " padding=True\n",
167
+ " )\n",
168
+ "\n",
169
+ "\n",
170
+ " def create_compute_metrics_fn(eval_dataset):\n",
171
+ " \"\"\"\n",
172
+ " Factory function that creates a compute_metrics function with access to eval_dataset.\n",
173
+ " \"\"\"\n",
174
+ " def compute_metrics(eval_pred):\n",
175
+ " predictions, labels = eval_pred\n",
176
+ " predictions_raw = predictions # Keep raw predictions for logging\n",
177
+ " predictions = np.argmax(predictions, axis=2)\n",
178
+ " \n",
179
+ " # Remove ignored indices\n",
180
+ " true_predictions = [\n",
181
+ " [id_to_label[p] for (p, l) in zip(pred, label) if l != -100]\n",
182
+ " for pred, label in zip(predictions, labels)\n",
183
+ " ]\n",
184
+ " true_labels = [\n",
185
+ " [id_to_label[l] for (p, l) in zip(pred, label) if l != -100]\n",
186
+ " for pred, label in zip(predictions, labels)\n",
187
+ " ]\n",
188
+ " \n",
189
+ " # Flatten\n",
190
+ " all_predictions = [item for sublist in true_predictions for item in sublist]\n",
191
+ " all_labels = [item for sublist in true_labels for item in sublist]\n",
192
+ " \n",
193
+ " # Calculate metrics excluding 'O' class\n",
194
+ " entity_labels = [l for l in label_list if l != 'O']\n",
195
+ " \n",
196
+ " precision, recall, f1, support = precision_recall_fscore_support(\n",
197
+ " all_labels,\n",
198
+ " all_predictions,\n",
199
+ " labels=entity_labels,\n",
200
+ " average='weighted',\n",
201
+ " zero_division=0\n",
202
+ " )\n",
203
+ "\n",
204
+ " return {\n",
205
+ " 'entity_precision': precision,\n",
206
+ " 'entity_recall': recall,\n",
207
+ " 'entity_f1': f1,\n",
208
+ " }\n",
209
+ " \n",
210
+ " return compute_metrics\n",
211
+ "\n",
212
+ " # Create the compute_metrics function with access to eval_ds\n",
213
+ " compute_metrics = create_compute_metrics_fn(eval_ds)\n",
214
+ "\n",
215
+ " class FocalLoss(torch.nn.Module):\n",
216
+ " def __init__(self, alpha=None, gamma=2.0, reduction='mean', ignore_index=-100):\n",
217
+ " \"\"\"\n",
218
+ " alpha: class weights tensor\n",
219
+ " gamma: focusing parameter (higher = more focus on hard examples)\n",
220
+ " ignore_index: label to ignore (for padding tokens)\n",
221
+ " \"\"\"\n",
222
+ " super().__init__()\n",
223
+ " self.alpha = alpha\n",
224
+ " self.gamma = gamma\n",
225
+ " self.reduction = reduction\n",
226
+ " self.ignore_index = ignore_index\n",
227
+ " \n",
228
+ " def forward(self, logits, labels):\n",
229
+ " # logits shape: (batch_size, seq_len, num_classes)\n",
230
+ " # labels shape: (batch_size, seq_len)\n",
231
+ " \n",
232
+ " # Reshape for loss calculation\n",
233
+ " logits_flat = logits.view(-1, logits.size(-1)) # (batch*seq_len, num_classes)\n",
234
+ " labels_flat = labels.view(-1) # (batch*seq_len)\n",
235
+ " \n",
236
+ " # Calculate cross entropy (without reduction)\n",
237
+ " ce_loss = torch.nn.functional.cross_entropy(\n",
238
+ " logits_flat, \n",
239
+ " labels_flat, \n",
240
+ " reduction='none',\n",
241
+ " ignore_index=self.ignore_index\n",
242
+ " )\n",
243
+ " \n",
244
+ " # Get the probabilities for the correct class\n",
245
+ " p = torch.exp(-ce_loss)\n",
246
+ " \n",
247
+ " # Calculate focal term: (1 - p)^gamma\n",
248
+ " focal_term = (1 - p) ** self.gamma\n",
249
+ " \n",
250
+ " # Apply focal term to loss\n",
251
+ " focal_loss = focal_term * ce_loss\n",
252
+ " \n",
253
+ " # Apply class weights if provided\n",
254
+ " if self.alpha is not None:\n",
255
+ " # Create a mask for valid (non-ignored) tokens\n",
256
+ " valid_mask = labels_flat != self.ignore_index\n",
257
+ " \n",
258
+ " # Gather the weights for each sample's true class\n",
259
+ " # Only for valid labels to avoid index errors\n",
260
+ " valid_labels = labels_flat.clone()\n",
261
+ " valid_labels[~valid_mask] = 0 # Set ignored labels to 0 to avoid index errors\n",
262
+ " \n",
263
+ " alpha_t = self.alpha.gather(0, valid_labels)\n",
264
+ " # Apply mask to weights\n",
265
+ " alpha_t = alpha_t * valid_mask.float()\n",
266
+ " \n",
267
+ " focal_loss = alpha_t * focal_loss\n",
268
+ " \n",
269
+ " # Apply reduction\n",
270
+ " if self.reduction == 'mean':\n",
271
+ " # Only average over non-ignored tokens\n",
272
+ " valid_tokens = (labels_flat != self.ignore_index).sum()\n",
273
+ " return focal_loss.sum() / valid_tokens.clamp(min=1)\n",
274
+ " elif self.reduction == 'sum':\n",
275
+ " return focal_loss.sum()\n",
276
+ " else:\n",
277
+ " return focal_loss\n",
278
+ " \n",
279
+ " class FocalLossTrainer(Trainer):\n",
280
+ " def __init__(self, *args, class_weights=None, gamma=2.0, **kwargs):\n",
281
+ " super().__init__(*args, **kwargs)\n",
282
+ " self.class_weights = class_weights\n",
283
+ " self.gamma = gamma\n",
284
+ " \n",
285
+ " def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):\n",
286
+ " \"\"\"\n",
287
+ " Override compute_loss to use focal loss.\n",
288
+ " num_items_in_batch parameter added for compatibility with newer transformers versions.\n",
289
+ " \"\"\"\n",
290
+ " labels = inputs.get(\"labels\")\n",
291
+ " outputs = model(**inputs)\n",
292
+ " logits = outputs.get(\"logits\")\n",
293
+ " \n",
294
+ " # Move weights to the same device as logits\n",
295
+ " if self.class_weights is not None:\n",
296
+ " weights = self.class_weights.to(logits.device)\n",
297
+ " else:\n",
298
+ " weights = None\n",
299
+ " \n",
300
+ " # Initialize focal loss\n",
301
+ " loss_fct = FocalLoss(\n",
302
+ " alpha=weights,\n",
303
+ " gamma=self.gamma,\n",
304
+ " ignore_index=-100\n",
305
+ " )\n",
306
+ " \n",
307
+ " # Calculate loss\n",
308
+ " loss = loss_fct(logits, labels)\n",
309
+ " \n",
310
+ " return (loss, outputs) if return_outputs else loss\n",
311
+ "\n",
312
+ " \n",
313
+ "\n",
314
+ " training_args = TrainingArguments(\n",
315
+ " output_dir=str(output),\n",
316
+ " learning_rate=learning_rate,\n",
317
+ " lr_scheduler_type=lr_scheduler_type,\n",
318
+ " per_device_train_batch_size=per_device_train_batch_size,\n",
319
+ " weight_decay=weight_decay,\n",
320
+ " warmup_ratio=warmup_ratio,\n",
321
+ " gradient_accumulation_steps=1,\n",
322
+ " logging_steps=50,\n",
323
+ " num_train_epochs=num_epochs,\n",
324
+ " save_strategy=\"steps\",\n",
325
+ " save_steps=save_steps,\n",
326
+ " save_total_limit=3,\n",
327
+ " eval_strategy=\"steps\",\n",
328
+ " eval_steps=save_steps,\n",
329
+ " load_best_model_at_end=True,\n",
330
+ " metric_for_best_model=\"eval_entity_f1\",\n",
331
+ " greater_is_better=True,\n",
332
+ " bf16=True,\n",
333
+ " tf32=True,\n",
334
+ " report_to='wandb',\n",
335
+ " run_name=run_name,\n",
336
+ " push_to_hub=True,\n",
337
+ " hub_strategy=\"checkpoint\",\n",
338
+ " hub_token=hub_token,\n",
339
+ " dataloader_persistent_workers=True,\n",
340
+ " dataloader_num_workers=2,\n",
341
+ " dataloader_pin_memory=True,\n",
342
+ " ddp_find_unused_parameters=False,\n",
343
+ " gradient_checkpointing=False,\n",
344
+ " hub_model_id=model_id,\n",
345
+ " hub_private_repo=True\n",
346
+ " )\n",
347
+ "\n",
348
+ " class CustomEarlyStoppingCallback(TrainerCallback):\n",
349
+ " def __init__(self, patience=2, threshold=0.001):\n",
350
+ " self.patience = patience\n",
351
+ " self.threshold = threshold\n",
352
+ " self.best_metric = None\n",
353
+ " self.wait = 0\n",
354
+ " \n",
355
+ " def on_evaluate(self, args, state, control, metrics=None, **kwargs):\n",
356
+ " if metrics is None or \"eval_entity_f1\" not in metrics:\n",
357
+ " return control\n",
358
+ " metric_value = metrics.get(\"eval_entity_f1\")\n",
359
+ " \n",
360
+ " if self.best_metric is None:\n",
361
+ " self.best_metric = metric_value\n",
362
+ " elif metric_value > self.best_metric + self.threshold:\n",
363
+ " self.best_metric = metric_value\n",
364
+ " self.wait = 0\n",
365
+ " else:\n",
366
+ " self.wait += 1\n",
367
+ " if self.wait >= self.patience:\n",
368
+ " control.should_training_stop = True\n",
369
+ " print(f\"Early stopping triggered. Best F1: {self.best_metric:.4f}\")\n",
370
+ " \n",
371
+ " return control\n",
372
+ " \n",
373
+ "\n",
374
+ " trainer = FocalLossTrainer(\n",
375
+ " model=model,\n",
376
+ " args=training_args,\n",
377
+ " train_dataset=train_ds,\n",
378
+ " eval_dataset=eval_ds,\n",
379
+ " processing_class=tokenizer,\n",
380
+ " data_collator=data_collator,\n",
381
+ " compute_metrics=compute_metrics,\n",
382
+ " class_weights=weight_tensor,\n",
383
+ " gamma=gamma,\n",
384
+ " callbacks=[CustomEarlyStoppingCallback(patience=patience, threshold=0.0001)]\n",
385
+ " )\n",
386
+ " \n",
387
+ " if wandb.run is not None:\n",
388
+ " # Add custom config values\n",
389
+ " wandb.config.update({\n",
390
+ " # Data configuration\n",
391
+ " \"train_samples\": len(train_ds),\n",
392
+ " \"eval_samples\": len(eval_ds),\n",
393
+ " \"train_size_requested\": train_size,\n",
394
+ " \"test_size_requested\": test_size,\n",
395
+ " \"actual_train_size\": len(train_ds),\n",
396
+ " \"actual_eval_size\": len(eval_ds),\n",
397
+ "\n",
398
+ " # Model architecture details\n",
399
+ " \"model_architecture\": \"deberta-v3-base\",\n",
400
+ " \"num_labels\": len(label_list),\n",
401
+ " \"label_list\": label_list,\n",
402
+ "\n",
403
+ " # Loss function configuration\n",
404
+ " \"loss_function\": \"focal_loss\",\n",
405
+ " \"focal_gamma\": gamma,\n",
406
+ " \"focal_alpha\": \"weighted\",\n",
407
+ " \"frequency_exponent\": frequency_exponent,\n",
408
+ "\n",
409
+ " # Dropout configuration\n",
410
+ " \"hidden_dropout_prob\": hidden_dropout_prob,\n",
411
+ " \"attention_probs_dropout_prob\": attention_probs_dropout_prob,\n",
412
+ "\n",
413
+ " # Training configuration not in TrainingArguments\n",
414
+ " \"max_saved_checkpoints\": max_saved_checkpoints,\n",
415
+ " \"early_stopping_patience\": patience,\n",
416
+ " \"early_stopping_threshold\": 0.001,\n",
417
+ "\n",
418
+ " # Environment info\n",
419
+ " \"cuda_devices\": cuda_visible_devices,\n",
420
+ " \"num_gpus\": num_devices,\n",
421
+ "\n",
422
+ " # Data processing\n",
423
+ " \"tokenizer\": \"microsoft/deberta-v3-base\"\n",
424
+ "\n",
425
+ " # Experiment metadata\n",
426
+ " \"experiment_type\": \"ner_fiction\",\n",
427
+ " \"data_source\": \"gutenberg_ao3_mixed\",\n",
428
+ " \"random_seed\": seed,\n",
429
+ "\n",
430
+ " # Logging configuration\n",
431
+ " \"n_eval_samples\": n_eval_samples,\n",
432
+ " \"log_predictions_to_wandb\": log_predictions_to_wandb,\n",
433
+ " })\n",
434
+ "\n",
435
+ " has_checkpoints = bool([f for f in os.scandir(output_dir) if f.is_dir() and \"checkpoint\" in f.name])\n",
436
+ " if has_checkpoints:\n",
437
+ " trainer.train(resume_from_checkpoint=True)\n",
438
+ " else:\n",
439
+ " trainer.train()\n",
440
+ "\n",
441
+ "notebook_launcher(train_fn, num_processes=num_devices)"
442
+ ]
443
+ }
444
+ ],
445
+ "metadata": {
446
+ "kernelspec": {
447
+ "display_name": "Python 3 (ipykernel)",
448
+ "language": "python",
449
+ "name": "python3"
450
+ },
451
+ "language_info": {
452
+ "codemirror_mode": {
453
+ "name": "ipython",
454
+ "version": 3
455
+ },
456
+ "file_extension": ".py",
457
+ "mimetype": "text/x-python",
458
+ "name": "python",
459
+ "nbconvert_exporter": "python",
460
+ "pygments_lexer": "ipython3",
461
+ "version": "3.12.3"
462
+ }
463
+ },
464
+ "nbformat": 4,
465
+ "nbformat_minor": 5
466
+ }
training_args.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5cd0cadb15b38d5d62eb7ba1c8d8cf6d8ff7a7453651ed77c16de5239c1b3221
3
- size 5905