Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| import multiprocessing | |
| import tensorflow as tf | |
| from datasets import load_dataset | |
| from tensorflow.keras.optimizers import Adam | |
| from transformers import ( | |
| AutoTokenizer, | |
| PushToHubCallback, | |
| TFAutoModelForSequenceClassification, | |
| ) | |
| base_model = "distilbert-base-uncased" | |
| output_dir = "out/model" | |
| checkpoint_path = "out/cp.ckpt" | |
| labels = ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"] | |
| label2id = {label: id for id, label in enumerate(labels)} | |
| id2label = {id: label for id, label in enumerate(labels)} | |
| # Load csv file of tweets in the script's directory | |
| dataset = load_dataset("csv", data_files="train.csv") | |
| tokenizer = AutoTokenizer.from_pretrained(base_model) | |
| # Asynchronous processing of data rows | |
| def process_data(row): | |
| text = row["comment_text"] | |
| labels_batch = {k: row[k] for k in row.keys() if k in labels} | |
| encoding = tokenizer(text, padding="max_length", truncation=True) | |
| label_arr = [0] * len(labels) | |
| # fill numpy array | |
| for id, label in enumerate(labels_batch): | |
| label_arr[id] = labels_batch[label] | |
| encoding["labels"] = label_arr | |
| return encoding | |
| # Initiate the model | |
| model = TFAutoModelForSequenceClassification.from_pretrained( | |
| base_model, | |
| problem_type="multi_label_classification", | |
| num_labels=len(labels), | |
| label2id=label2id, | |
| id2label=id2label, | |
| ) | |
| # Start processing and encoding data rows using available cores | |
| encoded = dataset.map( | |
| process_data, | |
| remove_columns=["id", "comment_text"], | |
| num_proc=int(multiprocessing.cpu_count()), | |
| ) | |
| # Convert encoding to tensors | |
| tf_dataset = model.prepare_tf_dataset( | |
| encoded["train"], batch_size=16, shuffle=True, tokenizer=tokenizer | |
| ) | |
| # Compile completion callback | |
| cp_callback = tf.keras.callbacks.ModelCheckpoint( | |
| filepath=checkpoint_path, save_weights_only=True, verbose=1 | |
| ) | |
| # Callback to submit our model to modelhub | |
| push_to_hub_callback = PushToHubCallback( | |
| output_dir=output_dir, | |
| tokenizer=tokenizer, | |
| hub_model_id="ogtega/tweet-toxicity-classifier", | |
| ) | |
| # Compile and train the model | |
| model.compile(optimizer=Adam(3e-5), loss="BinaryCrossentropy") | |
| model.fit(tf_dataset, callbacks=[cp_callback, push_to_hub_callback]) | |