Spaces:
Running
Running
| """ | |
| Title: Large-scale multi-label text classification | |
| Author: [Sayak Paul](https://twitter.com/RisingSayak), [Soumik Rakshit](https://github.com/soumik12345) | |
| Date created: 2020/09/25 | |
| Last modified: 2025/02/27 | |
| Description: Implementing a large-scale multi-label text classification model. | |
| Accelerator: GPU | |
| Converted to keras 3 and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234) | |
| """ | |
| """ | |
| ## Introduction | |
| In this example, we will build a multi-label text classifier to predict the subject areas | |
| of arXiv papers from their abstract bodies. This type of classifier can be useful for | |
| conference submission portals like [OpenReview](https://openreview.net/). Given a paper | |
| abstract, the portal could provide suggestions for which areas the paper would | |
| best belong to. | |
| The dataset was collected using the | |
| [`arXiv` Python library](https://github.com/lukasschwab/arxiv.py) | |
| that provides a wrapper around the | |
| [original arXiv API](http://arxiv.org/help/api/index). | |
| To learn more about the data collection process, please refer to | |
| [this notebook](https://github.com/soumik12345/multi-label-text-classification/blob/master/arxiv_scrape.ipynb). | |
| Additionally, you can also find the dataset on | |
| [Kaggle](https://www.kaggle.com/spsayakpaul/arxiv-paper-abstracts). | |
| """ | |
| """ | |
| ## Imports | |
| """ | |
| import os | |
| os.environ["KERAS_BACKEND"] = "jax" # or tensorflow, or torch | |
| import keras | |
| from keras import layers, ops | |
| from sklearn.model_selection import train_test_split | |
| from ast import literal_eval | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import numpy as np | |
| """ | |
| ## Perform exploratory data analysis | |
| In this section, we first load the dataset into a `pandas` dataframe and then perform | |
| some basic exploratory data analysis (EDA). | |
| """ | |
| arxiv_data = pd.read_csv( | |
| "https://github.com/soumik12345/multi-label-text-classification/releases/download/v0.2/arxiv_data.csv" | |
| ) | |
| arxiv_data.head() | |
| """ | |
| Our text features are present in the `summaries` column and their corresponding labels | |
| are in `terms`. As you can notice, there are multiple categories associated with a | |
| particular entry. | |
| """ | |
| print(f"There are {len(arxiv_data)} rows in the dataset.") | |
| """ | |
| Real-world data is noisy. One of the most commonly observed source of noise is data | |
| duplication. Here we notice that our initial dataset has got about 13k duplicate entries. | |
| """ | |
| total_duplicate_titles = sum(arxiv_data["titles"].duplicated()) | |
| print(f"There are {total_duplicate_titles} duplicate titles.") | |
| """ | |
| Before proceeding further, we drop these entries. | |
| """ | |
| arxiv_data = arxiv_data[~arxiv_data["titles"].duplicated()] | |
| print(f"There are {len(arxiv_data)} rows in the deduplicated dataset.") | |
| # There are some terms with occurrence as low as 1. | |
| print(sum(arxiv_data["terms"].value_counts() == 1)) | |
| # How many unique terms? | |
| print(arxiv_data["terms"].nunique()) | |
| """ | |
| As observed above, out of 3,157 unique combinations of `terms`, 2,321 entries have the | |
| lowest occurrence. To prepare our train, validation, and test sets with | |
| [stratification](https://en.wikipedia.org/wiki/Stratified_sampling), we need to drop | |
| these terms. | |
| """ | |
| # Filtering the rare terms. | |
| arxiv_data_filtered = arxiv_data.groupby("terms").filter(lambda x: len(x) > 1) | |
| arxiv_data_filtered.shape | |
| """ | |
| ## Convert the string labels to lists of strings | |
| The initial labels are represented as raw strings. Here we make them `List[str]` for a | |
| more compact representation. | |
| """ | |
| arxiv_data_filtered["terms"] = arxiv_data_filtered["terms"].apply( | |
| lambda x: literal_eval(x) | |
| ) | |
| arxiv_data_filtered["terms"].values[:5] | |
| """ | |
| ## Use stratified splits because of class imbalance | |
| The dataset has a | |
| [class imbalance problem](https://developers.google.com/machine-learning/glossary/#class-imbalanced-dataset). | |
| So, to have a fair evaluation result, we need to ensure the datasets are sampled with | |
| stratification. To know more about different strategies to deal with the class imbalance | |
| problem, you can follow | |
| [this tutorial](https://www.tensorflow.org/tutorials/structured_data/imbalanced_data). | |
| For an end-to-end demonstration of classification with imbablanced data, refer to | |
| [Imbalanced classification: credit card fraud detection](https://keras.io/examples/structured_data/imbalanced_classification/). | |
| """ | |
| test_split = 0.1 | |
| # Initial train and test split. | |
| train_df, test_df = train_test_split( | |
| arxiv_data_filtered, | |
| test_size=test_split, | |
| stratify=arxiv_data_filtered["terms"].values, | |
| ) | |
| # Splitting the test set further into validation | |
| # and new test sets. | |
| val_df = test_df.sample(frac=0.5) | |
| test_df.drop(val_df.index, inplace=True) | |
| print(f"Number of rows in training set: {len(train_df)}") | |
| print(f"Number of rows in validation set: {len(val_df)}") | |
| print(f"Number of rows in test set: {len(test_df)}") | |
| """ | |
| ## Multi-label binarization | |
| Now we preprocess our labels using the | |
| [`StringLookup`](https://keras.io/api/layers/preprocessing_layers/categorical/string_lookup) | |
| layer. | |
| """ | |
| # For RaggedTensor | |
| import tensorflow as tf | |
| terms = tf.ragged.constant(train_df["terms"].values) | |
| lookup = layers.StringLookup(output_mode="multi_hot") | |
| lookup.adapt(terms) | |
| vocab = lookup.get_vocabulary() | |
| def invert_multi_hot(encoded_labels): | |
| """Reverse a single multi-hot encoded label to a tuple of vocab terms.""" | |
| hot_indices = np.argwhere(encoded_labels == 1.0)[..., 0] | |
| return np.take(vocab, hot_indices) | |
| print("Vocabulary:\n") | |
| print(vocab) | |
| """ | |
| Here we are separating the individual unique classes available from the label | |
| pool and then using this information to represent a given label set with 0's and 1's. | |
| Below is an example. | |
| """ | |
| sample_label = train_df["terms"].iloc[0] | |
| print(f"Original label: {sample_label}") | |
| label_binarized = lookup([sample_label]) | |
| print(f"Label-binarized representation: {label_binarized}") | |
| """ | |
| ## Data preprocessing and `tf.data.Dataset` objects | |
| We first get percentile estimates of the sequence lengths. The purpose will be clear in a | |
| moment. | |
| """ | |
| train_df["summaries"].apply(lambda x: len(x.split(" "))).describe() | |
| """ | |
| Notice that 50% of the abstracts have a length of 154 (you may get a different number | |
| based on the split). So, any number close to that value is a good enough approximate for the | |
| maximum sequence length. | |
| Now, we implement utilities to prepare our datasets. | |
| """ | |
| max_seqlen = 150 | |
| batch_size = 128 | |
| padding_token = "<pad>" | |
| auto = tf.data.AUTOTUNE | |
| def make_dataset(dataframe, is_train=True): | |
| labels = tf.ragged.constant(dataframe["terms"].values) | |
| label_binarized = lookup(labels).numpy() | |
| dataset = tf.data.Dataset.from_tensor_slices( | |
| (dataframe["summaries"].values, label_binarized) | |
| ) | |
| dataset = dataset.shuffle(batch_size * 10) if is_train else dataset | |
| return dataset.batch(batch_size) | |
| """ | |
| Now we can prepare the `tf.data.Dataset` objects. | |
| """ | |
| train_dataset = make_dataset(train_df, is_train=True) | |
| validation_dataset = make_dataset(val_df, is_train=False) | |
| test_dataset = make_dataset(test_df, is_train=False) | |
| """ | |
| ## Dataset preview | |
| """ | |
| text_batch, label_batch = next(iter(train_dataset)) | |
| for i, text in enumerate(text_batch[:5]): | |
| label = label_batch[i].numpy()[None, ...] | |
| print(f"Abstract: {text}") | |
| print(f"Label(s): {invert_multi_hot(label[0])}") | |
| print(" ") | |
| """ | |
| ## Vectorization | |
| Before we feed the data to our model, we need to vectorize it (represent it in a numerical form). | |
| For that purpose, we will use the | |
| [`TextVectorization` layer](https://keras.io/api/layers/preprocessing_layers/text/text_vectorization). | |
| It can operate as a part of your main model so that the model is excluded from the core | |
| preprocessing logic. This greatly reduces the chances of training / serving skew during inference. | |
| We first calculate the number of unique words present in the abstracts. | |
| """ | |
| # Source: https://stackoverflow.com/a/18937309/7636462 | |
| vocabulary = set() | |
| train_df["summaries"].str.lower().str.split().apply(vocabulary.update) | |
| vocabulary_size = len(vocabulary) | |
| print(vocabulary_size) | |
| """ | |
| We now create our vectorization layer and `map()` to the `tf.data.Dataset`s created | |
| earlier. | |
| """ | |
| text_vectorizer = layers.TextVectorization( | |
| max_tokens=vocabulary_size, ngrams=2, output_mode="tf_idf" | |
| ) | |
| # `TextVectorization` layer needs to be adapted as per the vocabulary from our | |
| # training set. | |
| with tf.device("/CPU:0"): | |
| text_vectorizer.adapt(train_dataset.map(lambda text, label: text)) | |
| train_dataset = train_dataset.map( | |
| lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto | |
| ).prefetch(auto) | |
| validation_dataset = validation_dataset.map( | |
| lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto | |
| ).prefetch(auto) | |
| test_dataset = test_dataset.map( | |
| lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto | |
| ).prefetch(auto) | |
| """ | |
| A batch of raw text will first go through the `TextVectorization` layer and it will | |
| generate their integer representations. Internally, the `TextVectorization` layer will | |
| first create bi-grams out of the sequences and then represent them using | |
| [TF-IDF](https://wikipedia.org/wiki/Tf%E2%80%93idf). The output representations will then | |
| be passed to the shallow model responsible for text classification. | |
| To learn more about other possible configurations with `TextVectorizer`, please consult | |
| the | |
| [official documentation](https://keras.io/api/layers/preprocessing_layers/text/text_vectorization). | |
| **Note**: Setting the `max_tokens` argument to a pre-calculated vocabulary size is | |
| not a requirement. | |
| """ | |
| """ | |
| ## Create a text classification model | |
| We will keep our model simple -- it will be a small stack of fully-connected layers with | |
| ReLU as the non-linearity. | |
| """ | |
| def make_model(): | |
| shallow_mlp_model = keras.Sequential( | |
| [ | |
| layers.Dense(512, activation="relu"), | |
| layers.Dense(256, activation="relu"), | |
| layers.Dense(lookup.vocabulary_size(), activation="sigmoid"), | |
| ] # More on why "sigmoid" has been used here in a moment. | |
| ) | |
| return shallow_mlp_model | |
| """ | |
| ## Train the model | |
| We will train our model using the binary crossentropy loss. This is because the labels | |
| are not disjoint. For a given abstract, we may have multiple categories. So, we will | |
| divide the prediction task into a series of multiple binary classification problems. This | |
| is also why we kept the activation function of the classification layer in our model to | |
| sigmoid. Researchers have used other combinations of loss function and activation | |
| function as well. For example, in [Exploring the Limits of Weakly Supervised Pretraining](https://arxiv.org/abs/1805.00932), | |
| Mahajan et al. used the softmax activation function and cross-entropy loss to train | |
| their models. | |
| There are several options of metrics that can be used in multi-label classification. | |
| To keep this code example narrow we decided to use the | |
| [binary accuracy metric](https://keras.io/api/metrics/accuracy_metrics/#binaryaccuracy-class). | |
| To see the explanation why this metric is used we refer to this | |
| [pull-request](https://github.com/keras-team/keras-io/pull/1133#issuecomment-1322736860). | |
| There are also other suitable metrics for multi-label classification, like | |
| [F1 Score](https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score) or | |
| [Hamming loss](https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/HammingLoss). | |
| """ | |
| epochs = 20 | |
| shallow_mlp_model = make_model() | |
| shallow_mlp_model.compile( | |
| loss="binary_crossentropy", optimizer="adam", metrics=["binary_accuracy"] | |
| ) | |
| history = shallow_mlp_model.fit( | |
| train_dataset, validation_data=validation_dataset, epochs=epochs | |
| ) | |
| def plot_result(item): | |
| plt.plot(history.history[item], label=item) | |
| plt.plot(history.history["val_" + item], label="val_" + item) | |
| plt.xlabel("Epochs") | |
| plt.ylabel(item) | |
| plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14) | |
| plt.legend() | |
| plt.grid() | |
| plt.show() | |
| plot_result("loss") | |
| plot_result("binary_accuracy") | |
| """ | |
| While training, we notice an initial sharp fall in the loss followed by a gradual decay. | |
| """ | |
| """ | |
| ### Evaluate the model | |
| """ | |
| _, binary_acc = shallow_mlp_model.evaluate(test_dataset) | |
| print(f"Categorical accuracy on the test set: {round(binary_acc * 100, 2)}%.") | |
| """ | |
| The trained model gives us an evaluation accuracy of ~99%. | |
| """ | |
| """ | |
| ## Inference | |
| An important feature of the | |
| [preprocessing layers provided by Keras](https://keras.io/api/layers/preprocessing_layers/) | |
| is that they can be included inside a `tf.keras.Model`. We will export an inference model | |
| by including the `text_vectorization` layer on top of `shallow_mlp_model`. This will | |
| allow our inference model to directly operate on raw strings. | |
| **Note** that during training it is always preferable to use these preprocessing | |
| layers as a part of the data input pipeline rather than the model to avoid | |
| surfacing bottlenecks for the hardware accelerators. This also allows for | |
| asynchronous data processing. | |
| """ | |
| # We create a custom Model to override the predict method so | |
| # that it first vectorizes text data | |
| class ModelEndtoEnd(keras.Model): | |
| def predict(self, inputs): | |
| indices = text_vectorizer(inputs) | |
| return super().predict(indices) | |
| def get_inference_model(model): | |
| inputs = shallow_mlp_model.inputs | |
| outputs = shallow_mlp_model.outputs | |
| end_to_end_model = ModelEndtoEnd(inputs, outputs, name="end_to_end_model") | |
| end_to_end_model.compile( | |
| optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"] | |
| ) | |
| return end_to_end_model | |
| model_for_inference = get_inference_model(shallow_mlp_model) | |
| # Create a small dataset just for demonstrating inference. | |
| inference_dataset = make_dataset(test_df.sample(2), is_train=False) | |
| text_batch, label_batch = next(iter(inference_dataset)) | |
| predicted_probabilities = model_for_inference.predict(text_batch) | |
| # Perform inference. | |
| for i, text in enumerate(text_batch[:5]): | |
| label = label_batch[i].numpy()[None, ...] | |
| print(f"Abstract: {text}") | |
| print(f"Label(s): {invert_multi_hot(label[0])}") | |
| predicted_proba = [proba for proba in predicted_probabilities[i]] | |
| top_3_labels = [ | |
| x | |
| for _, x in sorted( | |
| zip(predicted_probabilities[i], lookup.get_vocabulary()), | |
| key=lambda pair: pair[0], | |
| reverse=True, | |
| ) | |
| ][:3] | |
| print(f"Predicted Label(s): ({', '.join([label for label in top_3_labels])})") | |
| print(" ") | |
| """ | |
| The prediction results are not that great but not below the par for a simple model like | |
| ours. We can improve this performance with models that consider word order like LSTM or | |
| even those that use Transformers ([Vaswani et al.](https://arxiv.org/abs/1706.03762)). | |
| """ | |
| """ | |
| ## Acknowledgements | |
| We would like to thank [Matt Watson](https://github.com/mattdangerw) for helping us | |
| tackle the multi-label binarization part and inverse-transforming the processed labels | |
| to the original form. | |
| Thanks to [Cingis Kratochvil](https://github.com/cumbalik) for suggesting and extending this code example by introducing binary accuracy as the evaluation metric. | |
| """ | |