lestienne commited on
Commit
386cc98
·
verified ·
1 Parent(s): 5ea870c

Add files using upload-large-folder tool

Browse files
Files changed (49) hide show
  1. lists/banking77/size=32/seed=0/0.0-0.3.txt +93 -0
  2. lists/banking77/size=32/seed=0/0.0-0.7.txt +216 -0
  3. lists/banking77/size=32/seed=0/0.0-1.0.txt +308 -0
  4. lists/banking77/size=32/seed=0/0.7-1.0.txt +92 -0
  5. lists/banking77/size=32/seed=1/0.0-0.3.txt +93 -0
  6. lists/banking77/size=32/seed=1/0.0-0.7.txt +216 -0
  7. lists/banking77/size=32/seed=1/0.0-1.0.txt +308 -0
  8. lists/banking77/size=32/seed=1/0.7-1.0.txt +92 -0
  9. lists/banking77/size=32/seed=5/0.0-0.3.txt +93 -0
  10. lists/banking77/size=32/seed=5/0.0-0.7.txt +216 -0
  11. lists/banking77/size=32/seed=5/0.7-1.0.txt +92 -0
  12. lists/banking77/size=32/seed=7/0.0-0.3.txt +93 -0
  13. lists/banking77/size=32/seed=7/0.0-0.7.txt +216 -0
  14. lists/banking77/size=32/seed=7/0.0-1.0.txt +308 -0
  15. lists/banking77/size=32/seed=7/0.7-1.0.txt +92 -0
  16. prompts/basic_20newsgroups.yaml +22 -0
  17. prompts/basic_agnews.yaml +6 -0
  18. prompts/basic_banking77.yaml +79 -0
  19. prompts/basic_dbpedia.yaml +16 -0
  20. prompts/basic_sst2.yaml +4 -0
  21. src/llmcal/__init__.py +0 -0
  22. src/llmcal/scripts/__init__.py +0 -0
  23. src/llmcal/scripts/affine_calibration.old.py +219 -0
  24. src/llmcal/scripts/affine_calibration.py +203 -0
  25. src/llmcal/scripts/affine_prediction.py +36 -0
  26. src/llmcal/scripts/compare_models.py +110 -0
  27. src/llmcal/scripts/compute_matched_results.py +309 -0
  28. src/llmcal/scripts/create_lists_new.py +54 -0
  29. src/llmcal/scripts/evals.py +42 -0
  30. src/llmcal/scripts/prepare_data.py +64 -0
  31. src/llmcal/scripts/results_bars.py +184 -0
  32. src/llmcal/scripts/results_table.py +167 -0
  33. src/llmcal/scripts/results_vs_samples.py +181 -0
  34. src/llmcal/scripts/run_posteriors.py +193 -0
  35. src/llmcal/scripts/train_lora.py +418 -0
  36. src/llmcal/src/__init__.py +0 -0
  37. src/llmcal/src/evaluation/calibration.py +84 -0
  38. src/llmcal/src/evaluation/metrics.py +86 -0
  39. src/llmcal/src/loggers.py +41 -0
  40. src/llmcal/src/prompts/__init__.py +6 -0
  41. src/llmcal/src/prompts/gemma.py +39 -0
  42. src/llmcal/src/prompts/llama3.py +38 -0
  43. src/llmcal/src/prompts/phi.py +38 -0
  44. src/llmcal/src/prompts/pythia.py +38 -0
  45. src/llmcal/src/prompts/qwen.py +37 -0
  46. src/llmcal/src/prompts/tinyllama.py +40 -0
  47. src/llmcal/src/utils.py +93 -0
  48. src/llmcal/tests/__init__.py +0 -0
  49. src/llmcal/tests/check_lists.py +64 -0
