thomashk2001 commited on
Commit
f73b95d
·
verified ·
1 Parent(s): 9569a13

Fine-tuned ViT on Tom & Jerry dataset

Browse files
README.md ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: apache-2.0
4
+ base_model: google/vit-base-patch16-224-in21k
5
+ tags:
6
+ - generated_from_trainer
7
+ metrics:
8
+ - accuracy
9
+ - precision
10
+ - recall
11
+ - f1
12
+ model-index:
13
+ - name: tom_and_jerry_vit_model
14
+ results: []
15
+ ---
16
+
17
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
18
+ should probably proofread and complete it, then remove this comment. -->
19
+
20
+ # tom_and_jerry_vit_model
21
+
22
+ This model is a fine-tuned version of [google/vit-base-patch16-224-in21k](https://huggingface.co/google/vit-base-patch16-224-in21k) on an unknown dataset.
23
+ It achieves the following results on the evaluation set:
24
+ - Loss: 0.1530
25
+ - Accuracy: 0.9562
26
+ - Precision: 0.9526
27
+ - Recall: 0.9587
28
+ - F1: 0.9553
29
+
30
+ ## Model description
31
+
32
+ More information needed
33
+
34
+ ## Intended uses & limitations
35
+
36
+ More information needed
37
+
38
+ ## Training and evaluation data
39
+
40
+ More information needed
41
+
42
+ ## Training procedure
43
+
44
+ ### Training hyperparameters
45
+
46
+ The following hyperparameters were used during training:
47
+ - learning_rate: 0.0002
48
+ - train_batch_size: 64
49
+ - eval_batch_size: 64
50
+ - seed: 42
51
+ - optimizer: Use OptimizerNames.ADAMW_TORCH_FUSED with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
52
+ - lr_scheduler_type: linear
53
+ - num_epochs: 5
54
+
55
+ ### Training results
56
+
57
+ | Training Loss | Epoch | Step | Validation Loss | Accuracy | Precision | Recall | F1 |
58
+ |:-------------:|:------:|:----:|:---------------:|:--------:|:---------:|:------:|:------:|
59
+ | 0.8223 | 0.4167 | 25 | 0.4506 | 0.8893 | 0.8939 | 0.8653 | 0.8742 |
60
+ | 0.2676 | 0.8333 | 50 | 0.2195 | 0.9392 | 0.9343 | 0.9376 | 0.9356 |
61
+ | 0.1896 | 1.25 | 75 | 0.1816 | 0.9526 | 0.9490 | 0.9504 | 0.9493 |
62
+ | 0.1085 | 1.6667 | 100 | 0.1940 | 0.9380 | 0.9316 | 0.9381 | 0.9344 |
63
+ | 0.1618 | 2.0833 | 125 | 0.1806 | 0.9477 | 0.9390 | 0.9493 | 0.9434 |
64
+ | 0.0784 | 2.5 | 150 | 0.1582 | 0.9574 | 0.9524 | 0.9570 | 0.9546 |
65
+ | 0.071 | 2.9167 | 175 | 0.1803 | 0.9416 | 0.9364 | 0.9413 | 0.9386 |
66
+ | 0.0533 | 3.3333 | 200 | 0.1539 | 0.9611 | 0.9623 | 0.9600 | 0.9605 |
67
+ | 0.0383 | 3.75 | 225 | 0.1446 | 0.9647 | 0.9654 | 0.9642 | 0.9646 |
68
+ | 0.0264 | 4.1667 | 250 | 0.1619 | 0.9513 | 0.9447 | 0.9546 | 0.9488 |
69
+ | 0.0227 | 4.5833 | 275 | 0.1524 | 0.9550 | 0.9498 | 0.9579 | 0.9531 |
70
+ | 0.0343 | 5.0 | 300 | 0.1530 | 0.9562 | 0.9526 | 0.9587 | 0.9553 |
71
+
72
+
73
+ ### Framework versions
74
+
75
+ - Transformers 4.55.2
76
+ - Pytorch 2.8.0+cu129
77
+ - Datasets 4.0.0
78
+ - Tokenizers 0.21.4
all_results.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 5.0,
3
+ "total_flos": 1.48554806942933e+18,
4
+ "train_loss": 0.17210483262936274,
5
+ "train_runtime": 149.2973,
6
+ "train_samples_per_second": 128.402,
7
+ "train_steps_per_second": 2.009
8
+ }
config.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ViTForImageClassification"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.0,
6
+ "encoder_stride": 16,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.0,
9
+ "hidden_size": 768,
10
+ "id2label": {
11
+ "0": "Jerry",
12
+ "1": "None",
13
+ "2": "Tom",
14
+ "3": "Tom_and_Jerry"
15
+ },
16
+ "image_size": 224,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 3072,
19
+ "label2id": {
20
+ "Jerry": "0",
21
+ "None": "1",
22
+ "Tom": "2",
23
+ "Tom_and_Jerry": "3"
24
+ },
25
+ "layer_norm_eps": 1e-12,
26
+ "model_type": "vit",
27
+ "num_attention_heads": 12,
28
+ "num_channels": 3,
29
+ "num_hidden_layers": 12,
30
+ "patch_size": 16,
31
+ "pooler_act": "tanh",
32
+ "pooler_output_size": 768,
33
+ "problem_type": "single_label_classification",
34
+ "qkv_bias": true,
35
+ "torch_dtype": "float32",
36
+ "transformers_version": "4.55.2"
37
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9861c7090b1151ebb471f9af44c36f7bc41515e3f27877c2865f8c4358cf9c63
3
+ size 343230128
train_results.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 5.0,
3
+ "total_flos": 1.48554806942933e+18,
4
+ "train_loss": 0.17210483262936274,
5
+ "train_runtime": 149.2973,
6
+ "train_samples_per_second": 128.402,
7
+ "train_steps_per_second": 2.009
8
+ }
trainer_state.json ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_global_step": 225,
3
+ "best_metric": 0.9645740739756267,
4
+ "best_model_checkpoint": "./vit_tom_jerry_mdl/checkpoint-200",
5
+ "epoch": 5.0,
6
+ "eval_steps": 25,
7
+ "global_step": 300,
8
+ "is_hyper_param_search": false,
9
+ "is_local_process_zero": true,
10
+ "is_world_process_zero": true,
11
+ "log_history": [
12
+ {
13
+ "epoch": 0.16666666666666666,
14
+ "grad_norm": 1.179008960723877,
15
+ "learning_rate": 0.000194,
16
+ "loss": 1.2461,
17
+ "step": 10
18
+ },
19
+ {
20
+ "epoch": 0.3333333333333333,
21
+ "grad_norm": 1.6004935503005981,
22
+ "learning_rate": 0.00018733333333333335,
23
+ "loss": 0.8223,
24
+ "step": 20
25
+ },
26
+ {
27
+ "epoch": 0.4166666666666667,
28
+ "eval_accuracy": 0.889294403892944,
29
+ "eval_f1": 0.8742397689620669,
30
+ "eval_loss": 0.450603187084198,
31
+ "eval_precision": 0.8939126626626628,
32
+ "eval_recall": 0.8652641370683252,
33
+ "eval_runtime": 3.461,
34
+ "eval_samples_per_second": 237.503,
35
+ "eval_steps_per_second": 3.756,
36
+ "step": 25
37
+ },
38
+ {
39
+ "epoch": 0.5,
40
+ "grad_norm": 1.335172414779663,
41
+ "learning_rate": 0.00018066666666666668,
42
+ "loss": 0.4994,
43
+ "step": 30
44
+ },
45
+ {
46
+ "epoch": 0.6666666666666666,
47
+ "grad_norm": 1.348803162574768,
48
+ "learning_rate": 0.000174,
49
+ "loss": 0.3169,
50
+ "step": 40
51
+ },
52
+ {
53
+ "epoch": 0.8333333333333334,
54
+ "grad_norm": 1.2741419076919556,
55
+ "learning_rate": 0.00016733333333333335,
56
+ "loss": 0.2676,
57
+ "step": 50
58
+ },
59
+ {
60
+ "epoch": 0.8333333333333334,
61
+ "eval_accuracy": 0.9391727493917275,
62
+ "eval_f1": 0.9356370716759485,
63
+ "eval_loss": 0.21945379674434662,
64
+ "eval_precision": 0.9342880020745046,
65
+ "eval_recall": 0.9375607233423536,
66
+ "eval_runtime": 3.409,
67
+ "eval_samples_per_second": 241.126,
68
+ "eval_steps_per_second": 3.813,
69
+ "step": 50
70
+ },
71
+ {
72
+ "epoch": 1.0,
73
+ "grad_norm": 1.302258849143982,
74
+ "learning_rate": 0.00016066666666666668,
75
+ "loss": 0.2316,
76
+ "step": 60
77
+ },
78
+ {
79
+ "epoch": 1.1666666666666667,
80
+ "grad_norm": 1.284261703491211,
81
+ "learning_rate": 0.000154,
82
+ "loss": 0.1896,
83
+ "step": 70
84
+ },
85
+ {
86
+ "epoch": 1.25,
87
+ "eval_accuracy": 0.9525547445255474,
88
+ "eval_f1": 0.9493014762192153,
89
+ "eval_loss": 0.1816127747297287,
90
+ "eval_precision": 0.9489785971260434,
91
+ "eval_recall": 0.9503527850213396,
92
+ "eval_runtime": 3.4233,
93
+ "eval_samples_per_second": 240.119,
94
+ "eval_steps_per_second": 3.798,
95
+ "step": 75
96
+ },
97
+ {
98
+ "epoch": 1.3333333333333333,
99
+ "grad_norm": 1.1670647859573364,
100
+ "learning_rate": 0.00014733333333333335,
101
+ "loss": 0.1451,
102
+ "step": 80
103
+ },
104
+ {
105
+ "epoch": 1.5,
106
+ "grad_norm": 0.9231269359588623,
107
+ "learning_rate": 0.00014066666666666668,
108
+ "loss": 0.1246,
109
+ "step": 90
110
+ },
111
+ {
112
+ "epoch": 1.6666666666666665,
113
+ "grad_norm": 1.229040503501892,
114
+ "learning_rate": 0.000134,
115
+ "loss": 0.1085,
116
+ "step": 100
117
+ },
118
+ {
119
+ "epoch": 1.6666666666666665,
120
+ "eval_accuracy": 0.9379562043795621,
121
+ "eval_f1": 0.9344160814351121,
122
+ "eval_loss": 0.19404640793800354,
123
+ "eval_precision": 0.9315772279233091,
124
+ "eval_recall": 0.9380771013895641,
125
+ "eval_runtime": 3.4388,
126
+ "eval_samples_per_second": 239.035,
127
+ "eval_steps_per_second": 3.78,
128
+ "step": 100
129
+ },
130
+ {
131
+ "epoch": 1.8333333333333335,
132
+ "grad_norm": 1.5356346368789673,
133
+ "learning_rate": 0.00012733333333333336,
134
+ "loss": 0.1379,
135
+ "step": 110
136
+ },
137
+ {
138
+ "epoch": 2.0,
139
+ "grad_norm": 0.9422320127487183,
140
+ "learning_rate": 0.00012066666666666668,
141
+ "loss": 0.1618,
142
+ "step": 120
143
+ },
144
+ {
145
+ "epoch": 2.0833333333333335,
146
+ "eval_accuracy": 0.9476885644768857,
147
+ "eval_f1": 0.9433613445349764,
148
+ "eval_loss": 0.1806175410747528,
149
+ "eval_precision": 0.9389737449498523,
150
+ "eval_recall": 0.94927038162268,
151
+ "eval_runtime": 3.417,
152
+ "eval_samples_per_second": 240.564,
153
+ "eval_steps_per_second": 3.805,
154
+ "step": 125
155
+ },
156
+ {
157
+ "epoch": 2.1666666666666665,
158
+ "grad_norm": 0.40647202730178833,
159
+ "learning_rate": 0.00011399999999999999,
160
+ "loss": 0.0831,
161
+ "step": 130
162
+ },
163
+ {
164
+ "epoch": 2.3333333333333335,
165
+ "grad_norm": 0.5886309146881104,
166
+ "learning_rate": 0.00010733333333333333,
167
+ "loss": 0.0784,
168
+ "step": 140
169
+ },
170
+ {
171
+ "epoch": 2.5,
172
+ "grad_norm": 0.3857206404209137,
173
+ "learning_rate": 0.00010066666666666667,
174
+ "loss": 0.0784,
175
+ "step": 150
176
+ },
177
+ {
178
+ "epoch": 2.5,
179
+ "eval_accuracy": 0.9574209245742092,
180
+ "eval_f1": 0.954610584452178,
181
+ "eval_loss": 0.1581779420375824,
182
+ "eval_precision": 0.9523814232432715,
183
+ "eval_recall": 0.957049310025567,
184
+ "eval_runtime": 3.4015,
185
+ "eval_samples_per_second": 241.656,
186
+ "eval_steps_per_second": 3.822,
187
+ "step": 150
188
+ },
189
+ {
190
+ "epoch": 2.6666666666666665,
191
+ "grad_norm": 0.7096158266067505,
192
+ "learning_rate": 9.4e-05,
193
+ "loss": 0.0726,
194
+ "step": 160
195
+ },
196
+ {
197
+ "epoch": 2.8333333333333335,
198
+ "grad_norm": 2.3420352935791016,
199
+ "learning_rate": 8.733333333333333e-05,
200
+ "loss": 0.071,
201
+ "step": 170
202
+ },
203
+ {
204
+ "epoch": 2.9166666666666665,
205
+ "eval_accuracy": 0.9416058394160584,
206
+ "eval_f1": 0.9385804706284082,
207
+ "eval_loss": 0.18030095100402832,
208
+ "eval_precision": 0.9364433325083628,
209
+ "eval_recall": 0.9413294616863843,
210
+ "eval_runtime": 3.4018,
211
+ "eval_samples_per_second": 241.637,
212
+ "eval_steps_per_second": 3.822,
213
+ "step": 175
214
+ },
215
+ {
216
+ "epoch": 3.0,
217
+ "grad_norm": 2.018068552017212,
218
+ "learning_rate": 8.066666666666667e-05,
219
+ "loss": 0.0739,
220
+ "step": 180
221
+ },
222
+ {
223
+ "epoch": 3.1666666666666665,
224
+ "grad_norm": 0.5890735387802124,
225
+ "learning_rate": 7.4e-05,
226
+ "loss": 0.0586,
227
+ "step": 190
228
+ },
229
+ {
230
+ "epoch": 3.3333333333333335,
231
+ "grad_norm": 0.2291577160358429,
232
+ "learning_rate": 6.733333333333333e-05,
233
+ "loss": 0.0533,
234
+ "step": 200
235
+ },
236
+ {
237
+ "epoch": 3.3333333333333335,
238
+ "eval_accuracy": 0.9610705596107056,
239
+ "eval_f1": 0.9604590740790521,
240
+ "eval_loss": 0.15387743711471558,
241
+ "eval_precision": 0.9622977817571188,
242
+ "eval_recall": 0.9600399492969355,
243
+ "eval_runtime": 3.3928,
244
+ "eval_samples_per_second": 242.28,
245
+ "eval_steps_per_second": 3.832,
246
+ "step": 200
247
+ },
248
+ {
249
+ "epoch": 3.5,
250
+ "grad_norm": 0.08786796778440475,
251
+ "learning_rate": 6.066666666666667e-05,
252
+ "loss": 0.0463,
253
+ "step": 210
254
+ },
255
+ {
256
+ "epoch": 3.6666666666666665,
257
+ "grad_norm": 0.22060526907444,
258
+ "learning_rate": 5.4000000000000005e-05,
259
+ "loss": 0.0383,
260
+ "step": 220
261
+ },
262
+ {
263
+ "epoch": 3.75,
264
+ "eval_accuracy": 0.9647201946472019,
265
+ "eval_f1": 0.9645740739756267,
266
+ "eval_loss": 0.14463135600090027,
267
+ "eval_precision": 0.9653599801235093,
268
+ "eval_recall": 0.9641853092636594,
269
+ "eval_runtime": 3.3707,
270
+ "eval_samples_per_second": 243.868,
271
+ "eval_steps_per_second": 3.857,
272
+ "step": 225
273
+ },
274
+ {
275
+ "epoch": 3.8333333333333335,
276
+ "grad_norm": 0.07651939243078232,
277
+ "learning_rate": 4.7333333333333336e-05,
278
+ "loss": 0.0439,
279
+ "step": 230
280
+ },
281
+ {
282
+ "epoch": 4.0,
283
+ "grad_norm": 1.2403360605239868,
284
+ "learning_rate": 4.066666666666667e-05,
285
+ "loss": 0.0557,
286
+ "step": 240
287
+ },
288
+ {
289
+ "epoch": 4.166666666666667,
290
+ "grad_norm": 0.15227381885051727,
291
+ "learning_rate": 3.4000000000000007e-05,
292
+ "loss": 0.0264,
293
+ "step": 250
294
+ },
295
+ {
296
+ "epoch": 4.166666666666667,
297
+ "eval_accuracy": 0.9513381995133819,
298
+ "eval_f1": 0.9487974655404795,
299
+ "eval_loss": 0.16188818216323853,
300
+ "eval_precision": 0.9447150174035173,
301
+ "eval_recall": 0.9546121976142474,
302
+ "eval_runtime": 3.414,
303
+ "eval_samples_per_second": 240.77,
304
+ "eval_steps_per_second": 3.808,
305
+ "step": 250
306
+ },
307
+ {
308
+ "epoch": 4.333333333333333,
309
+ "grad_norm": 0.1498788446187973,
310
+ "learning_rate": 2.733333333333333e-05,
311
+ "loss": 0.0274,
312
+ "step": 260
313
+ },
314
+ {
315
+ "epoch": 4.5,
316
+ "grad_norm": 0.07677994668483734,
317
+ "learning_rate": 2.0666666666666666e-05,
318
+ "loss": 0.0227,
319
+ "step": 270
320
+ },
321
+ {
322
+ "epoch": 4.583333333333333,
323
+ "eval_accuracy": 0.9549878345498783,
324
+ "eval_f1": 0.9531194945497546,
325
+ "eval_loss": 0.15235914289951324,
326
+ "eval_precision": 0.9497723508669295,
327
+ "eval_recall": 0.9579100556580387,
328
+ "eval_runtime": 3.4275,
329
+ "eval_samples_per_second": 239.824,
330
+ "eval_steps_per_second": 3.793,
331
+ "step": 275
332
+ },
333
+ {
334
+ "epoch": 4.666666666666667,
335
+ "grad_norm": 0.1796388328075409,
336
+ "learning_rate": 1.4000000000000001e-05,
337
+ "loss": 0.023,
338
+ "step": 280
339
+ },
340
+ {
341
+ "epoch": 4.833333333333333,
342
+ "grad_norm": 0.5997887849807739,
343
+ "learning_rate": 7.333333333333334e-06,
344
+ "loss": 0.0244,
345
+ "step": 290
346
+ },
347
+ {
348
+ "epoch": 5.0,
349
+ "grad_norm": 0.06375733017921448,
350
+ "learning_rate": 6.666666666666667e-07,
351
+ "loss": 0.0343,
352
+ "step": 300
353
+ },
354
+ {
355
+ "epoch": 5.0,
356
+ "eval_accuracy": 0.9562043795620438,
357
+ "eval_f1": 0.9552724576163357,
358
+ "eval_loss": 0.1529521346092224,
359
+ "eval_precision": 0.9525861346195577,
360
+ "eval_recall": 0.9587138585290341,
361
+ "eval_runtime": 3.421,
362
+ "eval_samples_per_second": 240.277,
363
+ "eval_steps_per_second": 3.8,
364
+ "step": 300
365
+ },
366
+ {
367
+ "epoch": 5.0,
368
+ "step": 300,
369
+ "total_flos": 1.48554806942933e+18,
370
+ "train_loss": 0.17210483262936274,
371
+ "train_runtime": 149.2973,
372
+ "train_samples_per_second": 128.402,
373
+ "train_steps_per_second": 2.009
374
+ }
375
+ ],
376
+ "logging_steps": 10,
377
+ "max_steps": 300,
378
+ "num_input_tokens_seen": 0,
379
+ "num_train_epochs": 5,
380
+ "save_steps": 100,
381
+ "stateful_callbacks": {
382
+ "EarlyStoppingCallback": {
383
+ "args": {
384
+ "early_stopping_patience": 3,
385
+ "early_stopping_threshold": 0.0
386
+ },
387
+ "attributes": {
388
+ "early_stopping_patience_counter": 3
389
+ }
390
+ },
391
+ "TrainerControl": {
392
+ "args": {
393
+ "should_epoch_stop": false,
394
+ "should_evaluate": false,
395
+ "should_log": false,
396
+ "should_save": true,
397
+ "should_training_stop": true
398
+ },
399
+ "attributes": {}
400
+ }
401
+ },
402
+ "total_flos": 1.48554806942933e+18,
403
+ "train_batch_size": 64,
404
+ "trial_name": null,
405
+ "trial_params": null
406
+ }
training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30afe6103a86791f9b1677c356bee415a0089cb6bc837bedacf8d86bb6bceb26
3
+ size 5713