|
|
import numpy as np, json |
|
|
from datasets import load_dataset |
|
|
from setfit import SetFitModel |
|
|
|
|
|
LABELS = json.load(open("labels.json")) |
|
|
name2id = {n:i for i,n in enumerate(LABELS)} |
|
|
|
|
|
ds = load_dataset("json", data_files={"val":"val.jsonl"})["val"] |
|
|
y_true = np.array([name2id[x["label"]] for x in ds]) |
|
|
|
|
|
m = SetFitModel.from_pretrained("DelaliScratchwerk/text-period-setfit") |
|
|
probs = np.stack([m.predict_proba([x["text"]])[0] for x in ds]) |
|
|
top1 = probs.max(axis=1) |
|
|
sorted_probs = np.sort(probs, axis=1)[:, ::-1] |
|
|
margin = sorted_probs[:,0] - sorted_probs[:,1] |
|
|
|
|
|
def describe(arr, name): |
|
|
q = np.quantile(arr, [0.1,0.25,0.5,0.75,0.9]) |
|
|
print(name, dict(zip(["p10","p25","p50","p75","p90"], map(lambda x: round(float(x),3), q)))) |
|
|
|
|
|
describe(top1, "top1_conf") |
|
|
describe(margin, "margin") |
|
|
|
|
|
|
|
|
print("Suggested UNCERTAINTY_THRESHOLD≈", round(float(np.quantile(top1, 0.25)),3)) |
|
|
print("Suggested MARGIN_THRESHOLD≈", round(float(np.quantile(margin, 0.25)),3)) |
|
|
|