File size: 1,022 Bytes
1c00c7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np, json
from datasets import load_dataset
from setfit import SetFitModel

LABELS = json.load(open("labels.json"))  # same list you train with
name2id = {n:i for i,n in enumerate(LABELS)}

ds = load_dataset("json", data_files={"val":"val.jsonl"})["val"]
y_true = np.array([name2id[x["label"]] for x in ds])

m = SetFitModel.from_pretrained("DelaliScratchwerk/text-period-setfit")
probs = np.stack([m.predict_proba([x["text"]])[0] for x in ds])
top1 = probs.max(axis=1)
sorted_probs = np.sort(probs, axis=1)[:, ::-1]
margin = sorted_probs[:,0] - sorted_probs[:,1]

def describe(arr, name):
    q = np.quantile(arr, [0.1,0.25,0.5,0.75,0.9])
    print(name, dict(zip(["p10","p25","p50","p75","p90"], map(lambda x: round(float(x),3), q))))

describe(top1, "top1_conf")
describe(margin, "margin")

# quick suggestion: set thresholds around p25
print("Suggested UNCERTAINTY_THRESHOLD≈", round(float(np.quantile(top1, 0.25)),3))
print("Suggested MARGIN_THRESHOLD≈", round(float(np.quantile(margin, 0.25)),3))