Spaces:
Running
Running
File size: 15,015 Bytes
9ce984a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 |
"""
Title: Large-scale multi-label text classification
Author: [Sayak Paul](https://twitter.com/RisingSayak), [Soumik Rakshit](https://github.com/soumik12345)
Date created: 2020/09/25
Last modified: 2025/02/27
Description: Implementing a large-scale multi-label text classification model.
Accelerator: GPU
Converted to keras 3 and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
"""
"""
## Introduction
In this example, we will build a multi-label text classifier to predict the subject areas
of arXiv papers from their abstract bodies. This type of classifier can be useful for
conference submission portals like [OpenReview](https://openreview.net/). Given a paper
abstract, the portal could provide suggestions for which areas the paper would
best belong to.
The dataset was collected using the
[`arXiv` Python library](https://github.com/lukasschwab/arxiv.py)
that provides a wrapper around the
[original arXiv API](http://arxiv.org/help/api/index).
To learn more about the data collection process, please refer to
[this notebook](https://github.com/soumik12345/multi-label-text-classification/blob/master/arxiv_scrape.ipynb).
Additionally, you can also find the dataset on
[Kaggle](https://www.kaggle.com/spsayakpaul/arxiv-paper-abstracts).
"""
"""
## Imports
"""
import os
os.environ["KERAS_BACKEND"] = "jax" # or tensorflow, or torch
import keras
from keras import layers, ops
from sklearn.model_selection import train_test_split
from ast import literal_eval
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
"""
## Perform exploratory data analysis
In this section, we first load the dataset into a `pandas` dataframe and then perform
some basic exploratory data analysis (EDA).
"""
arxiv_data = pd.read_csv(
"https://github.com/soumik12345/multi-label-text-classification/releases/download/v0.2/arxiv_data.csv"
)
arxiv_data.head()
"""
Our text features are present in the `summaries` column and their corresponding labels
are in `terms`. As you can notice, there are multiple categories associated with a
particular entry.
"""
print(f"There are {len(arxiv_data)} rows in the dataset.")
"""
Real-world data is noisy. One of the most commonly observed source of noise is data
duplication. Here we notice that our initial dataset has got about 13k duplicate entries.
"""
total_duplicate_titles = sum(arxiv_data["titles"].duplicated())
print(f"There are {total_duplicate_titles} duplicate titles.")
"""
Before proceeding further, we drop these entries.
"""
arxiv_data = arxiv_data[~arxiv_data["titles"].duplicated()]
print(f"There are {len(arxiv_data)} rows in the deduplicated dataset.")
# There are some terms with occurrence as low as 1.
print(sum(arxiv_data["terms"].value_counts() == 1))
# How many unique terms?
print(arxiv_data["terms"].nunique())
"""
As observed above, out of 3,157 unique combinations of `terms`, 2,321 entries have the
lowest occurrence. To prepare our train, validation, and test sets with
[stratification](https://en.wikipedia.org/wiki/Stratified_sampling), we need to drop
these terms.
"""
# Filtering the rare terms.
arxiv_data_filtered = arxiv_data.groupby("terms").filter(lambda x: len(x) > 1)
arxiv_data_filtered.shape
"""
## Convert the string labels to lists of strings
The initial labels are represented as raw strings. Here we make them `List[str]` for a
more compact representation.
"""
arxiv_data_filtered["terms"] = arxiv_data_filtered["terms"].apply(
lambda x: literal_eval(x)
)
arxiv_data_filtered["terms"].values[:5]
"""
## Use stratified splits because of class imbalance
The dataset has a
[class imbalance problem](https://developers.google.com/machine-learning/glossary/#class-imbalanced-dataset).
So, to have a fair evaluation result, we need to ensure the datasets are sampled with
stratification. To know more about different strategies to deal with the class imbalance
problem, you can follow
[this tutorial](https://www.tensorflow.org/tutorials/structured_data/imbalanced_data).
For an end-to-end demonstration of classification with imbablanced data, refer to
[Imbalanced classification: credit card fraud detection](https://keras.io/examples/structured_data/imbalanced_classification/).
"""
test_split = 0.1
# Initial train and test split.
train_df, test_df = train_test_split(
arxiv_data_filtered,
test_size=test_split,
stratify=arxiv_data_filtered["terms"].values,
)
# Splitting the test set further into validation
# and new test sets.
val_df = test_df.sample(frac=0.5)
test_df.drop(val_df.index, inplace=True)
print(f"Number of rows in training set: {len(train_df)}")
print(f"Number of rows in validation set: {len(val_df)}")
print(f"Number of rows in test set: {len(test_df)}")
"""
## Multi-label binarization
Now we preprocess our labels using the
[`StringLookup`](https://keras.io/api/layers/preprocessing_layers/categorical/string_lookup)
layer.
"""
# For RaggedTensor
import tensorflow as tf
terms = tf.ragged.constant(train_df["terms"].values)
lookup = layers.StringLookup(output_mode="multi_hot")
lookup.adapt(terms)
vocab = lookup.get_vocabulary()
def invert_multi_hot(encoded_labels):
"""Reverse a single multi-hot encoded label to a tuple of vocab terms."""
hot_indices = np.argwhere(encoded_labels == 1.0)[..., 0]
return np.take(vocab, hot_indices)
print("Vocabulary:\n")
print(vocab)
"""
Here we are separating the individual unique classes available from the label
pool and then using this information to represent a given label set with 0's and 1's.
Below is an example.
"""
sample_label = train_df["terms"].iloc[0]
print(f"Original label: {sample_label}")
label_binarized = lookup([sample_label])
print(f"Label-binarized representation: {label_binarized}")
"""
## Data preprocessing and `tf.data.Dataset` objects
We first get percentile estimates of the sequence lengths. The purpose will be clear in a
moment.
"""
train_df["summaries"].apply(lambda x: len(x.split(" "))).describe()
"""
Notice that 50% of the abstracts have a length of 154 (you may get a different number
based on the split). So, any number close to that value is a good enough approximate for the
maximum sequence length.
Now, we implement utilities to prepare our datasets.
"""
max_seqlen = 150
batch_size = 128
padding_token = "<pad>"
auto = tf.data.AUTOTUNE
def make_dataset(dataframe, is_train=True):
labels = tf.ragged.constant(dataframe["terms"].values)
label_binarized = lookup(labels).numpy()
dataset = tf.data.Dataset.from_tensor_slices(
(dataframe["summaries"].values, label_binarized)
)
dataset = dataset.shuffle(batch_size * 10) if is_train else dataset
return dataset.batch(batch_size)
"""
Now we can prepare the `tf.data.Dataset` objects.
"""
train_dataset = make_dataset(train_df, is_train=True)
validation_dataset = make_dataset(val_df, is_train=False)
test_dataset = make_dataset(test_df, is_train=False)
"""
## Dataset preview
"""
text_batch, label_batch = next(iter(train_dataset))
for i, text in enumerate(text_batch[:5]):
label = label_batch[i].numpy()[None, ...]
print(f"Abstract: {text}")
print(f"Label(s): {invert_multi_hot(label[0])}")
print(" ")
"""
## Vectorization
Before we feed the data to our model, we need to vectorize it (represent it in a numerical form).
For that purpose, we will use the
[`TextVectorization` layer](https://keras.io/api/layers/preprocessing_layers/text/text_vectorization).
It can operate as a part of your main model so that the model is excluded from the core
preprocessing logic. This greatly reduces the chances of training / serving skew during inference.
We first calculate the number of unique words present in the abstracts.
"""
# Source: https://stackoverflow.com/a/18937309/7636462
vocabulary = set()
train_df["summaries"].str.lower().str.split().apply(vocabulary.update)
vocabulary_size = len(vocabulary)
print(vocabulary_size)
"""
We now create our vectorization layer and `map()` to the `tf.data.Dataset`s created
earlier.
"""
text_vectorizer = layers.TextVectorization(
max_tokens=vocabulary_size, ngrams=2, output_mode="tf_idf"
)
# `TextVectorization` layer needs to be adapted as per the vocabulary from our
# training set.
with tf.device("/CPU:0"):
text_vectorizer.adapt(train_dataset.map(lambda text, label: text))
train_dataset = train_dataset.map(
lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
).prefetch(auto)
validation_dataset = validation_dataset.map(
lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
).prefetch(auto)
test_dataset = test_dataset.map(
lambda text, label: (text_vectorizer(text), label), num_parallel_calls=auto
).prefetch(auto)
"""
A batch of raw text will first go through the `TextVectorization` layer and it will
generate their integer representations. Internally, the `TextVectorization` layer will
first create bi-grams out of the sequences and then represent them using
[TF-IDF](https://wikipedia.org/wiki/Tf%E2%80%93idf). The output representations will then
be passed to the shallow model responsible for text classification.
To learn more about other possible configurations with `TextVectorizer`, please consult
the
[official documentation](https://keras.io/api/layers/preprocessing_layers/text/text_vectorization).
**Note**: Setting the `max_tokens` argument to a pre-calculated vocabulary size is
not a requirement.
"""
"""
## Create a text classification model
We will keep our model simple -- it will be a small stack of fully-connected layers with
ReLU as the non-linearity.
"""
def make_model():
shallow_mlp_model = keras.Sequential(
[
layers.Dense(512, activation="relu"),
layers.Dense(256, activation="relu"),
layers.Dense(lookup.vocabulary_size(), activation="sigmoid"),
] # More on why "sigmoid" has been used here in a moment.
)
return shallow_mlp_model
"""
## Train the model
We will train our model using the binary crossentropy loss. This is because the labels
are not disjoint. For a given abstract, we may have multiple categories. So, we will
divide the prediction task into a series of multiple binary classification problems. This
is also why we kept the activation function of the classification layer in our model to
sigmoid. Researchers have used other combinations of loss function and activation
function as well. For example, in [Exploring the Limits of Weakly Supervised Pretraining](https://arxiv.org/abs/1805.00932),
Mahajan et al. used the softmax activation function and cross-entropy loss to train
their models.
There are several options of metrics that can be used in multi-label classification.
To keep this code example narrow we decided to use the
[binary accuracy metric](https://keras.io/api/metrics/accuracy_metrics/#binaryaccuracy-class).
To see the explanation why this metric is used we refer to this
[pull-request](https://github.com/keras-team/keras-io/pull/1133#issuecomment-1322736860).
There are also other suitable metrics for multi-label classification, like
[F1 Score](https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score) or
[Hamming loss](https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/HammingLoss).
"""
epochs = 20
shallow_mlp_model = make_model()
shallow_mlp_model.compile(
loss="binary_crossentropy", optimizer="adam", metrics=["binary_accuracy"]
)
history = shallow_mlp_model.fit(
train_dataset, validation_data=validation_dataset, epochs=epochs
)
def plot_result(item):
plt.plot(history.history[item], label=item)
plt.plot(history.history["val_" + item], label="val_" + item)
plt.xlabel("Epochs")
plt.ylabel(item)
plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
plt.legend()
plt.grid()
plt.show()
plot_result("loss")
plot_result("binary_accuracy")
"""
While training, we notice an initial sharp fall in the loss followed by a gradual decay.
"""
"""
### Evaluate the model
"""
_, binary_acc = shallow_mlp_model.evaluate(test_dataset)
print(f"Categorical accuracy on the test set: {round(binary_acc * 100, 2)}%.")
"""
The trained model gives us an evaluation accuracy of ~99%.
"""
"""
## Inference
An important feature of the
[preprocessing layers provided by Keras](https://keras.io/api/layers/preprocessing_layers/)
is that they can be included inside a `tf.keras.Model`. We will export an inference model
by including the `text_vectorization` layer on top of `shallow_mlp_model`. This will
allow our inference model to directly operate on raw strings.
**Note** that during training it is always preferable to use these preprocessing
layers as a part of the data input pipeline rather than the model to avoid
surfacing bottlenecks for the hardware accelerators. This also allows for
asynchronous data processing.
"""
# We create a custom Model to override the predict method so
# that it first vectorizes text data
class ModelEndtoEnd(keras.Model):
def predict(self, inputs):
indices = text_vectorizer(inputs)
return super().predict(indices)
def get_inference_model(model):
inputs = shallow_mlp_model.inputs
outputs = shallow_mlp_model.outputs
end_to_end_model = ModelEndtoEnd(inputs, outputs, name="end_to_end_model")
end_to_end_model.compile(
optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]
)
return end_to_end_model
model_for_inference = get_inference_model(shallow_mlp_model)
# Create a small dataset just for demonstrating inference.
inference_dataset = make_dataset(test_df.sample(2), is_train=False)
text_batch, label_batch = next(iter(inference_dataset))
predicted_probabilities = model_for_inference.predict(text_batch)
# Perform inference.
for i, text in enumerate(text_batch[:5]):
label = label_batch[i].numpy()[None, ...]
print(f"Abstract: {text}")
print(f"Label(s): {invert_multi_hot(label[0])}")
predicted_proba = [proba for proba in predicted_probabilities[i]]
top_3_labels = [
x
for _, x in sorted(
zip(predicted_probabilities[i], lookup.get_vocabulary()),
key=lambda pair: pair[0],
reverse=True,
)
][:3]
print(f"Predicted Label(s): ({', '.join([label for label in top_3_labels])})")
print(" ")
"""
The prediction results are not that great but not below the par for a simple model like
ours. We can improve this performance with models that consider word order like LSTM or
even those that use Transformers ([Vaswani et al.](https://arxiv.org/abs/1706.03762)).
"""
"""
## Acknowledgements
We would like to thank [Matt Watson](https://github.com/mattdangerw) for helping us
tackle the multi-label binarization part and inverse-transforming the processed labels
to the original form.
Thanks to [Cingis Kratochvil](https://github.com/cumbalik) for suggesting and extending this code example by introducing binary accuracy as the evaluation metric.
"""
|