Spaces:
Sleeping
Sleeping
Update results
Browse files- evaluate.py +4 -4
- results.json +301 -4
evaluate.py
CHANGED
|
@@ -8,7 +8,7 @@ from huggingface_hub import list_repo_refs
|
|
| 8 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 9 |
|
| 10 |
device = "cuda"
|
| 11 |
-
test_indices_length =
|
| 12 |
|
| 13 |
models = ["distributed/optimized-gpt2-250m", "distributed/optimized-gpt2-250m-v0.1.1", "distributed/gpt2-94m"]
|
| 14 |
|
|
@@ -28,9 +28,9 @@ for model_name in models:
|
|
| 28 |
refs = list_repo_refs(model_name, repo_type="model")
|
| 29 |
global_epoch = max([int(tag.name) for tag in refs.tags]) if refs.tags else None
|
| 30 |
|
| 31 |
-
for epoch in range(0,global_epoch
|
| 32 |
|
| 33 |
-
if str(epoch) in results[model_name].keys():
|
| 34 |
continue
|
| 35 |
|
| 36 |
model = AutoModelForCausalLM.from_pretrained(model_name, revision=str(epoch), trust_remote_code=True)
|
|
@@ -80,7 +80,7 @@ for model_name in models:
|
|
| 80 |
model.zero_grad()
|
| 81 |
|
| 82 |
average_loss = total_loss / (index+1)
|
| 83 |
-
results[model_name][str(epoch)] = [average_loss]
|
| 84 |
print(f"Epoch: {epoch} Average Loss: {average_loss:.2f}")
|
| 85 |
|
| 86 |
with open("results.json", "w") as outfile:
|
|
|
|
| 8 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 9 |
|
| 10 |
device = "cuda"
|
| 11 |
+
test_indices_length = 10
|
| 12 |
|
| 13 |
models = ["distributed/optimized-gpt2-250m", "distributed/optimized-gpt2-250m-v0.1.1", "distributed/gpt2-94m"]
|
| 14 |
|
|
|
|
| 28 |
refs = list_repo_refs(model_name, repo_type="model")
|
| 29 |
global_epoch = max([int(tag.name) for tag in refs.tags]) if refs.tags else None
|
| 30 |
|
| 31 |
+
for epoch in range(0,global_epoch):
|
| 32 |
|
| 33 |
+
if str(epoch) in results[model_name]['main-net'].keys():
|
| 34 |
continue
|
| 35 |
|
| 36 |
model = AutoModelForCausalLM.from_pretrained(model_name, revision=str(epoch), trust_remote_code=True)
|
|
|
|
| 80 |
model.zero_grad()
|
| 81 |
|
| 82 |
average_loss = total_loss / (index+1)
|
| 83 |
+
results[model_name]['main-net'][str(epoch)] = [average_loss]
|
| 84 |
print(f"Epoch: {epoch} Average Loss: {average_loss:.2f}")
|
| 85 |
|
| 86 |
with open("results.json", "w") as outfile:
|
results.json
CHANGED
|
@@ -3285,9 +3285,306 @@
|
|
| 3285 |
],
|
| 3286 |
"1094": [
|
| 3287 |
5.5514771938323975
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3288 |
]
|
| 3289 |
},
|
| 3290 |
-
"baseline":{
|
| 3291 |
"0": [
|
| 3292 |
10.93135
|
| 3293 |
],
|
|
@@ -3795,7 +4092,7 @@
|
|
| 3795 |
}
|
| 3796 |
},
|
| 3797 |
"distributed/optimized-gpt2-250m-v0.1.1": {
|
| 3798 |
-
"main-net":{
|
| 3799 |
"0": [
|
| 3800 |
11.042416954040528
|
| 3801 |
],
|
|
@@ -4727,7 +5024,7 @@
|
|
| 4727 |
6.409368515014648
|
| 4728 |
]
|
| 4729 |
},
|
| 4730 |
-
"baseline":{
|
| 4731 |
"0": [
|
| 4732 |
10.93135
|
| 4733 |
],
|
|
@@ -5235,7 +5532,7 @@
|
|
| 5235 |
}
|
| 5236 |
},
|
| 5237 |
"distributed/gpt2-94m": {
|
| 5238 |
-
"main-net":{
|
| 5239 |
"0": [
|
| 5240 |
10.942681312561035
|
| 5241 |
],
|
|
|
|
| 3285 |
],
|
| 3286 |
"1094": [
|
| 3287 |
5.5514771938323975
|
| 3288 |
+
],
|
| 3289 |
+
"1095": [
|
| 3290 |
+
5.654173533121745
|
| 3291 |
+
],
|
| 3292 |
+
"1096": [
|
| 3293 |
+
5.783674240112305
|
| 3294 |
+
],
|
| 3295 |
+
"1097": [
|
| 3296 |
+
5.732811212539673
|
| 3297 |
+
],
|
| 3298 |
+
"1098": [
|
| 3299 |
+
5.725842118263245
|
| 3300 |
+
],
|
| 3301 |
+
"1099": [
|
| 3302 |
+
6.016797780990601
|
| 3303 |
+
],
|
| 3304 |
+
"1100": [
|
| 3305 |
+
5.492693265279134
|
| 3306 |
+
],
|
| 3307 |
+
"1101": [
|
| 3308 |
+
5.746817111968994
|
| 3309 |
+
],
|
| 3310 |
+
"1102": [
|
| 3311 |
+
5.732641816139221
|
| 3312 |
+
],
|
| 3313 |
+
"1103": [
|
| 3314 |
+
5.6667522430419925
|
| 3315 |
+
],
|
| 3316 |
+
"1104": [
|
| 3317 |
+
6.042284965515137
|
| 3318 |
+
],
|
| 3319 |
+
"1105": [
|
| 3320 |
+
5.957233905792236
|
| 3321 |
+
],
|
| 3322 |
+
"1106": [
|
| 3323 |
+
6.250933527946472
|
| 3324 |
+
],
|
| 3325 |
+
"1107": [
|
| 3326 |
+
6.189672231674194
|
| 3327 |
+
],
|
| 3328 |
+
"1108": [
|
| 3329 |
+
5.723267146519253
|
| 3330 |
+
],
|
| 3331 |
+
"1109": [
|
| 3332 |
+
5.82790470123291
|
| 3333 |
+
],
|
| 3334 |
+
"1110": [
|
| 3335 |
+
5.603849093119304
|
| 3336 |
+
],
|
| 3337 |
+
"1111": [
|
| 3338 |
+
5.782259106636047
|
| 3339 |
+
],
|
| 3340 |
+
"1112": [
|
| 3341 |
+
5.6029471556345625
|
| 3342 |
+
],
|
| 3343 |
+
"1113": [
|
| 3344 |
+
5.546664714813232
|
| 3345 |
+
],
|
| 3346 |
+
"1114": [
|
| 3347 |
+
5.836468458175659
|
| 3348 |
+
],
|
| 3349 |
+
"1115": [
|
| 3350 |
+
5.762608170509338
|
| 3351 |
+
],
|
| 3352 |
+
"1116": [
|
| 3353 |
+
6.046388339996338
|
| 3354 |
+
],
|
| 3355 |
+
"1117": [
|
| 3356 |
+
6.0027131080627445
|
| 3357 |
+
],
|
| 3358 |
+
"1118": [
|
| 3359 |
+
6.388125038146972
|
| 3360 |
+
],
|
| 3361 |
+
"1119": [
|
| 3362 |
+
6.11626410484314
|
| 3363 |
+
],
|
| 3364 |
+
"1120": [
|
| 3365 |
+
6.112424373626709
|
| 3366 |
+
],
|
| 3367 |
+
"1121": [
|
| 3368 |
+
6.001961326599121
|
| 3369 |
+
],
|
| 3370 |
+
"1122": [
|
| 3371 |
+
5.683912754058838
|
| 3372 |
+
],
|
| 3373 |
+
"1123": [
|
| 3374 |
+
5.743198204040527
|
| 3375 |
+
],
|
| 3376 |
+
"1124": [
|
| 3377 |
+
5.722175757090251
|
| 3378 |
+
],
|
| 3379 |
+
"1125": [
|
| 3380 |
+
5.8825499534606935
|
| 3381 |
+
],
|
| 3382 |
+
"1126": [
|
| 3383 |
+
6.028543186187744
|
| 3384 |
+
],
|
| 3385 |
+
"1127": [
|
| 3386 |
+
5.9720460891723635
|
| 3387 |
+
],
|
| 3388 |
+
"1128": [
|
| 3389 |
+
6.058712800343831
|
| 3390 |
+
],
|
| 3391 |
+
"1129": [
|
| 3392 |
+
5.475493907928467
|
| 3393 |
+
],
|
| 3394 |
+
"1130": [
|
| 3395 |
+
5.970499634742737
|
| 3396 |
+
],
|
| 3397 |
+
"1131": [
|
| 3398 |
+
5.9493067264556885
|
| 3399 |
+
],
|
| 3400 |
+
"1132": [
|
| 3401 |
+
5.458620548248291
|
| 3402 |
+
],
|
| 3403 |
+
"1133": [
|
| 3404 |
+
5.992060820261638
|
| 3405 |
+
],
|
| 3406 |
+
"1134": [
|
| 3407 |
+
5.951226472854614
|
| 3408 |
+
],
|
| 3409 |
+
"1135": [
|
| 3410 |
+
5.877881646156311
|
| 3411 |
+
],
|
| 3412 |
+
"1136": [
|
| 3413 |
+
5.603206443786621
|
| 3414 |
+
],
|
| 3415 |
+
"1137": [
|
| 3416 |
+
5.8340943336486815
|
| 3417 |
+
],
|
| 3418 |
+
"1138": [
|
| 3419 |
+
5.788412570953369
|
| 3420 |
+
],
|
| 3421 |
+
"1139": [
|
| 3422 |
+
5.737103462219238
|
| 3423 |
+
],
|
| 3424 |
+
"1140": [
|
| 3425 |
+
5.636613210042317
|
| 3426 |
+
],
|
| 3427 |
+
"1141": [
|
| 3428 |
+
5.949309587478638
|
| 3429 |
+
],
|
| 3430 |
+
"1142": [
|
| 3431 |
+
5.854878067970276
|
| 3432 |
+
],
|
| 3433 |
+
"1143": [
|
| 3434 |
+
5.924749374389648
|
| 3435 |
+
],
|
| 3436 |
+
"1144": [
|
| 3437 |
+
6.321739387512207
|
| 3438 |
+
],
|
| 3439 |
+
"1145": [
|
| 3440 |
+
5.9811422030131025
|
| 3441 |
+
],
|
| 3442 |
+
"1146": [
|
| 3443 |
+
5.701364517211914
|
| 3444 |
+
],
|
| 3445 |
+
"1147": [
|
| 3446 |
+
5.503353691101074
|
| 3447 |
+
],
|
| 3448 |
+
"1148": [
|
| 3449 |
+
5.773120641708374
|
| 3450 |
+
],
|
| 3451 |
+
"1149": [
|
| 3452 |
+
6.042929470539093
|
| 3453 |
+
],
|
| 3454 |
+
"1150": [
|
| 3455 |
+
5.8076521555582685
|
| 3456 |
+
],
|
| 3457 |
+
"1151": [
|
| 3458 |
+
5.682760079701741
|
| 3459 |
+
],
|
| 3460 |
+
"1152": [
|
| 3461 |
+
5.757667303085327
|
| 3462 |
+
],
|
| 3463 |
+
"1153": [
|
| 3464 |
+
5.896499156951904
|
| 3465 |
+
],
|
| 3466 |
+
"1154": [
|
| 3467 |
+
6.025218367576599
|
| 3468 |
+
],
|
| 3469 |
+
"1155": [
|
| 3470 |
+
5.879011154174805
|
| 3471 |
+
],
|
| 3472 |
+
"1156": [
|
| 3473 |
+
5.868439674377441
|
| 3474 |
+
],
|
| 3475 |
+
"1157": [
|
| 3476 |
+
6.418252754211426
|
| 3477 |
+
],
|
| 3478 |
+
"1158": [
|
| 3479 |
+
6.2828675508499146
|
| 3480 |
+
],
|
| 3481 |
+
"1159": [
|
| 3482 |
+
6.36786642074585
|
| 3483 |
+
],
|
| 3484 |
+
"1160": [
|
| 3485 |
+
6.58310022354126
|
| 3486 |
+
],
|
| 3487 |
+
"1161": [
|
| 3488 |
+
6.19826873143514
|
| 3489 |
+
],
|
| 3490 |
+
"1162": [
|
| 3491 |
+
6.289691209793091
|
| 3492 |
+
],
|
| 3493 |
+
"1163": [
|
| 3494 |
+
5.9907801151275635
|
| 3495 |
+
],
|
| 3496 |
+
"1164": [
|
| 3497 |
+
6.041745066642761
|
| 3498 |
+
],
|
| 3499 |
+
"1165": [
|
| 3500 |
+
6.02010326385498
|
| 3501 |
+
],
|
| 3502 |
+
"1166": [
|
| 3503 |
+
5.7515941460927325
|
| 3504 |
+
],
|
| 3505 |
+
"1167": [
|
| 3506 |
+
5.48467755317688
|
| 3507 |
+
],
|
| 3508 |
+
"1168": [
|
| 3509 |
+
6.096215724945068
|
| 3510 |
+
],
|
| 3511 |
+
"1169": [
|
| 3512 |
+
5.959380865097046
|
| 3513 |
+
],
|
| 3514 |
+
"1170": [
|
| 3515 |
+
5.851028124491374
|
| 3516 |
+
],
|
| 3517 |
+
"1171": [
|
| 3518 |
+
5.8480740785598755
|
| 3519 |
+
],
|
| 3520 |
+
"1172": [
|
| 3521 |
+
5.9064167737960815
|
| 3522 |
+
],
|
| 3523 |
+
"1173": [
|
| 3524 |
+
5.956684430440267
|
| 3525 |
+
],
|
| 3526 |
+
"1174": [
|
| 3527 |
+
6.00377357006073
|
| 3528 |
+
],
|
| 3529 |
+
"1175": [
|
| 3530 |
+
6.077920118967692
|
| 3531 |
+
],
|
| 3532 |
+
"1176": [
|
| 3533 |
+
5.967975934346517
|
| 3534 |
+
],
|
| 3535 |
+
"1177": [
|
| 3536 |
+
6.253712558746338
|
| 3537 |
+
],
|
| 3538 |
+
"1178": [
|
| 3539 |
+
5.780354785919189
|
| 3540 |
+
],
|
| 3541 |
+
"1179": [
|
| 3542 |
+
5.4884843826293945
|
| 3543 |
+
],
|
| 3544 |
+
"1180": [
|
| 3545 |
+
5.482951482137044
|
| 3546 |
+
],
|
| 3547 |
+
"1181": [
|
| 3548 |
+
5.966793219248454
|
| 3549 |
+
],
|
| 3550 |
+
"1182": [
|
| 3551 |
+
5.51493239402771
|
| 3552 |
+
],
|
| 3553 |
+
"1183": [
|
| 3554 |
+
5.4840850830078125
|
| 3555 |
+
],
|
| 3556 |
+
"1184": [
|
| 3557 |
+
5.834247946739197
|
| 3558 |
+
],
|
| 3559 |
+
"1185": [
|
| 3560 |
+
5.770521521568298
|
| 3561 |
+
],
|
| 3562 |
+
"1186": [
|
| 3563 |
+
5.671548962593079
|
| 3564 |
+
],
|
| 3565 |
+
"1187": [
|
| 3566 |
+
5.491109371185303
|
| 3567 |
+
],
|
| 3568 |
+
"1188": [
|
| 3569 |
+
5.561888694763184
|
| 3570 |
+
],
|
| 3571 |
+
"1189": [
|
| 3572 |
+
5.711345076560974
|
| 3573 |
+
],
|
| 3574 |
+
"1190": [
|
| 3575 |
+
5.628474712371826
|
| 3576 |
+
],
|
| 3577 |
+
"1191": [
|
| 3578 |
+
5.514147567749023
|
| 3579 |
+
],
|
| 3580 |
+
"1192": [
|
| 3581 |
+
5.556583046913147
|
| 3582 |
+
],
|
| 3583 |
+
"1193": [
|
| 3584 |
+
5.653698126475017
|
| 3585 |
]
|
| 3586 |
},
|
| 3587 |
+
"baseline": {
|
| 3588 |
"0": [
|
| 3589 |
10.93135
|
| 3590 |
],
|
|
|
|
| 4092 |
}
|
| 4093 |
},
|
| 4094 |
"distributed/optimized-gpt2-250m-v0.1.1": {
|
| 4095 |
+
"main-net": {
|
| 4096 |
"0": [
|
| 4097 |
11.042416954040528
|
| 4098 |
],
|
|
|
|
| 5024 |
6.409368515014648
|
| 5025 |
]
|
| 5026 |
},
|
| 5027 |
+
"baseline": {
|
| 5028 |
"0": [
|
| 5029 |
10.93135
|
| 5030 |
],
|
|
|
|
| 5532 |
}
|
| 5533 |
},
|
| 5534 |
"distributed/gpt2-94m": {
|
| 5535 |
+
"main-net": {
|
| 5536 |
"0": [
|
| 5537 |
10.942681312561035
|
| 5538 |
],
|