File size: 2,228 Bytes
75c005d
 
5b67bcb
 
10a61c2
131f8ea
 
639b99e
 
 
 
 
75c005d
ede725f
639b99e
10a61c2
 
75c005d
 
 
 
bb01739
131f8ea
ede725f
75c005d
131f8ea
bb01739
131f8ea
75c005d
131f8ea
75c005d
 
 
5b67bcb
131f8ea
 
5b67bcb
 
131f8ea
5b67bcb
131f8ea
 
75c005d
 
bb01739
75c005d
ede725f
75c005d
 
 
 
 
 
bb01739
131f8ea
5b67bcb
 
 
75c005d
 
bb01739
131f8ea
 
75c005d
 
bb01739
10a61c2
 
 
 
bb01739
639b99e
 
 
 
 
 
bb01739
ede725f
639b99e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/env python3

import multiprocessing

import tensorflow as tf
from datasets import load_dataset
from tensorflow.keras.optimizers import Adam
from transformers import (
    AutoTokenizer,
    PushToHubCallback,
    TFAutoModelForSequenceClassification,
)

base_model = "distilbert-base-uncased"
output_dir = "out/model"
checkpoint_path = "out/cp.ckpt"

labels = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]
label2id = {label: id for id, label in enumerate(labels)}
id2label = {id: label for id, label in enumerate(labels)}

# Load csv file of tweets in the script's directory
dataset = load_dataset("csv", data_files="train.csv")
tokenizer = AutoTokenizer.from_pretrained(base_model)


# Asynchronous processing of data rows
def process_data(row):
    text = row["comment_text"]
    labels_batch = {k: row[k] for k in row.keys() if k in labels}

    encoding = tokenizer(text, padding="max_length", truncation=True)

    label_arr = [0] * len(labels)

    # fill numpy array
    for id, label in enumerate(labels_batch):
        label_arr[id] = labels_batch[label]

    encoding["labels"] = label_arr

    return encoding


# Initiate the model
model = TFAutoModelForSequenceClassification.from_pretrained(
    base_model,
    problem_type="multi_label_classification",
    num_labels=len(labels),
    label2id=label2id,
    id2label=id2label,
)

# Start processing and encoding data rows using available cores
encoded = dataset.map(
    process_data,
    remove_columns=["id", "comment_text"],
    num_proc=int(multiprocessing.cpu_count()),
)

# Convert encoding to tensors
tf_dataset = model.prepare_tf_dataset(
    encoded["train"], batch_size=16, shuffle=True, tokenizer=tokenizer
)

# Compile completion callback
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, save_weights_only=True, verbose=1
)

# Callback to submit our model to modelhub
push_to_hub_callback = PushToHubCallback(
    output_dir=output_dir,
    tokenizer=tokenizer,
    hub_model_id="ogtega/tweet-toxicity-classifier",
)

# Compile and train the model
model.compile(optimizer=Adam(3e-5), loss="BinaryCrossentropy")
model.fit(tf_dataset, callbacks=[cp_callback, push_to_hub_callback])