Update README.md
Browse files
README.md
CHANGED
|
@@ -14,191 +14,38 @@ should probably proofread and complete it, then remove this comment. -->
|
|
| 14 |
|
| 15 |
# VersaPRM-Small-Subset
|
| 16 |
|
| 17 |
-
This model is a fine-tuned version of [UW-Madison-Lee-Lab/Llama-PRM800K](https://huggingface.co/UW-Madison-Lee-Lab/Llama-PRM800K) on
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
- num_devices: 8
|
| 53 |
-
- gradient_accumulation_steps: 2
|
| 54 |
-
- total_train_batch_size: 32
|
| 55 |
-
- total_eval_batch_size: 32
|
| 56 |
-
- optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
|
| 57 |
-
- lr_scheduler_type: cosine
|
| 58 |
-
- lr_scheduler_warmup_ratio: 0.1
|
| 59 |
-
- num_epochs: 3
|
| 60 |
-
|
| 61 |
-
### Training results
|
| 62 |
-
|
| 63 |
-
| Training Loss | Epoch | Step | Validation Loss | Prm accuracy | Prm precision | Prm recall | Prm specificty | Prm npv | Prm f1 | Prm f1 neg | Prm f1 auc | Prm f1 auc (fixed) |
|
| 64 |
-
|:-------------:|:------:|:----:|:---------------:|:------------:|:-------------:|:----------:|:--------------:|:-------:|:------:|:----------:|:----------:|:------------------:|
|
| 65 |
-
| No log | 0 | 0 | 0.3535 | 0.8333 | 0.8772 | 0.9346 | 0.2632 | 0.4167 | 0.9050 | 0.3226 | 0.5989 | 0.8195 |
|
| 66 |
-
| 0.3784 | 0.0229 | 5 | 0.3555 | 0.8333 | 0.8772 | 0.9346 | 0.2632 | 0.4167 | 0.9050 | 0.3226 | 0.5989 | 0.8175 |
|
| 67 |
-
| 0.406 | 0.0459 | 10 | 0.3487 | 0.8413 | 0.8718 | 0.9533 | 0.2105 | 0.4444 | 0.9107 | 0.2857 | 0.5819 | 0.8224 |
|
| 68 |
-
| 0.2865 | 0.0688 | 15 | 0.3669 | 0.8571 | 0.8678 | 0.9813 | 0.1579 | 0.6 | 0.9211 | 0.25 | 0.5696 | 0.8286 |
|
| 69 |
-
| 0.1769 | 0.0917 | 20 | 0.4144 | 0.8571 | 0.856 | 1.0 | 0.0526 | 1.0 | 0.9224 | 0.1 | 0.5263 | 0.8505 |
|
| 70 |
-
| 0.3069 | 0.1147 | 25 | 0.3531 | 0.8413 | 0.8537 | 0.9813 | 0.0526 | 0.3333 | 0.9130 | 0.0909 | 0.5170 | 0.8546 |
|
| 71 |
-
| 0.2295 | 0.1376 | 30 | 0.3140 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8569 |
|
| 72 |
-
| 0.2301 | 0.1606 | 35 | 0.3168 | 0.8651 | 0.875 | 0.9813 | 0.2105 | 0.6667 | 0.9251 | 0.32 | 0.5959 | 0.8706 |
|
| 73 |
-
| 0.351 | 0.1835 | 40 | 0.3087 | 0.8651 | 0.875 | 0.9813 | 0.2105 | 0.6667 | 0.9251 | 0.32 | 0.5959 | 0.8751 |
|
| 74 |
-
| 0.2607 | 0.2064 | 45 | 0.2788 | 0.8730 | 0.8889 | 0.9720 | 0.3158 | 0.6667 | 0.9286 | 0.4286 | 0.6439 | 0.8721 |
|
| 75 |
-
| 0.288 | 0.2294 | 50 | 0.2909 | 0.8651 | 0.8689 | 0.9907 | 0.1579 | 0.75 | 0.9258 | 0.2609 | 0.5743 | 0.8741 |
|
| 76 |
-
| 0.2185 | 0.2523 | 55 | 0.2831 | 0.8571 | 0.8803 | 0.9626 | 0.2632 | 0.5556 | 0.9196 | 0.3571 | 0.6129 | 0.8650 |
|
| 77 |
-
| 0.2307 | 0.2752 | 60 | 0.2987 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8667 |
|
| 78 |
-
| 0.1842 | 0.2982 | 65 | 0.3128 | 0.8651 | 0.875 | 0.9813 | 0.2105 | 0.6667 | 0.9251 | 0.32 | 0.5959 | 0.8706 |
|
| 79 |
-
| 0.1831 | 0.3211 | 70 | 0.2936 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8667 |
|
| 80 |
-
| 0.1947 | 0.3440 | 75 | 0.3333 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8679 |
|
| 81 |
-
| 0.127 | 0.3670 | 80 | 0.2984 | 0.8492 | 0.8793 | 0.9533 | 0.2632 | 0.5 | 0.9148 | 0.3448 | 0.6082 | 0.8763 |
|
| 82 |
-
| 0.1929 | 0.3899 | 85 | 0.3202 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8785 |
|
| 83 |
-
| 0.2229 | 0.4128 | 90 | 0.2970 | 0.8492 | 0.8793 | 0.9533 | 0.2632 | 0.5 | 0.9148 | 0.3448 | 0.6082 | 0.8805 |
|
| 84 |
-
| 0.2483 | 0.4358 | 95 | 0.2941 | 0.8730 | 0.8760 | 0.9907 | 0.2105 | 0.8 | 0.9298 | 0.3333 | 0.6006 | 0.8758 |
|
| 85 |
-
| 0.1739 | 0.4587 | 100 | 0.2689 | 0.8571 | 0.9083 | 0.9252 | 0.4737 | 0.5294 | 0.9167 | 0.5 | 0.6995 | 0.8741 |
|
| 86 |
-
| 0.1255 | 0.4817 | 105 | 0.2814 | 0.8730 | 0.8760 | 0.9907 | 0.2105 | 0.8 | 0.9298 | 0.3333 | 0.6006 | 0.8792 |
|
| 87 |
-
| 0.2923 | 0.5046 | 110 | 0.2896 | 0.8730 | 0.8760 | 0.9907 | 0.2105 | 0.8 | 0.9298 | 0.3333 | 0.6006 | 0.8765 |
|
| 88 |
-
| 0.1591 | 0.5275 | 115 | 0.2788 | 0.8492 | 0.8793 | 0.9533 | 0.2632 | 0.5 | 0.9148 | 0.3448 | 0.6082 | 0.8674 |
|
| 89 |
-
| 0.1779 | 0.5505 | 120 | 0.3477 | 0.8571 | 0.8678 | 0.9813 | 0.1579 | 0.6 | 0.9211 | 0.25 | 0.5696 | 0.8657 |
|
| 90 |
-
| 0.2207 | 0.5734 | 125 | 0.3152 | 0.8492 | 0.8667 | 0.9720 | 0.1579 | 0.5 | 0.9163 | 0.24 | 0.5649 | 0.8667 |
|
| 91 |
-
| 0.1996 | 0.5963 | 130 | 0.2898 | 0.8651 | 0.875 | 0.9813 | 0.2105 | 0.6667 | 0.9251 | 0.32 | 0.5959 | 0.8647 |
|
| 92 |
-
| 0.1385 | 0.6193 | 135 | 0.2960 | 0.8730 | 0.8760 | 0.9907 | 0.2105 | 0.8 | 0.9298 | 0.3333 | 0.6006 | 0.8684 |
|
| 93 |
-
| 0.1334 | 0.6422 | 140 | 0.3235 | 0.8730 | 0.8760 | 0.9907 | 0.2105 | 0.8 | 0.9298 | 0.3333 | 0.6006 | 0.8694 |
|
| 94 |
-
| 0.2133 | 0.6651 | 145 | 0.3182 | 0.8810 | 0.8898 | 0.9813 | 0.3158 | 0.75 | 0.9333 | 0.4444 | 0.6485 | 0.8765 |
|
| 95 |
-
| 0.1889 | 0.6881 | 150 | 0.3074 | 0.8651 | 0.8879 | 0.9626 | 0.3158 | 0.6 | 0.9238 | 0.4138 | 0.6392 | 0.8751 |
|
| 96 |
-
| 0.1451 | 0.7110 | 155 | 0.3549 | 0.8492 | 0.8667 | 0.9720 | 0.1579 | 0.5 | 0.9163 | 0.24 | 0.5649 | 0.8706 |
|
| 97 |
-
| 0.205 | 0.7339 | 160 | 0.3393 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8736 |
|
| 98 |
-
| 0.1852 | 0.7569 | 165 | 0.2970 | 0.8651 | 0.8879 | 0.9626 | 0.3158 | 0.6 | 0.9238 | 0.4138 | 0.6392 | 0.8768 |
|
| 99 |
-
| 0.1195 | 0.7798 | 170 | 0.3259 | 0.8730 | 0.8824 | 0.9813 | 0.2632 | 0.7143 | 0.9292 | 0.3846 | 0.6222 | 0.8785 |
|
| 100 |
-
| 0.2319 | 0.8028 | 175 | 0.3726 | 0.8730 | 0.8760 | 0.9907 | 0.2105 | 0.8 | 0.9298 | 0.3333 | 0.6006 | 0.8788 |
|
| 101 |
-
| 0.1131 | 0.8257 | 180 | 0.2880 | 0.8730 | 0.8760 | 0.9907 | 0.2105 | 0.8 | 0.9298 | 0.3333 | 0.6006 | 0.8795 |
|
| 102 |
-
| 0.2147 | 0.8486 | 185 | 0.2731 | 0.8889 | 0.8908 | 0.9907 | 0.3158 | 0.8571 | 0.9381 | 0.4615 | 0.6532 | 0.8802 |
|
| 103 |
-
| 0.1673 | 0.8716 | 190 | 0.3121 | 0.8651 | 0.8689 | 0.9907 | 0.1579 | 0.75 | 0.9258 | 0.2609 | 0.5743 | 0.8822 |
|
| 104 |
-
| 0.0806 | 0.8945 | 195 | 0.3096 | 0.8810 | 0.8833 | 0.9907 | 0.2632 | 0.8333 | 0.9339 | 0.4 | 0.6269 | 0.8815 |
|
| 105 |
-
| 0.1573 | 0.9174 | 200 | 0.2795 | 0.8730 | 0.8957 | 0.9626 | 0.3684 | 0.6364 | 0.9279 | 0.4667 | 0.6655 | 0.8812 |
|
| 106 |
-
| 0.1062 | 0.9404 | 205 | 0.3145 | 0.8810 | 0.8833 | 0.9907 | 0.2632 | 0.8333 | 0.9339 | 0.4 | 0.6269 | 0.8760 |
|
| 107 |
-
| 0.1036 | 0.9633 | 210 | 0.3353 | 0.8810 | 0.8833 | 0.9907 | 0.2632 | 0.8333 | 0.9339 | 0.4 | 0.6269 | 0.8829 |
|
| 108 |
-
| 0.1586 | 0.9862 | 215 | 0.3197 | 0.8810 | 0.8833 | 0.9907 | 0.2632 | 0.8333 | 0.9339 | 0.4 | 0.6269 | 0.8842 |
|
| 109 |
-
| 0.1319 | 1.0092 | 220 | 0.2869 | 0.8810 | 0.8833 | 0.9907 | 0.2632 | 0.8333 | 0.9339 | 0.4 | 0.6269 | 0.8790 |
|
| 110 |
-
| 0.1444 | 1.0321 | 225 | 0.3127 | 0.8730 | 0.8824 | 0.9813 | 0.2632 | 0.7143 | 0.9292 | 0.3846 | 0.6222 | 0.8812 |
|
| 111 |
-
| 0.1346 | 1.0550 | 230 | 0.4127 | 0.8651 | 0.8689 | 0.9907 | 0.1579 | 0.75 | 0.9258 | 0.2609 | 0.5743 | 0.8824 |
|
| 112 |
-
| 0.1412 | 1.0780 | 235 | 0.3096 | 0.8730 | 0.8957 | 0.9626 | 0.3684 | 0.6364 | 0.9279 | 0.4667 | 0.6655 | 0.8847 |
|
| 113 |
-
| 0.0782 | 1.1009 | 240 | 0.3728 | 0.8730 | 0.8760 | 0.9907 | 0.2105 | 0.8 | 0.9298 | 0.3333 | 0.6006 | 0.8775 |
|
| 114 |
-
| 0.1134 | 1.1239 | 245 | 0.4411 | 0.8651 | 0.8689 | 0.9907 | 0.1579 | 0.75 | 0.9258 | 0.2609 | 0.5743 | 0.8709 |
|
| 115 |
-
| 0.0536 | 1.1468 | 250 | 0.3540 | 0.8730 | 0.8760 | 0.9907 | 0.2105 | 0.8 | 0.9298 | 0.3333 | 0.6006 | 0.8733 |
|
| 116 |
-
| 0.0687 | 1.1697 | 255 | 0.3990 | 0.8651 | 0.8689 | 0.9907 | 0.1579 | 0.75 | 0.9258 | 0.2609 | 0.5743 | 0.8733 |
|
| 117 |
-
| 0.0868 | 1.1927 | 260 | 0.4859 | 0.8651 | 0.8689 | 0.9907 | 0.1579 | 0.75 | 0.9258 | 0.2609 | 0.5743 | 0.8684 |
|
| 118 |
-
| 0.0858 | 1.2156 | 265 | 0.4038 | 0.8651 | 0.8689 | 0.9907 | 0.1579 | 0.75 | 0.9258 | 0.2609 | 0.5743 | 0.8679 |
|
| 119 |
-
| 0.1594 | 1.2385 | 270 | 0.3725 | 0.8730 | 0.8760 | 0.9907 | 0.2105 | 0.8 | 0.9298 | 0.3333 | 0.6006 | 0.8603 |
|
| 120 |
-
| 0.0289 | 1.2615 | 275 | 0.4295 | 0.8651 | 0.8689 | 0.9907 | 0.1579 | 0.75 | 0.9258 | 0.2609 | 0.5743 | 0.8544 |
|
| 121 |
-
| 0.121 | 1.2844 | 280 | 0.4086 | 0.8651 | 0.8689 | 0.9907 | 0.1579 | 0.75 | 0.9258 | 0.2609 | 0.5743 | 0.8556 |
|
| 122 |
-
| 0.133 | 1.3073 | 285 | 0.3791 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8620 |
|
| 123 |
-
| 0.0599 | 1.3303 | 290 | 0.3711 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8596 |
|
| 124 |
-
| 0.0826 | 1.3532 | 295 | 0.4366 | 0.8651 | 0.8689 | 0.9907 | 0.1579 | 0.75 | 0.9258 | 0.2609 | 0.5743 | 0.8571 |
|
| 125 |
-
| 0.0724 | 1.3761 | 300 | 0.4015 | 0.8730 | 0.8760 | 0.9907 | 0.2105 | 0.8 | 0.9298 | 0.3333 | 0.6006 | 0.8529 |
|
| 126 |
-
| 0.1192 | 1.3991 | 305 | 0.3474 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8608 |
|
| 127 |
-
| 0.1415 | 1.4220 | 310 | 0.3613 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8650 |
|
| 128 |
-
| 0.1567 | 1.4450 | 315 | 0.4130 | 0.8571 | 0.8678 | 0.9813 | 0.1579 | 0.6 | 0.9211 | 0.25 | 0.5696 | 0.8645 |
|
| 129 |
-
| 0.0669 | 1.4679 | 320 | 0.4484 | 0.8571 | 0.8678 | 0.9813 | 0.1579 | 0.6 | 0.9211 | 0.25 | 0.5696 | 0.8660 |
|
| 130 |
-
| 0.0824 | 1.4908 | 325 | 0.4422 | 0.8651 | 0.875 | 0.9813 | 0.2105 | 0.6667 | 0.9251 | 0.32 | 0.5959 | 0.8687 |
|
| 131 |
-
| 0.0809 | 1.5138 | 330 | 0.4073 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8756 |
|
| 132 |
-
| 0.1214 | 1.5367 | 335 | 0.4339 | 0.8492 | 0.8667 | 0.9720 | 0.1579 | 0.5 | 0.9163 | 0.24 | 0.5649 | 0.8679 |
|
| 133 |
-
| 0.0728 | 1.5596 | 340 | 0.4479 | 0.8492 | 0.8667 | 0.9720 | 0.1579 | 0.5 | 0.9163 | 0.24 | 0.5649 | 0.8694 |
|
| 134 |
-
| 0.0982 | 1.5826 | 345 | 0.4658 | 0.8492 | 0.8607 | 0.9813 | 0.1053 | 0.5 | 0.9170 | 0.1739 | 0.5433 | 0.8660 |
|
| 135 |
-
| 0.0643 | 1.6055 | 350 | 0.4538 | 0.8571 | 0.8618 | 0.9907 | 0.1053 | 0.6667 | 0.9217 | 0.1818 | 0.5480 | 0.8637 |
|
| 136 |
-
| 0.0898 | 1.6284 | 355 | 0.3773 | 0.8413 | 0.8595 | 0.9720 | 0.1053 | 0.4 | 0.9123 | 0.1667 | 0.5386 | 0.8660 |
|
| 137 |
-
| 0.1687 | 1.6514 | 360 | 0.3235 | 0.8571 | 0.8803 | 0.9626 | 0.2632 | 0.5556 | 0.9196 | 0.3571 | 0.6129 | 0.8714 |
|
| 138 |
-
| 0.0977 | 1.6743 | 365 | 0.3467 | 0.8651 | 0.8814 | 0.9720 | 0.2632 | 0.625 | 0.9244 | 0.3704 | 0.6176 | 0.8719 |
|
| 139 |
-
| 0.069 | 1.6972 | 370 | 0.4018 | 0.8571 | 0.8618 | 0.9907 | 0.1053 | 0.6667 | 0.9217 | 0.1818 | 0.5480 | 0.8728 |
|
| 140 |
-
| 0.1326 | 1.7202 | 375 | 0.4035 | 0.8571 | 0.8618 | 0.9907 | 0.1053 | 0.6667 | 0.9217 | 0.1818 | 0.5480 | 0.8751 |
|
| 141 |
-
| 0.0738 | 1.7431 | 380 | 0.3621 | 0.8651 | 0.8689 | 0.9907 | 0.1579 | 0.75 | 0.9258 | 0.2609 | 0.5743 | 0.8714 |
|
| 142 |
-
| 0.0988 | 1.7661 | 385 | 0.3494 | 0.8730 | 0.8824 | 0.9813 | 0.2632 | 0.7143 | 0.9292 | 0.3846 | 0.6222 | 0.8719 |
|
| 143 |
-
| 0.0759 | 1.7890 | 390 | 0.3620 | 0.8571 | 0.8803 | 0.9626 | 0.2632 | 0.5556 | 0.9196 | 0.3571 | 0.6129 | 0.8694 |
|
| 144 |
-
| 0.1142 | 1.8119 | 395 | 0.3973 | 0.8571 | 0.8803 | 0.9626 | 0.2632 | 0.5556 | 0.9196 | 0.3571 | 0.6129 | 0.8726 |
|
| 145 |
-
| 0.1812 | 1.8349 | 400 | 0.4123 | 0.8571 | 0.8803 | 0.9626 | 0.2632 | 0.5556 | 0.9196 | 0.3571 | 0.6129 | 0.8719 |
|
| 146 |
-
| 0.0437 | 1.8578 | 405 | 0.3928 | 0.8571 | 0.8870 | 0.9533 | 0.3158 | 0.5455 | 0.9189 | 0.4 | 0.6345 | 0.8704 |
|
| 147 |
-
| 0.1087 | 1.8807 | 410 | 0.4132 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8699 |
|
| 148 |
-
| 0.094 | 1.9037 | 415 | 0.3848 | 0.8413 | 0.8718 | 0.9533 | 0.2105 | 0.4444 | 0.9107 | 0.2857 | 0.5819 | 0.8699 |
|
| 149 |
-
| 0.082 | 1.9266 | 420 | 0.3774 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8731 |
|
| 150 |
-
| 0.0896 | 1.9495 | 425 | 0.3949 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8724 |
|
| 151 |
-
| 0.1096 | 1.9725 | 430 | 0.4198 | 0.8651 | 0.875 | 0.9813 | 0.2105 | 0.6667 | 0.9251 | 0.32 | 0.5959 | 0.8728 |
|
| 152 |
-
| 0.1142 | 1.9954 | 435 | 0.4113 | 0.8651 | 0.875 | 0.9813 | 0.2105 | 0.6667 | 0.9251 | 0.32 | 0.5959 | 0.8736 |
|
| 153 |
-
| 0.0418 | 2.0183 | 440 | 0.4037 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8709 |
|
| 154 |
-
| 0.0231 | 2.0413 | 445 | 0.4156 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8687 |
|
| 155 |
-
| 0.0357 | 2.0642 | 450 | 0.4368 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8689 |
|
| 156 |
-
| 0.0396 | 2.0872 | 455 | 0.4785 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8679 |
|
| 157 |
-
| 0.0458 | 2.1101 | 460 | 0.5241 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8645 |
|
| 158 |
-
| 0.018 | 2.1330 | 465 | 0.5647 | 0.8571 | 0.8678 | 0.9813 | 0.1579 | 0.6 | 0.9211 | 0.25 | 0.5696 | 0.8628 |
|
| 159 |
-
| 0.0294 | 2.1560 | 470 | 0.6041 | 0.8571 | 0.8678 | 0.9813 | 0.1579 | 0.6 | 0.9211 | 0.25 | 0.5696 | 0.8618 |
|
| 160 |
-
| 0.064 | 2.1789 | 475 | 0.5872 | 0.8413 | 0.8655 | 0.9626 | 0.1579 | 0.4286 | 0.9115 | 0.2308 | 0.5603 | 0.8660 |
|
| 161 |
-
| 0.0708 | 2.2018 | 480 | 0.5229 | 0.8492 | 0.8793 | 0.9533 | 0.2632 | 0.5 | 0.9148 | 0.3448 | 0.6082 | 0.8645 |
|
| 162 |
-
| 0.0344 | 2.2248 | 485 | 0.4986 | 0.8571 | 0.8870 | 0.9533 | 0.3158 | 0.5455 | 0.9189 | 0.4 | 0.6345 | 0.8657 |
|
| 163 |
-
| 0.0184 | 2.2477 | 490 | 0.5377 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8655 |
|
| 164 |
-
| 0.0316 | 2.2706 | 495 | 0.5832 | 0.8413 | 0.8655 | 0.9626 | 0.1579 | 0.4286 | 0.9115 | 0.2308 | 0.5603 | 0.8628 |
|
| 165 |
-
| 0.0133 | 2.2936 | 500 | 0.5912 | 0.8492 | 0.8667 | 0.9720 | 0.1579 | 0.5 | 0.9163 | 0.24 | 0.5649 | 0.8657 |
|
| 166 |
-
| 0.0315 | 2.3165 | 505 | 0.5803 | 0.8492 | 0.8667 | 0.9720 | 0.1579 | 0.5 | 0.9163 | 0.24 | 0.5649 | 0.8667 |
|
| 167 |
-
| 0.0599 | 2.3394 | 510 | 0.5893 | 0.8413 | 0.8655 | 0.9626 | 0.1579 | 0.4286 | 0.9115 | 0.2308 | 0.5603 | 0.8637 |
|
| 168 |
-
| 0.0245 | 2.3624 | 515 | 0.5885 | 0.8413 | 0.8655 | 0.9626 | 0.1579 | 0.4286 | 0.9115 | 0.2308 | 0.5603 | 0.8660 |
|
| 169 |
-
| 0.0091 | 2.3853 | 520 | 0.5829 | 0.8413 | 0.8655 | 0.9626 | 0.1579 | 0.4286 | 0.9115 | 0.2308 | 0.5603 | 0.8662 |
|
| 170 |
-
| 0.0446 | 2.4083 | 525 | 0.5867 | 0.8413 | 0.8655 | 0.9626 | 0.1579 | 0.4286 | 0.9115 | 0.2308 | 0.5603 | 0.8652 |
|
| 171 |
-
| 0.0185 | 2.4312 | 530 | 0.5704 | 0.8413 | 0.8655 | 0.9626 | 0.1579 | 0.4286 | 0.9115 | 0.2308 | 0.5603 | 0.8652 |
|
| 172 |
-
| 0.02 | 2.4541 | 535 | 0.5542 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8655 |
|
| 173 |
-
| 0.0248 | 2.4771 | 540 | 0.5494 | 0.8413 | 0.8655 | 0.9626 | 0.1579 | 0.4286 | 0.9115 | 0.2308 | 0.5603 | 0.8665 |
|
| 174 |
-
| 0.0178 | 2.5 | 545 | 0.5424 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8672 |
|
| 175 |
-
| 0.0084 | 2.5229 | 550 | 0.5434 | 0.8413 | 0.8655 | 0.9626 | 0.1579 | 0.4286 | 0.9115 | 0.2308 | 0.5603 | 0.8650 |
|
| 176 |
-
| 0.0307 | 2.5459 | 555 | 0.5538 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8655 |
|
| 177 |
-
| 0.0414 | 2.5688 | 560 | 0.5469 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8652 |
|
| 178 |
-
| 0.0089 | 2.5917 | 565 | 0.5447 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8645 |
|
| 179 |
-
| 0.027 | 2.6147 | 570 | 0.5398 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8650 |
|
| 180 |
-
| 0.0087 | 2.6376 | 575 | 0.5417 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8660 |
|
| 181 |
-
| 0.0438 | 2.6606 | 580 | 0.5484 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8642 |
|
| 182 |
-
| 0.0238 | 2.6835 | 585 | 0.5475 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8657 |
|
| 183 |
-
| 0.0515 | 2.7064 | 590 | 0.5416 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8635 |
|
| 184 |
-
| 0.1527 | 2.7294 | 595 | 0.5277 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8650 |
|
| 185 |
-
| 0.0587 | 2.7523 | 600 | 0.5274 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8660 |
|
| 186 |
-
| 0.0062 | 2.7752 | 605 | 0.5361 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8655 |
|
| 187 |
-
| 0.0279 | 2.7982 | 610 | 0.5427 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8637 |
|
| 188 |
-
| 0.0046 | 2.8211 | 615 | 0.5462 | 0.8413 | 0.8655 | 0.9626 | 0.1579 | 0.4286 | 0.9115 | 0.2308 | 0.5603 | 0.8635 |
|
| 189 |
-
| 0.0156 | 2.8440 | 620 | 0.5533 | 0.8413 | 0.8655 | 0.9626 | 0.1579 | 0.4286 | 0.9115 | 0.2308 | 0.5603 | 0.8625 |
|
| 190 |
-
| 0.0077 | 2.8670 | 625 | 0.5456 | 0.8492 | 0.8729 | 0.9626 | 0.2105 | 0.5 | 0.9156 | 0.2963 | 0.5866 | 0.8650 |
|
| 191 |
-
| 0.0085 | 2.8899 | 630 | 0.5465 | 0.8413 | 0.8655 | 0.9626 | 0.1579 | 0.4286 | 0.9115 | 0.2308 | 0.5603 | 0.8637 |
|
| 192 |
-
| 0.0179 | 2.9128 | 635 | 0.5492 | 0.8413 | 0.8655 | 0.9626 | 0.1579 | 0.4286 | 0.9115 | 0.2308 | 0.5603 | 0.8630 |
|
| 193 |
-
| 0.0571 | 2.9358 | 640 | 0.5421 | 0.8571 | 0.8739 | 0.9720 | 0.2105 | 0.5714 | 0.9204 | 0.3077 | 0.5912 | 0.8640 |
|
| 194 |
-
| 0.0239 | 2.9587 | 645 | 0.5471 | 0.8413 | 0.8655 | 0.9626 | 0.1579 | 0.4286 | 0.9115 | 0.2308 | 0.5603 | 0.8665 |
|
| 195 |
-
| 0.0317 | 2.9817 | 650 | 0.5531 | 0.8413 | 0.8655 | 0.9626 | 0.1579 | 0.4286 | 0.9115 | 0.2308 | 0.5603 | 0.8635 |
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
### Framework versions
|
| 199 |
-
|
| 200 |
-
- PEFT 0.12.0
|
| 201 |
-
- Transformers 4.46.0
|
| 202 |
-
- Pytorch 2.4.0+cu118
|
| 203 |
-
- Datasets 3.0.0
|
| 204 |
-
- Tokenizers 0.20.1
|
|
|
|
| 14 |
|
| 15 |
# VersaPRM-Small-Subset
|
| 16 |
|
| 17 |
+
This model is a fine-tuned version of [UW-Madison-Lee-Lab/Llama-PRM800K](https://huggingface.co/UW-Madison-Lee-Lab/Llama-PRM800K) on a uniformly sampled random subset of [UW-Madison-Lee-Lab/MMLU-Pro-CoT-Train-Labeled](https://huggingface.co/datasets/UW-Madison-Lee-Lab/MMLU-Pro-CoT-Train-Labeled).
|
| 18 |
+
|
| 19 |
+
## Get rewards
|
| 20 |
+
```python
|
| 21 |
+
import torch
|
| 22 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 23 |
+
|
| 24 |
+
def get_tokenizer(model_id):
|
| 25 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 26 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 27 |
+
tokenizer.padding_side = 'left'
|
| 28 |
+
tokenizer.truncation_side = 'left'
|
| 29 |
+
return tokenizer
|
| 30 |
+
|
| 31 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 32 |
+
tokenizer = get_tokenizer('UW-Madison-Lee-Lab/VersaPRM-Small-Subset')
|
| 33 |
+
model = AutoModelForCausalLM.from_pretrained('UW-Madison-Lee-Lab/VersaPRM-Small-Subset')
|
| 34 |
+
candidate_tokens = [12, 10]
|
| 35 |
+
model.to(device)
|
| 36 |
+
|
| 37 |
+
question = 'Question: In Python 3, which of the following function convert a string to an int in python?\nA. short(x)\nB. float(x)\nC. integer(x [,base])\nD. double(x)\nE. int(x [,base])\nF. long(x [,base] )\nG. num(x)\nH. str(x)\nI. char(x)\nJ. digit(x [,base])'
|
| 38 |
+
solution = ["To convert a string to an integer in Python 3, we use the built-in function int().",
|
| 39 |
+
"The int() function takes two arguments: the string to be converted and an optional base (default is 10, which is for decimal).",
|
| 40 |
+
"For example: int(\"123\", 10) converts the string \"123\" to the integer 123.",
|
| 41 |
+
"Looking at the options, we can see that the correct function is option E: int(x [,base]).",
|
| 42 |
+
"The answer is (E)."]
|
| 43 |
+
input_text = question + ' \n\n' + ' \n\n\n\n'.join(solution) + ' \n\n\n\n' # solution steps are separated by ' \n\n\n\n'
|
| 44 |
+
input_id = torch.tensor([tokenizer.encode(input_text)]).to(device)
|
| 45 |
+
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
logits = model(input_id).logits[:,:,candidate_tokens]
|
| 48 |
+
scores = logits.softmax(dim=-1)[:,:,1]
|
| 49 |
+
step_scores = scores[input_id == 23535]
|
| 50 |
+
step_probs = step_scores.tolist()
|
| 51 |
+
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|