DelaliScratchwerk commited on
Commit
1c00c7b
·
verified ·
1 Parent(s): c2e4dfb

Create tune_thresholds.py

Browse files
Files changed (1) hide show
  1. tune_thresholds.py +26 -0
tune_thresholds.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np, json
2
+ from datasets import load_dataset
3
+ from setfit import SetFitModel
4
+
5
+ LABELS = json.load(open("labels.json")) # same list you train with
6
+ name2id = {n:i for i,n in enumerate(LABELS)}
7
+
8
+ ds = load_dataset("json", data_files={"val":"val.jsonl"})["val"]
9
+ y_true = np.array([name2id[x["label"]] for x in ds])
10
+
11
+ m = SetFitModel.from_pretrained("DelaliScratchwerk/text-period-setfit")
12
+ probs = np.stack([m.predict_proba([x["text"]])[0] for x in ds])
13
+ top1 = probs.max(axis=1)
14
+ sorted_probs = np.sort(probs, axis=1)[:, ::-1]
15
+ margin = sorted_probs[:,0] - sorted_probs[:,1]
16
+
17
+ def describe(arr, name):
18
+ q = np.quantile(arr, [0.1,0.25,0.5,0.75,0.9])
19
+ print(name, dict(zip(["p10","p25","p50","p75","p90"], map(lambda x: round(float(x),3), q))))
20
+
21
+ describe(top1, "top1_conf")
22
+ describe(margin, "margin")
23
+
24
+ # quick suggestion: set thresholds around p25
25
+ print("Suggested UNCERTAINTY_THRESHOLD≈", round(float(np.quantile(top1, 0.25)),3))
26
+ print("Suggested MARGIN_THRESHOLD≈", round(float(np.quantile(margin, 0.25)),3))