Create tune_thresholds.py
Browse files- tune_thresholds.py +26 -0
tune_thresholds.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np, json
|
| 2 |
+
from datasets import load_dataset
|
| 3 |
+
from setfit import SetFitModel
|
| 4 |
+
|
| 5 |
+
LABELS = json.load(open("labels.json")) # same list you train with
|
| 6 |
+
name2id = {n:i for i,n in enumerate(LABELS)}
|
| 7 |
+
|
| 8 |
+
ds = load_dataset("json", data_files={"val":"val.jsonl"})["val"]
|
| 9 |
+
y_true = np.array([name2id[x["label"]] for x in ds])
|
| 10 |
+
|
| 11 |
+
m = SetFitModel.from_pretrained("DelaliScratchwerk/text-period-setfit")
|
| 12 |
+
probs = np.stack([m.predict_proba([x["text"]])[0] for x in ds])
|
| 13 |
+
top1 = probs.max(axis=1)
|
| 14 |
+
sorted_probs = np.sort(probs, axis=1)[:, ::-1]
|
| 15 |
+
margin = sorted_probs[:,0] - sorted_probs[:,1]
|
| 16 |
+
|
| 17 |
+
def describe(arr, name):
|
| 18 |
+
q = np.quantile(arr, [0.1,0.25,0.5,0.75,0.9])
|
| 19 |
+
print(name, dict(zip(["p10","p25","p50","p75","p90"], map(lambda x: round(float(x),3), q))))
|
| 20 |
+
|
| 21 |
+
describe(top1, "top1_conf")
|
| 22 |
+
describe(margin, "margin")
|
| 23 |
+
|
| 24 |
+
# quick suggestion: set thresholds around p25
|
| 25 |
+
print("Suggested UNCERTAINTY_THRESHOLD≈", round(float(np.quantile(top1, 0.25)),3))
|
| 26 |
+
print("Suggested MARGIN_THRESHOLD≈", round(float(np.quantile(margin, 0.25)),3))
|