#!/usr/bin/env python3 import multiprocessing import tensorflow as tf from datasets import load_dataset from tensorflow.keras.optimizers import Adam from transformers import ( AutoTokenizer, PushToHubCallback, TFAutoModelForSequenceClassification, ) base_model = "distilbert-base-uncased" output_dir = "out/model" checkpoint_path = "out/cp.ckpt" labels = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"] label2id = {label: id for id, label in enumerate(labels)} id2label = {id: label for id, label in enumerate(labels)} # Load csv file of tweets in the script's directory dataset = load_dataset("csv", data_files="train.csv") tokenizer = AutoTokenizer.from_pretrained(base_model) # Asynchronous processing of data rows def process_data(row): text = row["comment_text"] labels_batch = {k: row[k] for k in row.keys() if k in labels} encoding = tokenizer(text, padding="max_length", truncation=True) label_arr = [0] * len(labels) # fill numpy array for id, label in enumerate(labels_batch): label_arr[id] = labels_batch[label] encoding["labels"] = label_arr return encoding # Initiate the model model = TFAutoModelForSequenceClassification.from_pretrained( base_model, problem_type="multi_label_classification", num_labels=len(labels), label2id=label2id, id2label=id2label, ) # Start processing and encoding data rows using available cores encoded = dataset.map( process_data, remove_columns=["id", "comment_text"], num_proc=int(multiprocessing.cpu_count()), ) # Convert encoding to tensors tf_dataset = model.prepare_tf_dataset( encoded["train"], batch_size=16, shuffle=True, tokenizer=tokenizer ) # Compile completion callback cp_callback = tf.keras.callbacks.ModelCheckpoint( filepath=checkpoint_path, save_weights_only=True, verbose=1 ) # Callback to submit our model to modelhub push_to_hub_callback = PushToHubCallback( output_dir=output_dir, tokenizer=tokenizer, hub_model_id="ogtega/tweet-toxicity-classifier", ) # Compile and train the model model.compile(optimizer=Adam(3e-5), loss="BinaryCrossentropy") model.fit(tf_dataset, callbacks=[cp_callback, push_to_hub_callback])