Spaces:
Running
Running
File size: 24,168 Bytes
9ce984a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 |
"""
Title: Multimodal entailment
Author: [Sayak Paul](https://twitter.com/RisingSayak)
Date created: 2021/08/08
Last modified: 2025/01/03
Description: Training a multimodal model for predicting entailment.
Accelerator: GPU
Converted to Keras 3 and made backend-agnostic by: [Humbulani Ndou](https://github.com/Humbulani1234)
"""
"""
## Introduction
In this example, we will build and train a model for predicting multimodal entailment. We will be
using the
[multimodal entailment dataset](https://github.com/google-research-datasets/recognizing-multimodal-entailment)
recently introduced by Google Research.
### What is multimodal entailment?
On social media platforms, to audit and moderate content
we may want to find answers to the
following questions in near real-time:
* Does a given piece of information contradict the other?
* Does a given piece of information imply the other?
In NLP, this task is called analyzing _textual entailment_. However, that's only
when the information comes from text content.
In practice, it's often the case the information available comes not just
from text content, but from a multimodal combination of text, images, audio, video, etc.
_Multimodal entailment_ is simply the extension of textual entailment to a variety
of new input modalities.
### Requirements
This example requires TensorFlow 2.5 or higher. In addition, TensorFlow Hub and
TensorFlow Text are required for the BERT model
([Devlin et al.](https://arxiv.org/abs/1810.04805)). These libraries can be installed
using the following command:
"""
"""shell
pip install -q tensorflow_text
"""
"""
## Imports
"""
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import random
import math
from skimage.io import imread
from skimage.transform import resize
from PIL import Image
import os
os.environ["KERAS_BACKEND"] = "jax" # or tensorflow, or torch
import keras
import keras_hub
from keras.utils import PyDataset
"""
## Define a label map
"""
label_map = {"Contradictory": 0, "Implies": 1, "NoEntailment": 2}
"""
## Collect the dataset
The original dataset is available
[here](https://github.com/google-research-datasets/recognizing-multimodal-entailment).
It comes with URLs of images which are hosted on Twitter's photo storage system called
the
[Photo Blob Storage (PBS for short)](https://blog.twitter.com/engineering/en_us/a/2012/blobstore-twitter-s-in-house-photo-storage-system).
We will be working with the downloaded images along with additional data that comes with
the original dataset. Thanks to
[Nilabhra Roy Chowdhury](https://de.linkedin.com/in/nilabhraroychowdhury) who worked on
preparing the image data.
"""
image_base_path = keras.utils.get_file(
"tweet_images",
"https://github.com/sayakpaul/Multimodal-Entailment-Baseline/releases/download/v1.0.0/tweet_images.tar.gz",
untar=True,
)
"""
## Read the dataset and apply basic preprocessing
"""
df = pd.read_csv(
"https://github.com/sayakpaul/Multimodal-Entailment-Baseline/raw/main/csvs/tweets.csv"
).iloc[
0:1000
] # Resources conservation since these are examples and not SOTA
df.sample(10)
"""
The columns we are interested in are the following:
* `text_1`
* `image_1`
* `text_2`
* `image_2`
* `label`
The entailment task is formulated as the following:
***Given the pairs of (`text_1`, `image_1`) and (`text_2`, `image_2`) do they entail (or
not entail or contradict) each other?***
We have the images already downloaded. `image_1` is downloaded as `id1` as its filename
and `image2` is downloaded as `id2` as its filename. In the next step, we will add two
more columns to `df` - filepaths of `image_1`s and `image_2`s.
"""
images_one_paths = []
images_two_paths = []
for idx in range(len(df)):
current_row = df.iloc[idx]
id_1 = current_row["id_1"]
id_2 = current_row["id_2"]
extentsion_one = current_row["image_1"].split(".")[-1]
extentsion_two = current_row["image_2"].split(".")[-1]
image_one_path = os.path.join(image_base_path, str(id_1) + f".{extentsion_one}")
image_two_path = os.path.join(image_base_path, str(id_2) + f".{extentsion_two}")
images_one_paths.append(image_one_path)
images_two_paths.append(image_two_path)
df["image_1_path"] = images_one_paths
df["image_2_path"] = images_two_paths
# Create another column containing the integer ids of
# the string labels.
df["label_idx"] = df["label"].apply(lambda x: label_map[x])
"""
## Dataset visualization
"""
def visualize(idx):
current_row = df.iloc[idx]
image_1 = plt.imread(current_row["image_1_path"])
image_2 = plt.imread(current_row["image_2_path"])
text_1 = current_row["text_1"]
text_2 = current_row["text_2"]
label = current_row["label"]
plt.subplot(1, 2, 1)
plt.imshow(image_1)
plt.axis("off")
plt.title("Image One")
plt.subplot(1, 2, 2)
plt.imshow(image_1)
plt.axis("off")
plt.title("Image Two")
plt.show()
print(f"Text one: {text_1}")
print(f"Text two: {text_2}")
print(f"Label: {label}")
random_idx = random.choice(range(len(df)))
visualize(random_idx)
random_idx = random.choice(range(len(df)))
visualize(random_idx)
"""
## Train/test split
The dataset suffers from
[class imbalance problem](https://developers.google.com/machine-learning/glossary#class-imbalanced-dataset).
We can confirm that in the following cell.
"""
df["label"].value_counts()
"""
To account for that we will go for a stratified split.
"""
# 10% for test
train_df, test_df = train_test_split(
df, test_size=0.1, stratify=df["label"].values, random_state=42
)
# 5% for validation
train_df, val_df = train_test_split(
train_df, test_size=0.05, stratify=train_df["label"].values, random_state=42
)
print(f"Total training examples: {len(train_df)}")
print(f"Total validation examples: {len(val_df)}")
print(f"Total test examples: {len(test_df)}")
"""
## Data input pipeline
Keras Hub provides
[variety of BERT family of models](https://keras.io/keras_hub/presets/).
Each of those models comes with a
corresponding preprocessing layer. You can learn more about these models and their
preprocessing layers from
[this resource](https://www.kaggle.com/models/keras/bert/keras/bert_base_en_uncased/2).
To keep the runtime of this example relatively short, we will use a base_unacased variant of
the original BERT model.
"""
"""
text preprocessing using KerasHub
"""
text_preprocessor = keras_hub.models.BertTextClassifierPreprocessor.from_preset(
"bert_base_en_uncased",
sequence_length=128,
)
"""
### Run the preprocessor on a sample input
"""
idx = random.choice(range(len(train_df)))
row = train_df.iloc[idx]
sample_text_1, sample_text_2 = row["text_1"], row["text_2"]
print(f"Text 1: {sample_text_1}")
print(f"Text 2: {sample_text_2}")
test_text = [sample_text_1, sample_text_2]
text_preprocessed = text_preprocessor(test_text)
print("Keys : ", list(text_preprocessed.keys()))
print("Shape Token Ids : ", text_preprocessed["token_ids"].shape)
print("Token Ids : ", text_preprocessed["token_ids"][0, :16])
print(" Shape Padding Mask : ", text_preprocessed["padding_mask"].shape)
print("Padding Mask : ", text_preprocessed["padding_mask"][0, :16])
print("Shape Segment Ids : ", text_preprocessed["segment_ids"].shape)
print("Segment Ids : ", text_preprocessed["segment_ids"][0, :16])
"""
We will now create `tf.data.Dataset` objects from the dataframes.
Note that the text inputs will be preprocessed as a part of the data input pipeline. But
the preprocessing modules can also be a part of their corresponding BERT models. This
helps reduce the training/serving skew and lets our models operate with raw text inputs.
Follow [this tutorial](https://www.tensorflow.org/text/tutorials/classify_text_with_bert)
to learn more about how to incorporate the preprocessing modules directly inside the
models.
"""
def dataframe_to_dataset(dataframe):
columns = ["image_1_path", "image_2_path", "text_1", "text_2", "label_idx"]
ds = UnifiedPyDataset(
dataframe,
batch_size=32,
workers=4,
)
return ds
"""
### Preprocessing utilities
"""
bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
def preprocess_text(text_1, text_2):
output = text_preprocessor([text_1, text_2])
output = {
feature: keras.ops.reshape(output[feature], [-1])
for feature in bert_input_features
}
return output
"""
### Create the final datasets, method adapted from PyDataset doc string.
"""
class UnifiedPyDataset(PyDataset):
"""A Keras-compatible dataset that processes a DataFrame for TensorFlow, JAX, and PyTorch."""
def __init__(
self,
df,
batch_size=32,
workers=4,
use_multiprocessing=False,
max_queue_size=10,
**kwargs,
):
"""
Args:
df: pandas DataFrame with data
batch_size: Batch size for dataset
workers: Number of workers to use for parallel loading (Keras)
use_multiprocessing: Whether to use multiprocessing
max_queue_size: Maximum size of the data queue for parallel loading
"""
super().__init__(**kwargs)
self.dataframe = df
columns = ["image_1_path", "image_2_path", "text_1", "text_2"]
# image files
self.image_x_1 = self.dataframe["image_1_path"]
self.image_x_2 = self.dataframe["image_1_path"]
self.image_y = self.dataframe["label_idx"]
# text files
self.text_x_1 = self.dataframe["text_1"]
self.text_x_2 = self.dataframe["text_2"]
self.text_y = self.dataframe["label_idx"]
# general
self.batch_size = batch_size
self.workers = workers
self.use_multiprocessing = use_multiprocessing
self.max_queue_size = max_queue_size
def __getitem__(self, index):
"""
Fetches a batch of data from the dataset at the given index.
"""
# Return x, y for batch idx.
low = index * self.batch_size
# Cap upper bound at array length; the last batch may be smaller
# if the total number of items is not a multiple of batch size.
# image files
high_image_1 = min(low + self.batch_size, len(self.image_x_1))
high_image_2 = min(low + self.batch_size, len(self.image_x_2))
# text
high_text_1 = min(low + self.batch_size, len(self.text_x_1))
high_text_2 = min(low + self.batch_size, len(self.text_x_1))
# images files
batch_image_x_1 = self.image_x_1[low:high_image_1]
batch_image_y_1 = self.image_y[low:high_image_1]
batch_image_x_2 = self.image_x_2[low:high_image_2]
batch_image_y_2 = self.image_y[low:high_image_2]
# text files
batch_text_x_1 = self.text_x_1[low:high_text_1]
batch_text_y_1 = self.text_y[low:high_text_1]
batch_text_x_2 = self.text_x_2[low:high_text_2]
batch_text_y_2 = self.text_y[low:high_text_2]
# image number 1 inputs
image_1 = [
resize(imread(file_name), (128, 128)) for file_name in batch_image_x_1
]
image_1 = [
( # exeperienced some shapes which were different from others.
np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))
if img.shape[2] == 4
else img
)
for img in image_1
]
image_1 = np.array(image_1)
# Both text inputs to the model, return a dict for inputs to BertBackbone
text = {
key: np.array(
[
d[key]
for d in [
preprocess_text(file_path1, file_path2)
for file_path1, file_path2 in zip(
batch_text_x_1, batch_text_x_2
)
]
]
)
for key in ["padding_mask", "token_ids", "segment_ids"]
}
# Image number 2 model inputs
image_2 = [
resize(imread(file_name), (128, 128)) for file_name in batch_image_x_2
]
image_2 = [
( # exeperienced some shapes which were different from others
np.array(Image.fromarray((img.astype(np.uint8))).convert("RGB"))
if img.shape[2] == 4
else img
)
for img in image_2
]
# Stack the list comprehension to an nd.array
image_2 = np.array(image_2)
return (
{
"image_1": image_1,
"image_2": image_2,
"padding_mask": text["padding_mask"],
"segment_ids": text["segment_ids"],
"token_ids": text["token_ids"],
},
# Target lables
np.array(batch_image_y_1),
)
def __len__(self):
"""
Returns the number of batches in the dataset.
"""
return math.ceil(len(self.dataframe) / self.batch_size)
"""
Create train, validation and test datasets
"""
def prepare_dataset(dataframe):
ds = dataframe_to_dataset(dataframe)
return ds
train_ds = prepare_dataset(train_df)
validation_ds = prepare_dataset(val_df)
test_ds = prepare_dataset(test_df)
"""
## Model building utilities
Our final model will accept two images along with their text counterparts. While the
images will be directly fed to the model the text inputs will first be preprocessed and
then will make it into the model. Below is a visual illustration of this approach:

The model consists of the following elements:
* A standalone encoder for the images. We will use a
[ResNet50V2](https://arxiv.org/abs/1603.05027) pre-trained on the ImageNet-1k dataset for
this.
* A standalone encoder for the images. A pre-trained BERT will be used for this.
After extracting the individual embeddings, they will be projected in an identical space.
Finally, their projections will be concatenated and be fed to the final classification
layer.
This is a multi-class classification problem involving the following classes:
* NoEntailment
* Implies
* Contradictory
`project_embeddings()`, `create_vision_encoder()`, and `create_text_encoder()` utilities
are referred from [this example](https://keras.io/examples/nlp/nl_image_search/).
"""
"""
Projection utilities
"""
def project_embeddings(
embeddings, num_projection_layers, projection_dims, dropout_rate
):
projected_embeddings = keras.layers.Dense(units=projection_dims)(embeddings)
for _ in range(num_projection_layers):
x = keras.ops.nn.gelu(projected_embeddings)
x = keras.layers.Dense(projection_dims)(x)
x = keras.layers.Dropout(dropout_rate)(x)
x = keras.layers.Add()([projected_embeddings, x])
projected_embeddings = keras.layers.LayerNormalization()(x)
return projected_embeddings
"""
Vision encoder utilities
"""
def create_vision_encoder(
num_projection_layers, projection_dims, dropout_rate, trainable=False
):
# Load the pre-trained ResNet50V2 model to be used as the base encoder.
resnet_v2 = keras.applications.ResNet50V2(
include_top=False, weights="imagenet", pooling="avg"
)
# Set the trainability of the base encoder.
for layer in resnet_v2.layers:
layer.trainable = trainable
# Receive the images as inputs.
image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
image_2 = keras.Input(shape=(128, 128, 3), name="image_2")
# Preprocess the input image.
preprocessed_1 = keras.applications.resnet_v2.preprocess_input(image_1)
preprocessed_2 = keras.applications.resnet_v2.preprocess_input(image_2)
# Generate the embeddings for the images using the resnet_v2 model
# concatenate them.
embeddings_1 = resnet_v2(preprocessed_1)
embeddings_2 = resnet_v2(preprocessed_2)
embeddings = keras.layers.Concatenate()([embeddings_1, embeddings_2])
# Project the embeddings produced by the model.
outputs = project_embeddings(
embeddings, num_projection_layers, projection_dims, dropout_rate
)
# Create the vision encoder model.
return keras.Model([image_1, image_2], outputs, name="vision_encoder")
"""
Text encoder utilities
"""
def create_text_encoder(
num_projection_layers, projection_dims, dropout_rate, trainable=False
):
# Load the pre-trained BERT BackBone using KerasHub.
bert = keras_hub.models.BertBackbone.from_preset(
"bert_base_en_uncased", num_classes=3
)
# Set the trainability of the base encoder.
bert.trainable = trainable
# Receive the text as inputs.
bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
inputs = {
feature: keras.Input(shape=(256,), dtype="int32", name=feature)
for feature in bert_input_features
}
# Generate embeddings for the preprocessed text using the BERT model.
embeddings = bert(inputs)["pooled_output"]
# Project the embeddings produced by the model.
outputs = project_embeddings(
embeddings, num_projection_layers, projection_dims, dropout_rate
)
# Create the text encoder model.
return keras.Model(inputs, outputs, name="text_encoder")
"""
Multimodal model utilities
"""
def create_multimodal_model(
num_projection_layers=1,
projection_dims=256,
dropout_rate=0.1,
vision_trainable=False,
text_trainable=False,
):
# Receive the images as inputs.
image_1 = keras.Input(shape=(128, 128, 3), name="image_1")
image_2 = keras.Input(shape=(128, 128, 3), name="image_2")
# Receive the text as inputs.
bert_input_features = ["padding_mask", "segment_ids", "token_ids"]
text_inputs = {
feature: keras.Input(shape=(256,), dtype="int32", name=feature)
for feature in bert_input_features
}
text_inputs = list(text_inputs.values())
# Create the encoders.
vision_encoder = create_vision_encoder(
num_projection_layers, projection_dims, dropout_rate, vision_trainable
)
text_encoder = create_text_encoder(
num_projection_layers, projection_dims, dropout_rate, text_trainable
)
# Fetch the embedding projections.
vision_projections = vision_encoder([image_1, image_2])
text_projections = text_encoder(text_inputs)
# Concatenate the projections and pass through the classification layer.
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
outputs = keras.layers.Dense(3, activation="softmax")(concatenated)
return keras.Model([image_1, image_2, *text_inputs], outputs)
multimodal_model = create_multimodal_model()
keras.utils.plot_model(multimodal_model, show_shapes=True)
"""
You can inspect the structure of the individual encoders as well by setting the
`expand_nested` argument of `plot_model()` to `True`. You are encouraged
to play with the different hyperparameters involved in building this model and
observe how the final performance is affected.
"""
"""
## Compile and train the model
"""
multimodal_model.compile(
optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
history = multimodal_model.fit(train_ds, validation_data=validation_ds, epochs=1)
"""
## Evaluate the model
"""
_, acc = multimodal_model.evaluate(test_ds)
print(f"Accuracy on the test set: {round(acc * 100, 2)}%.")
"""
## Additional notes regarding training
**Incorporating regularization**:
The training logs suggest that the model is starting to overfit and may have benefitted
from regularization. Dropout ([Srivastava et al.](https://jmlr.org/papers/v15/srivastava14a.html))
is a simple yet powerful regularization technique that we can use in our model.
But how should we apply it here?
We could always introduce Dropout (`keras.layers.Dropout`) in between different layers of the model.
But here is another recipe. Our model expects inputs from two different data modalities.
What if either of the modalities is not present during inference? To account for this,
we can introduce Dropout to the individual projections just before they get concatenated:
```python
vision_projections = keras.layers.Dropout(rate)(vision_projections)
text_projections = keras.layers.Dropout(rate)(text_projections)
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
```
**Attending to what matters**:
Do all parts of the images correspond equally to their textual counterparts? It's likely
not the case. To make our model only focus on the most important bits of the images that relate
well to their corresponding textual parts we can use "cross-attention":
```python
# Embeddings.
vision_projections = vision_encoder([image_1, image_2])
text_projections = text_encoder(text_inputs)
# Cross-attention (Luong-style).
query_value_attention_seq = keras.layers.Attention(use_scale=True, dropout=0.2)(
[vision_projections, text_projections]
)
# Concatenate.
concatenated = keras.layers.Concatenate()([vision_projections, text_projections])
contextual = keras.layers.Concatenate()([concatenated, query_value_attention_seq])
```
To see this in action, refer to
[this notebook](https://github.com/sayakpaul/Multimodal-Entailment-Baseline/blob/main/multimodal_entailment_attn.ipynb).
**Handling class imbalance**:
The dataset suffers from class imbalance. Investigating the confusion matrix of the
above model reveals that it performs poorly on the minority classes. If we had used a
weighted loss then the training would have been more guided. You can check out
[this notebook](https://github.com/sayakpaul/Multimodal-Entailment-Baseline/blob/main/multimodal_entailment.ipynb)
that takes class-imbalance into account during model training.
**Using only text inputs**:
Also, what if we had only incorporated text inputs for the entailment task? Because of
the nature of the text inputs encountered on social media platforms, text inputs alone
would have hurt the final performance. Under a similar training setup, by only using
text inputs we get to 67.14% top-1 accuracy on the same test set. Refer to
[this notebook](https://github.com/sayakpaul/Multimodal-Entailment-Baseline/blob/main/text_entailment.ipynb)
for details.
Finally, here is a table comparing different approaches taken for the entailment task:
| Type | Standard<br>Cross-entropy | Loss-weighted<br>Cross-entropy | Focal Loss |
|:---: |:---: |:---: |:---: |
| Multimodal | 77.86% | 67.86% | 86.43% |
| Only text | 67.14% | 11.43% | 37.86% |
You can check out [this repository](https://git.io/JR0HU) to learn more about how the
experiments were conducted to obtain these numbers.
"""
"""
## Final remarks
* The architecture we used in this example is too large for the number of data points
available for training. It's going to benefit from more data.
* We used a smaller variant of the original BERT model. Chances are high that with a
larger variant, this performance will be improved. TensorFlow Hub
[provides](https://www.tensorflow.org/text/tutorials/bert_glue#loading_models_from_tensorflow_hub)
a number of different BERT models that you can experiment with.
* We kept the pre-trained models frozen. Fine-tuning them on the multimodal entailment
task would could resulted in better performance.
* We built a simple baseline model for the multimodal entailment task. There are various
approaches that have been proposed to tackle the entailment problem.
[This presentation deck](https://docs.google.com/presentation/d/1mAB31BCmqzfedreNZYn4hsKPFmgHA9Kxz219DzyRY3c/edit?usp=sharing)
from the
[Recognizing Multimodal Entailment](https://multimodal-entailment.github.io/)
tutorial provides a comprehensive overview.
You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/multimodal-entailment)
and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/multimodal_entailment)
"""
|