lists/banking77/size=32/seed=0/0.0-0.3.txt ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 8200
2
+ 723
3
+ 832
4
+ 7702
5
+ 10558
6
+ 2724
7
+ 3991
8
+ 8087
9
+ 4580
10
+ 5829
11
+ 1223
12
+ 11974
13
+ 8991
14
+ 10115
15
+ 11593
16
+ 8587
17
+ 1287
18
+ 8882
19
+ 3913
20
+ 6998
21
+ 12163
22
+ 1284
23
+ 12023
24
+ 12975
25
+ 11480
26
+ 4042
27
+ 4016
28
+ 9031
29
+ 11034
30
+ 10159
31
+ 10039
32
+ 6372
33
+ 8787
34
+ 12040
35
+ 8743
36
+ 7326
37
+ 7556
38
+ 8780
39
+ 9351
40
+ 12522
41
+ 380
42
+ 6307
43
+ 9173
44
+ 4480
45
+ 11113
46
+ 3453
47
+ 5507
48
+ 12060
49
+ 3120
50
+ 2596
51
+ 3557
52
+ 3806
53
+ 9793
54
+ 802
55
+ 6929
56
+ 11922
57
+ 3416
58
+ 6763
59
+ 4630
60
+ 532
61
+ 1651
62
+ 7104
63
+ 6631
64
+ 12496
65
+ 9678
66
+ 12991
67
+ 10606
68
+ 5669
69
+ 11729
70
+ 6276
71
+ 720
72
+ 4933
73
+ 732
74
+ 5432
75
+ 12399
76
+ 703
77
+ 12141
78
+ 6878
79
+ 1120
80
+ 3015
81
+ 5379
82
+ 4540
83
+ 6314
84
+ 8604
85
+ 8896
86
+ 3014
87
+ 7668
88
+ 1800
89
+ 8216
90
+ 3996
91
+ 10239
92
+ 8844
93
+ 9544
lists/banking77/size=32/seed=0/0.0-0.7.txt ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 8200
2
+ 723
3
+ 832
4
+ 7702
5
+ 10558
6
+ 2724
7
+ 3991
8
+ 8087
9
+ 4580
10
+ 5829
11
+ 1223
12
+ 11974
13
+ 8991
14
+ 10115
15
+ 11593
16
+ 8587
17
+ 1287
18
+ 8882
19
+ 3913
20
+ 6998
21
+ 12163
22
+ 1284
23
+ 12023
24
+ 12975
25
+ 11480
26
+ 4042
27
+ 4016
28
+ 9031
29
+ 11034
30
+ 10159
31
+ 10039
32
+ 6372
33
+ 8787
34
+ 12040
35
+ 8743
36
+ 7326
37
+ 7556
38
+ 8780
39
+ 9351
40
+ 12522
41
+ 380
42
+ 6307
43
+ 9173
44
+ 4480
45
+ 11113
46
+ 3453
47
+ 5507
48
+ 12060
49
+ 3120
50
+ 2596
51
+ 3557
52
+ 3806
53
+ 9793
54
+ 802
55
+ 6929
56
+ 11922
57
+ 3416
58
+ 6763
59
+ 4630
60
+ 532
61
+ 1651
62
+ 7104
63
+ 6631
64
+ 12496
65
+ 9678
66
+ 12991
67
+ 10606
68
+ 5669
69
+ 11729
70
+ 6276
71
+ 720
72
+ 4933
73
+ 732
74
+ 5432
75
+ 12399
76
+ 703
77
+ 12141
78
+ 6878
79
+ 1120
80
+ 3015
81
+ 5379
82
+ 4540
83
+ 6314
84
+ 8604
85
+ 8896
86
+ 3014
87
+ 7668
88
+ 1800
89
+ 8216
90
+ 3996
91
+ 10239
92
+ 8844
93
+ 9544
94
+ 4185
95
+ 8509
96
+ 9747
97
+ 5551
98
+ 6861
99
+ 2508
100
+ 8876
101
+ 4049
102
+ 8649
103
+ 2897
104
+ 10500
105
+ 7933
106
+ 12190
107
+ 9294
108
+ 4577
109
+ 11197
110
+ 5034
111
+ 5740
112
+ 9929
113
+ 1076
114
+ 10826
115
+ 9482
116
+ 7239
117
+ 11458
118
+ 8430
119
+ 3317
120
+ 3801
121
+ 3892
122
+ 4620
123
+ 958
124
+ 9871
125
+ 6561
126
+ 7308
127
+ 6105
128
+ 12094
129
+ 9134
130
+ 7103
131
+ 1677
132
+ 81
133
+ 3702
134
+ 2250
135
+ 6531
136
+ 6483
137
+ 10062
138
+ 3815
139
+ 2512
140
+ 4387
141
+ 12747
142
+ 12246
143
+ 7676
144
+ 9474
145
+ 178
146
+ 6299
147
+ 354
148
+ 4428
149
+ 11401
150
+ 8852
151
+ 4722
152
+ 8793
153
+ 9823
154
+ 896
155
+ 13073
156
+ 11185
157
+ 11257
158
+ 7631
159
+ 3027
160
+ 2042
161
+ 6576
162
+ 1960
163
+ 3903
164
+ 10724
165
+ 10663
166
+ 141
167
+ 6986
168
+ 3985
169
+ 4417
170
+ 8987
171
+ 1236
172
+ 1082
173
+ 9826
174
+ 11057
175
+ 4219
176
+ 2678
177
+ 1323
178
+ 12575
179
+ 9321
180
+ 10018
181
+ 12394
182
+ 1636
183
+ 9990
184
+ 10476
185
+ 1691
186
+ 11576
187
+ 8843
188
+ 4986
189
+ 3414
190
+ 10343
191
+ 9605
192
+ 1929
193
+ 8469
194
+ 7339
195
+ 6385
196
+ 4176
197
+ 2444
198
+ 4996
199
+ 3766
200
+ 8447
201
+ 4314
202
+ 4582
203
+ 4971
204
+ 1221
205
+ 123
206
+ 4172
207
+ 7916
208
+ 10882
209
+ 7519
210
+ 9802
211
+ 4262
212
+ 4697
213
+ 10498
214
+ 4836
215
+ 7568
216
+ 8406
lists/banking77/size=32/seed=0/0.0-1.0.txt ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 8200
2
+ 723
3
+ 832
4
+ 7702
5
+ 10558
6
+ 2724
7
+ 3991
8
+ 8087
9
+ 4580
10
+ 5829
11
+ 1223
12
+ 11974
13
+ 8991
14
+ 10115
15
+ 11593
16
+ 8587
17
+ 1287
18
+ 8882
19
+ 3913
20
+ 6998
21
+ 12163
22
+ 1284
23
+ 12023
24
+ 12975
25
+ 11480
26
+ 4042
27
+ 4016
28
+ 9031
29
+ 11034
30
+ 10159
31
+ 10039
32
+ 6372
33
+ 8787
34
+ 12040
35
+ 8743
36
+ 7326
37
+ 7556
38
+ 8780
39
+ 9351
40
+ 12522
41
+ 380
42
+ 6307
43
+ 9173
44
+ 4480
45
+ 11113
46
+ 3453
47
+ 5507
48
+ 12060
49
+ 3120
50
+ 2596
51
+ 3557
52
+ 3806
53
+ 9793
54
+ 802
55
+ 6929
56
+ 11922
57
+ 3416
58
+ 6763
59
+ 4630
60
+ 532
61
+ 1651
62
+ 7104
63
+ 6631
64
+ 12496
65
+ 9678
66
+ 12991
67
+ 10606
68
+ 5669
69
+ 11729
70
+ 6276
71
+ 720
72
+ 4933
73
+ 732
74
+ 5432
75
+ 12399
76
+ 703
77
+ 12141
78
+ 6878
79
+ 1120
80
+ 3015
81
+ 5379
82
+ 4540
83
+ 6314
84
+ 8604
85
+ 8896
86
+ 3014
87
+ 7668
88
+ 1800
89
+ 8216
90
+ 3996
91
+ 10239
92
+ 8844
93
+ 9544
94
+ 4185
95
+ 8509
96
+ 9747
97
+ 5551
98
+ 6861
99
+ 2508
100
+ 8876
101
+ 4049
102
+ 8649
103
+ 2897
104
+ 10500
105
+ 7933
106
+ 12190
107
+ 9294
108
+ 4577
109
+ 11197
110
+ 5034
111
+ 5740
112
+ 9929
113
+ 1076
114
+ 10826
115
+ 9482
116
+ 7239
117
+ 11458
118
+ 8430
119
+ 3317
120
+ 3801
121
+ 3892
122
+ 4620
123
+ 958
124
+ 9871
125
+ 6561
126
+ 7308
127
+ 6105
128
+ 12094
129
+ 9134
130
+ 7103
131
+ 1677
132
+ 81
133
+ 3702
134
+ 2250
135
+ 6531
136
+ 6483
137
+ 10062
138
+ 3815
139
+ 2512
140
+ 4387
141
+ 12747
142
+ 12246
143
+ 7676
144
+ 9474
145
+ 178
146
+ 6299
147
+ 354
148
+ 4428
149
+ 11401
150
+ 8852
151
+ 4722
152
+ 8793
153
+ 9823
154
+ 896
155
+ 13073
156
+ 11185
157
+ 11257
158
+ 7631
159
+ 3027
160
+ 2042
161
+ 6576
162
+ 1960
163
+ 3903
164
+ 10724
165
+ 10663
166
+ 141
167
+ 6986
168
+ 3985
169
+ 4417
170
+ 8987
171
+ 1236
172
+ 1082
173
+ 9826
174
+ 11057
175
+ 4219
176
+ 2678
177
+ 1323
178
+ 12575
179
+ 9321
180
+ 10018
181
+ 12394
182
+ 1636
183
+ 9990
184
+ 10476
185
+ 1691
186
+ 11576
187
+ 8843
188
+ 4986
189
+ 3414
190
+ 10343
191
+ 9605
192
+ 1929
193
+ 8469
194
+ 7339
195
+ 6385
196
+ 4176
197
+ 2444
198
+ 4996
199
+ 3766
200
+ 8447
201
+ 4314
202
+ 4582
203
+ 4971
204
+ 1221
205
+ 123
206
+ 4172
207
+ 7916
208
+ 10882
209
+ 7519
210
+ 9802
211
+ 4262
212
+ 4697
213
+ 10498
214
+ 4836
215
+ 7568
216
+ 8406
217
+ 4673
218
+ 5543
219
+ 6448
220
+ 6818
221
+ 7075
222
+ 10106
223
+ 4743
224
+ 5779
225
+ 9052
226
+ 9100
227
+ 11452
228
+ 9203
229
+ 232
230
+ 129
231
+ 11705
232
+ 3924
233
+ 4110
234
+ 2775
235
+ 10424
236
+ 4254
237
+ 153
238
+ 12729
239
+ 11522
240
+ 12384
241
+ 3645
242
+ 8064
243
+ 1817
244
+ 7204
245
+ 2115
246
+ 9209
247
+ 6685
248
+ 3951
249
+ 5877
250
+ 9518
251
+ 7617
252
+ 11777
253
+ 5418
254
+ 401
255
+ 2447
256
+ 10679
257
+ 961
258
+ 10032
259
+ 263
260
+ 5308
261
+ 12198
262
+ 6637
263
+ 5749
264
+ 4684
265
+ 11457
266
+ 4403
267
+ 10389
268
+ 2016
269
+ 10707
270
+ 12159
271
+ 11978
272
+ 2990
273
+ 1993
274
+ 12951
275
+ 1978
276
+ 8162
277
+ 3548
278
+ 7764
279
+ 3678
280
+ 10087
281
+ 1759
282
+ 865
283
+ 3791
284
+ 4399
285
+ 12408
286
+ 6950
287
+ 11981
288
+ 6749
289
+ 4764
290
+ 6274
291
+ 3184
292
+ 6574
293
+ 2128
294
+ 4103
295
+ 11835
296
+ 9963
297
+ 7927
298
+ 833
299
+ 7475
300
+ 11897
301
+ 12568
302
+ 9281
303
+ 8815
304
+ 729
305
+ 10501
306
+ 7441
307
+ 146
308
+ 6673
lists/banking77/size=32/seed=0/0.7-1.0.txt ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 4673
2
+ 5543
3
+ 6448
4
+ 6818
5
+ 7075
6
+ 10106
7
+ 4743
8
+ 5779
9
+ 9052
10
+ 9100
11
+ 11452
12
+ 9203
13
+ 232
14
+ 129
15
+ 11705
16
+ 3924
17
+ 4110
18
+ 2775
19
+ 10424
20
+ 4254
21
+ 153
22
+ 12729
23
+ 11522
24
+ 12384
25
+ 3645
26
+ 8064
27
+ 1817
28
+ 7204
29
+ 2115
30
+ 9209
31
+ 6685
32
+ 3951
33
+ 5877
34
+ 9518
35
+ 7617
36
+ 11777
37
+ 5418
38
+ 401
39
+ 2447
40
+ 10679
41
+ 961
42
+ 10032
43
+ 263
44
+ 5308
45
+ 12198
46
+ 6637
47
+ 5749
48
+ 4684
49
+ 11457
50
+ 4403
51
+ 10389
52
+ 2016
53
+ 10707
54
+ 12159
55
+ 11978
56
+ 2990
57
+ 1993
58
+ 12951
59
+ 1978
60
+ 8162
61
+ 3548
62
+ 7764
63
+ 3678
64
+ 10087
65
+ 1759
66
+ 865
67
+ 3791
68
+ 4399
69
+ 12408
70
+ 6950
71
+ 11981
72
+ 6749
73
+ 4764
74
+ 6274
75
+ 3184
76
+ 6574
77
+ 2128
78
+ 4103
79
+ 11835
80
+ 9963
81
+ 7927
82
+ 833
83
+ 7475
84
+ 11897
85
+ 12568
86
+ 9281
87
+ 8815
88
+ 729
89
+ 10501
90
+ 7441
91
+ 146
92
+ 6673
lists/banking77/size=32/seed=1/0.0-0.3.txt ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 7371
2
+ 7722
3
+ 237
4
+ 1508
5
+ 8320
6
+ 10115
7
+ 10484
8
+ 7562
9
+ 11005
10
+ 11117
11
+ 12370
12
+ 8585
13
+ 10533
14
+ 1363
15
+ 7018
16
+ 6432
17
+ 12928
18
+ 12218
19
+ 3576
20
+ 9175
21
+ 3738
22
+ 9515
23
+ 9613
24
+ 6629
25
+ 5388
26
+ 3175
27
+ 5751
28
+ 10649
29
+ 1211
30
+ 10064
31
+ 3991
32
+ 656
33
+ 240
34
+ 8779
35
+ 9704
36
+ 3440
37
+ 4614
38
+ 6730
39
+ 12625
40
+ 8543
41
+ 8183
42
+ 1709
43
+ 8358
44
+ 2282
45
+ 12485
46
+ 10770
47
+ 12682
48
+ 12421
49
+ 10565
50
+ 6029
51
+ 5919
52
+ 4254
53
+ 325
54
+ 7110
55
+ 1783
56
+ 4549
57
+ 7609
58
+ 5403
59
+ 12647
60
+ 1573
61
+ 10490
62
+ 11056
63
+ 1163
64
+ 5285
65
+ 4996
66
+ 1065
67
+ 164
68
+ 9242
69
+ 11766
70
+ 4236
71
+ 8946
72
+ 445
73
+ 2060
74
+ 1272
75
+ 490
76
+ 11155
77
+ 870
78
+ 9137
79
+ 10214
80
+ 5465
81
+ 3050
82
+ 3521
83
+ 12526
84
+ 3693
85
+ 1604
86
+ 926
87
+ 12884
88
+ 7293
89
+ 8235
90
+ 11592
91
+ 9853
92
+ 12160
93
+ 3358
lists/banking77/size=32/seed=1/0.0-0.7.txt ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 7371
2
+ 7722
3
+ 237
4
+ 1508
5
+ 8320
6
+ 10115
7
+ 10484
8
+ 7562
9
+ 11005
10
+ 11117
11
+ 12370
12
+ 8585
13
+ 10533
14
+ 1363
15
+ 7018
16
+ 6432
17
+ 12928
18
+ 12218
19
+ 3576
20
+ 9175
21
+ 3738
22
+ 9515
23
+ 9613
24
+ 6629
25
+ 5388
26
+ 3175
27
+ 5751
28
+ 10649
29
+ 1211
30
+ 10064
31
+ 3991
32
+ 656
33
+ 240
34
+ 8779
35
+ 9704
36
+ 3440
37
+ 4614
38
+ 6730
39
+ 12625
40
+ 8543
41
+ 8183
42
+ 1709
43
+ 8358
44
+ 2282
45
+ 12485
46
+ 10770
47
+ 12682
48
+ 12421
49
+ 10565
50
+ 6029
51
+ 5919
52
+ 4254
53
+ 325
54
+ 7110
55
+ 1783
56
+ 4549
57
+ 7609
58
+ 5403
59
+ 12647
60
+ 1573
61
+ 10490
62
+ 11056
63
+ 1163
64
+ 5285
65
+ 4996
66
+ 1065
67
+ 164
68
+ 9242
69
+ 11766
70
+ 4236
71
+ 8946
72
+ 445
73
+ 2060
74
+ 1272
75
+ 490
76
+ 11155
77
+ 870
78
+ 9137
79
+ 10214
80
+ 5465
81
+ 3050
82
+ 3521
83
+ 12526
84
+ 3693
85
+ 1604
86
+ 926
87
+ 12884
88
+ 7293
89
+ 8235
90
+ 11592
91
+ 9853
92
+ 12160
93
+ 3358
94
+ 10800
95
+ 736
96
+ 6101
97
+ 6314
98
+ 213
99
+ 2236
100
+ 6328
101
+ 7958
102
+ 1900
103
+ 8852
104
+ 8099
105
+ 825
106
+ 9076
107
+ 2657
108
+ 379
109
+ 467
110
+ 5597
111
+ 1295
112
+ 12639
113
+ 11607
114
+ 6065
115
+ 1862
116
+ 2118
117
+ 7983
118
+ 259
119
+ 1082
120
+ 2765
121
+ 3786
122
+ 7856
123
+ 3212
124
+ 4324
125
+ 48
126
+ 7946
127
+ 6557
128
+ 6112
129
+ 12501
130
+ 11824
131
+ 113
132
+ 12887
133
+ 11921
134
+ 7881
135
+ 127
136
+ 1918
137
+ 12086
138
+ 13011
139
+ 12008
140
+ 10473
141
+ 12614
142
+ 1764
143
+ 12801
144
+ 6706
145
+ 12657
146
+ 4177
147
+ 12220
148
+ 5716
149
+ 6404
150
+ 10950
151
+ 5222
152
+ 8941
153
+ 3074
154
+ 9228
155
+ 4702
156
+ 1070
157
+ 5390
158
+ 9078
159
+ 7906
160
+ 648
161
+ 10184
162
+ 6956
163
+ 3729
164
+ 4558
165
+ 9058
166
+ 11021
167
+ 10202
168
+ 2161
169
+ 661
170
+ 8950
171
+ 1964
172
+ 6264
173
+ 7259
174
+ 2816
175
+ 2238
176
+ 4006
177
+ 12610
178
+ 6687
179
+ 6103
180
+ 6295
181
+ 1580
182
+ 1434
183
+ 2363
184
+ 10470
185
+ 7450
186
+ 770
187
+ 12124
188
+ 3476
189
+ 4998
190
+ 3026
191
+ 7244
192
+ 3125
193
+ 9718
194
+ 979
195
+ 3438
196
+ 1879
197
+ 7735
198
+ 1673
199
+ 8734
200
+ 7242
201
+ 5291
202
+ 10405
203
+ 3962
204
+ 5945
205
+ 8776
206
+ 1471
207
+ 5522
208
+ 12474
209
+ 1279
210
+ 6535
211
+ 12479
212
+ 1574
213
+ 11172
214
+ 2583
215
+ 6136
216
+ 2731
lists/banking77/size=32/seed=1/0.0-1.0.txt ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 7371
2
+ 7722
3
+ 237
4
+ 1508
5
+ 8320
6
+ 10115
7
+ 10484
8
+ 7562
9
+ 11005
10
+ 11117
11
+ 12370
12
+ 8585
13
+ 10533
14
+ 1363
15
+ 7018
16
+ 6432
17
+ 12928
18
+ 12218
19
+ 3576
20
+ 9175
21
+ 3738
22
+ 9515
23
+ 9613
24
+ 6629
25
+ 5388
26
+ 3175
27
+ 5751
28
+ 10649
29
+ 1211
30
+ 10064
31
+ 3991
32
+ 656
33
+ 240
34
+ 8779
35
+ 9704
36
+ 3440
37
+ 4614
38
+ 6730
39
+ 12625
40
+ 8543
41
+ 8183
42
+ 1709
43
+ 8358
44
+ 2282
45
+ 12485
46
+ 10770
47
+ 12682
48
+ 12421
49
+ 10565
50
+ 6029
51
+ 5919
52
+ 4254
53
+ 325
54
+ 7110
55
+ 1783
56
+ 4549
57
+ 7609
58
+ 5403
59
+ 12647
60
+ 1573
61
+ 10490
62
+ 11056
63
+ 1163
64
+ 5285
65
+ 4996
66
+ 1065
67
+ 164
68
+ 9242
69
+ 11766
70
+ 4236
71
+ 8946
72
+ 445
73
+ 2060
74
+ 1272
75
+ 490
76
+ 11155
77
+ 870
78
+ 9137
79
+ 10214
80
+ 5465
81
+ 3050
82
+ 3521
83
+ 12526
84
+ 3693
85
+ 1604
86
+ 926
87
+ 12884
88
+ 7293
89
+ 8235
90
+ 11592
91
+ 9853
92
+ 12160
93
+ 3358
94
+ 10800
95
+ 736
96
+ 6101
97
+ 6314
98
+ 213
99
+ 2236
100
+ 6328
101
+ 7958
102
+ 1900
103
+ 8852
104
+ 8099
105
+ 825
106
+ 9076
107
+ 2657
108
+ 379
109
+ 467
110
+ 5597
111
+ 1295
112
+ 12639
113
+ 11607
114
+ 6065
115
+ 1862
116
+ 2118
117
+ 7983
118
+ 259
119
+ 1082
120
+ 2765
121
+ 3786
122
+ 7856
123
+ 3212
124
+ 4324
125
+ 48
126
+ 7946
127
+ 6557
128
+ 6112
129
+ 12501
130
+ 11824
131
+ 113
132
+ 12887
133
+ 11921
134
+ 7881
135
+ 127
136
+ 1918
137
+ 12086
138
+ 13011
139
+ 12008
140
+ 10473
141
+ 12614
142
+ 1764
143
+ 12801
144
+ 6706
145
+ 12657
146
+ 4177
147
+ 12220
148
+ 5716
149
+ 6404
150
+ 10950
151
+ 5222
152
+ 8941
153
+ 3074
154
+ 9228
155
+ 4702
156
+ 1070
157
+ 5390
158
+ 9078
159
+ 7906
160
+ 648
161
+ 10184
162
+ 6956
163
+ 3729
164
+ 4558
165
+ 9058
166
+ 11021
167
+ 10202
168
+ 2161
169
+ 661
170
+ 8950
171
+ 1964
172
+ 6264
173
+ 7259
174
+ 2816
175
+ 2238
176
+ 4006
177
+ 12610
178
+ 6687
179
+ 6103
180
+ 6295
181
+ 1580
182
+ 1434
183
+ 2363
184
+ 10470
185
+ 7450
186
+ 770
187
+ 12124
188
+ 3476
189
+ 4998
190
+ 3026
191
+ 7244
192
+ 3125
193
+ 9718
194
+ 979
195
+ 3438
196
+ 1879
197
+ 7735
198
+ 1673
199
+ 8734
200
+ 7242
201
+ 5291
202
+ 10405
203
+ 3962
204
+ 5945
205
+ 8776
206
+ 1471
207
+ 5522
208
+ 12474
209
+ 1279
210
+ 6535
211
+ 12479
212
+ 1574
213
+ 11172
214
+ 2583
215
+ 6136
216
+ 2731
217
+ 10635
218
+ 2969
219
+ 1296
220
+ 4234
221
+ 4315
222
+ 8738
223
+ 1136
224
+ 10197
225
+ 11782
226
+ 5045
227
+ 10508
228
+ 4615
229
+ 1168
230
+ 7283
231
+ 3084
232
+ 7959
233
+ 6021
234
+ 12930
235
+ 11934
236
+ 4737
237
+ 6411
238
+ 2343
239
+ 10562
240
+ 1985
241
+ 8863
242
+ 10839
243
+ 5481
244
+ 3007
245
+ 9785
246
+ 3434
247
+ 4022
248
+ 2037
249
+ 11609
250
+ 348
251
+ 3069
252
+ 11783
253
+ 4367
254
+ 6096
255
+ 12665
256
+ 3573
257
+ 9385
258
+ 12224
259
+ 5476
260
+ 188
261
+ 7511
262
+ 10482
263
+ 2503
264
+ 982
265
+ 3357
266
+ 10371
267
+ 3766
268
+ 3403
269
+ 7971
270
+ 10859
271
+ 10099
272
+ 5980
273
+ 9315
274
+ 5394
275
+ 1005
276
+ 1572
277
+ 8014
278
+ 3843
279
+ 8243
280
+ 7102
281
+ 2266
282
+ 9030
283
+ 8820
284
+ 1634
285
+ 1287
286
+ 11855
287
+ 5641
288
+ 4943
289
+ 1978
290
+ 9467
291
+ 7116
292
+ 2100
293
+ 205
294
+ 9279
295
+ 12274
296
+ 12322
297
+ 9549
298
+ 10845
299
+ 1217
300
+ 12860
301
+ 4173
302
+ 8858
303
+ 1624
304
+ 11055
305
+ 292
306
+ 4825
307
+ 1246
308
+ 7758
lists/banking77/size=32/seed=1/0.7-1.0.txt ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 10635
2
+ 2969
3
+ 1296
4
+ 4234
5
+ 4315
6
+ 8738
7
+ 1136
8
+ 10197
9
+ 11782
10
+ 5045
11
+ 10508
12
+ 4615
13
+ 1168
14
+ 7283
15
+ 3084
16
+ 7959
17
+ 6021
18
+ 12930
19
+ 11934
20
+ 4737
21
+ 6411
22
+ 2343
23
+ 10562
24
+ 1985
25
+ 8863
26
+ 10839
27
+ 5481
28
+ 3007
29
+ 9785
30
+ 3434
31
+ 4022
32
+ 2037
33
+ 11609
34
+ 348
35
+ 3069
36
+ 11783
37
+ 4367
38
+ 6096
39
+ 12665
40
+ 3573
41
+ 9385
42
+ 12224
43
+ 5476
44
+ 188
45
+ 7511
46
+ 10482
47
+ 2503
48
+ 982
49
+ 3357
50
+ 10371
51
+ 3766
52
+ 3403
53
+ 7971
54
+ 10859
55
+ 10099
56
+ 5980
57
+ 9315
58
+ 5394
59
+ 1005
60
+ 1572
61
+ 8014
62
+ 3843
63
+ 8243
64
+ 7102
65
+ 2266
66
+ 9030
67
+ 8820
68
+ 1634
69
+ 1287
70
+ 11855
71
+ 5641
72
+ 4943
73
+ 1978
74
+ 9467
75
+ 7116
76
+ 2100
77
+ 205
78
+ 9279
79
+ 12274
80
+ 12322
81
+ 9549
82
+ 10845
83
+ 1217
84
+ 12860
85
+ 4173
86
+ 8858
87
+ 1624
88
+ 11055
89
+ 292
90
+ 4825
91
+ 1246
92
+ 7758
lists/banking77/size=32/seed=5/0.0-0.3.txt ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 9151
2
+ 4098
3
+ 12941
4
+ 8514
5
+ 2177
6
+ 3841
7
+ 3254
8
+ 11480
9
+ 9917
10
+ 6270
11
+ 9678
12
+ 2985
13
+ 8279
14
+ 11508
15
+ 9748
16
+ 10502
17
+ 6084
18
+ 10441
19
+ 8371
20
+ 4804
21
+ 12406
22
+ 8361
23
+ 12765
24
+ 2900
25
+ 1513
26
+ 11047
27
+ 100
28
+ 6154
29
+ 9317
30
+ 1092
31
+ 9808
32
+ 6627
33
+ 4111
34
+ 8865
35
+ 3155
36
+ 3062
37
+ 6220
38
+ 10334
39
+ 6618
40
+ 8101
41
+ 10724
42
+ 4806
43
+ 6714
44
+ 6002
45
+ 12015
46
+ 1501
47
+ 2733
48
+ 8538
49
+ 3792
50
+ 633
51
+ 3498
52
+ 3557
53
+ 2328
54
+ 9128
55
+ 9832
56
+ 8211
57
+ 10844
58
+ 13041
59
+ 5174
60
+ 3227
61
+ 10922
62
+ 1765
63
+ 11687
64
+ 5293
65
+ 6805
66
+ 4557
67
+ 9247
68
+ 12375
69
+ 3946
70
+ 9754
71
+ 2673
72
+ 9204
73
+ 7378
74
+ 7888
75
+ 8893
76
+ 9805
77
+ 11982
78
+ 5847
79
+ 5242
80
+ 11006
81
+ 954
82
+ 8193
83
+ 8644
84
+ 3074
85
+ 12140
86
+ 9429
87
+ 2137
88
+ 8592
89
+ 5615
90
+ 8239
91
+ 10454
92
+ 1637
93
+ 12218
lists/banking77/size=32/seed=5/0.0-0.7.txt ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 9151
2
+ 4098
3
+ 12941
4
+ 8514
5
+ 2177
6
+ 3841
7
+ 3254
8
+ 11480
9
+ 9917
10
+ 6270
11
+ 9678
12
+ 2985
13
+ 8279
14
+ 11508
15
+ 9748
16
+ 10502
17
+ 6084
18
+ 10441
19
+ 8371
20
+ 4804
21
+ 12406
22
+ 8361
23
+ 12765
24
+ 2900
25
+ 1513
26
+ 11047
27
+ 100
28
+ 6154
29
+ 9317
30
+ 1092
31
+ 9808
32
+ 6627
33
+ 4111
34
+ 8865
35
+ 3155
36
+ 3062
37
+ 6220
38
+ 10334
39
+ 6618
40
+ 8101
41
+ 10724
42
+ 4806
43
+ 6714
44
+ 6002
45
+ 12015
46
+ 1501
47
+ 2733
48
+ 8538
49
+ 3792
50
+ 633
51
+ 3498
52
+ 3557
53
+ 2328
54
+ 9128
55
+ 9832
56
+ 8211
57
+ 10844
58
+ 13041
59
+ 5174
60
+ 3227
61
+ 10922
62
+ 1765
63
+ 11687
64
+ 5293
65
+ 6805
66
+ 4557
67
+ 9247
68
+ 12375
69
+ 3946
70
+ 9754
71
+ 2673
72
+ 9204
73
+ 7378
74
+ 7888
75
+ 8893
76
+ 9805
77
+ 11982
78
+ 5847
79
+ 5242
80
+ 11006
81
+ 954
82
+ 8193
83
+ 8644
84
+ 3074
85
+ 12140
86
+ 9429
87
+ 2137
88
+ 8592
89
+ 5615
90
+ 8239
91
+ 10454
92
+ 1637
93
+ 12218
94
+ 7735
95
+ 9755
96
+ 7741
97
+ 2512
98
+ 4198
99
+ 12283
100
+ 2710
101
+ 4701
102
+ 7681
103
+ 10950
104
+ 8175
105
+ 864
106
+ 5084
107
+ 5228
108
+ 11493
109
+ 8104
110
+ 5961
111
+ 3719
112
+ 9848
113
+ 5783
114
+ 5543
115
+ 5856
116
+ 3459
117
+ 12433
118
+ 8383
119
+ 1565
120
+ 3611
121
+ 5234
122
+ 997
123
+ 8520
124
+ 5371
125
+ 9315
126
+ 1886
127
+ 3141
128
+ 9721
129
+ 3008
130
+ 5945
131
+ 9011
132
+ 12214
133
+ 8162
134
+ 11645
135
+ 5735
136
+ 12287
137
+ 12118
138
+ 4269
139
+ 5197
140
+ 10857
141
+ 7711
142
+ 10184
143
+ 4645
144
+ 12987
145
+ 8465
146
+ 8599
147
+ 12509
148
+ 4488
149
+ 1249
150
+ 6350
151
+ 4856
152
+ 4026
153
+ 7495
154
+ 11101
155
+ 108
156
+ 5663
157
+ 6254
158
+ 12611
159
+ 10597
160
+ 13065
161
+ 2965
162
+ 8648
163
+ 6900
164
+ 11077
165
+ 8585
166
+ 2450
167
+ 2749
168
+ 2738
169
+ 5239
170
+ 8460
171
+ 3725
172
+ 12718
173
+ 10914
174
+ 4596
175
+ 8745
176
+ 3406
177
+ 5527
178
+ 2
179
+ 3099
180
+ 3675
181
+ 8331
182
+ 9685
183
+ 12838
184
+ 5506
185
+ 5075
186
+ 10994
187
+ 1966
188
+ 3036
189
+ 34
190
+ 8000
191
+ 11625
192
+ 9500
193
+ 5456
194
+ 12236
195
+ 2497
196
+ 10170
197
+ 7396
198
+ 5751
199
+ 2984
200
+ 6809
201
+ 1186
202
+ 2322
203
+ 6489
204
+ 7885
205
+ 11127
206
+ 6145
207
+ 2858
208
+ 5625
209
+ 6044
210
+ 10496
211
+ 4756
212
+ 11825
213
+ 5291
214
+ 7115
215
+ 9199
216
+ 8223
lists/banking77/size=32/seed=5/0.7-1.0.txt ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 3920
2
+ 9749
3
+ 9549
4
+ 12269
5
+ 1941
6
+ 9574
7
+ 9130
8
+ 12726
9
+ 9993
10
+ 6748
11
+ 12666
12
+ 5323
13
+ 10057
14
+ 12294
15
+ 7938
16
+ 9448
17
+ 10814
18
+ 11019
19
+ 5905
20
+ 8576
21
+ 4598
22
+ 5675
23
+ 6820
24
+ 816
25
+ 3649
26
+ 7154
27
+ 11695
28
+ 2652
29
+ 5629
30
+ 10227
31
+ 6224
32
+ 1862
33
+ 2682
34
+ 12522
35
+ 7695
36
+ 4294
37
+ 11061
38
+ 11293
39
+ 9359
40
+ 8580
41
+ 151
42
+ 11078
43
+ 9792
44
+ 12324
45
+ 11690
46
+ 6309
47
+ 11451
48
+ 1346
49
+ 12969
50
+ 4896
51
+ 1653
52
+ 6142
53
+ 4962
54
+ 9261
55
+ 230
56
+ 5327
57
+ 7936
58
+ 7578
59
+ 6236
60
+ 2864
61
+ 7403
62
+ 12023
63
+ 1702
64
+ 5954
65
+ 10963
66
+ 8016
67
+ 4350
68
+ 12573
69
+ 6610
70
+ 3876
71
+ 1709
72
+ 3865
73
+ 88
74
+ 8993
75
+ 577
76
+ 10642
77
+ 1949
78
+ 2516
79
+ 3246
80
+ 8212
81
+ 10051
82
+ 5230
83
+ 2702
84
+ 12963
85
+ 5984
86
+ 4736
87
+ 3007
88
+ 2795
89
+ 7801
90
+ 7548
91
+ 9273
92
+ 4867
lists/banking77/size=32/seed=7/0.0-0.3.txt ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 13000
2
+ 7386
3
+ 6913
4
+ 6923
5
+ 4771
6
+ 8791
7
+ 7496
8
+ 6009
9
+ 9291
10
+ 8127
11
+ 8364
12
+ 2058
13
+ 5769
14
+ 263
15
+ 497
16
+ 7121
17
+ 5646
18
+ 5452
19
+ 8807
20
+ 11304
21
+ 10972
22
+ 11111
23
+ 11112
24
+ 1455
25
+ 12932
26
+ 73
27
+ 11694
28
+ 37
29
+ 2071
30
+ 6805
31
+ 9312
32
+ 1363
33
+ 1070
34
+ 5918
35
+ 4126
36
+ 4370
37
+ 460
38
+ 188
39
+ 861
40
+ 5737
41
+ 9981
42
+ 11911
43
+ 7569
44
+ 8347
45
+ 4018
46
+ 1714
47
+ 9266
48
+ 11128
49
+ 4921
50
+ 9072
51
+ 6213
52
+ 9223
53
+ 8050
54
+ 4225
55
+ 11824
56
+ 6687
57
+ 9959
58
+ 5900
59
+ 3696
60
+ 5504
61
+ 11545
62
+ 5731
63
+ 9731
64
+ 11566
65
+ 9910
66
+ 8338
67
+ 12669
68
+ 1439
69
+ 9584
70
+ 6315
71
+ 11492
72
+ 5969
73
+ 10499
74
+ 10993
75
+ 9847
76
+ 6115
77
+ 4718
78
+ 62
79
+ 9150
80
+ 8352
81
+ 11740
82
+ 4990
83
+ 9997
84
+ 405
85
+ 8467
86
+ 783
87
+ 2879
88
+ 4409
89
+ 11641
90
+ 1719
91
+ 7403
92
+ 9560
93
+ 10739
lists/banking77/size=32/seed=7/0.0-0.7.txt ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 13000
2
+ 7386
3
+ 6913
4
+ 6923
5
+ 4771
6
+ 8791
7
+ 7496
8
+ 6009
9
+ 9291
10
+ 8127
11
+ 8364
12
+ 2058
13
+ 5769
14
+ 263
15
+ 497
16
+ 7121
17
+ 5646
18
+ 5452
19
+ 8807
20
+ 11304
21
+ 10972
22
+ 11111
23
+ 11112
24
+ 1455
25
+ 12932
26
+ 73
27
+ 11694
28
+ 37
29
+ 2071
30
+ 6805
31
+ 9312
32
+ 1363
33
+ 1070
34
+ 5918
35
+ 4126
36
+ 4370
37
+ 460
38
+ 188
39
+ 861
40
+ 5737
41
+ 9981
42
+ 11911
43
+ 7569
44
+ 8347
45
+ 4018
46
+ 1714
47
+ 9266
48
+ 11128
49
+ 4921
50
+ 9072
51
+ 6213
52
+ 9223
53
+ 8050
54
+ 4225
55
+ 11824
56
+ 6687
57
+ 9959
58
+ 5900
59
+ 3696
60
+ 5504
61
+ 11545
62
+ 5731
63
+ 9731
64
+ 11566
65
+ 9910
66
+ 8338
67
+ 12669
68
+ 1439
69
+ 9584
70
+ 6315
71
+ 11492
72
+ 5969
73
+ 10499
74
+ 10993
75
+ 9847
76
+ 6115
77
+ 4718
78
+ 62
79
+ 9150
80
+ 8352
81
+ 11740
82
+ 4990
83
+ 9997
84
+ 405
85
+ 8467
86
+ 783
87
+ 2879
88
+ 4409
89
+ 11641
90
+ 1719
91
+ 7403
92
+ 9560
93
+ 10739
94
+ 7914
95
+ 2178
96
+ 6281
97
+ 3317
98
+ 8013
99
+ 10323
100
+ 7974
101
+ 6993
102
+ 1124
103
+ 9135
104
+ 2727
105
+ 3173
106
+ 3187
107
+ 3997
108
+ 12941
109
+ 10080
110
+ 3557
111
+ 1845
112
+ 390
113
+ 6406
114
+ 1058
115
+ 3439
116
+ 6828
117
+ 2593
118
+ 8350
119
+ 6862
120
+ 11809
121
+ 10470
122
+ 3086
123
+ 2048
124
+ 4366
125
+ 6729
126
+ 12244
127
+ 8945
128
+ 6469
129
+ 2143
130
+ 8790
131
+ 1252
132
+ 12153
133
+ 10093
134
+ 5914
135
+ 6056
136
+ 8720
137
+ 1809
138
+ 11414
139
+ 139
140
+ 4808
141
+ 3016
142
+ 8704
143
+ 11306
144
+ 9157
145
+ 5233
146
+ 1459
147
+ 98
148
+ 2449
149
+ 11750
150
+ 4541
151
+ 1272
152
+ 7637
153
+ 8616
154
+ 7205
155
+ 8599
156
+ 12872
157
+ 4083
158
+ 8591
159
+ 6337
160
+ 5711
161
+ 5771
162
+ 9057
163
+ 11667
164
+ 9548
165
+ 10941
166
+ 11294
167
+ 9670
168
+ 6073
169
+ 925
170
+ 4463
171
+ 2425
172
+ 11915
173
+ 2232
174
+ 6041
175
+ 2282
176
+ 12767
177
+ 2191
178
+ 6649
179
+ 11067
180
+ 10988
181
+ 4690
182
+ 10717
183
+ 288
184
+ 5403
185
+ 2116
186
+ 10815
187
+ 2249
188
+ 6329
189
+ 7290
190
+ 10531
191
+ 12888
192
+ 13071
193
+ 10318
194
+ 8373
195
+ 4462
196
+ 6876
197
+ 7204
198
+ 7362
199
+ 2835
200
+ 8353
201
+ 4432
202
+ 11354
203
+ 8852
204
+ 4629
205
+ 12266
206
+ 8970
207
+ 10152
208
+ 56
209
+ 3277
210
+ 4593
211
+ 13077
212
+ 6348
213
+ 9217
214
+ 2934
215
+ 9546
216
+ 2161
lists/banking77/size=32/seed=7/0.0-1.0.txt ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 13000
2
+ 7386
3
+ 6913
4
+ 6923
5
+ 4771
6
+ 8791
7
+ 7496
8
+ 6009
9
+ 9291
10
+ 8127
11
+ 8364
12
+ 2058
13
+ 5769
14
+ 263
15
+ 497
16
+ 7121
17
+ 5646
18
+ 5452
19
+ 8807
20
+ 11304
21
+ 10972
22
+ 11111
23
+ 11112
24
+ 1455
25
+ 12932
26
+ 73
27
+ 11694
28
+ 37
29
+ 2071
30
+ 6805
31
+ 9312
32
+ 1363
33
+ 1070
34
+ 5918
35
+ 4126
36
+ 4370
37
+ 460
38
+ 188
39
+ 861
40
+ 5737
41
+ 9981
42
+ 11911
43
+ 7569
44
+ 8347
45
+ 4018
46
+ 1714
47
+ 9266
48
+ 11128
49
+ 4921
50
+ 9072
51
+ 6213
52
+ 9223
53
+ 8050
54
+ 4225
55
+ 11824
56
+ 6687
57
+ 9959
58
+ 5900
59
+ 3696
60
+ 5504
61
+ 11545
62
+ 5731
63
+ 9731
64
+ 11566
65
+ 9910
66
+ 8338
67
+ 12669
68
+ 1439
69
+ 9584
70
+ 6315
71
+ 11492
72
+ 5969
73
+ 10499
74
+ 10993
75
+ 9847
76
+ 6115
77
+ 4718
78
+ 62
79
+ 9150
80
+ 8352
81
+ 11740
82
+ 4990
83
+ 9997
84
+ 405
85
+ 8467
86
+ 783
87
+ 2879
88
+ 4409
89
+ 11641
90
+ 1719
91
+ 7403
92
+ 9560
93
+ 10739
94
+ 7914
95
+ 2178
96
+ 6281
97
+ 3317
98
+ 8013
99
+ 10323
100
+ 7974
101
+ 6993
102
+ 1124
103
+ 9135
104
+ 2727
105
+ 3173
106
+ 3187
107
+ 3997
108
+ 12941
109
+ 10080
110
+ 3557
111
+ 1845
112
+ 390
113
+ 6406
114
+ 1058
115
+ 3439
116
+ 6828
117
+ 2593
118
+ 8350
119
+ 6862
120
+ 11809
121
+ 10470
122
+ 3086
123
+ 2048
124
+ 4366
125
+ 6729
126
+ 12244
127
+ 8945
128
+ 6469
129
+ 2143
130
+ 8790
131
+ 1252
132
+ 12153
133
+ 10093
134
+ 5914
135
+ 6056
136
+ 8720
137
+ 1809
138
+ 11414
139
+ 139
140
+ 4808
141
+ 3016
142
+ 8704
143
+ 11306
144
+ 9157
145
+ 5233
146
+ 1459
147
+ 98
148
+ 2449
149
+ 11750
150
+ 4541
151
+ 1272
152
+ 7637
153
+ 8616
154
+ 7205
155
+ 8599
156
+ 12872
157
+ 4083
158
+ 8591
159
+ 6337
160
+ 5711
161
+ 5771
162
+ 9057
163
+ 11667
164
+ 9548
165
+ 10941
166
+ 11294
167
+ 9670
168
+ 6073
169
+ 925
170
+ 4463
171
+ 2425
172
+ 11915
173
+ 2232
174
+ 6041
175
+ 2282
176
+ 12767
177
+ 2191
178
+ 6649
179
+ 11067
180
+ 10988
181
+ 4690
182
+ 10717
183
+ 288
184
+ 5403
185
+ 2116
186
+ 10815
187
+ 2249
188
+ 6329
189
+ 7290
190
+ 10531
191
+ 12888
192
+ 13071
193
+ 10318
194
+ 8373
195
+ 4462
196
+ 6876
197
+ 7204
198
+ 7362
199
+ 2835
200
+ 8353
201
+ 4432
202
+ 11354
203
+ 8852
204
+ 4629
205
+ 12266
206
+ 8970
207
+ 10152
208
+ 56
209
+ 3277
210
+ 4593
211
+ 13077
212
+ 6348
213
+ 9217
214
+ 2934
215
+ 9546
216
+ 2161
217
+ 3302
218
+ 11311
219
+ 6134
220
+ 10786
221
+ 4451
222
+ 3519
223
+ 10932
224
+ 6309
225
+ 4710
226
+ 5751
227
+ 231
228
+ 8558
229
+ 1275
230
+ 154
231
+ 11966
232
+ 12113
233
+ 6060
234
+ 7269
235
+ 2979
236
+ 7270
237
+ 11919
238
+ 5222
239
+ 88
240
+ 1592
241
+ 8725
242
+ 6583
243
+ 4792
244
+ 2713
245
+ 9258
246
+ 11816
247
+ 2268
248
+ 7014
249
+ 10837
250
+ 9493
251
+ 219
252
+ 10660
253
+ 11781
254
+ 7854
255
+ 3742
256
+ 7040
257
+ 11961
258
+ 39
259
+ 12412
260
+ 6119
261
+ 12132
262
+ 2897
263
+ 12583
264
+ 7671
265
+ 5126
266
+ 11689
267
+ 1107
268
+ 5472
269
+ 10630
270
+ 7562
271
+ 8901
272
+ 179
273
+ 8693
274
+ 3908
275
+ 9583
276
+ 8069
277
+ 1847
278
+ 902
279
+ 421
280
+ 7544
281
+ 8953
282
+ 9438
283
+ 11537
284
+ 8004
285
+ 11547
286
+ 12557
287
+ 8439
288
+ 349
289
+ 8924
290
+ 4111
291
+ 228
292
+ 10192
293
+ 323
294
+ 10135
295
+ 12743
296
+ 2137
297
+ 10546
298
+ 10814
299
+ 1490
300
+ 7723
301
+ 6345
302
+ 6475
303
+ 3069
304
+ 9827
305
+ 1064
306
+ 1532
307
+ 4926
308
+ 4797
lists/banking77/size=32/seed=7/0.7-1.0.txt ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 3302
2
+ 11311
3
+ 6134
4
+ 10786
5
+ 4451
6
+ 3519
7
+ 10932
8
+ 6309
9
+ 4710
10
+ 5751
11
+ 231
12
+ 8558
13
+ 1275
14
+ 154
15
+ 11966
16
+ 12113
17
+ 6060
18
+ 7269
19
+ 2979
20
+ 7270
21
+ 11919
22
+ 5222
23
+ 88
24
+ 1592
25
+ 8725
26
+ 6583
27
+ 4792
28
+ 2713
29
+ 9258
30
+ 11816
31
+ 2268
32
+ 7014
33
+ 10837
34
+ 9493
35
+ 219
36
+ 10660
37
+ 11781
38
+ 7854
39
+ 3742
40
+ 7040
41
+ 11961
42
+ 39
43
+ 12412
44
+ 6119
45
+ 12132
46
+ 2897
47
+ 12583
48
+ 7671
49
+ 5126
50
+ 11689
51
+ 1107
52
+ 5472
53
+ 10630
54
+ 7562
55
+ 8901
56
+ 179
57
+ 8693
58
+ 3908
59
+ 9583
60
+ 8069
61
+ 1847
62
+ 902
63
+ 421
64
+ 7544
65
+ 8953
66
+ 9438
67
+ 11537
68
+ 8004
69
+ 11547
70
+ 12557
71
+ 8439
72
+ 349
73
+ 8924
74
+ 4111
75
+ 228
76
+ 10192
77
+ 323
78
+ 10135
79
+ 12743
80
+ 2137
81
+ 10546
82
+ 10814
83
+ 1490
84
+ 7723
85
+ 6345
86
+ 6475
87
+ 3069
88
+ 9827
89
+ 1064
90
+ 1532
91
+ 4926
92
+ 4797
prompts/basic_20newsgroups.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompt_template: "Determine the category of the posted document given by the user."
2
+ answers_templates:
3
+ - "Atheism"
4
+ - "Graphics"
5
+ - "Microsoft"
6
+ - "IBM Hardware"
7
+ - "Mac Hardware"
8
+ - "X Window System"
9
+ - "Sales"
10
+ - "Cars"
11
+ - "Motorcycles"
12
+ - "Baseball"
13
+ - "Hockey"
14
+ - "Cryptography"
15
+ - "Electronics"
16
+ - "Medicine"
17
+ - "Space"
18
+ - "Christianity"
19
+ - "Guns"
20
+ - "Middle East"
21
+ - "Politics"
22
+ - "Religion"
prompts/basic_agnews.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ prompt_template: "Determine the category of the news article given by the user."
2
+ answers_templates:
3
+ - "World"
4
+ - "Sports"
5
+ - "Business"
6
+ - "Science and Technology"
prompts/basic_banking77.yaml ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompt_template: "Classify the intent of the question input by the user."
2
+ answers_templates:
3
+ - "Active my card"
4
+ - "Age limit"
5
+ - "Apple pay or google pay"
6
+ - "ATM support"
7
+ - "Automatic top up"
8
+ - "Balance not updated after bank transfer"
9
+ - "Balance not updated after cheque or cash deposit"
10
+ - "Beneficiary not allowed"
11
+ - "Cancel transfer"
12
+ - "Card about to expire"
13
+ - "Card acceptance"
14
+ - "Card arrival"
15
+ - "Card delivery estimate"
16
+ - "Card linking"
17
+ - "Card not working"
18
+ - "Card payment fee charged"
19
+ - "Card payment not recognised"
20
+ - "Card payment wrong exchange rate"
21
+ - "Card swallowed"
22
+ - "Cash withdrawal charge"
23
+ - "Cash withdrawal not recognised"
24
+ - "Change pin"
25
+ - "Compromised card"
26
+ - "Contactless not working"
27
+ - "Country support"
28
+ - "Declined card payment"
29
+ - "Declined cash withdrawal"
30
+ - "Declined transfer"
31
+ - "Direct debit payment not recognised"
32
+ - "Disposable card limits"
33
+ - "Edit personal details"
34
+ - "Exchange charge"
35
+ - "Exchange rate"
36
+ - "Exchange via app"
37
+ - "Extra charge on statement"
38
+ - "Failed transfer"
39
+ - "Fiat currency support"
40
+ - "Get disposable virtual card"
41
+ - "Get physical card"
42
+ - "Getting spare card"
43
+ - "Getting virtual card"
44
+ - "Lost or stolen card"
45
+ - "Lost or stolen phone"
46
+ - "Order physical card"
47
+ - "Passcode forgotten"
48
+ - "Pending card payment"
49
+ - "Pending cash withdrawal"
50
+ - "Pending top up"
51
+ - "Pending transfer"
52
+ - "Pin blocked"
53
+ - "Receiving money"
54
+ - "Refund not showing up"
55
+ - "Request refund"
56
+ - "Reverted card payment?"
57
+ - "Supported cards and currencies"
58
+ - "Terminate account"
59
+ - "Top up by bank transfer charge"
60
+ - "Top up by card charge"
61
+ - "Top up by cash or cheque"
62
+ - "Top up failed"
63
+ - "Top up limits"
64
+ - "Top up reverted"
65
+ - "Topping up by card"
66
+ - "Transaction charged twice"
67
+ - "Transfer fee charged"
68
+ - "Transfer into account"
69
+ - "Transfer not received by recipient"
70
+ - "Transfer timing"
71
+ - "Unable to verify identity"
72
+ - "Verify my identity"
73
+ - "Verify source of funds"
74
+ - "Verify top up"
75
+ - "Virtual card not working"
76
+ - "Visa or mastercard"
77
+ - "Why verify identity"
78
+ - "Wrong amount of cash received"
79
+ - "Wrong exchange rate for cash withdrawal"
prompts/basic_dbpedia.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ prompt_template: "Determine the category of the article given by the user."
2
+ answers_templates:
3
+ - "Company"
4
+ - "Educational Institution"
5
+ - "Artist"
6
+ - "Athlete"
7
+ - "Office Holder"
8
+ - "Mean Of Transportation"
9
+ - "Building"
10
+ - "Natural Place"
11
+ - "Village"
12
+ - "Animal"
13
+ - "Plant"
14
+ - "Album"
15
+ - "Film"
16
+ - "Written Work"
prompts/basic_sst2.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ prompt_template: "Determine if the following review is positive or negative, based on the input given by the user."
2
+ answers_templates:
3
+ - "Negative"
4
+ - "Positive"
src/llmcal/__init__.py ADDED
File without changes
src/llmcal/scripts/__init__.py ADDED
File without changes
src/llmcal/scripts/affine_calibration.old.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import warnings
4
+ from pathlib import Path
5
+ import pandas as pd
6
+ import torch
7
+ from torch.utils.data import DataLoader, TensorDataset
8
+ from torch.optim.lbfgs import LBFGS
9
+ import torch.nn.functional as F
10
+ from typing import Literal
11
+
12
+ from ..src.loggers import TBLogger, CSVLogger
13
+ from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold, GroupKFold, KFold
14
+
15
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*Experiment logs directory outputs*")
16
+
17
+
18
+
19
+ class AffineCalibrator(torch.nn.Module):
20
+
21
+ def __init__(self, method: str, num_classes: int):
22
+ super().__init__()
23
+ self.method = method
24
+ self.num_classes = num_classes
25
+ self._init_params(method)
26
+
27
+ def _init_params(self, method):
28
+ if method == "dp_calibration":
29
+ self.alpha = torch.nn.Parameter(torch.ones(1), requires_grad=True)
30
+ self.beta = torch.nn.Parameter(torch.zeros(self.num_classes), requires_grad=True)
31
+ elif method == "vector_scaling":
32
+ self.alpha = torch.nn.Parameter(torch.ones(self.num_classes), requires_grad=True)
33
+ self.beta = torch.nn.Parameter(torch.zeros(self.num_classes), requires_grad=True)
34
+ elif method == "temp_scaling":
35
+ self.alpha = torch.nn.Parameter(torch.ones(1), requires_grad=True)
36
+ self.beta = torch.nn.Parameter(torch.zeros(self.num_classes), requires_grad=False)
37
+ elif method == "bias_shift":
38
+ self.alpha = torch.nn.Parameter(torch.ones(1), requires_grad=False)
39
+ self.beta = torch.nn.Parameter(torch.zeros(self.num_classes), requires_grad=True)
40
+ else:
41
+ raise ValueError(f"Invalid method: {method}")
42
+
43
+ def forward(self, logits):
44
+ return logits * self.alpha + self.beta
45
+
46
+
47
+
48
+ def main(
49
+ output_dir: str = 'output',
50
+ log_dir: str = 'output/logs',
51
+ checkpoint_dir: str = 'output/checkpoints',
52
+ train_logits: str = 'logits.csv',
53
+ train_labels: str = 'labels.csv',
54
+ predict_logits: str = 'logits.csv',
55
+ predict_labels: str = 'labels.csv',
56
+ method: Literal["dp_calibration", "temp_scaling", "bias_only"] = "dp_calibration",
57
+ learning_rate: float = 1e-3,
58
+ tolerance: float = 1e-4,
59
+ max_ls: int = 100,
60
+ seed: int = 0,
61
+ ):
62
+ torch.set_float32_matmul_precision("high")
63
+ output_dir = Path(output_dir)
64
+ checkpoint_dir = Path(checkpoint_dir)
65
+
66
+ # Load train data
67
+ train_logits = torch.log_softmax(torch.from_numpy(pd.read_csv(train_logits, index_col=0, header=None).values).float(), dim=1)
68
+ train_labels = torch.from_numpy(pd.read_csv(train_labels, index_col=0, header=None).values.flatten()).long()
69
+
70
+ # Load predict data
71
+ df_predict_logits = pd.read_csv(predict_logits, index_col=0, header=None)
72
+ predict_logits = torch.log_softmax(torch.from_numpy(df_predict_logits.values).float(), dim=1)
73
+ df_predict_labels = pd.read_csv(predict_labels, index_col=0, header=None)
74
+ predict_labels = torch.from_numpy(df_predict_labels.values.flatten()).long()
75
+
76
+ state = fit(method, train_logits, train_labels, log_dir, tolerance, train_logits.shape[1], learning_rate, max_ls, seed)
77
+
78
+ # Predict
79
+ model = AffineCalibrator(method=method, num_classes=train_logits.shape[1])
80
+ model.load_state_dict(state['model'])
81
+ cal_logits = predict(model, predict_logits)
82
+
83
+ # Save results
84
+ pd.DataFrame(cal_logits, index=df_predict_logits.index).to_csv(output_dir / 'logits.csv', index=True, header=False)
85
+ df_predict_labels.to_csv(output_dir / 'labels.csv', index=True, header=False)
86
+ torch.save(state, checkpoint_dir / 'last.ckpt')
87
+
88
+
89
+ def fit(method, logits, labels, log_dir, tolerance, num_classes, learning_rate, max_ls, seed):
90
+
91
+ # Create folds
92
+ steps = []
93
+ rs = torch.Generator().manual_seed(seed)
94
+ for i in range(5):
95
+ ids = torch.randperm(logits.shape[0], generator=rs)
96
+ trni = ids[:int(0.7*len(ids))]
97
+ tsti = ids[int(0.7*len(ids)):]
98
+
99
+ # Train model
100
+ model = AffineCalibrator(method=method, num_classes=num_classes)
101
+ optimizer = LBFGS(
102
+ params=(param for param in model.parameters() if param.requires_grad),
103
+ lr=learning_rate,
104
+ max_iter=max_ls,
105
+ tolerance_change=tolerance,
106
+ )
107
+ train_dataset = TensorDataset(logits[trni], labels[trni])
108
+ train_loader = DataLoader(
109
+ train_dataset,
110
+ batch_size=len(train_dataset),
111
+ shuffle=False,
112
+ )
113
+ val_dataset = TensorDataset(logits[tsti], labels[tsti])
114
+ val_loader = DataLoader(
115
+ val_dataset,
116
+ batch_size=len(val_dataset),
117
+ shuffle=False,
118
+ )
119
+ state = _fit_to_fold(model, optimizer, train_loader, val_loader, os.path.join(log_dir,f"fold_{i}"), float('inf'), tolerance, patience=10)
120
+ steps.append(state['step_count'])
121
+
122
+ print(f"Fitting final model with {max(steps)} steps. All steps: {steps}")
123
+ model = AffineCalibrator(method=method, num_classes=num_classes)
124
+ optimizer = LBFGS(
125
+ params=(param for param in model.parameters() if param.requires_grad),
126
+ lr=learning_rate,
127
+ max_iter=max_ls,
128
+ tolerance_change=tolerance,
129
+ )
130
+ train_dataset = TensorDataset(logits[trni], labels[trni])
131
+ train_loader = DataLoader(
132
+ train_dataset,
133
+ batch_size=len(train_dataset),
134
+ shuffle=False,
135
+ )
136
+ state = _fit_to_fold(model, optimizer, train_loader, None, os.path.join(log_dir,'final'), max(steps), tolerance, patience=None)
137
+ return state
138
+
139
+ @torch.no_grad()
140
+ def validate(model, val_loader):
141
+ logits, labels = next(iter(val_loader))
142
+ cal_logits = model(logits)
143
+ loss = F.cross_entropy(cal_logits, labels)
144
+ er = (cal_logits.argmax(dim=1) != labels).float().mean().item()
145
+ return loss.item(), er
146
+
147
+ def _fit_to_fold(model, optimizer, train_loader, val_loader, log_dir, max_step_count, tolerance=1e-4, patience=10):
148
+ if val_loader is None:
149
+ val_loader = train_loader
150
+
151
+ model.train()
152
+ loggers = [
153
+ TBLogger(log_dir),
154
+ CSVLogger(log_dir),
155
+ ]
156
+ logits, labels = next(iter(train_loader))
157
+ priors = torch.bincount(labels, minlength=logits.shape[1]).float() / len(labels)
158
+ priors_ce = -torch.log(priors[labels]).mean().item()
159
+ if priors_ce == 0:
160
+ priors_ce = 1.
161
+ priors_er = (priors.argmax() != labels).float().mean().item()
162
+ if priors_er == 0:
163
+ priors_er = 1.
164
+
165
+ state = {
166
+ 'model': model.state_dict(),
167
+ 'best_val_loss': float('inf'),
168
+ 'step_count': 0,
169
+ 'best_step_count': 0,
170
+ 'patience': 0,
171
+ }
172
+ while state['step_count'] < max_step_count:
173
+
174
+ logits, labels = next(iter(train_loader))
175
+ def closure():
176
+ optimizer.zero_grad()
177
+ cal_logits = model(logits)
178
+ loss = F.cross_entropy(cal_logits, labels)
179
+ er = (cal_logits.argmax(dim=1) != labels).float().mean().item()
180
+ for logger in loggers:
181
+ logger.log_metrics({
182
+ "train/NCE": loss.item() / priors_ce,
183
+ "train/NER": er / priors_er,
184
+ }, step=state['step_count'])
185
+ loss.backward()
186
+ state['step_count'] += 1
187
+ return loss
188
+
189
+ optimizer.step(closure)
190
+
191
+ val_loss, val_er = validate(model, val_loader)
192
+ norm_val_loss = val_loss / priors_ce
193
+ for logger in loggers:
194
+ logger.log_metrics({
195
+ "val/NCE": norm_val_loss,
196
+ "val/NER": val_er / priors_er,
197
+ }, step=state['step_count'])
198
+
199
+ if abs(state['best_val_loss'] - norm_val_loss) <= tolerance and patience is not None:
200
+ if state['patience'] >= patience:
201
+ break
202
+ state['patience'] += 1
203
+ else:
204
+ state['model'] = model.state_dict()
205
+ state['best_val_loss'] = norm_val_loss
206
+ state['best_step_count'] = state['step_count']
207
+ return state
208
+
209
+ @torch.no_grad()
210
+ def predict(model, logits):
211
+ model.eval()
212
+ cal_logits = model(logits)
213
+ cal_logits = torch.log_softmax(cal_logits, dim=1).numpy()
214
+ return cal_logits
215
+
216
+
217
+ if __name__ == '__main__':
218
+ from fire import Fire
219
+ Fire(main)
src/llmcal/scripts/affine_calibration.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import warnings
4
+ from pathlib import Path
5
+ import pandas as pd
6
+ import torch
7
+ from torch.utils.data import DataLoader, TensorDataset
8
+ from torch.optim.lbfgs import LBFGS
9
+ import torch.nn.functional as F
10
+ from typing import Literal
11
+
12
+ from ..src.loggers import TBLogger, CSVLogger
13
+
14
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*Experiment logs directory outputs*")
15
+
16
+
17
+
18
+ class AffineCalibrator(torch.nn.Module):
19
+
20
+ def __init__(self, method: str, num_classes: int):
21
+ super().__init__()
22
+ self.method = method
23
+ self.num_classes = num_classes
24
+ self._init_params(method)
25
+
26
+ def _init_params(self, method):
27
+ if method == "dp_calibration":
28
+ self.alpha = torch.nn.Parameter(torch.ones(1), requires_grad=True)
29
+ self.beta = torch.nn.Parameter(torch.zeros(self.num_classes), requires_grad=True)
30
+ elif method == "vector_scaling":
31
+ self.alpha = torch.nn.Parameter(torch.ones(self.num_classes), requires_grad=True)
32
+ self.beta = torch.nn.Parameter(torch.zeros(self.num_classes), requires_grad=True)
33
+ elif method == "temp_scaling":
34
+ self.alpha = torch.nn.Parameter(torch.ones(1), requires_grad=True)
35
+ self.beta = torch.nn.Parameter(torch.zeros(self.num_classes), requires_grad=False)
36
+ elif method == "bias_shift":
37
+ self.alpha = torch.nn.Parameter(torch.ones(1), requires_grad=False)
38
+ self.beta = torch.nn.Parameter(torch.zeros(self.num_classes), requires_grad=True)
39
+ elif method == "matrix_scaling":
40
+ self.alpha = torch.nn.Parameter(torch.eye(self.num_classes), requires_grad=True)
41
+ self.beta = torch.nn.Parameter(torch.zeros(self.num_classes), requires_grad=True)
42
+ else:
43
+ raise ValueError(f"Invalid method: {method}")
44
+
45
+ def forward(self, logits):
46
+ if self.method != "matrix_scaling":
47
+ return logits * self.alpha + self.beta
48
+ return logits @ self.alpha.T + self.beta
49
+
50
+
51
+
52
+ def main(
53
+ output_dir: str = 'output',
54
+ log_dir: str = 'output/logs',
55
+ checkpoint_dir: str = 'output/checkpoints',
56
+ train_logits: str = 'logits.csv',
57
+ train_labels: str = 'labels.csv',
58
+ predict_logits: str = 'logits.csv',
59
+ predict_labels: str = 'labels.csv',
60
+ method: Literal["dp_calibration", "temp_scaling", "bias_only"] = "dp_calibration",
61
+ learning_rate: float = 1e-3,
62
+ tolerance: float = 1e-4,
63
+ max_ls: int = 100,
64
+ seed: int = 0,
65
+ ):
66
+ torch.set_float32_matmul_precision("high")
67
+ output_dir = Path(output_dir)
68
+ checkpoint_dir = Path(checkpoint_dir)
69
+
70
+ # Load train data
71
+ train_logits = torch.log_softmax(torch.from_numpy(pd.read_csv(train_logits, index_col=0, header=None).values).float(), dim=1)
72
+ train_labels = torch.from_numpy(pd.read_csv(train_labels, index_col=0, header=None).values.flatten()).long()
73
+
74
+ # Load predict data
75
+ df_predict_logits = pd.read_csv(predict_logits, index_col=0, header=None)
76
+ predict_logits = torch.log_softmax(torch.from_numpy(df_predict_logits.values).float(), dim=1)
77
+ df_predict_labels = pd.read_csv(predict_labels, index_col=0, header=None)
78
+ predict_labels = torch.from_numpy(df_predict_labels.values.flatten()).long()
79
+
80
+ num_classes = train_logits.shape[1]
81
+ model = AffineCalibrator(method=method, num_classes=num_classes)
82
+ state = fit(model, train_logits, train_labels, log_dir, tolerance, learning_rate, max_ls)
83
+ torch.save(state, checkpoint_dir / 'state.ckpt')
84
+ model.load_state_dict(state['best_model'])
85
+
86
+ # Predict
87
+ cal_logits = predict(model, predict_logits)
88
+
89
+ # Save results
90
+ pd.DataFrame(cal_logits, index=df_predict_logits.index).to_csv(output_dir / 'logits.csv', index=True, header=False)
91
+ df_predict_labels.to_csv(output_dir / 'labels.csv', index=True, header=False)
92
+
93
+
94
+ def fit(model, logits, labels, log_dir, tolerance, learning_rate, max_ls):
95
+
96
+ # Train model
97
+ optimizer = LBFGS(
98
+ params=(param for param in model.parameters() if param.requires_grad),
99
+ lr=learning_rate,
100
+ max_iter=max_ls,
101
+ tolerance_change=tolerance,
102
+ )
103
+ train_dataset = TensorDataset(logits, labels)
104
+ train_loader = DataLoader(
105
+ train_dataset,
106
+ batch_size=len(train_dataset),
107
+ shuffle=False,
108
+ )
109
+ val_dataset = TensorDataset(logits, labels)
110
+ val_loader = DataLoader(
111
+ val_dataset,
112
+ batch_size=len(val_dataset),
113
+ shuffle=False,
114
+ )
115
+ state = _fit(model, optimizer, train_loader, val_loader, log_dir, float('inf'), tolerance, 10)
116
+ return state
117
+
118
+ @torch.no_grad()
119
+ def validate(model, val_loader):
120
+ logits, labels = next(iter(val_loader))
121
+ cal_logits = model(logits)
122
+ loss = F.cross_entropy(cal_logits, labels)
123
+ er = (cal_logits.argmax(dim=1) != labels).float().mean().item()
124
+ return loss.item(), er
125
+
126
+ def _fit(model, optimizer, train_loader, val_loader, log_dir, max_step_count, tolerance=1e-4, patience=10):
127
+
128
+ model.train()
129
+ loggers = [
130
+ TBLogger(log_dir),
131
+ CSVLogger(log_dir),
132
+ ]
133
+ logits, labels = next(iter(train_loader))
134
+ priors = torch.bincount(labels, minlength=logits.shape[1]).float() / len(labels)
135
+ priors_ce = -torch.log(priors[labels]).mean().item()
136
+ if priors_ce == 0:
137
+ priors_ce = 1.
138
+ priors_er = (priors.argmax() != labels).float().mean().item()
139
+ if priors_er == 0:
140
+ priors_er = 1.
141
+
142
+ state = {
143
+ 'last_model': model.state_dict(),
144
+ 'best_model': model.state_dict(),
145
+ 'best_val_loss': float('inf'),
146
+ 'step_count': 0,
147
+ 'best_step_count': 0,
148
+ 'patience': 0,
149
+ }
150
+ should_stop = False
151
+ while not should_stop:
152
+
153
+ logits, labels = next(iter(train_loader))
154
+ def closure():
155
+ optimizer.zero_grad()
156
+ cal_logits = model(logits)
157
+ loss = F.cross_entropy(cal_logits, labels)
158
+ er = (cal_logits.argmax(dim=1) != labels).float().mean().item()
159
+ for logger in loggers:
160
+ logger.log_metrics({
161
+ "train/NCE": loss.item() / priors_ce,
162
+ "train/NER": er / priors_er,
163
+ }, step=state['step_count'])
164
+ loss.backward()
165
+ state['step_count'] += 1
166
+ return loss
167
+
168
+ optimizer.step(closure)
169
+
170
+ val_loss, val_er = validate(model, val_loader)
171
+ norm_val_loss = val_loss / priors_ce
172
+ for logger in loggers:
173
+ logger.log_metrics({
174
+ "val/NCE": norm_val_loss,
175
+ "val/NER": val_er / priors_er,
176
+ }, step=state['step_count'])
177
+
178
+ if (state['best_val_loss'] - norm_val_loss) / norm_val_loss <= tolerance:
179
+ if patience is not None:
180
+ if state['patience'] >= patience:
181
+ should_stop = True
182
+ state['patience'] += 1
183
+ else:
184
+ state['best_model'] = model.state_dict()
185
+ state['best_val_loss'] = norm_val_loss
186
+ state['best_step_count'] = state['step_count']
187
+
188
+ state['last_model'] = model.state_dict()
189
+ should_stop = should_stop or state['step_count'] >= max_step_count
190
+
191
+ return state
192
+
193
+ @torch.no_grad()
194
+ def predict(model, logits):
195
+ model.eval()
196
+ cal_logits = model(logits)
197
+ cal_logits = torch.log_softmax(cal_logits, dim=1).numpy()
198
+ return cal_logits
199
+
200
+
201
+ if __name__ == '__main__':
202
+ from fire import Fire
203
+ Fire(main)
src/llmcal/scripts/affine_prediction.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+ from .affine_calibration import AffineCalibrator, predict
4
+
5
+ import pandas as pd
6
+ import torch
7
+
8
+ def main(
9
+ checkpoint_path: str,
10
+ method: str,
11
+ predict_logits: str,
12
+ predict_labels: str,
13
+ output_dir: str = 'output',
14
+ ):
15
+ # Load logits
16
+ predict_logits = torch.log_softmax(torch.from_numpy(pd.read_csv(predict_logits, index_col=0, header=None).values).float(), dim=1)
17
+ df_predict_labels = pd.read_csv(predict_labels, index_col=0, header=None)
18
+
19
+ # Load model
20
+ model = AffineCalibrator(method=method, num_classes=predict_logits.shape[1])
21
+ state = torch.load(checkpoint_path, weights_only=False)
22
+ model.load_state_dict(state['best_model'])
23
+
24
+ # Predict
25
+ cal_logits = predict(model, predict_logits)
26
+
27
+ # Save
28
+ output_dir = Path(output_dir)
29
+ pd.DataFrame(cal_logits, index=df_predict_labels.index).to_csv(output_dir / 'logits.csv', index=True, header=False)
30
+ df_predict_labels.to_csv(output_dir / 'labels.csv', index=True, header=False)
31
+
32
+
33
+
34
+ if __name__ == "__main__":
35
+ from fire import Fire
36
+ Fire(main)
src/llmcal/scripts/compare_models.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+
4
+ from matplotlib import pyplot as plt
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from ..src.utils import load_yaml
9
+ from .results_vs_samples import DATASETS, metric2name, process_data, plot_metric_vs_samples
10
+
11
+ method2style = {
12
+ "no_adaptation": "-",
13
+ "lora_1.0_no_es": "-",
14
+ "lora_1.0_no_es_plus_tempscaling": "--",
15
+ "dp_calibration": "-.",
16
+ "temp_scaling": "-.",
17
+ "vector_scaling": "-.",
18
+ }
19
+
20
+ model2style = {
21
+ "llama3.2-1b-instruct": "-",
22
+ "qwen2.5-7b-instruct": "--",
23
+ }
24
+
25
+ model2name = {
26
+ "llama3.2-1b-instruct": "LLama3.2-1B",
27
+ "qwen2.5-7b-instruct": "Qwen2.5-7B",
28
+ }
29
+
30
+ def main(
31
+ datasets,
32
+ metrics,
33
+ sizes,
34
+ methods_config,
35
+ output_path,
36
+ models,
37
+ results_dirs,
38
+ intervals,
39
+ methods,
40
+ ):
41
+ datasets = list(map(str, datasets.split()))
42
+ sizes = list(map(int, sizes.split()))
43
+ models = list(map(str, models.split()))
44
+ methods = list(map(str, methods.split()))
45
+ methods_config = load_yaml(methods_config)
46
+ metrics = list(map(str, metrics.split()))
47
+ output_path = Path(output_path)
48
+ output_dir = output_path.parent
49
+ results_dirs = list(map(Path, results_dirs.split()))
50
+
51
+ fig, axs = plt.subplots(len(metrics), len(datasets), figsize=(6 * len(datasets), 12))
52
+ processed_data = {}
53
+ custom_handles = []
54
+ all_data = []
55
+ for i, (model, results_dir) in enumerate(zip(models, results_dirs)):
56
+ for method in methods:
57
+ # methods_config[method]["color"] = f"C{i}"
58
+ methods_config[method]["linestyle"] = model2style[model]
59
+ for ax, metric in zip(axs,metrics):
60
+ data = pd.read_json(results_dir / f"{metric}.jsonl", orient='records', lines=True)
61
+ processed_data[metric] = data
62
+ data = process_data(data, datasets, sizes, methods)
63
+ plot_metric_vs_samples(ax, data, methods, methods_config, datasets, sizes, intervals=intervals, pos=i/10, no_adaptation="text", modelname_noa=model2name[model], fontsize_noa=16)
64
+ data["model"] = model
65
+ data["metric"] = metric
66
+ all_data.append(data)
67
+ data.to_csv(output_dir / f"{metric}.csv", index=False)
68
+ ax[0].set_ylabel(f"{metric2name[metric]}", fontsize=22)
69
+
70
+ all_data = pd.concat(all_data)
71
+ for j, dataset in enumerate(datasets):
72
+ axs[0,j].set_title(DATASETS[dataset]["name"], fontsize=22)
73
+ for ax, metric in zip(axs,metrics):
74
+ min_y = all_data.loc[
75
+ (all_data["dataset"] == dataset) & \
76
+ (all_data["metric"] == metric) & \
77
+ (all_data["method"].isin(set(methods) - {"no_adaptation"})),"median"].min()
78
+ max_y = all_data.loc[
79
+ (all_data["dataset"] == dataset) & \
80
+ (all_data["metric"] == metric) & \
81
+ (all_data["method"].isin(set(methods) - {"no_adaptation"})),"median"].max()
82
+ ax[j].set_ylim(min_y*0.99, max_y*1.2)
83
+ ax[j].set_yticks(np.round(ax[j].get_yticks(),3))
84
+ ax[j].set_yticklabels(ax[j].get_yticks(), fontsize=16)
85
+ ax[j].grid(axis="y")
86
+
87
+
88
+ fig.text(0.5, 0.04, 'Number of train samples', ha='center', fontsize=22)
89
+
90
+ # Gather handles and labels from all axes
91
+ custom_handles = []
92
+ for i, model in enumerate(models):
93
+ custom_handles.append(
94
+ plt.Line2D([0], [0], color="black", linestyle=model2style[model], label=model2name[model])
95
+ )
96
+ for method in methods:
97
+ if method == "no_adaptation":
98
+ continue
99
+ custom_handles.append(
100
+ plt.Line2D([0], [0], color=methods_config[method]["color"], linestyle="none", marker="o", markersize=10, label=methods_config[method]["label"])
101
+ )
102
+ # fig.legend(handles=custom_handles, loc='upper right', bbox_to_anchor=(1.08, .95), title_fontsize=24, fontsize=22)
103
+ fig.legend(handles=custom_handles, loc='lower center', bbox_to_anchor=(0.5, -0.1), fontsize=24, ncol=4)
104
+
105
+ plt.savefig(output_path, bbox_inches="tight", dpi=300)
106
+ plt.close(fig)
107
+
108
+ if __name__ == "__main__":
109
+ from fire import Fire
110
+ Fire(main)
src/llmcal/scripts/compute_matched_results.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+
7
+ from ..src.evaluation.metrics import compute_psr_with_mincal
8
+
9
+
10
+ METHODS = [
11
+ "no_adaptation",
12
+ "dpcal",
13
+ "tempscaling",
14
+ "vectorscaling",
15
+ "biasshift",
16
+ "finetunne_lora",
17
+ "lora_plus_dpcal",
18
+ "lora_plus_tempscaling",
19
+ "lora_plus_biasshfit",
20
+ "lora_plus_vectorscaling",
21
+ ]
22
+
23
+
24
+ def read_finetuning_results(root_results_dir: Path):
25
+ data = []
26
+ for dataset_dir in root_results_dir.iterdir():
27
+ train_dataset = dataset_dir.name
28
+ for size_dir in dataset_dir.iterdir():
29
+ size = int(size_dir.name.split("=")[1])
30
+ for seed_dir in size_dir.iterdir():
31
+ seed = int(seed_dir.name.split("=")[1])
32
+ for method in seed_dir.iterdir():
33
+ method_name = method.name
34
+ for train_lst in method.iterdir():
35
+ train_lst_name = train_lst.name
36
+ for val_lst in train_lst.iterdir():
37
+ val_lst_name = val_lst.name
38
+ for test_dataset_dir in val_lst.iterdir():
39
+ if not test_dataset_dir.name.startswith("test="):
40
+ continue
41
+ test_dataset = test_dataset_dir.name.split("=")[1]
42
+ for test_lst in test_dataset_dir.iterdir():
43
+ if not test_lst.name.startswith("list=test"):
44
+ continue
45
+ test_lst_name = test_lst.name.split("=")[1]
46
+ if not (logits_path := test_lst / "logits.csv").exists():
47
+ continue
48
+ if not (labels_path := test_lst / "labels.csv").exists():
49
+ continue
50
+ data.append({
51
+ "train_dataset": train_dataset,
52
+ "size": size,
53
+ "seed": seed,
54
+ "method": method_name,
55
+ "train_lst": train_lst_name,
56
+ "val_lst": val_lst_name,
57
+ "cal_lst": None,
58
+ "test_dataset": test_dataset,
59
+ "test_lst": test_lst_name,
60
+ "logits": logits_path,
61
+ "labels": labels_path,
62
+ })
63
+
64
+ return pd.DataFrame(data)
65
+
66
+ def read_lora_plus_calibration_results(root_results_dir: Path):
67
+ data = []
68
+ for dataset_dir in root_results_dir.iterdir():
69
+ train_dataset = dataset_dir.name
70
+ for size_dir in dataset_dir.iterdir():
71
+ size = int(size_dir.name.split("=")[1])
72
+ for seed_dir in size_dir.iterdir():
73
+ seed = int(seed_dir.name.split("=")[1])
74
+ for method in seed_dir.iterdir():
75
+ method_name = method.name
76
+ for train_lst in method.iterdir():
77
+ train_lst_name = train_lst.name
78
+ for val_lst in train_lst.iterdir():
79
+ val_lst_name = val_lst.name
80
+ for cal_lst in val_lst.iterdir():
81
+ cal_lst_name = cal_lst.name
82
+ for test_dataset_dir in cal_lst.iterdir():
83
+ if not test_dataset_dir.name.startswith("test="):
84
+ continue
85
+ test_dataset = test_dataset_dir.name.split("=")[1]
86
+ for test_lst in test_dataset_dir.iterdir():
87
+ if not test_lst.name.startswith("list=test"):
88
+ continue
89
+ test_lst_name = test_lst.name.split("=")[1]
90
+ if not (logits_path := test_lst / "logits.csv").exists():
91
+ continue
92
+ if not (labels_path := test_lst / "labels.csv").exists():
93
+ continue
94
+ data.append({
95
+ "train_dataset": train_dataset,
96
+ "size": size,
97
+ "seed": seed,
98
+ "method": method_name,
99
+ "train_lst": train_lst_name,
100
+ "val_lst": val_lst_name,
101
+ "cal_lst": cal_lst_name,
102
+ "test_dataset": test_dataset,
103
+ "test_lst": test_lst_name,
104
+ "logits": logits_path,
105
+ "labels": labels_path,
106
+ })
107
+
108
+ return pd.DataFrame(data)
109
+
110
+
111
+
112
+ def read_calibration_results(root_results_dir: Path):
113
+ data = []
114
+ for dataset_dir in root_results_dir.iterdir():
115
+ train_dataset = dataset_dir.name
116
+ for size_dir in dataset_dir.iterdir():
117
+ size = int(size_dir.name.split("=")[1])
118
+ for seed_dir in size_dir.iterdir():
119
+ seed = int(seed_dir.name.split("=")[1])
120
+ for method in seed_dir.iterdir():
121
+ method_name = method.name
122
+ for train_lst in method.iterdir():
123
+ train_lst_name = train_lst.name
124
+ for val_lst in train_lst.iterdir():
125
+ val_lst_name = val_lst.name
126
+ for test_dataset_dir in val_lst.iterdir():
127
+ if not test_dataset_dir.name.startswith("test="):
128
+ continue
129
+ test_dataset = test_dataset_dir.name.split("=")[1]
130
+ for test_lst in test_dataset_dir.iterdir():
131
+ if not test_lst.name.startswith("list=test"):
132
+ continue
133
+ test_lst_name = test_lst.name.split("=")[1]
134
+ if not (logits_path := test_lst / "logits.csv").exists():
135
+ continue
136
+ if not (labels_path := test_lst / "labels.csv").exists():
137
+ continue
138
+ data.append({
139
+ "train_dataset": train_dataset,
140
+ "size": size,
141
+ "seed": seed,
142
+ "method": method_name,
143
+ "train_lst": train_lst_name,
144
+ "val_lst": val_lst_name,
145
+ "cal_lst": None,
146
+ "test_dataset": test_dataset,
147
+ "test_lst": test_lst_name,
148
+ "logits": logits_path,
149
+ "labels": labels_path,
150
+ })
151
+
152
+ return pd.DataFrame(data)
153
+
154
+
155
+ def read_no_adaptation_results(root_results_dir: Path):
156
+ data = []
157
+ for dataset_dir in root_results_dir.iterdir():
158
+ train_dataset = dataset_dir.name
159
+ for test_dataset_dir in (dataset_dir / "size=all/seed=all").iterdir():
160
+ if not test_dataset_dir.name.startswith("test="):
161
+ continue
162
+ test_dataset = test_dataset_dir.name.split("=")[1]
163
+ for test_lst in test_dataset_dir.iterdir():
164
+ if not test_lst.name.startswith("list=test"):
165
+ continue
166
+ test_lst_name = test_lst.name.split("=")[1]
167
+ if not (logits_path := test_lst / "logits.csv").exists():
168
+ continue
169
+ if not (labels_path := test_lst / "labels.csv").exists():
170
+ continue
171
+ data.append({
172
+ "train_dataset": train_dataset,
173
+ "size": "all",
174
+ "seed": "all",
175
+ "method": "no_adaptation",
176
+ "train_lst": None,
177
+ "val_lst": None,
178
+ "cal_lst": None,
179
+ "test_dataset": test_dataset,
180
+ "test_lst": test_lst_name,
181
+ "logits": logits_path,
182
+ "labels": labels_path,
183
+ })
184
+
185
+ return pd.DataFrame(data)
186
+
187
+ def compute_metrics(data, metric):
188
+ data_with_metrics = data.copy()
189
+ for i, row in tqdm(data.iterrows(), total=len(data)):
190
+ logits = pd.read_csv(row["logits"], index_col=0, header=None).values.astype(float)
191
+ labels = pd.read_csv(row["labels"], index_col=0, header=None).values.flatten().astype(int)
192
+ value, min_value = compute_psr_with_mincal(logits, labels, metric, "none")
193
+ data_with_metrics.loc[i, "result"] = value
194
+ data_with_metrics.loc[i, "min_result"] = min_value
195
+ data_with_metrics = data_with_metrics.drop(columns=["logits", "labels"])
196
+ return data_with_metrics
197
+
198
+
199
+ def extract_method(row):
200
+
201
+ if row["method_type"] == "no_adaptation":
202
+ method = row["method_type"]
203
+ elif row["method_type"] == "calibration":
204
+ method = row["method"]
205
+ elif row["method_type"] == "finetune_lora":
206
+ s, e = map(float,row["train_lst"].split("-"))
207
+ p_train = e - s
208
+ if row["method"] != "lora_ans_no_es":
209
+ method = f"lora_{p_train:.1f}"
210
+ else:
211
+ method = f"lora_{p_train:.1f}_no_es"
212
+ elif row["method_type"] in ["lora_plus_dpcal", "lora_plus_tempscaling", "lora_plus_biasshift", "lora_plus_vectorscaling", "lora_plus_dpcal_trainontest", "lora_plus_tempscaling_trainontest", "lora_plus_dpcal_naive", "lora_plus_tempscaling_naive"]:
213
+ s, e = map(float,row["train_lst"].split("-"))
214
+ p_train = e - s
215
+ if row["method"] != "lora_ans_no_es":
216
+ method = f"lora_{p_train:.1f}" + "_plus_" + row["method_type"].split("_plus_")[1]
217
+ else:
218
+ method = f"lora_{p_train:.1f}_no_es" + "_plus_" + row["method_type"].split("_plus_")[1]
219
+ else:
220
+ raise ValueError(f"Unknown method: {row['method_type']}, {row['method']}")
221
+
222
+ return method
223
+
224
+
225
+
226
+ def process_data(data, reduced = False):
227
+
228
+ # Keep matched trainings
229
+ data = data[data["train_dataset"] == data["test_dataset"]]
230
+ data = data.drop(columns=["train_dataset"])
231
+ data = data.rename(columns={"test_dataset": "dataset"})
232
+
233
+ # Keep evaluation in test
234
+ if reduced:
235
+ data = data[data["test_lst"].str.startswith("test_")]
236
+ else:
237
+ data = data[data["test_lst"] == "test"]
238
+ data = data.drop(columns=["test_lst"])
239
+
240
+ # Replace method name for full description
241
+ data["method"] = data.apply(extract_method, axis=1)
242
+ data = data.drop(columns=["method_type", "train_lst", "val_lst", "cal_lst"])
243
+
244
+ # Reorder columns
245
+ data = data.loc[:, ["dataset", "method", "size", "seed", "result", "min_result"]]
246
+
247
+ return data
248
+
249
+
250
+
251
+ def main(
252
+ metric: str,
253
+ finetuning_root_results_dirs: str = None,
254
+ lora_plus_cal_root_results_dirs: str = None,
255
+ lora_plus_cal_naive_root_results_dirs: str = None,
256
+ cal_root_results_dirs: str = None,
257
+ trainontest_root_results_dirs: str = None,
258
+ no_adaptation_root_results_dirs: str = None,
259
+ output_path: str = "outputs",
260
+ reduced: bool = False,
261
+ ):
262
+ # Read results
263
+ finetuning_root_results_dirs = [Path(d) for d in finetuning_root_results_dirs.split(",")] if finetuning_root_results_dirs is not None else []
264
+ cal_root_results_dirs = [Path(d) for d in cal_root_results_dirs.split(",")] if cal_root_results_dirs is not None else []
265
+ lora_plus_cal_root_results_dirs = [Path(d) for d in lora_plus_cal_root_results_dirs.split(",")] if lora_plus_cal_root_results_dirs is not None else []
266
+ lora_plus_cal_naive_root_results_dirs = [Path(d) for d in lora_plus_cal_naive_root_results_dirs.split(",")] if lora_plus_cal_naive_root_results_dirs is not None else []
267
+ trainontest_root_results_dirs = [Path(d) for d in trainontest_root_results_dirs.split(",")] if trainontest_root_results_dirs is not None else []
268
+ no_adaptation_root_results_dirs = [Path(d) for d in no_adaptation_root_results_dirs.split(",")] if no_adaptation_root_results_dirs is not None else []
269
+ all_data = []
270
+ for root_results_dir in finetuning_root_results_dirs:
271
+ finetuning_data = read_finetuning_results(root_results_dir)
272
+ finetuning_data["method_type"] = str(root_results_dir).split("/")[-2]
273
+ all_data.append(finetuning_data)
274
+ for root_results_dir in cal_root_results_dirs:
275
+ cal_data = read_calibration_results(root_results_dir)
276
+ cal_data["method_type"] = str(root_results_dir).split("/")[-2]
277
+ all_data.append(cal_data)
278
+ for root_results_dir in lora_plus_cal_root_results_dirs:
279
+ cal_data = read_lora_plus_calibration_results(root_results_dir)
280
+ cal_data["method_type"] = str(root_results_dir).split("/")[-2]
281
+ all_data.append(cal_data)
282
+ for root_results_dir in lora_plus_cal_naive_root_results_dirs:
283
+ cal_data = read_lora_plus_calibration_results(root_results_dir)
284
+ cal_data["method_type"] = str(root_results_dir).split("/")[-2]
285
+ all_data.append(cal_data)
286
+ for root_results_dir in trainontest_root_results_dirs:
287
+ trainontest_data = read_lora_plus_calibration_results(root_results_dir)
288
+ trainontest_data["method_type"] = str(root_results_dir).split("/")[-2]
289
+ all_data.append(trainontest_data)
290
+ for root_results_dir in no_adaptation_root_results_dirs:
291
+ no_adaptation_data = read_no_adaptation_results(root_results_dir)
292
+ no_adaptation_data["method_type"] = "no_adaptation"
293
+ all_data.append(no_adaptation_data)
294
+
295
+ data = pd.concat(all_data, ignore_index=True)
296
+
297
+ # Compute metrics
298
+ data_with_metrics = compute_metrics(data, metric)
299
+
300
+ # Process data
301
+ data_with_metrics = process_data(data_with_metrics, reduced)
302
+
303
+ # Save data
304
+ data_with_metrics.to_json(output_path, orient="records", lines=True)
305
+
306
+
307
+ if __name__ == "__main__":
308
+ from fire import Fire
309
+ Fire(main)
src/llmcal/scripts/create_lists_new.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from pathlib import Path
4
+ import numpy as np
5
+ import yaml
6
+ from tqdm import tqdm
7
+
8
+ DATASETS = {"sst2": 2, "agnews": 4, "dbpedia": 14, "20newsgroups": 20, "banking77": 77}
9
+ TEST_SAMPLES = {"sst2": 400, "agnews": 400, "dbpedia": 700, "20newsgroups": 800, "banking77": 1000}
10
+ N_SEEDS = 10
11
+ FACTORS = [8, 16, 32, 64, 128, 256, 512]
12
+
13
+ def main():
14
+ rs = np.random.RandomState(8364)
15
+ for dataset in tqdm(DATASETS):
16
+ num_classes = DATASETS[dataset]
17
+ for factor in FACTORS:
18
+
19
+ scale = factor / np.log2(num_classes)
20
+ nearest_power_of_2 = 2 ** np.round(np.log2(scale)) # round to nearest power of 2
21
+ num_samples = int(nearest_power_of_2 * num_classes)
22
+
23
+ for seed in range(N_SEEDS):
24
+
25
+ os.makedirs(f"lists/{dataset}/size={factor}/seed={seed}", exist_ok=True)
26
+
27
+ full_trainlist = np.loadtxt(f"../llmcal2/lists/{dataset}/train.txt", dtype=int)
28
+ if Path(f"../llmcal2/lists/{dataset}/train_{num_samples}_0.3_{seed}.txt").exists() and Path(f"../llmcal2/lists/{dataset}/val_{num_samples}_0.3_{seed}.txt").exists():
29
+ samples_list = np.hstack([
30
+ np.loadtxt(f"../llmcal2/lists/{dataset}/train_{num_samples}_0.3_{seed}.txt", dtype=int),
31
+ np.loadtxt(f"../llmcal2/lists/{dataset}/val_{num_samples}_0.3_{seed}.txt", dtype=int),
32
+ ])
33
+ else:
34
+ seedrs = np.random.RandomState(2834+seed)
35
+ idx = seedrs.permutation(full_trainlist)
36
+ samples_list = idx[:num_samples]
37
+
38
+ np.savetxt(f"lists/{dataset}/size={factor}/seed={seed}/0.0-0.7.txt", samples_list[:(num_samples-int(num_samples*0.3))], fmt="%d")
39
+ np.savetxt(f"lists/{dataset}/size={factor}/seed={seed}/0.7-1.0.txt", samples_list[(num_samples-int(num_samples*0.3)):], fmt="%d")
40
+ np.savetxt(f"lists/{dataset}/size={factor}/seed={seed}/0.0-0.3.txt", samples_list[:(num_samples-int(num_samples*0.7))], fmt="%d")
41
+ np.savetxt(f"lists/{dataset}/size={factor}/seed={seed}/0.0-1.0.txt", samples_list, fmt="%d")
42
+
43
+ for dataset in tqdm(DATASETS):
44
+ full_train_list = np.loadtxt(f"../llmcal2/lists/{dataset}/train.txt", dtype=int)
45
+ np.savetxt(f"lists/{dataset}/train.txt", full_train_list, fmt="%d")
46
+ full_test_list = np.loadtxt(f"../llmcal2/lists/{dataset}/test.txt", dtype=int)
47
+ np.savetxt(f"lists/{dataset}/test.txt", full_test_list, fmt="%d")
48
+ partial_test_list = np.loadtxt(f"../llmcal2/lists/{dataset}/test_{TEST_SAMPLES[dataset]}.txt", dtype=int)
49
+ np.savetxt(f"lists/{dataset}/test_{TEST_SAMPLES[dataset]}.txt", partial_test_list, fmt="%d")
50
+
51
+
52
+
53
+ if __name__ == "__main__":
54
+ main()
src/llmcal/scripts/evals.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from datasets import load_dataset
3
+ import numpy as np
4
+ from scipy.special import log_softmax
5
+
6
+ def compute_nce(scores, labels):
7
+ logprobs = log_softmax(scores, axis=1)
8
+ ce = -logprobs[np.arange(len(labels)), labels].mean()
9
+ priors = np.bincount(labels) / len(labels)
10
+ ce_priors = -np.mean(np.log(priors[labels]))
11
+ nce = ce / ce_priors
12
+ return ce, nce
13
+
14
+ def compute_ner(scores, labels):
15
+ preds = np.argmax(scores, axis=1)
16
+ er = np.mean(preds != labels)
17
+ max_label = np.bincount(labels).argmax()
18
+ ner = er / np.mean(preds != max_label)
19
+ return er, ner
20
+
21
+
22
+
23
+ def main():
24
+ data = load_dataset("meta-llama/Llama-3.2-1B-evals", "Llama-3.2-1B-evals__mmlu__details", split="latest")
25
+ # data = load_dataset("meta-llama/Llama-3.1-405B-evals", "Llama-3.1-405B-evals__mmlu__details", split="latest")
26
+ classes = ["A", "B", "C", "D"]
27
+ # keep columns "output_choice_negative_log_likelihood", "input_correct_responses"
28
+ data = data.select_columns(["output_choice_negative_log_likelihoods", "input_correct_responses"]).to_pandas()
29
+ data = data.rename(columns={"output_choice_negative_log_likelihoods": "score", "input_correct_responses": "label"})
30
+ scores = np.vstack(data["score"].apply(lambda x: np.array(x["raw"])))
31
+ labels = data["label"].apply(lambda x: classes.index(x[0].split(" ")[1])).astype(int).values.flatten()
32
+ ce, nce = compute_nce(scores, labels)
33
+ er, ner = compute_ner(scores, labels)
34
+ goodness = nce * ner
35
+ print(f"NCE: {nce}")
36
+ print(f"CE: {ce}")
37
+ print(f"NER: {ner}")
38
+ print(f"ER: {er}")
39
+ print(f"Goodness: {goodness}")
40
+
41
+ if __name__ == "__main__":
42
+ main()
src/llmcal/scripts/prepare_data.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+ import numpy as np
4
+ import pandas as pd
5
+ from ..src.utils import load_yaml
6
+ from ..src.prompts import *
7
+
8
+ def load_dataset(dataset_path):
9
+ return pd.read_csv(dataset_path, index_col=0, header=0)
10
+
11
+ def load_shots(dataset, shots_list, answers):
12
+ if shots_list is None:
13
+ return []
14
+ shots_list = np.loadtxt(shots_list, dtype=int)
15
+ shots = dataset.loc[shots_list]
16
+ shots["label"] = shots["label"].apply(lambda x: answers[x])
17
+ return shots.to_dict(orient="records")
18
+
19
+
20
+ def select_prompt(model):
21
+ if "llama3" in model:
22
+ return Llama3Prompt
23
+ elif "gemma" in model:
24
+ return GemmaPrompt
25
+ elif "qwen" in model:
26
+ return QwenPrompt
27
+ elif "tinyllama" in model:
28
+ return TinyLlamaPrompt
29
+ elif "phi3" in model:
30
+ return Phi3Prompt
31
+ elif "pythia" in model:
32
+ return PythiaPrompt
33
+ else:
34
+ raise ValueError(f"Unknown model: {model}")
35
+
36
+
37
+ def main(dataset_path, prompt_template, model, output_path, shots_list=None, max_characters=400):
38
+ dataset_path = Path(dataset_path)
39
+ prompt_template = Path(prompt_template)
40
+ output_path = Path(output_path)
41
+
42
+ # Load data
43
+ dataset = load_dataset(dataset_path)
44
+
45
+ # Create prompts
46
+ prompt = load_yaml(prompt_template)
47
+ prompt_template = prompt["prompt_template"]
48
+ answers = prompt["answers_templates"]
49
+ shots = load_shots(dataset, shots_list, answers)
50
+ prompt_cls = select_prompt(model)
51
+ prompt = prompt_cls(max_characters=max_characters)
52
+ prompt.fit(prompt_template, shots)
53
+ dataset["answer"] = [answers for _ in range(len(dataset))]
54
+ dataset["prompt"] = prompt.apply(dataset["text"])
55
+ dataset = dataset.loc[:,["prompt", "answer", "label"]].reset_index(drop=False).rename(columns={"index": "idx"})
56
+
57
+ # Save prompts
58
+ dataset.to_json(output_path, orient="records", lines=True)
59
+
60
+
61
+
62
+ if __name__ == "__main__":
63
+ from fire import Fire
64
+ Fire(main)
src/llmcal/scripts/results_bars.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+ from typing import List
4
+ import numpy as np
5
+ import pandas as pd
6
+ import matplotlib.pyplot as plt
7
+ from ..src.utils import load_yaml
8
+ from .results_vs_samples import compute_num_samples
9
+
10
+
11
+ DATASETS = {
12
+ "sst2": {"name": "SST-2", "num_classes": 2},
13
+ "agnews": {"name": "AGNews", "num_classes": 4},
14
+ "dbpedia": {"name": "DBPedia", "num_classes": 14},
15
+ "20newsgroups": {"name": "20 Newsgroups", "num_classes": 20},
16
+ "banking77": {"name": "Banking77", "num_classes": 77},
17
+ }
18
+
19
+ metric2name = {
20
+ "nce": "NCE",
21
+ "ner": "NER",
22
+ }
23
+
24
+ def read_data(results_dir: Path, metrics: List[str]):
25
+ dfs = []
26
+ for metric in metrics:
27
+ df = pd.read_json(results_dir / f"{metric}.jsonl", orient='records', lines=True)
28
+ df = df.rename(columns={"result": f"{metric}"})
29
+ df = df.drop(columns=["min_result"])
30
+ dfs.append(df)
31
+
32
+ # merge dataframes with columns [dataset, method, size, seed, nce] and [dataset, method, size, seed, ner] on all columns but last
33
+ data = dfs[0]
34
+ for df in dfs[1:]:
35
+ data = data.merge(df, on=["dataset", "method", "size", "seed"], how="outer")
36
+ return data
37
+
38
+ def plot_bars(data, methods_config, output_path, datasets, metrics, methods, sizes, no_adaptation="auto", fontsize_noa=22, pos=0):
39
+ fig, ax = plt.subplots(len(metrics), len(datasets), figsize=(6 * len(datasets), 12), sharex=False)
40
+
41
+ data_adapted = data.loc[data["method"] != "no_adaptation"]
42
+ adapted_methods = [m for m in methods if m != "no_adaptation"]
43
+ data_no_adapt = data.loc[data["method"] == "no_adaptation"]
44
+
45
+ i_ax, j_ax = 0, 0
46
+ n_methods = len(methods) + 1
47
+ medians = {}
48
+ for i, metric in enumerate(metrics):
49
+ medians[metric] = []
50
+ for j, dataset in enumerate(datasets):
51
+ # plot bar groups. One group per size. Each bar is a method
52
+ for k, method in enumerate(adapted_methods):
53
+ for s, size in enumerate(sizes):
54
+ method_data = data_adapted.loc[data_adapted["method"] == method]
55
+ method_data = method_data.loc[method_data["dataset"] == dataset]
56
+ method_data = method_data.loc[method_data["size"] == size]
57
+ median = method_data.groupby("size")[metric].median().values[0]
58
+ medians[metric].append(median.max())
59
+ q1 = method_data.groupby("size")[metric].quantile(0.25).values[0]
60
+ q3 = method_data.groupby("size")[metric].quantile(0.75).values[0]
61
+ # alpha = 0.5 if "SFT+PHC" in methods_config[method]["label"] else 1.0
62
+ # alpha = 0.7 if "SFT " in methods_config[method]["label"] else 1.0
63
+ # if "SFT " in methods_config[method]["label"]:
64
+ # hatch = "x"
65
+ # elif "SFT+PHC" in methods_config[method]["label"]:
66
+ # hatch = "/"
67
+ # else:
68
+ # hatch = None
69
+ alpha = None
70
+ hatch = None
71
+ ax[i, j].bar(s + k / n_methods, median, yerr=[[median - q1], [q3 - median]], label=methods_config[method]["label"], width=0.8 / n_methods, color=methods_config[method]["color"], alpha=alpha, hatch=hatch)
72
+ # ax[i,j].set_xticks(range(len(sizes)))
73
+
74
+ # plot no adaptation
75
+ for i, metric in enumerate(metrics):
76
+ y_max = np.round(min(1.4, 1.05 * max(medians[metric])),1)
77
+ for j, dataset in enumerate(datasets):
78
+ dataset_data = data_no_adapt.loc[data_no_adapt["dataset"] == dataset]
79
+ method_data = dataset_data.loc[dataset_data["method"] == "no_adaptation"]
80
+ # min_q1 = data_adapted[data_adapted["dataset"] == dataset].groupby("size")[metric].quantile(0.25).min()
81
+ # max_median = data_adapted[data_adapted["dataset"] == dataset].groupby("size")[metric].median().max()
82
+ if method_data.loc[:, metric].item() < y_max and no_adaptation in ["plot", "auto"]:
83
+ num_samples = [-1/n_methods, len(sizes) - 1 + len(adapted_methods)/n_methods]
84
+ noa_medians = [method_data.loc[:, metric].item()] * len(num_samples)
85
+ ax[i,j].plot(num_samples, noa_medians, label=methods_config["no_adaptation"]["label"], color=methods_config["no_adaptation"]["color"], linestyle=methods_config["no_adaptation"]["linestyle"])
86
+ i_ax = i
87
+ j_ax = j
88
+ # print(metric, dataset)
89
+ elif no_adaptation in ["text", "auto"]:
90
+ text = f"{methods_config['no_adaptation']['label']}"
91
+ ax[i,j].text(.95, .95-pos,
92
+ f"{text} = {method_data.loc[:, metric].item():.2f}",
93
+ fontsize=fontsize_noa, ha="right", va="top", transform=ax[i,j].transAxes, color=methods_config["no_adaptation"]["color"]
94
+ )
95
+
96
+
97
+ # no_adapt_value = data_no_adapt.loc[data_no_adapt["dataset"] == dataset,metric].values[0]
98
+ xlims = [-1/n_methods, len(sizes) - 1 + len(adapted_methods)/n_methods]
99
+ # ax[i,j].plot(xlims, [no_adapt_value] * len(sizes), label=methods_config["no_adaptation"]["label"], color=methods_config["no_adaptation"]["color"], linestyle="--")
100
+ ax[i,j].set_xlim(xlims)
101
+ # ax[i,j].set_xticks(range(len(sizes)))
102
+ # ax[i,j].set_xticklabels(ax[i,j].get_xticklabels(), fontsize=26)
103
+
104
+ ax[i,j].set_ylim(0, y_max)
105
+ if j == 0:
106
+ ax[i,j].set_ylabel(f"{metric2name[metric]}", fontsize=30)
107
+ ax[i,j].set_yticks(np.arange(0,int(y_max*10+1),2)/10)
108
+ ax[i,j].set_yticklabels([f"{d:.1f}" for d in np.arange(0,int(y_max*10+1),2)/10], fontsize=24)
109
+ # ax[i,j].set_yticks(ax[i,j].get_yticks())
110
+ # ax[i,j].set_yticklabels([f"{d:.1f}" for d in ax[i,j].get_yticks()], fontsize=24)
111
+ else:
112
+ # ax[i,j].sharey(ax[i,0])
113
+ ax[i,j].set_yticks([])
114
+ ax[i,j].set_yticklabels([])
115
+
116
+
117
+ # ax[i,j].set_yticks(ax[i,j].get_yticks())
118
+ # ax[i,j].set_yticklabels(ax[i,j].get_yticklabels(), fontsize=24)
119
+ # ax[i,j].grid(axis="y")
120
+
121
+ #YES
122
+ for i, dataset in enumerate(datasets):
123
+ ax[0, i].set_title(DATASETS[dataset]["name"], fontsize=30)
124
+ ax[0, i].set_xticks([])
125
+ ax[-1, i].set_xticks(range(len(sizes)))
126
+ ax[-1, i].set_xticklabels([f"{' '*15}N = {size}" for size in compute_num_samples(sizes, dataset)], fontsize=26)
127
+
128
+
129
+
130
+ # for j, metric in enumerate(metrics):
131
+ # ax[j, 0].set_ylabel(f"{metric2name[metric]}", fontsize=30)
132
+ # ax[j, 0].set_yticks(ax[j, 0].get_yticks())
133
+ # ax[j, 0].set_yticklabels(ax[j, 0].get_yticklabels(), fontsize=24)
134
+
135
+ fig.text(0.5, -0.05, 'Adaptation sizes', ha='center', fontsize=26)
136
+
137
+ # Gather handles and labels from all axes
138
+ labels = [methods_config[method]["label"] for method in methods]
139
+ handles = []
140
+ hs, ls = ax[i_ax,j_ax].get_legend_handles_labels()
141
+ for l in labels:
142
+ i = 0
143
+ while ls[i] != l:
144
+ i += 1
145
+ handles.append(hs[i])
146
+ remaining = len(handles) % 4
147
+ for i in range(remaining):
148
+ handles.append(plt.Line2D([], [], color='none', label=''))
149
+ labels.append('')
150
+
151
+
152
+
153
+ fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.72, -0.3), title="Method", ncol=4, title_fontsize=30, fontsize=28)
154
+
155
+ fig.tight_layout(pad=1.0, h_pad=1.0, w_pad=-8.0)
156
+ plt.savefig(output_path, bbox_inches="tight", dpi=300)
157
+ plt.close(fig)
158
+
159
+
160
+
161
+
162
+ def main(
163
+ datasets,
164
+ metrics,
165
+ sizes,
166
+ methods_config,
167
+ results_dir,
168
+ output_path,
169
+ methods,
170
+ ):
171
+ metrics = list(map(str, metrics.split()))
172
+ datasets = list(map(str, datasets.split()))
173
+ methods = list(map(str, methods.split()))
174
+ sizes = list(map(int, sizes.split()))
175
+ sizes = [sizes[0], sizes[-1]]
176
+ results_dir = Path(results_dir)
177
+ methods_config = load_yaml(methods_config)
178
+
179
+ data = read_data(results_dir, metrics)
180
+ plot_bars(data, methods_config, output_path, datasets, metrics, methods, sizes)
181
+
182
+ if __name__ == "__main__":
183
+ from fire import Fire
184
+ Fire(main)
src/llmcal/scripts/results_table.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+
4
+ from matplotlib import pyplot as plt
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from ..src.utils import load_yaml
9
+ from .results_vs_samples import DATASETS, metric2name, process_data
10
+
11
+ def highlight_local(x):
12
+ if x['group'] == "noa":
13
+ return str(x['result'])
14
+ return f"$\\mathbf{{{x['result']}}}$"
15
+
16
+ def highlight_global(x):
17
+ return f"\\underline{{{x['result']}}}"
18
+
19
+ def get_all_mins(x):
20
+ import pdb; pdb.set_trace()
21
+ return x[x.round(2) == x.round(2).min()].index.to_numpy()
22
+
23
+ def method2group(method):
24
+ if method == "no_adaptation":
25
+ return "noa"
26
+ elif method in ["temp_scaling", "dp_calibration", "bias_shift", "vector_scaling"]:
27
+ return "phc"
28
+ elif method.startswith("lora") and "plus" not in method:
29
+ return "sft"
30
+ elif method.startswith("lora") and "plus" in method:
31
+ return "sft+phc"
32
+
33
+ def create_table(results_dir, methods, metrics, methods_config, datasets, sizes):
34
+
35
+ methods = [m for m in methods if m != "no_adaptation"]
36
+ # data_all = pd.DataFrame([
37
+ # {"dataset": dataset, "size": size, "method": method, "group": method2group(method), "result": ""}
38
+ # for size in [sizes[0], sizes[-1]]
39
+ # for method in methods
40
+ # for dataset in datasets
41
+ # ]).pivot(index=["size","group", "method"], columns="dataset", values="result")
42
+ # index = [(size, method2group(method), method) for size in [sizes[0],sizes[-1]] for method in methods]
43
+ # data_all = data_all.reindex(index)
44
+
45
+ data_all = []
46
+ for metric in metrics:
47
+ data = pd.read_json(results_dir / f"{metric}.jsonl", orient='records', lines=True)
48
+ data = process_data(data, datasets, [sizes[0],sizes[-1]], methods)
49
+ # no_adaptation = data[data["method"] == "no_adaptation"]
50
+ # min_size_no_adaptation = no_adaptation.copy()
51
+ # min_size_no_adaptation["size"] = sizes[0]
52
+ # max_size_no_adaptation = no_adaptation.copy()
53
+ # max_size_no_adaptation["size"] = sizes[-1]
54
+ data = data[data["method"] != "no_adaptation"]
55
+ # data = pd.concat([data, min_size_no_adaptation, max_size_no_adaptation], ignore_index=True)
56
+ data = data[data["size"].isin([sizes[0],sizes[-1]])].reset_index(drop=True)
57
+ data["group"] = data["method"].apply(method2group)
58
+
59
+ data["result"] = data["median"].apply(lambda x: f"{x:.2f}")
60
+ best_idx = []
61
+ for (dataset, size, group), group_data in data.groupby(["dataset","size","group"]):
62
+ med = group_data["median"]
63
+ best_idx.extend(med[med.round(2) == med.round(2).min()].index.to_list())
64
+ # best_idx = data.groupby(["dataset","size","group"])["median"].idxmin().values
65
+ # best_idx = data[data["median"].isin(data.loc[best_idx,"median"])].index.to_numpy()
66
+ data.loc[best_idx,"result"] = data.loc[best_idx,:].apply(highlight_local, axis=1)
67
+ best_idx = []
68
+ for (dataset, size), group_data in data.groupby(["dataset","size"]):
69
+ med = group_data["median"]
70
+ best_idx.extend(med[med.round(2) == med.round(2).min()].index.to_list())
71
+ # best_idx = data.groupby(["dataset","size"])["median"].idxmin().values
72
+ # best_idx = data[data["median"].isin(data.loc[best_idx,"median"])].index.to_numpy()
73
+ data.loc[best_idx,"result"] = data.loc[best_idx,:].apply(highlight_global, axis=1)
74
+ data = data.pivot(index=["size","group","method"], columns="dataset", values="result")
75
+ index = [(size, method2group(method), method) for size in [sizes[0],sizes[-1]] for method in methods]
76
+ data = data.reindex(index)
77
+ data = data.fillna("N/A")
78
+ data_all.append(data)
79
+
80
+ # for size in [sizes[0],sizes[-1]]:
81
+ # for method in methods:
82
+ # for dataset in datasets:
83
+ # if data_all.loc[(size, method2group(method), method), dataset] == "":
84
+ # data_all.loc[(size, method2group(method), method), dataset] += data.loc[(size, method2group(method), method), dataset]
85
+ # else:
86
+ # data_all.loc[(size, method2group(method), method), dataset] += " / " + data.loc[(size, method2group(method), method), dataset]
87
+
88
+ data_all = pd.concat(data_all, axis=1, keys=metrics)
89
+ data_all.columns = data_all.columns.swaplevel(0,1)
90
+ data_all = data_all.loc[:,[(dataset,metric) for dataset in datasets for metric in metrics]]
91
+ # data_all.index = data.index.map(lambda x: ({sizes[0]: f"min (N = {int(np.log2(sizes[0]))})", sizes[-1]: f"max (N = {int(np.log2(sizes[-1]))})"}[x[0]], methods_config[x[1]]["label"].replace("%","\\%").replace("\n", " &" * len(datasets) + " \\\\\n & ") ))
92
+ smallest = f"$T' = {int(np.log2(sizes[0]))}$"
93
+ largest = f"$T' = {int(np.log2(sizes[-1]))}$"
94
+ data_all.index = data.index.map(lambda x: ({sizes[0]: "\\rotatebox[origin=c]{{90}}{smallest}".format(smallest="{" + smallest + "}"), sizes[-1]: "\\rotatebox[origin=c]{{90}}{largest}".format(largest="{" + largest + "}")}[x[0]], method2group(x[2]), methods_config[x[2]]["label"] ))
95
+ data_all = data_all.reset_index(level=1,drop=True)
96
+ data_all = data_all.loc[:,datasets]
97
+ data_all.columns = data_all.columns.map(lambda x: (DATASETS[x[0]]["name"], metric2name[x[1]]))
98
+ data_all.columns.name = None
99
+ data_all.index.names = [None, None]
100
+
101
+ # noa_data = pd.DataFrame({
102
+ # "dataset": datasets,
103
+ # "method": [methods_config["no_adaptation"]["label"]] * len(datasets),
104
+ # "size": [""] * len(datasets),
105
+ # "result": [""] * len(datasets),
106
+ # "metric": [""] * len(datasets),
107
+ # }).pivot(index=["size","method"], columns=["dataset","metric"], values="result").loc[:,datasets]
108
+ # noa_data.columns.name = None
109
+ # noa_data.index.names = [None, None]
110
+
111
+ noa_data = []
112
+ for metric in metrics:
113
+ data = pd.read_json(results_dir / f"{metric}.jsonl", orient='records', lines=True)
114
+ data = process_data(data, datasets, [sizes[0],sizes[-1]], ["no_adaptation"])
115
+ data["metric"] = metric
116
+ noa_data.append(data)
117
+ # for dataset in datasets:
118
+ # m = data[(data["dataset"] == dataset)]["median"].values[0]
119
+ # noa_data.loc[:, dataset] += f"{m:.2f} \ "
120
+ # noa_data.iloc[0,:] = noa_data.iloc[0,:].apply(lambda x: x[:-2])
121
+ # noa_data.columns = noa_data.columns.map(lambda x: (DATASETS[x[0]]["name"], metric2name[x[1]]))
122
+ noa_data = pd.concat(noa_data, axis=0)
123
+ noa_data = noa_data.pivot(index="method",columns=["dataset","metric"],values=["median"])
124
+ noa_data.columns = noa_data.columns.droplevel(0)
125
+ noa_data = noa_data.loc[:,[(dataset,metric) for dataset in datasets for metric in metrics]]
126
+ noa_data.columns = noa_data.columns.map(lambda x: (DATASETS[x[0]]["name"], metric2name[x[1]]))
127
+ noa_data.columns.names = [None, None]
128
+ noa_data.index = noa_data.index.map(lambda x: ("",methods_config[x]["label"]))
129
+ noa_data.index.names = ["size", "method"]
130
+ noa_data = noa_data.apply(lambda x: x.apply(lambda y: f"{y:.2f}"), axis=1)
131
+ data_all = pd.concat([noa_data, data_all], axis=0)
132
+
133
+ return data_all
134
+
135
+ def main(
136
+ datasets,
137
+ sizes,
138
+ metrics,
139
+ methods,
140
+ methods_config,
141
+ results_dir,
142
+ output_path
143
+ ):
144
+ datasets = list(map(str, datasets.split()))
145
+ sizes = list(map(int, sizes.split()))
146
+ methods = list(map(str, methods.split()))
147
+ methods_config = load_yaml(methods_config)
148
+ metrics = list(map(str, metrics.split()))
149
+ output_path = Path(output_path)
150
+ output_dir = output_path.parent
151
+ results_dir = Path(results_dir)
152
+
153
+
154
+ table = create_table(results_dir, methods, metrics, methods_config, datasets, sizes)
155
+ table_str = table.to_latex(escape=False, column_format="ll" + "||c|c" * len(datasets))
156
+ table_str = table_str.replace("multirow[t]", "multirow[c]")
157
+ table_str = table_str.replace("multicolumn{2}{r}", "multicolumn{2}{c||}")
158
+ with open(output_path, "w") as f:
159
+ f.write(table_str)
160
+
161
+
162
+
163
+
164
+
165
+ if __name__ == '__main__':
166
+ from fire import Fire
167
+ Fire(main)
src/llmcal/scripts/results_vs_samples.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pathlib import Path
3
+
4
+ from matplotlib import pyplot as plt
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from ..src.utils import load_yaml
9
+
10
+ DATASETS = {
11
+ "sst2": {"name": "SST-2", "num_classes": 2},
12
+ "agnews": {"name": "AGNews", "num_classes": 4},
13
+ "dbpedia": {"name": "DBPedia", "num_classes": 14},
14
+ "20newsgroups": {"name": "20 Newsgroups", "num_classes": 20},
15
+ "banking77": {"name": "Banking77", "num_classes": 77},
16
+ }
17
+
18
+ metric2name = {
19
+ "nce": "NCE",
20
+ "ner": "NER",
21
+ }
22
+
23
+
24
+
25
+
26
+ def process_data(data, datasets, sizes, methods):
27
+
28
+ # Keep only selected datasets
29
+ data = data[data["dataset"].isin(datasets)]
30
+
31
+ # Keep only selected sizes
32
+ data = data[data["size"].isin(sizes) | (data["size"] == 'all')]
33
+
34
+ # Keep only selected methods
35
+ data = data[data["method"].isin(methods)]
36
+
37
+ # Group by size
38
+ data = data.groupby(["dataset", "method", "size"])["result"].agg(
39
+ median=lambda x: x.median(),
40
+ q1=lambda x: x.quantile(0.25),
41
+ q3=lambda x: x.quantile(0.75),
42
+ count=lambda x: x.count(),
43
+ ).reset_index()
44
+
45
+ assert (data["count"] > 0).all()
46
+ data = data.drop(columns=["count"])
47
+
48
+ return data
49
+
50
+
51
+ def compute_num_samples(sizes, dataset):
52
+ num_classes = DATASETS[dataset]["num_classes"]
53
+ scale = sizes / np.log2(num_classes)
54
+ nearest_power_of_2 = 2 ** np.round(np.log2(scale))
55
+ num_samples = nearest_power_of_2 * num_classes
56
+ return num_samples.astype(int)
57
+
58
+
59
+
60
+ def plot_metric_vs_samples(ax, data, all_methods, methods_config, datasets, sizes, intervals=False, pos=0, no_adaptation="plot", modelname_noa=None, fontsize_noa=18):
61
+ datasets_data = {}
62
+ for i, dataset in enumerate(datasets):
63
+
64
+ # All methods in dataset
65
+ dataset_data = data[data["dataset"] == dataset]
66
+ methods = [m for m in all_methods if m in dataset_data["method"].unique()]
67
+
68
+ for j, method in enumerate(methods):
69
+ # Get data for method
70
+ method_data = dataset_data[dataset_data["method"] == method].set_index("size").drop(columns=["dataset", "method"])
71
+
72
+ # Fill missing sizes
73
+ missing_sizes = set(sizes) - set(method_data.index) if method != "no_adaptation" else set()
74
+ for size in missing_sizes:
75
+ method_data.loc[size] = [np.nan, np.nan, np.nan]
76
+
77
+ # Sort by size
78
+ method_data = method_data.sort_index()
79
+
80
+ # Plot
81
+ if method == "no_adaptation":
82
+ if dataset_data.loc[:, "q1"].min() < method_data.loc["all", "median"] < dataset_data.loc[:, "median"].max() and no_adaptation in ["plot", "auto"]:
83
+ num_samples = compute_num_samples(sizes, dataset)
84
+ medians = [method_data.loc["all", "median"]] * len(num_samples)
85
+ q1 = [method_data.loc["all", "q1"]] * len(num_samples)
86
+ q3 = [method_data.loc["all", "q3"]] * len(num_samples)
87
+ kwargs = methods_config[method]
88
+ ax[i].plot(num_samples, medians, **kwargs)
89
+ elif no_adaptation in ["text", "auto"]:
90
+ if modelname_noa is not None:
91
+ text = f"{methods_config['no_adaptation']['label']} ({modelname_noa})"
92
+ else:
93
+ text = f"{methods_config['no_adaptation']['label']}"
94
+ ax[i].text(.95, .95-pos,
95
+ f"{text} = {method_data.loc['all', 'median']:.2f}",
96
+ fontsize=fontsize_noa, ha="right", va="top", transform=ax[i].transAxes, color=methods_config[method]["color"]
97
+ )
98
+ elif no_adaptation == "skip":
99
+ pass
100
+
101
+ else:
102
+ num_samples = compute_num_samples(method_data.index.astype(int), dataset)
103
+ medians = method_data["median"]
104
+ q1 = method_data["q1"]
105
+ q3 = method_data["q3"]
106
+ kwargs = methods_config[method]
107
+ ax[i].plot(num_samples, medians, **kwargs)
108
+ if intervals:
109
+ ax[i].fill_between(num_samples, q1, q3, alpha=0.3, color=kwargs["color"])
110
+
111
+
112
+ ax[i].set_xscale("log")
113
+ ax[i].set_xticks(num_samples)
114
+ ax[i].set_xticklabels(num_samples, fontsize=18)
115
+ ax[i].set_xlim([min(num_samples)*0.9, max(num_samples)*1.1])
116
+
117
+
118
+
119
+ datasets_data[dataset] = dataset_data
120
+
121
+ return datasets_data
122
+
123
+
124
+ def main(
125
+ datasets,
126
+ sizes,
127
+ metrics,
128
+ methods,
129
+ methods_config,
130
+ results_dir,
131
+ output_path,
132
+ intervals = False,
133
+ ):
134
+ datasets = list(map(str, datasets.split()))
135
+ sizes = list(map(int, sizes.split()))
136
+ methods = list(map(str, methods.split()))
137
+ methods_config = load_yaml(methods_config)
138
+ metrics = list(map(str, metrics.split()))
139
+ output_path = Path(output_path)
140
+ output_dir = output_path.parent
141
+ results_dir = Path(results_dir)
142
+
143
+ fig, axs = plt.subplots(len(metrics), len(datasets), figsize=(6 * len(datasets), 12))
144
+ processed_data = {}
145
+ for ax, metric in zip(axs,metrics):
146
+ data = pd.read_json(results_dir / f"{metric}.jsonl", orient='records', lines=True)
147
+ processed_data[metric] = data
148
+ data = process_data(data, datasets, sizes, methods)
149
+ datasets_data = plot_metric_vs_samples(ax, data, methods, methods_config, datasets, sizes, intervals=intervals, no_adaptation="auto")
150
+ for i, dataset in enumerate(datasets):
151
+ min_y, max_y = datasets_data[dataset].loc[datasets_data[dataset]["method"].isin(set(methods) - {"no_adaptation"}),"median"].min(), datasets_data[dataset].loc[datasets_data[dataset]["method"].isin(set(methods) - {"no_adaptation"}),"median"].max()
152
+ ax[i].set_ylim(min_y*0.99, max_y*1.01)
153
+ ax[i].set_yticks(np.round(ax[i].get_yticks(),3))
154
+ ax[i].set_yticklabels(ax[i].get_yticks(), fontsize=16)
155
+ # ax[i].grid(axis="y")
156
+
157
+ data.to_csv(output_dir / f"{metric}.csv", index=False)
158
+ ax[0].set_ylabel(f"{metric2name[metric]}", fontsize=22)
159
+ for j, dataset in enumerate(datasets):
160
+ axs[0,j].set_title(DATASETS[dataset]["name"], fontsize=22)
161
+
162
+ fig.text(0.5, 0.04, 'Number of train samples', ha='center', fontsize=22)
163
+ # axs[0,-1].legend(loc="upper right", bbox_to_anchor=(2.4, 1), title="Method", title_fontsize=24, fontsize=22)
164
+
165
+ # Gather handles and labels from all axes
166
+ handles, labels = [], []
167
+ for ax in axs.flat:
168
+ hs, ls = ax.get_legend_handles_labels()
169
+ for h, l in zip(hs, ls):
170
+ if l not in labels:
171
+ handles.append(h)
172
+ labels.append(l)
173
+ fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(.5, -0.1), ncols=5, title="Method", title_fontsize=28, fontsize=26)
174
+
175
+ plt.savefig(output_path, bbox_inches="tight", dpi=300)
176
+ plt.close(fig)
177
+
178
+
179
+ if __name__ == '__main__':
180
+ from fire import Fire
181
+ Fire(main)
src/llmcal/scripts/run_posteriors.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ from pathlib import Path
6
+ from typing import Optional, Union, Literal
7
+ import warnings
8
+
9
+ from litgpt.tokenizer import Tokenizer
10
+ from litgpt.utils import (
11
+ check_valid_checkpoint_dir,
12
+ get_default_supported_precision,
13
+ check_nvlink_connectivity,
14
+ load_checkpoint
15
+ )
16
+ import lightning as L
17
+ from lightning_utilities.core.imports import RequirementCache
18
+ from lightning.fabric.plugins import BitsandbytesPrecision
19
+ from lightning.fabric.strategies import FSDPStrategy
20
+ from tqdm import tqdm
21
+
22
+ from ..src.utils import get_dataloader
23
+
24
+
25
+ def setup(
26
+ base_checkpoint_dir: str,
27
+ checkpoint_dir,
28
+ data_path: str,
29
+ output_dir: str,
30
+ prediction_lists: Optional[str] = None,
31
+ peft: Union[Literal["lora", "adapter"], None] = None,
32
+ precision: Optional[str] = None,
33
+ devices: Union[int, str] = 1,
34
+ num_nodes: int = 1,
35
+ batch_size: int = 1,
36
+ max_seq_length: int = 1024,
37
+ **peft_kwargs,
38
+ ):
39
+ # Basic setup
40
+ torch.set_float32_matmul_precision("high")
41
+ data_path = Path(data_path)
42
+ output_dir = Path(output_dir)
43
+ base_checkpoint_dir = Path(base_checkpoint_dir)
44
+ checkpoint_dir = Path(checkpoint_dir)
45
+ prediction_list = np.hstack([np.loadtxt(prediction_list, dtype=int) for prediction_list in prediction_lists.split(",")])
46
+
47
+ # Load config file
48
+ check_valid_checkpoint_dir(base_checkpoint_dir)
49
+ if peft is None:
50
+ from litgpt.config import Config
51
+ from litgpt.model import Block
52
+ elif peft == "lora":
53
+ from litgpt.lora import Config, Block
54
+ elif peft == "adapter":
55
+ from litgpt.adapter import Config, Block
56
+ else:
57
+ raise ValueError(f"Unknown peft type: {peft}")
58
+ config = Config.from_file(checkpoint_dir / "model_config.yaml", **peft_kwargs)
59
+
60
+ # Precision
61
+ precision = precision or get_default_supported_precision(training=True)
62
+
63
+ # Strategy
64
+ if devices * num_nodes > 1:
65
+ strategy = "ddp"
66
+ else:
67
+ strategy = "auto"
68
+
69
+ # Init fabric
70
+ fabric = L.Fabric(
71
+ devices=devices,
72
+ num_nodes=num_nodes,
73
+ strategy=strategy,
74
+ precision=precision,
75
+ )
76
+ if torch.cuda.is_available() and devices > 1:
77
+ check_nvlink_connectivity(fabric)
78
+
79
+ # Launch
80
+ fabric.launch(main, peft, config, base_checkpoint_dir, checkpoint_dir, data_path, output_dir, prediction_list, batch_size, max_seq_length)
81
+
82
+
83
+ def main(
84
+ fabric: L.Fabric,
85
+ peft,
86
+ config,
87
+ base_checkpoint_dir,
88
+ checkpoint_dir,
89
+ data_path,
90
+ output_dir,
91
+ prediction_list,
92
+ batch_size,
93
+ max_seq_length,
94
+ ):
95
+
96
+ # Seed everything
97
+ fabric.seed_everything(92837)
98
+
99
+ # Load model parameters from checkpoint
100
+ if peft is None:
101
+ from litgpt.model import GPT
102
+ elif peft == "lora":
103
+ from litgpt.lora import GPT
104
+ elif peft == "adapter":
105
+ from litgpt.adapter import GPT
106
+ else:
107
+ raise ValueError(f"Unknown peft type: {peft}")
108
+ checkpoint_path = base_checkpoint_dir / "lit_model.pth"
109
+ with fabric.init_module(empty_init=(fabric.world_size > 1)):
110
+ model = GPT(config)
111
+ model.max_seq_length = max_seq_length
112
+ model.set_kv_cache(batch_size=batch_size, max_seq_length=max_seq_length)
113
+ model = fabric.setup_module(model)
114
+ load_checkpoint(fabric, model, checkpoint_path, strict=False)
115
+
116
+ if peft == "lora":
117
+ from litgpt.lora import merge_lora_weights
118
+ lora_checkpoint_path = checkpoint_dir / "lit_model.pth.lora"
119
+ load_checkpoint(fabric, model, lora_checkpoint_path, strict=False)
120
+ merge_lora_weights(model)
121
+ elif peft == "adapter":
122
+ adapter_checkpoint_path = checkpoint_dir / "lit_model.pth.adapter"
123
+ load_checkpoint(fabric, model, adapter_checkpoint_path, strict=False)
124
+
125
+ # Load tokenizer
126
+ tokenizer = Tokenizer(checkpoint_dir)
127
+
128
+ # Predict
129
+ dataloader = get_dataloader([data_path], [prediction_list], tokenizer, batch_size, pad_token_id=0, max_seq_length=model.max_seq_length, shuffle = False)
130
+ dataloader = fabric.setup_dataloaders(dataloader)
131
+ predictions = predict(fabric, model, dataloader)
132
+ if fabric.global_rank == 0:
133
+ pd.DataFrame(predictions["logits"], index=predictions["idx"]).to_csv(output_dir / f"logits.csv", index=True, header=False)
134
+ pd.DataFrame(predictions["label"], index=predictions["idx"]).to_csv(output_dir / f"labels.csv", index=True, header=False)
135
+
136
+
137
+ def predict_step(fabric, model, indices, prompt_ids, prompt_mask, answers_ids, labels):
138
+ logits = []
139
+ for input_ids, attention_mask, answers in zip(prompt_ids, prompt_mask, answers_ids):
140
+ input_ids = input_ids[attention_mask == 1].unsqueeze(0)
141
+ T = torch.sum(attention_mask)
142
+ with fabric.init_tensor():
143
+ input_pos = torch.arange(0, T)
144
+ output = model(idx=input_ids, input_pos=input_pos)
145
+ answers_logits = []
146
+ for answer in answers:
147
+ answer = answer.unsqueeze(0)
148
+ input_pos = torch.arange(T, answer.shape[1] + T, device=answer.device, dtype=answer.dtype)
149
+ ans_out = model(idx=answer, input_pos=input_pos)
150
+ logprobs = torch.cat([output[:,-1:,:], ans_out[:,:-1,:]], dim=1).log_softmax(dim=2)
151
+ index = answer.unsqueeze(2)
152
+ gather_probs = torch.gather(logprobs, -1, index).squeeze(2)
153
+ ans_logit = gather_probs.sum()
154
+ answers_logits.append(ans_logit)
155
+ logits.append(torch.stack(answers_logits, dim=0))
156
+ logits = torch.stack(logits, dim=0)
157
+ return {"idx": indices, "logits": logits, "label": labels}
158
+
159
+
160
+ @torch.no_grad()
161
+ def predict(fabric, model, dataloader):
162
+ predict_outputs = {"idx": [], "logits": [], "label": []}
163
+ model.eval()
164
+ for i, batch in enumerate(dataloader):
165
+ if i % max(len(dataloader) // 50,1) == 0:
166
+ fabric.print(f"Predicting batch {i+1}/{len(dataloader)}")
167
+ outputs = predict_step(fabric, model, batch["idx"], batch["prompt_ids"], batch["prompt_mask"], batch["answers_ids"], batch["label"])
168
+ fabric.barrier()
169
+ gathered_outputs = fabric.all_gather(outputs)
170
+ if fabric.global_rank == 0:
171
+ for k, v in gathered_outputs.items():
172
+ if fabric.world_size > 1:
173
+ v = v.view(-1, *v.shape[2:]).cpu()
174
+ else:
175
+ v = v.cpu()
176
+ if k in ["idx", "label"]:
177
+ predict_outputs[k].append(v.long())
178
+ else:
179
+ predict_outputs[k].append(v.float())
180
+
181
+ if fabric.global_rank == 0:
182
+ for k, v in predict_outputs.items():
183
+ predict_outputs[k] = torch.cat(v, dim=0).numpy()
184
+
185
+ return predict_outputs
186
+
187
+
188
+
189
+
190
+
191
+ if __name__ == "__main__":
192
+ from fire import Fire
193
+ Fire(setup)
src/llmcal/scripts/train_lora.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ import torch
4
+ from pathlib import Path
5
+ from typing import Literal, Optional, Union
6
+ import warnings
7
+
8
+ from litgpt.tokenizer import Tokenizer
9
+ from litgpt.utils import (
10
+ check_valid_checkpoint_dir,
11
+ get_default_supported_precision,
12
+ check_nvlink_connectivity,
13
+ load_checkpoint,
14
+ CycleIterator,
15
+ )
16
+ from litgpt.lora import Config, Block, GPT, mark_only_lora_as_trainable, lora_filter
17
+ import lightning as L
18
+ from lightning_utilities.core.imports import RequirementCache
19
+ from lightning.fabric.utilities.load import _lazy_load as lazy_load
20
+ from lightning.fabric.plugins import BitsandbytesPrecision
21
+ from lightning.fabric.strategies import FSDPStrategy
22
+ from tqdm import tqdm
23
+
24
+ from ..src.utils import get_dataloader, save_yaml
25
+ from ..src.loggers import TBLogger, CSVLogger
26
+
27
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*Experiment logs directory outputs*")
28
+
29
+ def setup(
30
+ base_checkpoint_dir: str,
31
+ lora_checkpoint_dir: Optional[str] = None,
32
+ data_paths: str = None,
33
+ train_lists: str = None,
34
+ val_lists: str = None,
35
+ output_dir: str = None,
36
+ output_checkpoint_dir: str = None,
37
+ log_dir: str = None,
38
+ precision: Optional[str] = None,
39
+ quantize: Optional[Literal["bnb.nf4", "bnb.nf4-dq", "bnb.fp4", "bnb.fp4-dq", "bnb.int8-training"]] = None,
40
+ devices: Union[int, str] = 1,
41
+ num_nodes: int = 1,
42
+ global_batch_size: int = 16,
43
+ micro_batch_size: int = 1,
44
+ val_check_interval = 16,
45
+ learning_rate = 0.0001,
46
+ optimizer = Literal["sgd", "adamw"],
47
+ weight_decay = 0.0,
48
+ loss: Literal["fs", "ans", "norm"] = "fs",
49
+ patience: int = 10,
50
+ max_steps: int = -1,
51
+ seed = 0,
52
+ max_seq_length: int = 1024,
53
+ **lora_kwargs,
54
+ ):
55
+
56
+ # Basic setup
57
+ torch.set_float32_matmul_precision("high")
58
+ data_paths = [Path(data_path) for data_path in data_paths.split(",")]
59
+ output_dir = Path(output_dir)
60
+ output_checkpoint_dir = Path(output_checkpoint_dir)
61
+ base_checkpoint_dir = Path(base_checkpoint_dir)
62
+ lora_checkpoint_dir = Path(lora_checkpoint_dir) if lora_checkpoint_dir is not None else None
63
+ train_lists = [np.loadtxt(train_list, dtype=int) for train_list in train_lists.split(",")]
64
+
65
+ if val_lists is None:
66
+ rs = np.random.RandomState(seed)
67
+ val_lists = [rs.choice(train_list, min(len(train_list) // 10, 10), replace=False) for train_list in train_lists]
68
+ else:
69
+ val_lists = [np.loadtxt(val_list, dtype=int) for val_list in val_lists.split(",")]
70
+
71
+ # Load config file
72
+ check_valid_checkpoint_dir(base_checkpoint_dir)
73
+ config = Config.from_file(base_checkpoint_dir / "model_config.yaml", **lora_kwargs)
74
+
75
+ # Precision and quantization
76
+ precision = precision or get_default_supported_precision(training=True)
77
+ plugins = None
78
+ if quantize is not None and quantize.startswith("bnb."):
79
+ if "mixed" in precision:
80
+ raise ValueError("Quantization and mixed precision is not supported.")
81
+ if RequirementCache("bitsandbytes != 0.42.0"):
82
+ warnings.warn(
83
+ "LitGPT only supports bitsandbytes v0.42.0. "
84
+ "This may result in errors when using quantization."
85
+ )
86
+ dtype = {"16-true": torch.float16, "bf16-true": torch.bfloat16, "32-true": torch.float32}[precision]
87
+ plugins = BitsandbytesPrecision(quantize[4:], dtype)
88
+ precision = None
89
+
90
+ if devices * num_nodes > 1:
91
+ if quantize:
92
+ raise NotImplementedError(
93
+ "Quantization is currently not supported for multi-GPU training. Please set devices=1 and num_nodes=1"
94
+ " when using the --quantize flag."
95
+ )
96
+ strategy = FSDPStrategy(
97
+ auto_wrap_policy={Block},
98
+ activation_checkpointing_policy={Block},
99
+ state_dict_type="full",
100
+ limit_all_gathers=True,
101
+ cpu_offload=False,
102
+ )
103
+ else:
104
+ strategy = "auto"
105
+
106
+ # Init fabric
107
+ fabric = L.Fabric(
108
+ devices=devices,
109
+ num_nodes=num_nodes,
110
+ strategy=strategy,
111
+ precision=precision,
112
+ plugins=plugins,
113
+ loggers=[
114
+ TBLogger(save_dir=log_dir),
115
+ CSVLogger(save_dir=log_dir),
116
+ ]
117
+ )
118
+ if torch.cuda.is_available() and devices > 1:
119
+ check_nvlink_connectivity(fabric)
120
+
121
+ # Launch
122
+ train_args = {
123
+ "loss": "norm" if loss.startswith("norm-") else loss,
124
+ "K": int(loss.split("-")[-1]) if loss.startswith("norm-") else None,
125
+ "global_batch_size": global_batch_size,
126
+ "micro_batch_size": micro_batch_size,
127
+ "val_check_interval": val_check_interval,
128
+ "learning_rate": learning_rate,
129
+ "optimizer_name": optimizer,
130
+ "weight_decay": weight_decay,
131
+ "patience": patience,
132
+ "max_steps": max_steps,
133
+ }
134
+ fabric.launch(main, config, base_checkpoint_dir, lora_checkpoint_dir, data_paths, output_dir, output_checkpoint_dir, train_lists, val_lists, train_args, devices, seed, max_seq_length)
135
+
136
+
137
+ def main(
138
+ fabric: L.Fabric,
139
+ config: Config,
140
+ base_checkpoint_dir: Path,
141
+ lora_checkpoint_dir: Optional[Path],
142
+ data_paths: Path,
143
+ output_dir: Path,
144
+ output_checkpoint_dir: Path,
145
+ train_lists: np.ndarray,
146
+ val_lists: np.ndarray,
147
+ train_args: dict,
148
+ devices: int,
149
+ seed: int,
150
+ max_seq_length: int,
151
+ ):
152
+ fabric.seed_everything(seed)
153
+
154
+ # Init dataloaders
155
+ tokenizer = Tokenizer(base_checkpoint_dir)
156
+ train_dataloader = get_dataloader(data_paths, train_lists, tokenizer, train_args["micro_batch_size"], 0, max_seq_length, shuffle = True, seed = seed)
157
+ val_dataloader = get_dataloader(data_paths, val_lists, tokenizer, train_args["micro_batch_size"], 0, max_seq_length, shuffle = False, seed = seed)
158
+ train_dataloader, val_dataloader = fabric.setup_dataloaders(train_dataloader, val_dataloader)
159
+
160
+ # Init Base model
161
+ base_checkpoint_path = base_checkpoint_dir / "lit_model.pth"
162
+ with fabric.init_module(empty_init=(fabric.world_size > 1)):
163
+ model = GPT(config)
164
+ model.max_seq_length = max_seq_length
165
+ mark_only_lora_as_trainable(model)
166
+ model = fabric.setup_module(model)
167
+
168
+ # Init optimizer
169
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
170
+ if train_args["optimizer_name"] == "sgd":
171
+ optimizer = torch.optim.SGD(trainable_params, lr=train_args["learning_rate"], weight_decay=train_args["weight_decay"])
172
+ elif train_args["optimizer_name"] == "adamw":
173
+ optimizer = torch.optim.AdamW(trainable_params, lr=train_args["learning_rate"], weight_decay=train_args["weight_decay"])
174
+ else:
175
+ raise ValueError(f"Unknown optimizer: {train_args['optimizer_name']}")
176
+ optimizer = fabric.setup_optimizers(optimizer)
177
+
178
+ # Load weights
179
+ load_checkpoint(fabric, model, base_checkpoint_path, strict=False)
180
+ if lora_checkpoint_dir is not None:
181
+ lora_checkpoint_path = lora_checkpoint_dir / "lit_model.pth.lora"
182
+ load_checkpoint(fabric, model, lora_checkpoint_path, strict=False)
183
+
184
+ # Train
185
+ fit(fabric, model, optimizer, train_dataloader, val_dataloader, devices, output_dir, seed, **train_args)
186
+ if fabric.device.type == "cuda":
187
+ fabric.print(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
188
+ fabric.print("Training finished.")
189
+
190
+ if fabric.global_rank == 0:
191
+ save_yaml(train_args, output_dir / "train_args.yaml")
192
+
193
+ fabric.save(output_checkpoint_dir / "lit_model.pth.lora", {k: v for k, v in model.state_dict().items() if lora_filter(k,v)})
194
+
195
+
196
+
197
+ def fit(fabric, model, optimizer, train_dataloader, val_dataloader, devices, output_dir, seed, **train_args):
198
+
199
+ if (state_dict_path := output_dir / "last.ckpt").exists():
200
+ state = lazy_load(state_dict_path)
201
+ model.load_state_dict(state["model"], strict=False) # only lora params are saved
202
+ else:
203
+ state = {
204
+ "model": {k: v for k, v in model.state_dict().items() if lora_filter(k,v)},
205
+ "step_count": 0,
206
+ "iter_num": 0,
207
+ "best_val_loss": float("inf"),
208
+ "last_val_loss": float("inf"),
209
+ "patience_count": 0,
210
+ "cum_train_loss": 0,
211
+ "cum_train_num_tokens": 0,
212
+ "start_time": time.perf_counter(),
213
+ "end_time": None,
214
+ }
215
+
216
+ train_iterator = CycleIterator(train_dataloader)
217
+ rs = np.random.RandomState(seed)
218
+ gradient_accumulation_iters = (train_args["global_batch_size"] // devices) // train_args["micro_batch_size"]
219
+
220
+ # Define the loss
221
+ if train_args["loss"] == "fs":
222
+ loss_fn = FullSentenceLoss()
223
+ elif train_args["loss"] == "ans":
224
+ loss_fn = LossOnAnswer()
225
+ elif train_args["loss"] == "norm":
226
+ loss_fn = LossNormByAnswers(len(train_dataloader.dataset[0]["answers_ids"]), train_args["K"], seed)
227
+ else:
228
+ raise ValueError(f"Unknown loss: {train_args['loss']}")
229
+
230
+ # Advance until state
231
+ step_count = 0
232
+ iter_num = 0
233
+ while step_count < state["step_count"]:
234
+ iter_num += 1
235
+ batch = next(train_iterator)
236
+ if train_args["loss"] == "norm":
237
+ loss_fn.use_ids(batch["label"])
238
+ is_accumulating = iter_num % gradient_accumulation_iters != 0
239
+ if not is_accumulating:
240
+ step_count += 1
241
+
242
+ # Continue training
243
+ model.train()
244
+ stop_training = False
245
+ while not stop_training:
246
+ state["iter_num"] += 1
247
+ batch = next(train_iterator)
248
+ iter_t0 = time.perf_counter()
249
+
250
+ # Perform forward and backward pass
251
+ is_accumulating = state["iter_num"] % gradient_accumulation_iters != 0
252
+ with fabric.no_backward_sync(model, enabled=is_accumulating):
253
+ loss, num_tokens = loss_fn(model, batch["prompt_ids"], batch["prompt_mask"], batch["answers_ids"], batch["label"])
254
+ fabric.backward(loss / num_tokens / gradient_accumulation_iters)
255
+
256
+ # Accumulate loss for logging
257
+ state["cum_train_loss"] += loss.item()
258
+ state["cum_train_num_tokens"] += num_tokens
259
+
260
+ # Perform optimizer step
261
+ if not is_accumulating:
262
+ optimizer.step()
263
+ optimizer.zero_grad()
264
+ state["step_count"] += 1
265
+
266
+ # Log train loss
267
+ if not is_accumulating or state["iter_num"] == 1:
268
+ t1 = time.perf_counter()
269
+ metrics = {
270
+ "train/loss": state["cum_train_loss"] / state["cum_train_num_tokens"],
271
+ "iter": state["iter_num"],
272
+ "step": state["step_count"],
273
+ "epoch": train_iterator.epoch,
274
+ "iter_time": t1 - iter_t0,
275
+ }
276
+ fabric.print(
277
+ f"Epoch {metrics['epoch']+1} | iter {metrics['iter']} step {metrics['step']} |"
278
+ f" train loss: {metrics['train/loss']:.3f},"
279
+ f" val loss: {state['last_val_loss']:.3f} |"
280
+ f" best val loss: {state['best_val_loss']:.3f} |"
281
+ f" patience: {state['patience_count']} |"
282
+ f" iter time: {metrics['iter_time'] * 1000:.2f} ms"
283
+ )
284
+ fabric.log_dict(metrics, step=state["step_count"])
285
+
286
+ # Validate
287
+ if not is_accumulating and state["step_count"] % train_args["val_check_interval"] == 0:
288
+ val_loss, val_num_tokens = validate(fabric, model, val_dataloader, train_args, seed)
289
+ state["last_val_loss"] = val_loss.item() / val_num_tokens
290
+ fabric.log_dict({
291
+ "val/loss": state["last_val_loss"],
292
+ }, step=state["step_count"])
293
+ if state["last_val_loss"] < state["best_val_loss"]:
294
+ state.update({
295
+ "model": {k: v for k, v in model.state_dict().items() if lora_filter(k,v)},
296
+ "end_time": time.perf_counter(),
297
+ "best_val_loss": state["last_val_loss"],
298
+ "patience_count": 0,
299
+ })
300
+ fabric.save(output_dir / "best.ckpt", state)
301
+ else:
302
+ state["patience_count"] += 1
303
+ fabric.barrier()
304
+
305
+ # Save last checkpoint
306
+ if not is_accumulating:
307
+ state["model"] = {k: v for k, v in model.state_dict().items() if lora_filter(k,v)}
308
+ state["end_time"] = time.perf_counter()
309
+ fabric.save(output_dir / "last.ckpt", state)
310
+ state["cum_train_loss"] = 0
311
+ state["cum_train_num_tokens"] = 0
312
+
313
+ # Check if training should stop
314
+ if train_args["max_steps"] > 0:
315
+ stop_training = state["step_count"] >= train_args["max_steps"]
316
+ else:
317
+ stop_training = state["patience_count"] >= train_args["patience"]
318
+
319
+
320
+ @torch.no_grad()
321
+ def validate(fabric, model, val_dataloader, train_args, seed):
322
+ if train_args["loss"] == "fs":
323
+ loss_fn = FullSentenceLoss()
324
+ elif train_args["loss"] == "ans":
325
+ loss_fn = LossOnAnswer()
326
+ elif train_args["loss"] == "norm":
327
+ loss_fn = LossNormByAnswers(len(val_dataloader.dataset[0]["answers_ids"]), train_args["K"], seed)
328
+
329
+ total_loss = 0
330
+ total_num_tokens = 0
331
+ model.eval()
332
+ fabric.print("Validating...")
333
+ for batch in val_dataloader:
334
+ loss, num_tokens = loss_fn(model, batch["prompt_ids"], batch["prompt_mask"], batch["answers_ids"], batch["label"])
335
+ total_loss += loss
336
+ total_num_tokens += num_tokens
337
+ model.train()
338
+ return total_loss, total_num_tokens
339
+
340
+ class FullSentenceLoss(torch.nn.Module):
341
+ def __init__(self):
342
+ super().__init__()
343
+
344
+ def forward(self, model, prompt_ids, prompt_mask, answers_ids, labels):
345
+ loss = 0
346
+ num_tokens = 0
347
+ for input_ids, attention_mask, answers, label in zip(prompt_ids, prompt_mask, answers_ids, labels):
348
+ input_ids = input_ids[attention_mask == 1].unsqueeze(0)
349
+ full_input_ids = torch.cat([input_ids, answers[label.item()].unsqueeze(0)], dim=1)
350
+ logprobs = model(full_input_ids, None)[:,:-1,:].log_softmax(dim=2)
351
+ index = full_input_ids[:,1:].unsqueeze(2)
352
+ gather_logprobs = torch.gather(logprobs, -1, index).squeeze(2)
353
+ loss = loss - gather_logprobs.sum()
354
+ num_tokens = num_tokens + index.size(1)
355
+ return loss, num_tokens
356
+
357
+ class LossOnAnswer(torch.nn.Module):
358
+ def __init__(self):
359
+ super().__init__()
360
+
361
+ def forward(self, model, prompt_ids, prompt_mask, answers_ids, labels):
362
+ loss = 0
363
+ num_tokens = 0
364
+ for input_ids, attention_mask, answers, label in zip(prompt_ids, prompt_mask, answers_ids, labels):
365
+ input_ids = input_ids[attention_mask == 1].unsqueeze(0)
366
+ full_input_ids = torch.cat([input_ids, answers[label.item()].unsqueeze(0)], dim=1)
367
+ logprobs = model(full_input_ids, None)[:,input_ids.shape[1]-1:-1,:].log_softmax(dim=2)
368
+ index = full_input_ids[:,input_ids.shape[1]:].unsqueeze(2)
369
+ gather_logprobs = torch.gather(logprobs, -1, index).squeeze(2)
370
+ loss = loss - gather_logprobs.sum()
371
+ num_tokens = num_tokens + index.size(1)
372
+ return loss, num_tokens
373
+
374
+
375
+ class LossNormByAnswers(torch.nn.Module):
376
+
377
+ def __init__(self, total_num_answers, K = 5, seed = None):
378
+ super().__init__()
379
+ self.total_num_answers = total_num_answers
380
+ self.K = K
381
+ self._rs = np.random.RandomState(seed)
382
+
383
+ def use_ids(self, label):
384
+ if self.K is not None:
385
+ use_ids = np.hstack(
386
+ (self._rs.choice([i for i in range(self.total_num_answers) if i != label], min(self.K-1,self.total_num_answers-1), replace=False),[label.item()])
387
+ )
388
+ else:
389
+ use_ids = np.arange(self.total_num_answers)
390
+ return use_ids
391
+
392
+
393
+ def forward(self, model, prompt_ids, prompt_mask, answers_ids, labels):
394
+ loss = 0
395
+ num_tokens = 0
396
+ for input_ids, attention_mask, answers, label in zip(prompt_ids, prompt_mask, answers_ids, labels):
397
+ input_ids = input_ids[attention_mask == 1].unsqueeze(0)
398
+ class_logprobs = []
399
+ use_ids = self.use_ids(label)
400
+ for i, ans_ids in enumerate(answers):
401
+ if i not in use_ids:
402
+ class_logprobs.append(torch.tensor(-float("inf"), device=input_ids.device))
403
+ continue
404
+ full_input_ids = torch.cat([input_ids, ans_ids.unsqueeze(0)], dim=1)
405
+ logprobs = model(full_input_ids, None)[:,input_ids.shape[1]-1:-1,:].log_softmax(dim=2)
406
+ index = full_input_ids[:,input_ids.shape[1]:].unsqueeze(2)
407
+ gather_logprobs = torch.gather(logprobs, -1, index).squeeze(2)
408
+ logprob = gather_logprobs.sum()
409
+ class_logprobs.append(logprob)
410
+ logits = torch.stack(class_logprobs, dim=0)
411
+ num_tokens = num_tokens + answers[label.item()].size(0)
412
+ loss = loss + torch.nn.functional.cross_entropy(logits.unsqueeze(0), label.unsqueeze(0), reduction="sum")
413
+ return loss, num_tokens
414
+
415
+
416
+ if __name__ == '__main__':
417
+ from fire import Fire
418
+ Fire(setup)
src/llmcal/src/__init__.py ADDED
File without changes
src/llmcal/src/evaluation/calibration.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold, GroupKFold, KFold
7
+
8
+
9
+ class DPCalibrator(nn.Module):
10
+
11
+ def __init__(self, n_classes):
12
+ super().__init__()
13
+ self.n_classes = n_classes
14
+ self.alpha = nn.Parameter(torch.tensor(1.0))
15
+ self.beta = nn.Parameter(torch.zeros(n_classes))
16
+
17
+ def forward(self, x):
18
+ return self.alpha * x + self.beta
19
+
20
+ def calibrate(self, logprobs):
21
+ self.eval()
22
+ with torch.no_grad():
23
+ cal_logprobs = torch.log_softmax(self(logprobs), dim=1)
24
+ return cal_logprobs
25
+
26
+ def fit(self, logprobs, labels):
27
+ self.train()
28
+ optimizer = torch.optim.LBFGS(self.parameters(), lr=1e-1, max_iter=40)
29
+
30
+ priors = torch.bincount(labels, minlength=logprobs.shape[1]).float() / len(labels)
31
+ priors_ce = -torch.log(priors[labels]).mean().item()
32
+
33
+ last_nce = float("inf")
34
+ while True:
35
+
36
+ def closure():
37
+ optimizer.zero_grad()
38
+ cal_logits = self(logprobs)
39
+ loss = F.cross_entropy(cal_logits, labels)
40
+ loss.backward()
41
+ return loss
42
+
43
+ loss = optimizer.step(closure)
44
+
45
+ nce = loss.item() / priors_ce
46
+ if abs(last_nce - nce) < 1e-5:
47
+ break
48
+ last_nce = nce
49
+
50
+ return self
51
+
52
+
53
+ def train_cal_on_test(logits, labels):
54
+ calibrator = DPCalibrator(n_classes=logits.shape[1])
55
+ logprobs = torch.log_softmax(torch.from_numpy(logits).float(), dim=1)
56
+ labels = torch.from_numpy(labels).long()
57
+ calibrator.fit(logprobs, labels)
58
+ calibrated_logprobs = calibrator.calibrate(logprobs).numpy()
59
+ return calibrated_logprobs
60
+
61
+
62
+ def calibrate_xval(logits, targets, seed=0, condition_ids=None, stratified=True, nfolds=5):
63
+ logprobs = torch.log_softmax(torch.from_numpy(logits).float(), dim=1)
64
+ targets = torch.from_numpy(targets).long()
65
+ logprobscal = torch.zeros(logprobs.size())
66
+
67
+ if stratified:
68
+ if condition_ids is not None:
69
+ skf = StratifiedGroupKFold(n_splits=nfolds, shuffle=True, random_state=seed)
70
+ else:
71
+ skf = StratifiedKFold(n_splits=nfolds, shuffle=True, random_state=seed)
72
+ else:
73
+ if condition_ids is not None:
74
+ skf = GroupKFold(n_splits=nfolds)
75
+ else:
76
+ skf = KFold(n_splits=nfolds, shuffle=True, random_state=seed)
77
+
78
+ for trni, tsti in skf.split(logprobs, targets, condition_ids):
79
+ model = DPCalibrator(n_classes=logprobs.shape[1])
80
+ model.fit(logprobs[trni], targets[trni])
81
+ with torch.no_grad():
82
+ logprobscal[tsti] = torch.log_softmax(model.forward(logprobs[tsti]), dim=1)
83
+
84
+ return logprobscal
src/llmcal/src/evaluation/metrics.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ from scipy.special import softmax, log_softmax
4
+ from .calibration import train_cal_on_test, calibrate_xval
5
+
6
+ def compute_ner(logits, labels):
7
+ er = (logits.argmax(axis=1) != labels).mean()
8
+ winner = np.bincount(labels, minlength=logits.shape[1]).argmax()
9
+ norm = (labels != winner).mean()
10
+ return er / norm
11
+
12
+ def compute_nce(logits, labels):
13
+ ce = -log_softmax(logits, axis=1)[np.arange(len(labels)), labels].mean()
14
+ priors = np.bincount(labels, minlength=logits.shape[1]) / len(labels)
15
+ norm = -np.log(priors[labels]).mean()
16
+ return ce / norm
17
+
18
+ def compute_nbrier(logits, labels):
19
+ one_hot = np.zeros(logits.shape)
20
+ one_hot[np.arange(len(labels)), labels] = 1
21
+ brier = ((one_hot - softmax(logits, axis=1))**2).mean()
22
+ priors = np.bincount(labels, minlength=logits.shape[1]) / len(labels)
23
+ norm = ((one_hot - priors)**2).mean()
24
+ return brier / norm
25
+
26
+ def compute_cal_loss(logits, labels, mode="trainontest", metric="nce"):
27
+ if mode == "trainontest":
28
+ cal_logprobs = train_cal_on_test(logits, labels)
29
+ elif mode == "xval":
30
+ cal_logprobs = calibrate_xval(logits, labels, seed=1234, condition_ids=None, stratified=True, nfolds=5)
31
+ else:
32
+ raise ValueError(f"Unknown mode: {mode}")
33
+ nce = compute_metric(logits, labels, metric)
34
+ cal_nce = compute_metric(cal_logprobs, labels, metric)
35
+ return (nce - cal_nce) / nce
36
+
37
+ def compute_ece(logits, labels):
38
+ n_bins = 10
39
+ bin_boundaries = np.linspace(0, 1, n_bins + 1)
40
+ bin_lowers = bin_boundaries[:-1]
41
+ bin_uppers = bin_boundaries[1:]
42
+
43
+ softmaxes = softmax(logits, axis=1)
44
+ confidences = softmaxes.max(axis=1)
45
+ predictions = softmaxes.argmax(axis=1)
46
+ accuracies = predictions == labels
47
+
48
+ ece = 0
49
+ for bin_lower, bin_upper in zip(bin_lowers, bin_uppers):
50
+ in_bin = (confidences > bin_lower) * (confidences < bin_upper)
51
+ prop_in_bin = in_bin.mean()
52
+ if prop_in_bin > 0:
53
+ accuracy_in_bin = accuracies[in_bin].mean()
54
+ avg_confidence_in_bin = confidences[in_bin].mean()
55
+ ece += np.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin
56
+ return ece
57
+
58
+ def compute_metric(logits, labels, metric):
59
+ if metric == "ner":
60
+ return compute_ner(logits, labels)
61
+ elif metric == "nce":
62
+ return compute_nce(logits, labels)
63
+ elif metric == "nbrier":
64
+ return compute_nbrier(logits, labels)
65
+ elif "calloss" in metric:
66
+ _, metric, mode = metric.split("_")
67
+ return compute_cal_loss(logits, labels, mode, metric)
68
+ elif metric == "ece":
69
+ return compute_ece(logits, labels)
70
+ else:
71
+ raise ValueError(f"Unknown metric: {metric}")
72
+
73
+
74
+ def compute_psr_with_mincal(logits, labels, psr, mode):
75
+ if mode == "trainontest":
76
+ cal_logprobs = train_cal_on_test(logits, labels)
77
+ elif mode == "xval":
78
+ cal_logprobs = calibrate_xval(logits, labels, seed=1234, condition_ids=None, stratified=True, nfolds=5)
79
+ elif mode == "none":
80
+ cal_logprobs = logits
81
+ else:
82
+ raise ValueError(f"Unknown mode: {mode}")
83
+ loss = compute_metric(logits, labels, psr)
84
+ cal_loss = compute_metric(cal_logprobs, labels, psr)
85
+
86
+ return loss, cal_loss
src/llmcal/src/loggers.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import os
4
+ from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger as _CSVLogger
5
+ import pandas as pd
6
+ from .utils import save_yaml
7
+
8
+ class TBLogger(TensorBoardLogger):
9
+
10
+ def __init__(self, save_dir):
11
+ _save_dir = "/".join(save_dir.split("/")[:-1])
12
+ _version = save_dir.split("/")[-1]
13
+ super().__init__(
14
+ save_dir=_save_dir,
15
+ name="",
16
+ version=_version,
17
+ log_graph=False,
18
+ default_hp_metric=False,
19
+ prefix="",
20
+ sub_dir=None,
21
+ )
22
+
23
+ def log_hyperparams(self, hyperparams, metrics = None):
24
+ super().log_hyperparams(hyperparams, metrics)
25
+ save_yaml(hyperparams, os.path.join(self.log_dir, "hyperparams.yaml"))
26
+
27
+
28
+ class CSVLogger(_CSVLogger):
29
+
30
+ def __init__(self, save_dir):
31
+ _save_dir = "/".join(save_dir.split("/")[:-1])
32
+ _version = save_dir.split("/")[-1]
33
+ super().__init__(
34
+ save_dir=_save_dir,
35
+ name="",
36
+ version=_version,
37
+ prefix="",
38
+ flush_logs_every_n_steps=1,
39
+ )
40
+ if os.path.exists(os.path.join(self.log_dir, "metrics.csv")):
41
+ self.experiment.metrics = pd.read_csv(os.path.join(self.log_dir, "metrics.csv")).to_dict(orient="records")
src/llmcal/src/prompts/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .llama3 import Llama3Prompt
2
+ from .phi import Phi3Prompt
3
+ from .tinyllama import TinyLlamaPrompt
4
+ from .pythia import PythiaPrompt
5
+ from .gemma import GemmaPrompt
6
+ from .qwen import QwenPrompt
src/llmcal/src/prompts/gemma.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ class GemmaPrompt:
5
+
6
+ def __init__(self, max_characters=400):
7
+ self.max_characters = max_characters
8
+ self.prompt = None
9
+
10
+ def apply(self, text):
11
+ filled_prompts = []
12
+ for t in text:
13
+ filled_prompts.append(self.prompt.replace("{inpt}", t[:self.max_characters]))
14
+ return filled_prompts
15
+
16
+ def fit(self, prompt_template, shots):
17
+ preface = (
18
+ "<start_of_turn>user\n"
19
+ f"{prompt_template}\n\n"
20
+ )
21
+ output_preface = (
22
+ "{inpt}<end_of_turn>\n"
23
+ "<start_of_turn>model\n"
24
+ )
25
+
26
+ if len(shots) == 0:
27
+ self.prompt = preface + output_preface
28
+ return self
29
+
30
+ shot_template = (
31
+ "{shot_inpt}<end_of_turn>\n"
32
+ "<start_of_turn>model\n{shot_label}<end_of_turn>\n<start_of_turn>user\n"
33
+ )
34
+ shots_prompt = ""
35
+ for shot in shots:
36
+ shots_prompt += shot_template.format(shot_inpt=shot["text"][:self.max_characters], shot_label=shot["label"])
37
+ self.prompt = preface + shots_prompt + output_preface
38
+
39
+ return self
src/llmcal/src/prompts/llama3.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ class Llama3Prompt:
4
+
5
+ def __init__(self, max_characters=400):
6
+ self.max_characters = max_characters
7
+ self.prompt = None
8
+
9
+ def apply(self, text):
10
+ filled_prompts = []
11
+ for t in text:
12
+ filled_prompts.append(self.prompt.replace("{inpt}", t[:self.max_characters]))
13
+ return filled_prompts
14
+
15
+ def fit(self, prompt_template, shots):
16
+ preface = (
17
+ "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"
18
+ f"{prompt_template}<|eot_id|>" # No newline
19
+ )
20
+ output_preface = (
21
+ "<|start_header_id|>user<|end_header_id|>\n\n{inpt}<|eot_id|>" # No newline
22
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
23
+ )
24
+
25
+ if len(shots) == 0:
26
+ self.prompt = preface + output_preface
27
+ return self
28
+
29
+ shot_template = (
30
+ "<|start_header_id|>user<|end_header_id|>\n\n{shot_inpt}<|eot_id|>" # No newline
31
+ "<|start_header_id|>assistant<|end_header_id|>\n\n{shot_label}<|eot_id|>" # No newline
32
+ )
33
+ shots_prompt = ""
34
+ for shot in shots:
35
+ shots_prompt += shot_template.format(shot_inpt=shot["text"][:self.max_characters], shot_label=shot["label"])
36
+ self.prompt = preface + shots_prompt + output_preface
37
+
38
+ return self
src/llmcal/src/prompts/phi.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ class Phi3Prompt:
5
+
6
+ def __init__(self, max_characters=400):
7
+ self.max_characters = max_characters
8
+ self.prompt = None
9
+
10
+ def apply(self, text):
11
+ filled_prompts = []
12
+ for t in text:
13
+ filled_prompts.append(self.prompt.format(inpt=t[:self.max_characters]))
14
+ return filled_prompts
15
+
16
+ def fit(self, prompt_template, shots):
17
+ preface = (
18
+ f'<|system|>\n{prompt_template}<|end|>\n'
19
+ )
20
+ output_preface = (
21
+ "<|user|>\n{inpt}<|end|>\n"
22
+ "<|assistant|>\n"
23
+ )
24
+
25
+ if len(shots) == 0:
26
+ self.prompt = preface + output_preface
27
+ return self
28
+
29
+ shot_template = (
30
+ "<|user|>\n{shot_inpt}<|end|>\n"
31
+ "<|assistant|>\n{shot_label}<|end|>\n"
32
+ )
33
+ shots_prompt = ""
34
+ for shot in shots:
35
+ shots_prompt += shot_template.format(shot_inpt=shot["text"][:self.max_characters], shot_label=shot["label"])
36
+ self.prompt = preface + shots_prompt + output_preface
37
+
38
+ return self
src/llmcal/src/prompts/pythia.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ class PythiaPrompt:
5
+
6
+ def __init__(self, max_characters=400):
7
+ self.max_characters = max_characters
8
+ self.prompt = None
9
+
10
+ def apply(self, text):
11
+ filled_prompts = []
12
+ for t in text:
13
+ filled_prompts.append(self.prompt.format(inpt=t[:self.max_characters]))
14
+ return filled_prompts
15
+
16
+ def fit(self, prompt_template, shots):
17
+ preface = (
18
+ f'<|system|>\n{prompt_template}<|end|>\n'
19
+ )
20
+ output_preface = (
21
+ "<|user|>\n{inpt}<|end|>\n"
22
+ "<|assistant|>\n"
23
+ )
24
+
25
+ if len(shots) == 0:
26
+ self.prompt = preface + output_preface
27
+ return self
28
+
29
+ shot_template = (
30
+ "<|user|>\n{shot_inpt}<|end|>\n"
31
+ "<|assistant|>\n{shot_label}<|end|>\n"
32
+ )
33
+ shots_prompt = ""
34
+ for shot in shots:
35
+ shots_prompt += shot_template.format(shot_inpt=shot["text"][:self.max_characters], shot_label=shot["label"])
36
+ self.prompt = preface + shots_prompt + output_preface
37
+
38
+ return self
src/llmcal/src/prompts/qwen.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ class QwenPrompt:
4
+
5
+ def __init__(self, max_characters=400):
6
+ self.max_characters = max_characters
7
+ self.prompt = None
8
+
9
+ def apply(self, text):
10
+ filled_prompts = []
11
+ for t in text:
12
+ filled_prompts.append(self.prompt.replace("{inpt}", t[:self.max_characters]))
13
+ return filled_prompts
14
+
15
+ def fit(self, prompt_template, shots):
16
+ preface = (
17
+ f"<|im_start|>system\n{prompt_template}<|im_end|>\n"
18
+ )
19
+ output_preface = (
20
+ "<|im_start|>user\n{inpt}<|im_end|>\n"
21
+ "<|im_start|>assistant\n"
22
+ )
23
+
24
+ if len(shots) == 0:
25
+ self.prompt = preface + output_preface
26
+ return self
27
+
28
+ shot_template = (
29
+ "<|im_start|>user\n{shot_inpt}<|im_end|>\n"
30
+ "<|im_start|>assistant\n{shot_label}<|im_end|>\n"
31
+ )
32
+ shots_prompt = ""
33
+ for shot in shots:
34
+ shots_prompt += shot_template.format(shot_inpt=shot["text"][:self.max_characters], shot_label=shot["label"])
35
+ self.prompt = preface + shots_prompt + output_preface
36
+
37
+ return self
src/llmcal/src/prompts/tinyllama.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ class TinyLlamaPrompt:
5
+
6
+ def __init__(self, max_characters=400):
7
+ self.max_characters = max_characters
8
+ self.prompt = None
9
+
10
+ def apply(self, text):
11
+ filled_prompts = []
12
+ for t in text:
13
+ filled_prompts.append(self.prompt.format(inpt=t[:self.max_characters]))
14
+ return filled_prompts
15
+
16
+ def fit(self, prompt_template, shots):
17
+ preface = (
18
+ "<|system|>\n"
19
+ f"{prompt_template}</s>\n"
20
+ )
21
+ output_preface = (
22
+ "<|user|>\n{inpt}</s>\n"
23
+ "<|assistant|>\n"
24
+ )
25
+
26
+ if len(shots) == 0:
27
+ self.prompt = preface + output_preface
28
+ return self
29
+
30
+ shot_template = (
31
+ "<|user|>\n{shot_inpt}</s>\n"
32
+ "<|assistant|>\n{shot_label}</s>\n"
33
+ )
34
+ shots_prompt = ""
35
+ for shot in shots:
36
+ shots_prompt += shot_template.format(shot_inpt=shot["text"][:self.max_characters], shot_label=shot["label"])
37
+ self.prompt = preface + shots_prompt + output_preface
38
+
39
+ return self
40
+
src/llmcal/src/utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from functools import partial
3
+ from typing import Dict, List
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.utils.data import Dataset, DataLoader
9
+ import pandas as pd
10
+
11
+ import yaml
12
+
13
+ from litgpt import Tokenizer
14
+
15
+ def load_yaml(path):
16
+ with open(path, "r") as f:
17
+ file = yaml.safe_load(f)
18
+ if file is None:
19
+ return {}
20
+ return file
21
+
22
+ def save_yaml(data: dict, path) -> None:
23
+ with open(path, "w") as f:
24
+ yaml.dump(data, f)
25
+
26
+ class JSONDataset(Dataset):
27
+
28
+ def __init__(self, paths, lsts, tokenizer):
29
+ self.lsts = lsts
30
+ self.paths = paths
31
+ self.tokenizer = tokenizer
32
+ data = []
33
+ for i, (path, lst) in enumerate(zip(paths, lsts)):
34
+ d = pd.read_json(path, lines=True).set_index("idx").loc[lst].reset_index(drop=False)
35
+ d["task_id"] = i
36
+ d = d.apply(self._transform, axis=1)
37
+ data.append(d)
38
+ self.data = pd.concat(data, ignore_index=False)
39
+
40
+ def _transform(self, sample):
41
+ idx = torch.tensor(sample["idx"], dtype=torch.long)
42
+ prompt_ids = self.tokenizer.encode(sample["prompt"], bos=True).long()
43
+ answers_ids = [self.tokenizer.encode(ans, bos=True)[1:].long() for ans in sample["answer"]]
44
+ label = torch.tensor(sample["label"], dtype=torch.long)
45
+ task_id = torch.tensor(sample["task_id"], dtype=torch.long)
46
+ return pd.Series({"idx": idx, "prompt_ids": prompt_ids, "answers_ids": answers_ids, "label": label, "task_id": task_id})
47
+
48
+ def __len__(self):
49
+ return len(self.data)
50
+
51
+ def __getitem__(self, idx):
52
+ return self.data.iloc[idx].to_dict()
53
+
54
+ class Collator:
55
+
56
+ def __init__(self, pad_token_id, max_seq_len):
57
+ # batch = {"idx": ..., "prompt_ids": ..., "answers_ids": ...}
58
+ self.pad_token_id = pad_token_id
59
+ self.max_seq_len = max_seq_len
60
+
61
+ def __call__(self, batch):
62
+ prompts_ids = []
63
+ prompt_masks = []
64
+ answers_ids = []
65
+ max_ans_len = max([max([ans.shape[0] for ans in sample["answers_ids"]]) for sample in batch])
66
+
67
+ max_prompt_len = min(self.max_seq_len - max_ans_len, max([sample["prompt_ids"].shape[0] for sample in batch]))
68
+ for sample in batch:
69
+ seq = sample["prompt_ids"][-max_prompt_len:]
70
+ prompts_ids.append(torch.cat([torch.ones(max_prompt_len - seq.shape[0], dtype=torch.long) * self.pad_token_id, seq]))
71
+ prompt_masks.append(torch.cat([torch.zeros(max_prompt_len - seq.shape[0], dtype=torch.long), torch.ones(seq.shape[0], dtype=torch.long)]))
72
+ answers_ids.append(sample["answers_ids"])
73
+ return {
74
+ "idx": torch.stack([sample["idx"] for sample in batch]),
75
+ "prompt_ids": torch.stack(prompts_ids),
76
+ "prompt_mask": torch.stack(prompt_masks),
77
+ "answers_ids": answers_ids,
78
+ "task_id": torch.stack([sample["task_id"] for sample in batch]),
79
+ "label": torch.stack([sample["label"] for sample in batch])
80
+ }
81
+
82
+
83
+ def get_dataloader(data_paths, lsts, tokenizer, batch_size = 1, pad_token_id = 0, max_seq_length = 2048, shuffle = False, seed = 42):
84
+ dataset = JSONDataset(data_paths, lsts, tokenizer)
85
+ collate_fn = Collator(pad_token_id=pad_token_id, max_seq_len=max_seq_length)
86
+ dataloader = DataLoader(
87
+ dataset,
88
+ batch_size=batch_size,
89
+ collate_fn=collate_fn,
90
+ shuffle=shuffle,
91
+ generator=torch.Generator().manual_seed(seed)
92
+ )
93
+ return dataloader
src/llmcal/tests/__init__.py ADDED
File without changes
src/llmcal/tests/check_lists.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import pandas as pd
4
+ from pathlib import Path
5
+ import yaml
6
+ from tqdm import tqdm
7
+
8
+ DATASETS = {"sst2": 2, "agnews": 4, "dbpedia": 14, "20newsgroups": 20, "banking77": 77}
9
+ N_SHOTS = [0, 1, 2, 4, 8, 16, 32, 64]
10
+ N_SEEDS = 5
11
+ FACTORS = [8, 32, 64, 128, 256, 512]
12
+ VAL_PROPS = [0.0, 0.3]
13
+ TEST_SAMPLES = {"sst2": 400, "agnews": 400, "dbpedia": 700, "20newsgroups": 800, "banking77": 1000}
14
+
15
+ def main():
16
+ for dataset in tqdm(DATASETS):
17
+ num_classes = DATASETS[dataset]
18
+ for factor in FACTORS:
19
+ scale = factor / np.log2(num_classes)
20
+ nearest_power_of_2 = 2 ** np.round(np.log2(scale)) # round to nearest power of 2
21
+ num_samples = int(nearest_power_of_2 * num_classes)
22
+
23
+ # Read data
24
+ data = pd.read_csv(f"data/{dataset}/all.csv")
25
+
26
+ # check train, test and test_nsamples lists are ok
27
+ train_list = np.loadtxt(f"lists/{dataset}/train.txt")
28
+ assert data.index.isin(train_list).sum() == len(train_list) and np.unique(train_list).size == len(train_list)
29
+
30
+ test_list = np.loadtxt(f"lists/{dataset}/test.txt")
31
+ assert data.index.isin(test_list).sum() == len(test_list) and np.unique(test_list).size == len(test_list)
32
+
33
+ test_nsamples_list = np.loadtxt(f"lists/{dataset}/test_{TEST_SAMPLES[dataset]}.txt")
34
+ assert data.index.isin(test_nsamples_list).sum() == len(test_nsamples_list) and np.unique(test_nsamples_list).size == len(test_nsamples_list)
35
+
36
+ # Check no overlap between train and test, and train and test_nsamples
37
+ assert len(np.intersect1d(train_list, test_list)) == 0
38
+ assert len(np.intersect1d(train_list, test_nsamples_list)) == 0
39
+
40
+ for valprop in VAL_PROPS:
41
+ for seed in range(N_SEEDS):
42
+ with open(f"lists/{dataset}/size={factor}/valprop={valprop}/seed={seed}/matched.yaml", 'r') as file:
43
+ matched = yaml.load(file, Loader=yaml.FullLoader)
44
+
45
+ val_size = int(valprop * num_samples)
46
+ train_size = num_samples - val_size
47
+ assert len(matched["train"][dataset]) == train_size
48
+ if val_size > 0:
49
+ assert len(matched["val"][dataset]) == val_size
50
+ assert not np.isin(matched["val"][dataset], matched["train"][dataset]).any()
51
+ assert not np.isin(matched["val"][dataset], test_list).any()
52
+ else:
53
+ assert np.isin(matched["val"][dataset], matched["train"][dataset]).all()
54
+
55
+ with open(f"lists/{dataset}/size={factor}/valprop={valprop}/seed={seed}/mismatched.yaml", 'r') as file:
56
+ mismatched = yaml.load(file, Loader=yaml.FullLoader)
57
+
58
+ assert all([train_dataset != dataset for train_dataset in mismatched["train"]])
59
+
60
+
61
+ print("All lists are ok!")
62
+
63
+ if __name__ == '__main__':
64
+ main()