File size: 2,963 Bytes
6f6a301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import warnings

import numpy as np

from datasets import load_dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, average_precision_score
from scipy.special import softmax
from transformers import AutoModelForSequenceClassification, AutoConfig, AutoTokenizer
from txtai import Embeddings
from txtai.pipeline import HFTrainer


def batchlabel(rows):
    return {"label": [config.label2id[label] for label in rows["label"]]}

def batchtext(rows):
    texts = []
    for x in rows["id"]:
        results = embeddings.search("SELECT text FROM txtai WHERE id=:id", 1, parameters={"id": x})
        texts.append(results[0]["text"])

    return {"text": texts}

def metrics(pred):
    logits, labelids = pred
    preds = logits.argmax(-1)

    # Calculate accuracy, precision, recall, and F1-score
    accuracy = accuracy_score(labelids, preds)
    precision = precision_score(labelids, preds, average="weighted", zero_division=0)
    recall = recall_score(labelids, preds, average="weighted", zero_division=0)
    f1 = f1_score(labelids, preds, average="weighted", zero_division=0)

    # Calculate PR AUC
    probs = softmax(logits, axis=-1)
    nclasses = logits.shape[1]
    onehot = np.eye(nclasses)[labelids]

    # average_precision_score doesn't support zero_division parameter
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", message="No positive class found in y_true")
        prauc = average_precision_score(onehot, probs, average="weighted")

    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "prauc": prauc
    }

# Embeddings database
embeddings = Embeddings()
embeddings.load(provider="huggingface-hub", container="neuml/txtai-wikipedia-slim")

# Training dataset
ds = load_dataset("csv", data_files="labels.csv", split="train", keep_default_na=False)
labels = dict(enumerate(sorted(ds.unique("label"))))
print(labels)

# Base model
path = "jhu-clsp/ettin-encoder-32m"

# Model configuration
config = AutoConfig.from_pretrained(path)
config.num_labels = len(labels)
config.id2label = labels
config.label2id = {label: uid for uid, label in labels.items()}

# Map label ids
ds = ds.map(batchlabel, batched=True)

# Map text
ds = ds.map(batchtext, batched=True)

# Split into train and test
ds = ds.train_test_split(test_size=0.05, seed=42)
training, test = ds["train"], ds["test"]

# Model to train
model = AutoModelForSequenceClassification.from_pretrained(path, config=config)
tokenizer = AutoTokenizer.from_pretrained(path)

train = HFTrainer()
train(
    (model, tokenizer), training, test, metrics=metrics, maxlength=512, bf16=True,
    learning_rate=5e-5, per_device_train_batch_size=64, num_train_epochs=3,
    warmup_ratio=0.1, lr_scheduler_type="cosine",
    eval_strategy="steps", eval_steps=500, logging_steps=500,
    tokenizers=True, dataloader_num_workers=20,
    output_dir="domain-labeler"
)