TextPeriod_Summarization / tune_thresholds.py
DelaliScratchwerk's picture
Create tune_thresholds.py
1c00c7b verified
import numpy as np, json
from datasets import load_dataset
from setfit import SetFitModel
LABELS = json.load(open("labels.json")) # same list you train with
name2id = {n:i for i,n in enumerate(LABELS)}
ds = load_dataset("json", data_files={"val":"val.jsonl"})["val"]
y_true = np.array([name2id[x["label"]] for x in ds])
m = SetFitModel.from_pretrained("DelaliScratchwerk/text-period-setfit")
probs = np.stack([m.predict_proba([x["text"]])[0] for x in ds])
top1 = probs.max(axis=1)
sorted_probs = np.sort(probs, axis=1)[:, ::-1]
margin = sorted_probs[:,0] - sorted_probs[:,1]
def describe(arr, name):
q = np.quantile(arr, [0.1,0.25,0.5,0.75,0.9])
print(name, dict(zip(["p10","p25","p50","p75","p90"], map(lambda x: round(float(x),3), q))))
describe(top1, "top1_conf")
describe(margin, "margin")
# quick suggestion: set thresholds around p25
print("Suggested UNCERTAINTY_THRESHOLD≈", round(float(np.quantile(top1, 0.25)),3))
print("Suggested MARGIN_THRESHOLD≈", round(float(np.quantile(margin, 0.25)),3))